update NAS-Bench-102

This commit is contained in:
D-X-Y
2019-12-21 11:13:08 +11:00
parent 69ca0860aa
commit 95ec4d328e
3 changed files with 76 additions and 39 deletions

View File

@@ -1,6 +1,8 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
#################################################################################
# NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search #
#################################################################################
import os, sys, copy, random, torch, numpy as np
from collections import OrderedDict, defaultdict
@@ -12,19 +14,21 @@ def print_information(information, extra_info=None, show=False):
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names):
flop, param, latency = information.get_comput_costs(dataset)
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency > 0 else None)
train_loss, train_acc = information.get_metrics(dataset, 'train')
#flop, param, latency = information.get_comput_costs(dataset)
metric = information.get_comput_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency']
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train')
if dataset == 'cifar10-valid':
valid_loss, valid_acc = information.get_metrics(dataset, 'x-valid')
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(valid_loss, valid_acc))
valid_info = information.get_metrics(dataset, 'x-valid')
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']))
elif dataset == 'cifar10':
test__loss, test__acc = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(test__loss, test__acc))
test__info = information.get_metrics(dataset, 'ori-test')
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
else:
valid_loss, valid_acc = information.get_metrics(dataset, 'x-valid')
test__loss, test__acc = information.get_metrics(dataset, 'x-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(valid_loss, valid_acc), metric2str(test__loss, test__acc))
valid_info = information.get_metrics(dataset, 'x-valid')
test__info = information.get_metrics(dataset, 'x-test')
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
strings += [str1, str2]
if show: print('\n'.join(strings))
return strings
@@ -34,19 +38,21 @@ class NASBench102API(object):
def __init__(self, file_path_or_dict, verbose=True):
if isinstance(file_path_or_dict, str):
if verbose: print('try to create NAS-Bench-102 api from {:}'.format(file_path_or_dict))
if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
file_path_or_dict = torch.load(file_path_or_dict)
else:
file_path_or_dict = copy.deepcopy( file_path_or_dict )
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
import pdb; pdb.set_trace() # we will update this api soon
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
self.arch2infos = OrderedDict()
self.arch2infos_less = OrderedDict()
self.arch2infos_full = OrderedDict()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
self.arch2infos[xkey] = ArchResults.create_from_state_dict( file_path_or_dict['arch2infos'][xkey] )
all_info = file_path_or_dict['arch2infos'][xkey]
self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] )
self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] )
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
@@ -73,35 +79,46 @@ class NASBench102API(object):
else: arch_index = -1
return arch_index
def query_by_arch(self, arch):
arch_index = self.query_index_by_arch(arch)
if arch_index == -1: return None
if arch_index in self.arch2infos:
strings = print_information(self.arch2infos[ arch_index ], 'arch-index={:}'.format(arch_index))
def query_by_arch(self, arch, use_12epochs_result=False):
if isinstance(arch, int):
arch_index = arch
else:
arch_index = self.query_index_by_arch(arch)
if arch_index == -1: return None # the following two lines are used to support few training epochs
if use_12epochs_result: arch2infos = self.arch2infos_less
else : arch2infos = self.arch2infos_full
if arch_index in arch2infos:
strings = print_information(arch2infos[ arch_index ], 'arch-index={:}'.format(arch_index))
return '\n'.join(strings)
else:
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
return None
def query_by_index(self, arch_index, dataname):
assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index)
archInfo = copy.deepcopy( self.arch2infos[ arch_index ] )
def query_by_index(self, arch_index, dataname, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
info = archInfo.query(dataname)
return info
def query_meta_info_by_index(self, arch_index):
assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index)
archInfo = copy.deepcopy( self.arch2infos[ arch_index ] )
def query_meta_info_by_index(self, arch_index, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
return archInfo
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None):
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
best_index, highest_accuracy = -1, None
for i, idx in enumerate(self.evaluated_indexes):
flop, param, latency = self.arch2infos[idx].get_comput_costs(dataset)
flop, param, latency = arch2infos[idx].get_comput_costs(dataset)
if FLOP_max is not None and flop > FLOP_max : continue
if Param_max is not None and param > Param_max: continue
loss, accuracy = self.arch2infos[idx].get_metrics(dataset, metric_on_set)
loss, accuracy = arch2infos[idx].get_metrics(dataset, metric_on_set)
if best_index == -1:
best_index, highest_accuracy = idx, accuracy
elif highest_accuracy < accuracy:
@@ -113,21 +130,29 @@ class NASBench102API(object):
return copy.deepcopy(self.meta_archs[index])
def show(self, index=-1):
if index == -1: # show all architectures
if index < 0: # show all architectures
print(self)
for i, idx in enumerate(self.evaluated_indexes):
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
print('arch : {:}'.format(self.meta_archs[idx]))
strings = print_information(self.arch2infos[idx])
print('>' * 20)
strings = print_information(self.arch2infos_full[idx])
print('>' * 40 + ' 200 epochs ' + '>' * 40)
print('\n'.join(strings))
print('<' * 20)
strings = print_information(self.arch2infos_less[idx])
print('>' * 40 + ' 12 epochs ' + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
if 0 <= index < len(self.meta_archs):
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
else:
strings = print_information(self.arch2infos[index])
strings = print_information(self.arch2infos_full[index])
print('>' * 40 + ' 200 epochs ' + '>' * 40)
print('\n'.join(strings))
strings = print_information(self.arch2infos_less[index])
print('>' * 40 + ' 12 epochs ' + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))