Update NATS-Bench API to v1.1

This commit is contained in:
D-X-Y
2020-12-20 00:50:55 +08:00
parent dae387a97d
commit ff989ba814
2 changed files with 12 additions and 4 deletions

View File

@@ -17,7 +17,13 @@ from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names
def get_fake_torch_home_dir():
return os.environ['FAKE_TORCH_HOME']
print('This file is {:}'.format(os.path.abspath(__file__)))
print('The current directory is {:}'.format(os.path.abspath(os.getcwd())))
xname = 'FAKE_TORCH_HOME'
if xname in os.environ:
return os.environ['FAKE_TORCH_HOME']
else:
return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'fake_torch_dir')
class TestNATSBench(object):
@@ -70,8 +76,10 @@ class TestNATSBench(object):
print(xinfo)
print(data[777].train_acc1es)
info_012_epochs = api.get_more_info(284, 'cifar10', hp=200)
print(info_012_epochs['train-accuracy'])
info_012_epochs = api.get_more_info(284, 'cifar10', hp= 12)
print('Train accuracy for 12 epochs is {:}'.format(info_012_epochs['train-accuracy']))
info_200_epochs = api.get_more_info(284, 'cifar10', hp=200)
print('Train accuracy for 200 epochs is {:}'.format(info_200_epochs['train-accuracy']))
def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False):