update API
This commit is contained in:
@@ -170,10 +170,28 @@ class NASBench201API(object):
|
||||
return archresult.get_comput_costs(dataset)
|
||||
|
||||
# obtain the metric for the `index`-th architecture
|
||||
# `dataset` indicates the dataset:
|
||||
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
|
||||
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set
|
||||
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
|
||||
# `iepoch` indicates the index of training epochs from 0 to 11/199.
|
||||
# When iepoch=None, it will return the metric for the last training epoch
|
||||
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
|
||||
# `use_12epochs_result` indicates different hyper-parameters for training
|
||||
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
|
||||
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
|
||||
# `is_random`
|
||||
# When is_random=True, the performance of a random architecture will be returned
|
||||
# When is_random=False, the performanceo of all trials will be averaged.
|
||||
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
# if randomly select one trial, select the seed at first
|
||||
if isinstance(is_random, bool) and is_random:
|
||||
seeds = archresult.get_dataset_seeds(dataset)
|
||||
is_random = random.choice(seeds)
|
||||
if dataset == 'cifar10-valid':
|
||||
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random)
|
||||
@@ -202,7 +220,7 @@ class NASBench201API(object):
|
||||
else:
|
||||
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
valid_info = None
|
||||
test__info = None
|
||||
try:
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
@@ -213,7 +231,7 @@ class NASBench201API(object):
|
||||
est_valid_info = None
|
||||
xifo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy']}
|
||||
if valid_info is not None:
|
||||
if test__info is not None:
|
||||
xifo['test-loss'] = test__info['loss'],
|
||||
xifo['test-accuracy'] = test__info['accuracy']
|
||||
if valid_info is not None:
|
||||
@@ -347,14 +365,20 @@ class ArchResults(object):
|
||||
info = result.get_eval(setname, iepoch)
|
||||
for key, value in info.items(): infos[key].append( value )
|
||||
return_info = dict()
|
||||
if is_random:
|
||||
if isinstance(is_random, bool) and is_random: # randomly select one
|
||||
index = random.randint(0, len(results)-1)
|
||||
for key, value in infos.items(): return_info[key] = value[index]
|
||||
else:
|
||||
elif isinstance(is_random, bool) and not is_random: # average
|
||||
for key, value in infos.items():
|
||||
if len(value) > 0 and value[0] is not None:
|
||||
return_info[key] = np.mean(value)
|
||||
else: return_info[key] = None
|
||||
elif isinstance(is_random, int): # specify the seed
|
||||
if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds))
|
||||
index = x_seeds.index(is_random)
|
||||
for key, value in infos.items(): return_info[key] = value[index]
|
||||
else:
|
||||
raise ValueError('invalid value for is_random: {:}'.format(is_random))
|
||||
return return_info
|
||||
|
||||
def show(self, is_print=False):
|
||||
@@ -363,6 +387,9 @@ class ArchResults(object):
|
||||
def get_dataset_names(self):
|
||||
return list(self.dataset_seed.keys())
|
||||
|
||||
def get_dataset_seeds(self, dataset):
|
||||
return copy.deepcopy( self.dataset_seed[dataset] )
|
||||
|
||||
def get_net_param(self, dataset, seed=None):
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
|
Reference in New Issue
Block a user