Update test-weights of NAS-Bench-201

This commit is contained in:
D-X-Y
2020-03-16 11:11:01 +11:00
parent fb76814369
commit e76451c791
3 changed files with 58 additions and 14 deletions

View File

@@ -3,13 +3,15 @@
###############################################################################################
# Before run these commands, the files must be properly put.
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_0-e61699
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 1
###############################################################################################
import os, sys, time, glob, random, argparse
import os, gc, sys, time, glob, random, argparse
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from collections import OrderedDict
from tqdm import tqdm
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
@@ -24,30 +26,48 @@ def get_cor(A, B):
return float(np.corrcoef(A, B)[0,1])
def tostr(accdict, norms):
xstr = []
for key, accs in accdict.items():
cor = get_cor(accs, norms)
xstr.append('{:}: {:.3f}'.format(key, cor))
return ' '.join(xstr)
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool, valid_or_test: bool):
print('\nEvaluate dataset={:}'.format(data))
norms, accs = [], []
for idx in tqdm(range(len(api))):
final_accs = OrderedDict({'cifar10-valid': [], 'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
for idx in range(len(api)):
info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
if valid_or_test:
accs.append(info['valid-accuracy'])
else:
accs.append(info['test-accuracy'])
for key in final_accs.keys():
info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False)
final_accs[key].append(info['test-accuracy'])
config = api.get_net_config(idx, data)
net = get_cell_based_tiny_net(config)
api.reload(weight_dir, idx)
params = api.get_net_param(idx, data, None)
cur_norms = []
for seed, param in params.items():
net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append( summary['lognorm'] )
with torch.no_grad():
net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append( summary['lognorm'] )
norms.append( float(np.mean(cur_norms)) )
api.clear_params(idx, use_12epochs_result)
correlation = get_cor(norms, accs)
print('For {:} with {:} epochs on {:} : the correlation is {:}'.format(data, 12 if use_12epochs_result else 200, 'valid' if valid_or_test else 'test', correlation))
if idx % 200 == 199 or idx + 1 == len(api):
correlation = get_cor(norms, accs)
head = '{:05d}/{:05d}'.format(idx, len(api))
stem = tostr(final_accs, norms)
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}. {:}'.format(time_string(), head, data, 12 if use_12epochs_result else 200, 'valid' if valid_or_test else 'test', correlation, stem))
torch.cuda.empty_cache() ; gc.collect()
def main(meta_file: str, weight_dir, save_dir):
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result, valid_or_test):
api = API(meta_file)
datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
print(time_string() + ' ' + '='*50)
@@ -62,7 +82,8 @@ def main(meta_file: str, weight_dir, save_dir):
print('Using 200 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums))
print(time_string() + ' ' + '='*50)
evaluate(api, weight_dir, 'cifar10-valid', False, True)
#evaluate(api, weight_dir, 'cifar10-valid', False, True)
evaluate(api, weight_dir, xdata, use_12epochs_result, valid_or_test)
print('{:} finish this test.'.format(time_string()))
@@ -71,6 +92,9 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--base_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.')
parser.add_argument('--dataset' , type=str, default=None, help='.')
parser.add_argument('--use_12' , type=int, default=None, help='.')
parser.add_argument('--use_valid', type=int, default=None, help='.')
args = parser.parse_args()
save_dir = Path(args.save_dir)
@@ -80,5 +104,5 @@ if __name__ == '__main__':
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)
main(str(meta_file), weight_dir, save_dir)
main(str(meta_file), weight_dir, save_dir, args.dataset, bool(args.use_12), bool(args.use_valid))