Support accumulate and reset time function for API

This commit is contained in:
D-X-Y
2020-07-13 02:53:11 +00:00
parent af1be7f740
commit 88a5be1368
6 changed files with 296 additions and 2 deletions

View File

@@ -55,10 +55,16 @@ def get_cell_based_tiny_net(config):
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
def get_search_spaces(xtype, name) -> List[Text]:
if xtype == 'cell':
if xtype == 'cell' or xtype == 'tss': # The topology search space.
from .cell_operations import SearchSpaceNames
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
return SearchSpaceNames[name]
elif xtype == 'sss': # The size search space.
if name == 'nas-bench-301':
return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64],
'numbers': 5}
else:
raise ValueError('Invalid name : {:}'.format(name))
else:
raise ValueError('invalid search-space type is {:}'.format(xtype))

View File

@@ -26,6 +26,7 @@ DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5',
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'nas-bench-201': NAS_BENCH_201,
'nas-bench-301': NAS_BENCH_201,
'darts' : DARTS_SPACE}

View File

@@ -58,6 +58,7 @@ class NASBench201API(NASBenchMetaAPI):
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None,
verbose: bool=True):
self.filename = None
self.reset_time()
if file_path_or_dict is None:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
print ('Try to use the default NAS-Bench-201 path from {:}.'.format(file_path_or_dict))

View File

@@ -57,6 +57,7 @@ class NASBench301API(NASBenchMetaAPI):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
self.filename = None
self.reset_time()
if file_path_or_dict is None:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
@@ -128,7 +129,7 @@ class NASBench301API(NASBenchMetaAPI):
"""
if self.verbose:
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
self._query_info_str_by_arch(arch, hp, print_information)
return self._query_info_str_by_arch(arch, hp, print_information)
def get_more_info(self, index: int, dataset: Text, iepoch=None, hp='12', is_random=True):
"""This function will return the metric for the `index`-th architecture

View File

@@ -61,6 +61,25 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def avaliable_hps(self):
return list(copy.deepcopy(self._avaliable_hps))
@property
def used_time(self):
return self._used_time
def reset_time(self):
self._used_time = 0
def simulate_train_eval(self, arch, dataset, hp='12'):
index = self.query_index_by_arch(arch)
all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
if dataset == 'cifar10':
info = self.get_more_info(index, 'cifar10-valid', iepoch=None, hp=hp, is_random=True)
else:
info = self.get_more_info(index, dataset, iepoch=None, hp=hp, is_random=True)
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
self._used_time += time_cost
return valid_acc, time_cost, self._used_time
def random(self):
"""Return a random index of all architectures."""
return random.randint(0, len(self.meta_archs)-1)