Update NATS-Bench (sss version 1.2)

This commit is contained in:
D-X-Y
2020-08-30 08:04:52 +00:00
parent 469a207945
commit 5f151d1970
15 changed files with 317 additions and 229 deletions

View File

@@ -11,7 +11,6 @@
# python exps/NATS-Bench/sss-collect.py #
##############################################################################
import os, re, sys, time, shutil, argparse, collections
import numpy as np
import torch
from tqdm import tqdm
from pathlib import Path
@@ -22,7 +21,7 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from log_utils import AverageMeter, time_string, convert_secs2time
from config_utils import dict2config
from models import CellStructure, get_cell_based_tiny_net
from nas_201_api import ArchResults, ResultsCount
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from utils import get_md5_file
@@ -193,8 +192,8 @@ def simplify(save_dir, save_name, nets, total):
arch_str = nets[index]
hp2info = OrderedDict()
full_save_path = full_save_dir / '{:06d}.npy'.format(index)
simple_save_path = simple_save_dir / '{:06d}.npy'.format(index)
full_save_path = full_save_dir / '{:06d}.pickle'.format(index)
simple_save_path = simple_save_dir / '{:06d}.pickle'.format(index)
for hp in hps:
sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
@@ -213,13 +212,13 @@ def simplify(save_dir, save_name, nets, total):
to_save_data = OrderedDict({'01': hp2info['01'].state_dict(),
'12': hp2info['12'].state_dict(),
'90': hp2info['90'].state_dict()})
np.save(str(full_save_path), to_save_data)
pickle_save(to_save_data, str(full_save_path))
for hp in hps: hp2info[hp].clear_params()
to_save_data = OrderedDict({'01': hp2info['01'].state_dict(),
'12': hp2info['12'].state_dict(),
'90': hp2info['90'].state_dict()})
np.save(str(simple_save_path), to_save_data)
pickle_save(to_save_data, str(simple_save_path))
arch2infos[index] = to_save_data
# measure elapsed time
arch_time.update(time.time() - end_time)
@@ -231,18 +230,23 @@ def simplify(save_dir, save_name, nets, total):
'total_archs': total,
'arch2infos' : arch2infos,
'evaluated_indexes': evaluated_indexes}
save_file_name = save_dir / '{:}.npy'.format(save_name)
np.save(str(save_file_name), final_infos)
save_file_name = save_dir / '{:}.pickle'.format(save_name)
pickle_save(final_infos, str(save_file_name))
# move the benchmark file to a new path
hd5sum = get_md5_file(save_file_name)
hd5_file_name = save_dir / '{:}-{:}.npy'.format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(save_file_name, hd5_file_name)
hd5sum = get_md5_file(str(save_file_name) + '.pbz2')
hd5_file_name = save_dir / '{:}-{:}.pickle.pbz2'.format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(str(save_file_name) + '.pbz2', hd5_file_name)
print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name))
# move the directory to a new path
hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum)
hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(full_save_dir, hd5_full_save_dir)
shutil.move(simple_save_dir, hd5_simple_save_dir)
# save the meta information for simple and full
final_infos['arch2infos'] = None
final_infos['evaluated_indexes'] = set()
pickle_save(final_infos, str(hd5_full_save_dir / 'meta.pickle'))
pickle_save(final_infos, str(hd5_simple_save_dir / 'meta.pickle'))
def traverse_net(candidates: List[int], N: int):

View File

@@ -0,0 +1,97 @@
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# Usage: python exps/NATS-Bench/test-nats-api.py #
##############################################################################
import os, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create
from log_utils import time_string
from models import get_cell_based_tiny_net, CellStructure
def test_api(api, is_301=True):
print('{:} start testing the api : {:}'.format(time_string(), api))
api.clear_params(12)
api.reload(index=12)
# Query the informations of 1113-th architecture
info_strs = api.query_info_str_by_arch(1113)
print(info_strs)
info = api.query_by_index(113)
print('{:}\n'.format(info))
info = api.query_by_index(113, 'cifar100')
print('{:}\n'.format(info))
info = api.query_meta_info_by_index(115, '90' if is_301 else '200')
print('{:}\n'.format(info))
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
for xset in ['train', 'test', 'valid']:
best_index, highest_accuracy = api.find_best(dataset, xset)
print('')
params = api.get_net_param(12, 'cifar10', None)
# Obtain the config and create the network
config = api.get_net_config(12, 'cifar10')
print('{:}\n'.format(config))
network = get_cell_based_tiny_net(config)
network.load_state_dict(next(iter(params.values())))
# Obtain the cost information
info = api.get_cost_info(12, 'cifar10')
print('{:}\n'.format(info))
info = api.get_latency(12, 'cifar10')
print('{:}\n'.format(info))
for index in [13, 15, 19, 200]:
info = api.get_latency(index, 'cifar10')
# Count the number of architectures
info = api.statistics('cifar100', '12')
print('{:} statistics results : {:}\n'.format(time_string(), info))
# Show the information of the 123-th architecture
api.show(123)
# Obtain both cost and performance information
info = api.get_more_info(1234, 'cifar10')
print('{:}\n'.format(info))
print('{:} finish testing the api : {:}'.format(time_string(), api))
if not is_301:
arch_str = '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
matrix = api.str2matrix(arch_str)
print('Compute the adjacency matrix of {:}'.format(arch_str))
print(matrix)
info = api.simulate_train_eval(123, 'cifar10')
print('simulate_train_eval : {:}\n\n'.format(info))
if __name__ == '__main__':
for fast_mode in [True, False]:
for verbose in [True, False]:
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
api301 = create(None, 'size', fast_mode=fast_mode, verbose=True)
print('{:} --->>> {:}'.format(time_string(), api301))
test_api(api301, True)
# api201 = create(None, 'topology', True) # use the default file path
# test_api(api201, False)
# print ('Test {:} done'.format(api201))