rm PD ; update NAS-Bench-102 baselines
This commit is contained in:
@@ -36,6 +36,7 @@ def get_cell_based_tiny_net(config):
|
||||
def get_search_spaces(xtype, name):
|
||||
if xtype == 'cell':
|
||||
from .cell_operations import SearchSpaceNames
|
||||
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
|
||||
return SearchSpaceNames[name]
|
||||
else:
|
||||
raise ValueError('invalid search-space type is {:}'.format(xtype))
|
||||
|
@@ -16,12 +16,13 @@ OPS = {
|
||||
'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine),
|
||||
}
|
||||
|
||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||
AA_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
|
||||
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
||||
'aa-nas' : AA_NAS_BENCHMARK,
|
||||
'full' : sorted(list(OPS.keys()))}
|
||||
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
||||
'aa-nas' : NAS_BENCH_102,
|
||||
'nas-bench-102': NAS_BENCH_102,
|
||||
'full' : sorted(list(OPS.keys()))}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
@@ -129,6 +129,27 @@ class NASBench102API(object):
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
def get_more_info(self, index, dataset, use_12epochs_result=False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
if dataset == 'cifar10-valid':
|
||||
train_info = archresult.get_metrics(dataset, 'train', is_random=True)
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', is_random=True)
|
||||
test__info = archresult.get_metrics(dataset, 'ori-test', is_random=True)
|
||||
total = train_info['iepoch'] + 1
|
||||
return {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'train-all-time': train_info['all_time'],
|
||||
'valid-loss' : valid_info['loss'],
|
||||
'valid-accuracy': valid_info['accuracy'],
|
||||
'valid-all-time': valid_info['all_time'],
|
||||
'valid-per-time': valid_info['all_time'] / total,
|
||||
'test-loss' : test__info['loss'],
|
||||
'test-accuracy' : test__info['accuracy']}
|
||||
else:
|
||||
raise ValueError('coming soon...')
|
||||
|
||||
def show(self, index=-1):
|
||||
if index < 0: # show all architectures
|
||||
print(self)
|
||||
@@ -367,23 +388,28 @@ class ResultsCount(object):
|
||||
def get_train(self, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
if self.train_times is not None: xtime = self.train_times[iepoch]
|
||||
else : xtime = None
|
||||
if self.train_times is not None:
|
||||
xtime = self.train_times[iepoch]
|
||||
atime = sum([self.train_times[i] for i in range(iepoch+1)])
|
||||
else: xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.train_losses[iepoch],
|
||||
'accuracy': self.train_acc1es[iepoch],
|
||||
'time' : xtime}
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
|
||||
def get_eval(self, name, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
||||
xtime = self.eval_times['{:}@{:}'.format(name,iepoch)]
|
||||
else: xtime = None
|
||||
atime = sum([self.eval_times['{:}@{:}'.format(name,i)] for i in range(iepoch+1)])
|
||||
else: xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.eval_losses['{:}@{:}'.format(name,iepoch)],
|
||||
'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)],
|
||||
'time' : xtime}
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
|
||||
def get_net_param(self):
|
||||
return self.net_state_dict
|
||||
|
Reference in New Issue
Block a user