simplify baselines

This commit is contained in:
D-X-Y
2019-12-31 22:02:11 +11:00
parent f8f44bfb31
commit 9ec25663f1
12 changed files with 338 additions and 124 deletions

View File

@@ -41,8 +41,9 @@ class NASBench102API(object):
if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
file_path_or_dict = torch.load(file_path_or_dict)
else:
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy( file_path_or_dict )
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
@@ -152,26 +153,40 @@ class NASBench102API(object):
archresult = arch2infos[index]
return archresult.get_net_param(dataset, seed)
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False):
# obtain the metric for the `index`-th architecture
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index]
if dataset == 'cifar10-valid':
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=True)
valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=True)
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=True)
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random)
try:
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except:
test__info = None
total = train_info['iepoch'] + 1
return {'train-loss' : train_info['loss'],
xifo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'train-all-time': train_info['all_time'],
'valid-loss' : valid_info['loss'],
'valid-accuracy': valid_info['accuracy'],
'valid-all-time': valid_info['all_time'],
'valid-per-time': valid_info['all_time'] / total,
'valid-per-time': None if valid_info['all_time'] is None else valid_info['all_time'] / total}
if test__info is not None:
xifo['test-loss'] = test__info['loss']
xifo['test-accuracy'] = test__info['accuracy']
return xifo
else:
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
if dataset == 'cifar10':
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
return {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'test-loss' : test__info['loss'],
'test-accuracy' : test__info['accuracy']}
else:
raise ValueError('coming soon...')
def show(self, index=-1):
if index < 0: # show all architectures
@@ -369,7 +384,7 @@ class ResultsCount(object):
def update_latency(self, latency):
self.latency = copy.deepcopy( latency )
def update_eval(self, accs, losses, times): # old version
def update_eval(self, accs, losses, times): # new version
data_names = set([x.split('@')[0] for x in accs.keys()])
for data_name in data_names:
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)