update README

This commit is contained in:
D-X-Y
2020-01-16 01:43:07 +11:00
parent 4be2a0000c
commit b299945b23
8 changed files with 205 additions and 50 deletions

View File

@@ -162,6 +162,13 @@ class NASBench201API(object):
archresult = arch2infos[index]
return archresult.get_net_param(dataset, seed)
# obtain the cost metric for the `index`-th architecture on a dataset
def get_cost_info(self, index, dataset, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index]
return archresult.get_comput_costs(dataset)
# obtain the metric for the `index`-th architecture
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
@@ -177,6 +184,7 @@ class NASBench201API(object):
total = train_info['iepoch'] + 1
xifo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'train-per-time': None if train_info['all_time'] is None else train_info['all_time'] / total,
'train-all-time': train_info['all_time'],
'valid-loss' : valid_info['loss'],
'valid-accuracy': valid_info['accuracy'],
@@ -188,21 +196,32 @@ class NASBench201API(object):
return xifo
else:
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
if dataset == 'cifar10':
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
try:
if dataset == 'cifar10':
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
except:
valid_info = None
try:
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except:
valid_info = None
try:
est_valid_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
except:
est_valid_info = None
xifo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'test-loss' : test__info['loss'],
'test-accuracy' : test__info['accuracy']}
'train-accuracy': train_info['accuracy']}
if valid_info is not None:
xifo['test-loss'] = test__info['loss'],
xifo['test-accuracy'] = test__info['accuracy']
if valid_info is not None:
xifo['valid-loss'] = valid_info['loss']
xifo['valid-accuracy'] = valid_info['accuracy']
if est_valid_info is not None:
xifo['est-valid-loss'] = est_valid_info['loss']
xifo['est-valid-accuracy'] = est_valid_info['accuracy']
return xifo
def show(self, index=-1):