update baseline NAS algos

This commit is contained in:
D-X-Y
2019-11-14 13:55:42 +11:00
parent 5c73aeb50b
commit 7843940846
13 changed files with 924 additions and 33 deletions

View File

@@ -1,5 +1,5 @@
import os, sys, copy, torch, numpy as np
from collections import OrderedDict
def print_information(information, extra_info=None, show=False):
@@ -29,20 +29,26 @@ def print_information(information, extra_info=None, show=False):
class AANASBenchAPI(object):
def __init__(self, file_path_or_dict):
def __init__(self, file_path_or_dict, verbose=True):
if isinstance(file_path_or_dict, str):
if verbose: print('try to create AA-NAS-Bench api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
file_path_or_dict = torch.load(file_path_or_dict)
else:
file_path_or_dict = copy.deepcopy( file_path_or_dict )
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
self.arch2infos = copy.deepcopy( file_path_or_dict['arch2infos'] )
self.evaluated_indexes = sorted(list( copy.deepcopy( file_path_or_dict['evaluated_indexes'] ) ))
self.arch2infos = OrderedDict()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
self.arch2infos[xkey] = ArchResults.create_from_state_dict( file_path_or_dict['arch2infos'][xkey] )
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()])
self.archstr2index[ arch.tostr() ] = idx
#assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()])
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx
def __getitem__(self, index):
return copy.deepcopy( self.meta_archs[index] )
@@ -54,12 +60,12 @@ class AANASBenchAPI(object):
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
def query_index_by_arch(self, arch):
if arch.tostr() in self.archstr2index:
arch_index = self.archstr2index[ arch.tostr() ]
#else:
# arch_str = Structure.str2fullstructure( arch.tostr() ).tostr()
# if arch_str in self.archstr2index:
# arch_index = self.archstr2index[ arch_str ]
if isinstance(arch, str):
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
else : arch_index = -1
elif hasattr(arch, 'tostr'):
if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
else : arch_index = -1
else: arch_index = -1
return arch_index
@@ -80,6 +86,11 @@ class AANASBenchAPI(object):
info = archInfo.query(dataname)
return info
def query_meta_info_by_index(self, arch_index):
assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index)
archInfo = copy.deepcopy( self.arch2infos[ arch_index ] )
return archInfo
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None):
best_index, highest_accuracy = -1, None
for i, idx in enumerate(self.evaluated_indexes):