update README

This commit is contained in:
D-X-Y
2019-12-28 15:42:36 +11:00
parent d791622b63
commit 4c144b7437
6 changed files with 59 additions and 28 deletions

View File

@@ -147,14 +147,14 @@ class NASBench102API(object):
archresult = arch2infos[index]
return archresult.get_net_param(dataset, seed)
def get_more_info(self, index, dataset, use_12epochs_result=False):
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index]
if dataset == 'cifar10-valid':
train_info = archresult.get_metrics(dataset, 'train', is_random=True)
valid_info = archresult.get_metrics(dataset, 'x-valid', is_random=True)
test__info = archresult.get_metrics(dataset, 'ori-test', is_random=True)
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=True)
valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=True)
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=True)
total = train_info['iepoch'] + 1
return {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],