Update weight watcher codes

This commit is contained in:
D-X-Y
2020-07-05 22:29:26 +00:00
parent 9659f132be
commit 6facc39a42
9 changed files with 167 additions and 161 deletions

View File

@@ -77,6 +77,7 @@ class NASBench201API(NASBenchMetaAPI):
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set(['12', '200'])
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_info = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()

View File

@@ -75,11 +75,13 @@ class NASBench301API(NASBenchMetaAPI):
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
self.archstr2index = {}

View File

@@ -57,6 +57,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def __repr__(self):
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
@property
def avaliable_hps(self):
return list(copy.deepcopy(self._avaliable_hps))
def random(self):
"""Return a random index of all architectures."""
return random.randint(0, len(self.meta_archs)-1)