update for NAS-Bench-102
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .api import AANASBenchAPI
|
||||
from .api import NASBench102API
|
||||
from .api import ArchResults, ResultsCount
|
@@ -2,7 +2,7 @@
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, copy, random, torch, numpy as np
|
||||
from collections import OrderedDict
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
@@ -30,16 +30,17 @@ def print_information(information, extra_info=None, show=False):
|
||||
return strings
|
||||
|
||||
|
||||
class AANASBenchAPI(object):
|
||||
class NASBench102API(object):
|
||||
|
||||
def __init__(self, file_path_or_dict, verbose=True):
|
||||
if isinstance(file_path_or_dict, str):
|
||||
if verbose: print('try to create AA-NAS-Bench api from {:}'.format(file_path_or_dict))
|
||||
if verbose: print('try to create NAS-Bench-102 api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
file_path_or_dict = torch.load(file_path_or_dict)
|
||||
else:
|
||||
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||
import pdb; pdb.set_trace() # we will update this api soon
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
@@ -144,27 +145,46 @@ class ArchResults(object):
|
||||
def get_comput_costs(self, dataset):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
flops = [result.flop for result in results]
|
||||
params = [result.params for result in results]
|
||||
|
||||
flops = [result.flop for result in results]
|
||||
params = [result.params for result in results]
|
||||
lantencies = [result.get_latency() for result in results]
|
||||
return np.mean(flops), np.mean(params), np.mean(lantencies)
|
||||
lantencies = [x for x in lantencies if x > 0]
|
||||
mean_latency = np.mean(lantencies) if len(lantencies) > 0 else None
|
||||
time_infos = defaultdict(list)
|
||||
for result in results:
|
||||
time_info = result.get_times()
|
||||
for key, value in time_info.items(): time_infos[key].append( value )
|
||||
|
||||
info = {'flops' : np.mean(flops),
|
||||
'params' : np.mean(params),
|
||||
'latency': mean_latency}
|
||||
for key, value in time_infos.items():
|
||||
if len(value) > 0 and value[0] is not None:
|
||||
info[key] = np.mean(value)
|
||||
else: info[key] = None
|
||||
return info
|
||||
|
||||
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
loss, accuracy = [], []
|
||||
infos = defaultdict(list)
|
||||
for result in results:
|
||||
if setname == 'train':
|
||||
info = result.get_train(iepoch)
|
||||
else:
|
||||
info = result.get_eval(setname, iepoch)
|
||||
loss.append( info['loss'] )
|
||||
accuracy.append( info['accuracy'] )
|
||||
for key, value in info.items(): infos[key].append( value )
|
||||
return_info = dict()
|
||||
if is_random:
|
||||
index = random.randint(0, len(loss)-1)
|
||||
return loss[index], accuracy[index]
|
||||
index = random.randint(0, len(results)-1)
|
||||
for key, value in infos.items(): return_info[key] = value[index]
|
||||
else:
|
||||
return float(np.mean(loss)), float(np.mean(accuracy))
|
||||
for key, value in infos.items():
|
||||
if len(value) > 0 and value[0] is not None:
|
||||
return_info[key] = np.mean(value)
|
||||
else: return_info[key] = None
|
||||
return return_info
|
||||
|
||||
def show(self, is_print=False):
|
||||
return print_information(self, None, is_print)
|
||||
@@ -245,8 +265,10 @@ class ResultsCount(object):
|
||||
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
|
||||
self.name = name
|
||||
self.net_state_dict = state_dict
|
||||
self.train_accs = copy.deepcopy(train_accs)
|
||||
self.train_acc1es = copy.deepcopy(train_accs)
|
||||
self.train_acc5es = None
|
||||
self.train_losses = copy.deepcopy(train_losses)
|
||||
self.train_times = None
|
||||
self.arch_config = copy.deepcopy(arch_config)
|
||||
self.params = params
|
||||
self.flop = flop
|
||||
@@ -256,44 +278,97 @@ class ResultsCount(object):
|
||||
# evaluation results
|
||||
self.reset_eval()
|
||||
|
||||
def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times):
|
||||
self.train_acc1es = train_acc1es
|
||||
self.train_acc5es = train_acc5es
|
||||
self.train_losses = train_losses
|
||||
self.train_times = train_times
|
||||
|
||||
def reset_eval(self):
|
||||
self.eval_names = []
|
||||
self.eval_accs = {}
|
||||
self.eval_acc1es = {}
|
||||
self.eval_times = {}
|
||||
self.eval_losses = {}
|
||||
|
||||
def update_latency(self, latency):
|
||||
self.latency = copy.deepcopy( latency )
|
||||
|
||||
def update_eval(self, accs, losses, times): # old version
|
||||
data_names = set([x.split('@')[0] for x in accs.keys()])
|
||||
for data_name in data_names:
|
||||
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
|
||||
self.eval_names.append( data_name )
|
||||
for iepoch in range(self.epochs):
|
||||
xkey = '{:}@{:}'.format(data_name, iepoch)
|
||||
self.eval_acc1es[ xkey ] = accs[ xkey ]
|
||||
self.eval_losses[ xkey ] = losses[ xkey ]
|
||||
self.eval_times [ xkey ] = times[ xkey ]
|
||||
|
||||
def update_OLD_eval(self, name, accs, losses): # old version
|
||||
assert name not in self.eval_names, '{:} has already added'.format(name)
|
||||
self.eval_names.append( name )
|
||||
for iepoch in range(self.epochs):
|
||||
if iepoch in accs:
|
||||
self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch]
|
||||
self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch]
|
||||
|
||||
def __repr__(self):
|
||||
num_eval = len(self.eval_names)
|
||||
set_name = '[' + ', '.join(self.eval_names) + ']'
|
||||
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
|
||||
|
||||
def get_latency(self):
|
||||
if self.latency is None: return -1
|
||||
else: return sum(self.latency) / len(self.latency)
|
||||
|
||||
def update_eval(self, name, accs, losses):
|
||||
assert name not in self.eval_names, '{:} has already added'.format(name)
|
||||
self.eval_names.append( name )
|
||||
self.eval_accs[name] = copy.deepcopy( accs )
|
||||
self.eval_losses[name] = copy.deepcopy( losses )
|
||||
def get_times(self):
|
||||
if self.train_times is not None and isinstance(self.train_times, dict):
|
||||
train_times = list( self.train_times.values() )
|
||||
time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)}
|
||||
for name in self.eval_names:
|
||||
xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)]
|
||||
time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes)
|
||||
time_info['T-{:}@total'.format(name)] = np.sum(xtimes)
|
||||
else:
|
||||
time_info = {'T-train@epoch': None, 'T-train@total': None }
|
||||
for name in self.eval_names:
|
||||
time_info['T-{:}@epoch'.format(name)] = None
|
||||
time_info['T-{:}@total'.format(name)] = None
|
||||
return time_info
|
||||
|
||||
def __repr__(self):
|
||||
num_eval = len(self.eval_names)
|
||||
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets)'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval))
|
||||
|
||||
def valid_evaluation_set(self):
|
||||
def get_eval_set(self):
|
||||
return self.eval_names
|
||||
|
||||
def get_train(self, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
return {'loss': self.train_losses[iepoch], 'accuracy': self.train_accs[iepoch]}
|
||||
if self.train_times is not None: xtime = self.train_times[iepoch]
|
||||
else : xtime = None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.train_losses[iepoch],
|
||||
'accuracy': self.train_acc1es[iepoch],
|
||||
'time' : xtime}
|
||||
|
||||
def get_eval(self, name, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
return {'loss': self.eval_losses[name][iepoch], 'accuracy': self.eval_accs[name][iepoch]}
|
||||
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
||||
xtime = self.eval_times['{:}@{:}'.format(name,iepoch)]
|
||||
else: xtime = None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.eval_losses['{:}@{:}'.format(name,iepoch)],
|
||||
'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)],
|
||||
'time' : xtime}
|
||||
|
||||
def get_net_param(self):
|
||||
return self.net_state_dict
|
||||
|
||||
def get_config(self, str2structure):
|
||||
#return copy.deepcopy(self.arch_config)
|
||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \
|
||||
'N' : self.arch_config['num_cells'], \
|
||||
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
|
||||
|
||||
def state_dict(self):
|
||||
_state_dict = {key: value for key, value in self.__dict__.items()}
|
||||
return _state_dict
|
Reference in New Issue
Block a user