change batchsize in DARTS-NASNet to 64 ; add some type checking

This commit is contained in:
D-X-Y
2020-02-07 10:15:58 +11:00
parent 923b0fe9cf
commit 1efe3cbccf
4 changed files with 16 additions and 10 deletions

View File

@@ -37,9 +37,12 @@ def print_information(information, extra_info=None, show=False):
if show: print('\n'.join(strings))
return strings
"""
This is the class for API of NAS-Bench-201.
"""
class NASBench201API(object):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict, verbose=True):
if isinstance(file_path_or_dict, str):
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
@@ -49,6 +52,7 @@ class NASBench201API(object):
file_path_or_dict = copy.deepcopy( file_path_or_dict )
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )