updates
This commit is contained in:
2
lib/aa_nas_api/__init__.py
Normal file
2
lib/aa_nas_api/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .api import AANASBenchAPI
|
||||
from .api import ArchResults, ResultsCount
|
290
lib/aa_nas_api/api.py
Normal file
290
lib/aa_nas_api/api.py
Normal file
@@ -0,0 +1,290 @@
|
||||
import os, sys, copy, torch, numpy as np
|
||||
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
dataset_names = information.get_dataset_names()
|
||||
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
|
||||
def metric2str(loss, acc):
|
||||
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
|
||||
|
||||
for ida, dataset in enumerate(dataset_names):
|
||||
flop, param, latency = information.get_comput_costs(dataset)
|
||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency > 0 else None)
|
||||
train_loss, train_acc = information.get_metrics(dataset, 'train')
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_loss, valid_acc = information.get_metrics(dataset, 'x-valid')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(valid_loss, valid_acc))
|
||||
elif dataset == 'cifar10':
|
||||
test__loss, test__acc = information.get_metrics(dataset, 'ori-test')
|
||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(test__loss, test__acc))
|
||||
else:
|
||||
valid_loss, valid_acc = information.get_metrics(dataset, 'x-valid')
|
||||
test__loss, test__acc = information.get_metrics(dataset, 'x-test')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(valid_loss, valid_acc), metric2str(test__loss, test__acc))
|
||||
strings += [str1, str2]
|
||||
if show: print('\n'.join(strings))
|
||||
return strings
|
||||
|
||||
|
||||
class AANASBenchAPI(object):
|
||||
|
||||
def __init__(self, file_path_or_dict):
|
||||
if isinstance(file_path_or_dict, str):
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
file_path_or_dict = torch.load(file_path_or_dict)
|
||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
self.arch2infos = copy.deepcopy( file_path_or_dict['arch2infos'] )
|
||||
self.evaluated_indexes = sorted(list( copy.deepcopy( file_path_or_dict['evaluated_indexes'] ) ))
|
||||
self.archstr2index = {}
|
||||
for idx, arch in enumerate(self.meta_archs):
|
||||
assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()])
|
||||
self.archstr2index[ arch.tostr() ] = idx
|
||||
|
||||
def __getitem__(self, index):
|
||||
return copy.deepcopy( self.meta_archs[index] )
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta_archs)
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
|
||||
|
||||
def query_index_by_arch(self, arch):
|
||||
if arch.tostr() in self.archstr2index:
|
||||
arch_index = self.archstr2index[ arch.tostr() ]
|
||||
#else:
|
||||
# arch_str = Structure.str2fullstructure( arch.tostr() ).tostr()
|
||||
# if arch_str in self.archstr2index:
|
||||
# arch_index = self.archstr2index[ arch_str ]
|
||||
else: arch_index = -1
|
||||
return arch_index
|
||||
|
||||
def query_by_arch(self, arch):
|
||||
arch_index = self.query_index_by_arch(arch)
|
||||
if arch_index == -1: return None
|
||||
if arch_index in self.arch2infos:
|
||||
strings = print_information(self.arch2infos[ arch_index ], 'arch-index={:}'.format(arch_index))
|
||||
return '\n'.join(strings)
|
||||
else:
|
||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||
return None
|
||||
|
||||
def query_by_index(self, arch_index, dataname):
|
||||
assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index)
|
||||
archInfo = copy.deepcopy( self.arch2infos[ arch_index ] )
|
||||
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
|
||||
info = archInfo.query(dataname)
|
||||
return info
|
||||
|
||||
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None):
|
||||
best_index, highest_accuracy = -1, None
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
flop, param, latency = self.arch2infos[idx].get_comput_costs(dataset)
|
||||
if FLOP_max is not None and flop > FLOP_max : continue
|
||||
if Param_max is not None and param > Param_max: continue
|
||||
loss, accuracy = self.arch2infos[idx].get_metrics(dataset, metric_on_set)
|
||||
if best_index == -1:
|
||||
best_index, highest_accuracy = idx, accuracy
|
||||
elif highest_accuracy < accuracy:
|
||||
best_index, highest_accuracy = idx, accuracy
|
||||
return best_index
|
||||
|
||||
def arch(self, index):
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
def show(self, index=-1):
|
||||
if index == -1: # show all architectures
|
||||
print(self)
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
|
||||
print('arch : {:}'.format(self.meta_archs[idx]))
|
||||
strings = print_information(self.arch2infos[idx])
|
||||
print('>' * 20)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 20)
|
||||
else:
|
||||
if 0 <= index < len(self.meta_archs):
|
||||
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
|
||||
else:
|
||||
strings = print_information(self.arch2infos[index])
|
||||
print('\n'.join(strings))
|
||||
else:
|
||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||
|
||||
|
||||
|
||||
class ArchResults(object):
|
||||
|
||||
def __init__(self, arch_index, arch_str):
|
||||
self.arch_index = int(arch_index)
|
||||
self.arch_str = copy.deepcopy(arch_str)
|
||||
self.all_results = dict()
|
||||
self.dataset_seed = dict()
|
||||
self.clear_net_done = False
|
||||
|
||||
def get_comput_costs(self, dataset):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
flops = [result.flop for result in results]
|
||||
params = [result.params for result in results]
|
||||
lantencies = [result.get_latency() for result in results]
|
||||
return np.mean(flops), np.mean(params), np.mean(lantencies)
|
||||
|
||||
def get_metrics(self, dataset, setname, iepoch=None):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
loss, accuracy = [], []
|
||||
for result in results:
|
||||
if setname == 'train':
|
||||
info = result.get_train(iepoch)
|
||||
else:
|
||||
info = result.get_eval(setname, iepoch)
|
||||
loss.append( info['loss'] )
|
||||
accuracy.append( info['accuracy'] )
|
||||
return float(np.mean(loss)), float(np.mean(accuracy))
|
||||
|
||||
def show(self, is_print=False):
|
||||
return print_information(self, None, is_print)
|
||||
|
||||
def get_dataset_names(self):
|
||||
return list(self.dataset_seed.keys())
|
||||
|
||||
def query(self, dataset, seed=None):
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
else:
|
||||
return self.all_results[ (dataset, seed) ]
|
||||
|
||||
def arch_idx_str(self):
|
||||
return '{:06d}'.format(self.arch_index)
|
||||
|
||||
def update(self, dataset_name, seed, result):
|
||||
if dataset_name not in self.dataset_seed:
|
||||
self.dataset_seed[dataset_name] = []
|
||||
assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name)
|
||||
self.dataset_seed[ dataset_name ].append( seed )
|
||||
self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] )
|
||||
assert (dataset_name, seed) not in self.all_results
|
||||
self.all_results[ (dataset_name, seed) ] = result
|
||||
self.clear_net_done = False
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = dict()
|
||||
for key, value in self.__dict__.items():
|
||||
if key == 'all_results': # contain the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v))
|
||||
xvalue[_k] = _v.state_dict()
|
||||
else:
|
||||
xvalue = value
|
||||
state_dict[key] = xvalue
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
new_state_dict = dict()
|
||||
for key, value in state_dict.items():
|
||||
if key == 'all_results': # to convert to the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
xvalue[_k] = ResultsCount.create_from_state_dict(_v)
|
||||
else: xvalue = value
|
||||
new_state_dict[key] = xvalue
|
||||
self.__dict__.update(new_state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict_or_file):
|
||||
x = ArchResults(-1, -1)
|
||||
if isinstance(state_dict_or_file, str): # a file path
|
||||
state_dict = torch.load(state_dict_or_file)
|
||||
elif isinstance(state_dict_or_file, dict):
|
||||
state_dict = state_dict_or_file
|
||||
else:
|
||||
raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file)))
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
||||
|
||||
def clear_params(self):
|
||||
for key, result in self.all_results.items():
|
||||
result.net_state_dict = None
|
||||
self.clear_net_done = True
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
|
||||
|
||||
|
||||
|
||||
class ResultsCount(object):
|
||||
|
||||
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
|
||||
self.name = name
|
||||
self.net_state_dict = state_dict
|
||||
self.train_accs = copy.deepcopy(train_accs)
|
||||
self.train_losses = copy.deepcopy(train_losses)
|
||||
self.arch_config = copy.deepcopy(arch_config)
|
||||
self.params = params
|
||||
self.flop = flop
|
||||
self.seed = seed
|
||||
self.epochs = epochs
|
||||
self.latency = latency
|
||||
# evaluation results
|
||||
self.reset_eval()
|
||||
|
||||
def reset_eval(self):
|
||||
self.eval_names = []
|
||||
self.eval_accs = {}
|
||||
self.eval_losses = {}
|
||||
|
||||
def update_latency(self, latency):
|
||||
self.latency = copy.deepcopy( latency )
|
||||
|
||||
def get_latency(self):
|
||||
if self.latency is None: return -1
|
||||
else: return sum(self.latency) / len(self.latency)
|
||||
|
||||
def update_eval(self, name, accs, losses):
|
||||
assert name not in self.eval_names, '{:} has already added'.format(name)
|
||||
self.eval_names.append( name )
|
||||
self.eval_accs[name] = copy.deepcopy( accs )
|
||||
self.eval_losses[name] = copy.deepcopy( losses )
|
||||
|
||||
def __repr__(self):
|
||||
num_eval = len(self.eval_names)
|
||||
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets)'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval))
|
||||
|
||||
def valid_evaluation_set(self):
|
||||
return self.eval_names
|
||||
|
||||
def get_train(self, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
return {'loss': self.train_losses[iepoch], 'accuracy': self.train_accs[iepoch]}
|
||||
|
||||
def get_eval(self, name, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
return {'loss': self.eval_losses[name][iepoch], 'accuracy': self.eval_accs[name][iepoch]}
|
||||
|
||||
def get_net_param(self):
|
||||
return self.net_state_dict
|
||||
|
||||
def state_dict(self):
|
||||
_state_dict = {key: value for key, value in self.__dict__.items()}
|
||||
return _state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict):
|
||||
x = ResultsCount(None, None, None, None, None, None, None, None, None, None)
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
@@ -13,7 +13,7 @@ from .cell_searchs import CellStructure, CellArchitectures
|
||||
|
||||
# Cell-based NAS Models
|
||||
def get_cell_based_tiny_net(config):
|
||||
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS']
|
||||
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
|
||||
from .cell_searchs import nas_super_nets
|
||||
if config.name in group_names:
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
|
@@ -3,10 +3,12 @@ from .search_model_darts_v2 import TinyNetworkDartsV2
|
||||
from .search_model_gdas import TinyNetworkGDAS
|
||||
from .search_model_setn import TinyNetworkSETN
|
||||
from .search_model_enas import TinyNetworkENAS
|
||||
from .search_model_random import TinyNetworkRANDOM
|
||||
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
|
||||
|
||||
nas_super_nets = {'DARTS-V1': TinyNetworkDartsV1,
|
||||
'DARTS-V2': TinyNetworkDartsV2,
|
||||
'GDAS' : TinyNetworkGDAS,
|
||||
'SETN' : TinyNetworkSETN,
|
||||
'ENAS' : TinyNetworkENAS}
|
||||
'ENAS' : TinyNetworkENAS,
|
||||
'RANDOM' : TinyNetworkRANDOM}
|
||||
|
@@ -60,6 +60,17 @@ class Structure:
|
||||
strings.append( string )
|
||||
return '+'.join(strings)
|
||||
|
||||
def check_valid(self):
|
||||
nodes = {0: True}
|
||||
for i, node_info in enumerate(self.nodes):
|
||||
sums = []
|
||||
for op, xin in node_info:
|
||||
if op == 'none' or nodes[xin] == False: x = False
|
||||
else: x = True
|
||||
sums.append( x )
|
||||
nodes[i+1] = sum(sums) > 0
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def to_unique_str(self, consider_zero=False):
|
||||
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
||||
# two operations are special, i.e., none and skip_connect
|
||||
|
79
lib/models/cell_searchs/search_model_random.py
Normal file
79
lib/models/cell_searchs/search_model_random.py
Normal file
@@ -0,0 +1,79 @@
|
||||
##############################################################################
|
||||
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
|
||||
##############################################################################
|
||||
import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkRANDOM(nn.Module):
|
||||
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
||||
super(TinyNetworkRANDOM, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C))
|
||||
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||
self.cells.append( cell )
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy( search_space )
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_cache = None
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def random_genotype(self, set_cache):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
op_name = random.choice( self.op_names )
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
arch = Structure( genotypes )
|
||||
if set_cache: self.arch_cache = arch
|
||||
return arch
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_dynamic(feature, self.arch_cache)
|
||||
else: feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
return out, logits
|
Reference in New Issue
Block a user