update NAS-Bench-102

This commit is contained in:
D-X-Y
2019-12-26 23:29:36 +11:00
parent 1d5e8debad
commit d791622b63
3 changed files with 116 additions and 2 deletions

View File

@@ -78,6 +78,16 @@ class NASBench102API(object):
else : arch_index = -1
else: arch_index = -1
return arch_index
def reload(self, archive_root, index):
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path)
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
def query_by_arch(self, arch, use_12epochs_result=False):
if isinstance(arch, int):
@@ -125,10 +135,18 @@ class NASBench102API(object):
best_index, highest_accuracy = idx, accuracy
return best_index
# return the topology structure of the `index`-th architecture
def arch(self, index):
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
return copy.deepcopy(self.meta_archs[index])
# obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
def get_net_param(self, index, dataset, seed, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index]
return archresult.get_net_param(dataset, seed)
def get_more_info(self, index, dataset, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
@@ -238,6 +256,13 @@ class ArchResults(object):
def get_dataset_names(self):
return list(self.dataset_seed.keys())
def get_net_param(self, dataset, seed=None):
if seed is None:
x_seeds = self.dataset_seed[dataset]
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
else:
return self.all_results[(dataset, seed)].get_net_param()
def query(self, dataset, seed=None):
if seed is None:
x_seeds = self.dataset_seed[dataset]