update NAS-Bench-102
This commit is contained in:
@@ -78,6 +78,16 @@ class NASBench102API(object):
|
||||
else : arch_index = -1
|
||||
else: arch_index = -1
|
||||
return arch_index
|
||||
|
||||
def reload(self, archive_root, index):
|
||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
|
||||
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
||||
xdata = torch.load(xfile_path)
|
||||
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
||||
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
|
||||
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
|
||||
|
||||
def query_by_arch(self, arch, use_12epochs_result=False):
|
||||
if isinstance(arch, int):
|
||||
@@ -125,10 +135,18 @@ class NASBench102API(object):
|
||||
best_index, highest_accuracy = idx, accuracy
|
||||
return best_index
|
||||
|
||||
# return the topology structure of the `index`-th architecture
|
||||
def arch(self, index):
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
# obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
|
||||
def get_net_param(self, index, dataset, seed, use_12epochs_result=False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
return archresult.get_net_param(dataset, seed)
|
||||
|
||||
def get_more_info(self, index, dataset, use_12epochs_result=False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
@@ -238,6 +256,13 @@ class ArchResults(object):
|
||||
def get_dataset_names(self):
|
||||
return list(self.dataset_seed.keys())
|
||||
|
||||
def get_net_param(self, dataset, seed=None):
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
|
||||
else:
|
||||
return self.all_results[(dataset, seed)].get_net_param()
|
||||
|
||||
def query(self, dataset, seed=None):
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
|
Reference in New Issue
Block a user