Update weight watcher codes
This commit is contained in:
@@ -77,6 +77,7 @@ class NASBench201API(NASBenchMetaAPI):
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
|
||||
self.arch2infos_dict = OrderedDict()
|
||||
self._avaliable_hps = set(['12', '200'])
|
||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||
all_info = file_path_or_dict['arch2infos'][xkey]
|
||||
hp2archres = OrderedDict()
|
||||
|
@@ -75,11 +75,13 @@ class NASBench301API(NASBenchMetaAPI):
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
|
||||
self.arch2infos_dict = OrderedDict()
|
||||
self._avaliable_hps = set()
|
||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||
all_infos = file_path_or_dict['arch2infos'][xkey]
|
||||
hp2archres = OrderedDict()
|
||||
for hp_key, results in all_infos.items():
|
||||
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
||||
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
|
||||
self.arch2infos_dict[xkey] = hp2archres
|
||||
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
|
||||
self.archstr2index = {}
|
||||
|
@@ -57,6 +57,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
def __repr__(self):
|
||||
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
|
||||
|
||||
@property
|
||||
def avaliable_hps(self):
|
||||
return list(copy.deepcopy(self._avaliable_hps))
|
||||
|
||||
def random(self):
|
||||
"""Return a random index of all architectures."""
|
||||
return random.randint(0, len(self.meta_archs)-1)
|
||||
|
Reference in New Issue
Block a user