change batchsize in DARTS-NASNet to 64 ; add some type checking
This commit is contained in:
@@ -37,9 +37,12 @@ def print_information(information, extra_info=None, show=False):
|
||||
if show: print('\n'.join(strings))
|
||||
return strings
|
||||
|
||||
|
||||
"""
|
||||
This is the class for API of NAS-Bench-201.
|
||||
"""
|
||||
class NASBench201API(object):
|
||||
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict, verbose=True):
|
||||
if isinstance(file_path_or_dict, str):
|
||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||
@@ -49,6 +52,7 @@ class NASBench201API(object):
|
||||
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
||||
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
|
Reference in New Issue
Block a user