Update test weights and shapes

This commit is contained in:
D-X-Y
2020-03-20 23:38:47 -07:00
parent d8784b3070
commit b702ddf5a2
5 changed files with 115 additions and 64 deletions

View File

@@ -4,18 +4,15 @@
# Before run these commands, the files must be properly put.
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_0-e61699
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 1
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
###############################################################################################
import os, gc, sys, time, glob, random, argparse
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from collections import OrderedDict
from tqdm import tqdm
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from nas_201_api import NASBench201API as API
from log_utils import time_string
from models import get_cell_based_tiny_net
@@ -34,19 +31,22 @@ def tostr(accdict, norms):
return ' '.join(xstr)
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool, valid_or_test: bool):
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
print('\nEvaluate dataset={:}'.format(data))
norms, accs = [], []
final_accs = OrderedDict({'cifar10-valid': [], 'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
norms = []
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
for idx in range(len(api)):
info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
if valid_or_test:
accs.append(info['valid-accuracy'])
else:
accs.append(info['test-accuracy'])
for key in final_accs.keys():
for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']:
info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False)
final_accs[key].append(info['test-accuracy'])
if key == 'cifar10-valid':
final_val_accs['cifar10'].append(info['valid-accuracy'])
elif key == 'cifar10':
final_test_accs['cifar10'].append(info['test-accuracy'])
else:
final_test_accs[key].append(info['test-accuracy'])
final_val_accs[key].append(info['valid-accuracy'])
config = api.get_net_config(idx, data)
net = get_cell_based_tiny_net(config)
api.reload(weight_dir, idx)
@@ -60,14 +60,15 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool, valid_or_tes
norms.append( float(np.mean(cur_norms)) )
api.clear_params(idx, use_12epochs_result)
if idx % 200 == 199 or idx + 1 == len(api):
correlation = get_cor(norms, accs)
head = '{:05d}/{:05d}'.format(idx, len(api))
stem = tostr(final_accs, norms)
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}. {:}'.format(time_string(), head, data, 12 if use_12epochs_result else 200, 'valid' if valid_or_test else 'test', correlation, stem))
stem_val = tostr(final_val_accs, norms)
stem_test = tostr(final_test_accs, norms)
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200))
print(' -->> {:} || {:}'.format(stem_val, stem_test))
torch.cuda.empty_cache() ; gc.collect()
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result, valid_or_test):
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result):
api = API(meta_file)
datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
print(time_string() + ' ' + '='*50)
@@ -83,7 +84,7 @@ def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result, valid
print(time_string() + ' ' + '='*50)
#evaluate(api, weight_dir, 'cifar10-valid', False, True)
evaluate(api, weight_dir, xdata, use_12epochs_result, valid_or_test)
evaluate(api, weight_dir, xdata, use_12epochs_result)
print('{:} finish this test.'.format(time_string()))
@@ -94,7 +95,6 @@ if __name__ == '__main__':
parser.add_argument('--base_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.')
parser.add_argument('--dataset' , type=str, default=None, help='.')
parser.add_argument('--use_12' , type=int, default=None, help='.')
parser.add_argument('--use_valid', type=int, default=None, help='.')
args = parser.parse_args()
save_dir = Path(args.save_dir)
@@ -104,5 +104,5 @@ if __name__ == '__main__':
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)
main(str(meta_file), weight_dir, save_dir, args.dataset, bool(args.use_12), bool(args.use_valid))
main(str(meta_file), weight_dir, save_dir, args.dataset, bool(args.use_12))