update README
This commit is contained in:
@@ -147,14 +147,14 @@ class NASBench102API(object):
|
||||
archresult = arch2infos[index]
|
||||
return archresult.get_net_param(dataset, seed)
|
||||
|
||||
def get_more_info(self, index, dataset, use_12epochs_result=False):
|
||||
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
if dataset == 'cifar10-valid':
|
||||
train_info = archresult.get_metrics(dataset, 'train', is_random=True)
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', is_random=True)
|
||||
test__info = archresult.get_metrics(dataset, 'ori-test', is_random=True)
|
||||
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=True)
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=True)
|
||||
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=True)
|
||||
total = train_info['iepoch'] + 1
|
||||
return {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
|
Reference in New Issue
Block a user