update vis

This commit is contained in:
D-X-Y
2020-01-01 22:18:42 +11:00
parent 9ec25663f1
commit 28e4b8406f
12 changed files with 153 additions and 40 deletions

View File

@@ -131,15 +131,17 @@ class NASBench102API(object):
else : basestr, arch2infos = '200epochs', self.arch2infos_full
best_index, highest_accuracy = -1, None
for i, idx in enumerate(self.evaluated_indexes):
flop, param, latency = arch2infos[idx].get_comput_costs(dataset)
info = arch2infos[idx].get_comput_costs(dataset)
flop, param, latency = info['flops'], info['params'], info['latency']
if FLOP_max is not None and flop > FLOP_max : continue
if Param_max is not None and param > Param_max: continue
loss, accuracy = arch2infos[idx].get_metrics(dataset, metric_on_set)
xinfo = arch2infos[idx].get_metrics(dataset, metric_on_set)
loss, accuracy = xinfo['loss'], xinfo['accuracy']
if best_index == -1:
best_index, highest_accuracy = idx, accuracy
elif highest_accuracy < accuracy:
best_index, highest_accuracy = idx, accuracy
return best_index
return best_index, highest_accuracy
# return the topology structure of the `index`-th architecture
def arch(self, index):
@@ -183,10 +185,18 @@ class NASBench102API(object):
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else:
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
return {'train-loss' : train_info['loss'],
try:
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except:
valid_info = None
xifo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'],
'test-loss' : test__info['loss'],
'test-accuracy' : test__info['accuracy']}
if valid_info is not None:
xifo['valid-loss'] = valid_info['loss']
xifo['valid-accuracy'] = valid_info['accuracy']
return xifo
def show(self, index=-1):
if index < 0: # show all architectures