update baseline NAS algos
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import os, sys, copy, torch, numpy as np
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
@@ -29,20 +29,26 @@ def print_information(information, extra_info=None, show=False):
|
||||
|
||||
class AANASBenchAPI(object):
|
||||
|
||||
def __init__(self, file_path_or_dict):
|
||||
def __init__(self, file_path_or_dict, verbose=True):
|
||||
if isinstance(file_path_or_dict, str):
|
||||
if verbose: print('try to create AA-NAS-Bench api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
file_path_or_dict = torch.load(file_path_or_dict)
|
||||
else:
|
||||
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
self.arch2infos = copy.deepcopy( file_path_or_dict['arch2infos'] )
|
||||
self.evaluated_indexes = sorted(list( copy.deepcopy( file_path_or_dict['evaluated_indexes'] ) ))
|
||||
self.arch2infos = OrderedDict()
|
||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||
self.arch2infos[xkey] = ArchResults.create_from_state_dict( file_path_or_dict['arch2infos'][xkey] )
|
||||
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
|
||||
self.archstr2index = {}
|
||||
for idx, arch in enumerate(self.meta_archs):
|
||||
assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()])
|
||||
self.archstr2index[ arch.tostr() ] = idx
|
||||
#assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()])
|
||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||
self.archstr2index[ arch ] = idx
|
||||
|
||||
def __getitem__(self, index):
|
||||
return copy.deepcopy( self.meta_archs[index] )
|
||||
@@ -54,12 +60,12 @@ class AANASBenchAPI(object):
|
||||
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
|
||||
|
||||
def query_index_by_arch(self, arch):
|
||||
if arch.tostr() in self.archstr2index:
|
||||
arch_index = self.archstr2index[ arch.tostr() ]
|
||||
#else:
|
||||
# arch_str = Structure.str2fullstructure( arch.tostr() ).tostr()
|
||||
# if arch_str in self.archstr2index:
|
||||
# arch_index = self.archstr2index[ arch_str ]
|
||||
if isinstance(arch, str):
|
||||
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
|
||||
else : arch_index = -1
|
||||
elif hasattr(arch, 'tostr'):
|
||||
if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
|
||||
else : arch_index = -1
|
||||
else: arch_index = -1
|
||||
return arch_index
|
||||
|
||||
@@ -80,6 +86,11 @@ class AANASBenchAPI(object):
|
||||
info = archInfo.query(dataname)
|
||||
return info
|
||||
|
||||
def query_meta_info_by_index(self, arch_index):
|
||||
assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index)
|
||||
archInfo = copy.deepcopy( self.arch2infos[ arch_index ] )
|
||||
return archInfo
|
||||
|
||||
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None):
|
||||
best_index, highest_accuracy = -1, None
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
|
Reference in New Issue
Block a user