Upgrade NAS-API to v2.0:
we use an abstract class NASBenchMetaAPI to define the spec of an API; it can be inherited to support different kinds of NAS API, while keep the query interface the same.
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
# [2020.02.25] Initialize the API as v1.1
|
||||
# [2020.03.09] Upgrade the API to v1.2
|
||||
# [2020.03.16] Upgrade the API to v1.3
|
||||
# [2020.06.30] Upgrade the API to v2.0
|
||||
import os
|
||||
from setuptools import setup
|
||||
|
||||
@@ -15,7 +16,7 @@ def read(fname='README.md'):
|
||||
|
||||
setup(
|
||||
name = "nas_bench_201",
|
||||
version = "1.3",
|
||||
version = "2.0",
|
||||
author = "Xuanyi Dong",
|
||||
author_email = "dongxuanyi888@gmail.com",
|
||||
description = "API for NAS-Bench-201 (a benchmark for neural architecture search).",
|
||||
|
@@ -22,7 +22,7 @@ def create_result_count(used_seed: int, dataset: Text, arch_config: Dict[Text, A
|
||||
results: Dict[Text, Any], dataloader_dict: Dict[Text, Any]) -> ResultsCount:
|
||||
xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'],
|
||||
results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None)
|
||||
net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes':arch_config['class_num']}, None)
|
||||
net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes': arch_config['class_num']}, None)
|
||||
network = get_cell_based_tiny_net(net_config)
|
||||
network.load_state_dict(xresult.get_net_param())
|
||||
if 'train_times' in results: # new version
|
||||
@@ -126,7 +126,6 @@ def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch
|
||||
arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'ori-test', eval_per_sample * nums['ImageNet16-120-test'])
|
||||
# arch_info_full.debug_test()
|
||||
# arch_info_less.debug_test()
|
||||
# import pdb; pdb.set_trace()
|
||||
return arch_info_full, arch_info_less
|
||||
|
||||
|
||||
|
93
exps/NAS-Bench-201/test-nas-api-vis.py
Normal file
93
exps/NAS-Bench-201/test-nas-api-vis.py
Normal file
@@ -0,0 +1,93 @@
|
||||
###############################################################
|
||||
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
||||
###############################################################
|
||||
# Usage: python exps/NAS-Bench-201/test-nas-api-vis.py
|
||||
###############################################################
|
||||
import os, sys, time, torch, argparse
|
||||
import numpy as np
|
||||
from typing import List, Text, Dict, Any
|
||||
from shutil import copyfile
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import matplotlib
|
||||
import seaborn as sns
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import dict2config, load_config
|
||||
from nas_201_api import NASBench201API, NASBench301API
|
||||
from log_utils import time_string
|
||||
from models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
def visualize_info(api, vis_save_dir, indicator):
|
||||
vis_save_dir = vis_save_dir.resolve()
|
||||
# print ('{:} start to visualize {:} information'.format(time_string(), api))
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
|
||||
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
|
||||
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
|
||||
cifar010_info = torch.load(cifar010_cache_path)
|
||||
cifar100_info = torch.load(cifar100_cache_path)
|
||||
imagenet_info = torch.load(imagenet_cache_path)
|
||||
indexes = list(range(len(cifar010_info['params'])))
|
||||
|
||||
print ('{:} start to visualize relative ranking'.format(time_string()))
|
||||
|
||||
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
|
||||
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
|
||||
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
|
||||
|
||||
cifar100_labels, imagenet_labels = [], []
|
||||
for idx in cifar010_ord_indexes:
|
||||
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
|
||||
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
|
||||
print ('{:} prepare data done.'.format(time_string()))
|
||||
|
||||
dpi, width, height = 200, 1400, 800
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
LabelSize, LegendFontsize = 18, 12
|
||||
resnet_scale, resnet_alpha = 120, 0.5
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xlim(min(indexes), max(indexes))
|
||||
plt.ylim(min(indexes), max(indexes))
|
||||
# plt.ylabel('y').set_rotation(30)
|
||||
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical')
|
||||
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
|
||||
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
|
||||
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
|
||||
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
|
||||
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
|
||||
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
|
||||
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
|
||||
plt.grid(zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
|
||||
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--check_N', type=int, default=32768, help='For safety.')
|
||||
# use for train the model
|
||||
args = parser.parse_args()
|
||||
|
||||
visualize_info(None, Path('output/vis-nas-bench/'), 'tss')
|
||||
|
||||
visualize_info(None, Path('output/vis-nas-bench/'), 'sss')
|
283
exps/NAS-Bench-201/test-nas-api.py
Normal file
283
exps/NAS-Bench-201/test-nas-api.py
Normal file
@@ -0,0 +1,283 @@
|
||||
###############################################################
|
||||
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
||||
###############################################################
|
||||
# Usage: python exps/NAS-Bench-201/test-nas-api.py
|
||||
###############################################################
|
||||
import os, sys, time, torch, argparse
|
||||
import numpy as np
|
||||
from typing import List, Text, Dict, Any
|
||||
from shutil import copyfile
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import matplotlib
|
||||
import seaborn as sns
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import dict2config, load_config
|
||||
from nas_201_api import NASBench201API, NASBench301API
|
||||
from log_utils import time_string
|
||||
from models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
def test_api(api, is_301=True):
|
||||
print('{:} start testing the api : {:}'.format(time_string(), api))
|
||||
api.clear_params(12)
|
||||
api.reload(index=12)
|
||||
|
||||
# Query the informations of 1113-th architecture
|
||||
info_strs = api.query_info_str_by_arch(1113)
|
||||
print(info_strs)
|
||||
info = api.query_by_index(113)
|
||||
print('{:}\n'.format(info))
|
||||
info = api.query_by_index(113, 'cifar100')
|
||||
print('{:}\n'.format(info))
|
||||
|
||||
info = api.query_meta_info_by_index(115, '90' if is_301 else '200')
|
||||
print('{:}\n'.format(info))
|
||||
|
||||
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
|
||||
for xset in ['train', 'test', 'valid']:
|
||||
best_index, highest_accuracy = api.find_best(dataset, xset)
|
||||
print('')
|
||||
params = api.get_net_param(12, 'cifar10', None)
|
||||
|
||||
# obtain the config and create the network
|
||||
config = api.get_net_config(12, 'cifar10')
|
||||
print('{:}\n'.format(config))
|
||||
network = get_cell_based_tiny_net(config)
|
||||
network.load_state_dict(next(iter(params.values())))
|
||||
|
||||
# obtain the cost information
|
||||
info = api.get_cost_info(12, 'cifar10')
|
||||
print('{:}\n'.format(info))
|
||||
info = api.get_latency(12, 'cifar10')
|
||||
print('{:}\n'.format(info))
|
||||
|
||||
# count the number of architectures
|
||||
info = api.statistics('cifar100', '12')
|
||||
print('{:}\n'.format(info))
|
||||
|
||||
# show the information of the 123-th architecture
|
||||
api.show(123)
|
||||
|
||||
# obtain both cost and performance information
|
||||
info = api.get_more_info(1234, 'cifar10')
|
||||
print('{:}\n'.format(info))
|
||||
print('{:} finish testing the api : {:}'.format(time_string(), api))
|
||||
|
||||
|
||||
def visualize_sss_info(api, dataset, vis_save_dir):
|
||||
vis_save_dir = vis_save_dir.resolve()
|
||||
print ('{:} start to visualize {:} information'.format(time_string(), dataset))
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_file_path = vis_save_dir / '{:}-cache-sss-info.pth'.format(dataset)
|
||||
if not cache_file_path.exists():
|
||||
print ('Do not find cache file : {:}'.format(cache_file_path))
|
||||
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
|
||||
for index in range(len(api)):
|
||||
info = api.get_cost_info(index, dataset)
|
||||
params.append(info['params'])
|
||||
flops.append(info['flops'])
|
||||
# accuracy
|
||||
info = api.get_more_info(index, dataset, hp='90')
|
||||
train_accs.append(info['train-accuracy'])
|
||||
test_accs.append(info['test-accuracy'])
|
||||
if dataset == 'cifar10':
|
||||
info = api.get_more_info(index, 'cifar10-valid', hp='90')
|
||||
valid_accs.append(info['valid-accuracy'])
|
||||
else:
|
||||
valid_accs.append(info['valid-accuracy'])
|
||||
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs}
|
||||
torch.save(info, cache_file_path)
|
||||
else:
|
||||
print ('Find cache file : {:}'.format(cache_file_path))
|
||||
info = torch.load(cache_file_path)
|
||||
params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs']
|
||||
print ('{:} collect data done.'.format(time_string()))
|
||||
|
||||
pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64']
|
||||
pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid]
|
||||
largest_indexes = [api.query_index_by_arch('64:64:64:64:64')]
|
||||
|
||||
indexes = list(range(len(params)))
|
||||
dpi, width, height = 250, 8500, 1300
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
LabelSize, LegendFontsize = 24, 24
|
||||
# resnet_scale, resnet_alpha = 120, 0.5
|
||||
xscale, xalpha = 120, 0.8
|
||||
|
||||
fig, axs = plt.subplots(1, 4, figsize=figsize)
|
||||
# ax1, ax2, ax3, ax4, ax5 = axs
|
||||
for ax in axs:
|
||||
for tick in ax.xaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(LabelSize)
|
||||
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))
|
||||
for tick in ax.yaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(LabelSize)
|
||||
ax2, ax3, ax4, ax5 = axs
|
||||
# ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5))
|
||||
# ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||
# ax1.set_xlabel('architecture ID', fontsize=LabelSize)
|
||||
# ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize)
|
||||
|
||||
ax2.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax2.scatter([params[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
|
||||
ax2.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
|
||||
ax2.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||
ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
|
||||
ax2.legend(loc=4, fontsize=LegendFontsize)
|
||||
|
||||
ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax3.scatter([params[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
|
||||
ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
|
||||
ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||
ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize)
|
||||
ax3.legend(loc=4, fontsize=LegendFontsize)
|
||||
|
||||
ax4.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax4.scatter([flops[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
|
||||
ax4.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
|
||||
ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
|
||||
ax4.set_ylabel('train accuracy (%)', fontsize=LabelSize)
|
||||
ax4.legend(loc=4, fontsize=LegendFontsize)
|
||||
|
||||
ax5.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax5.scatter([flops[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
|
||||
ax5.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
|
||||
ax5.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
|
||||
ax5.set_ylabel('test accuracy (%)', fontsize=LabelSize)
|
||||
ax5.legend(loc=4, fontsize=LegendFontsize)
|
||||
|
||||
save_path = vis_save_dir / 'sss-{:}.png'.format(dataset)
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
plt.close('all')
|
||||
|
||||
|
||||
def visualize_tss_info(api, dataset, vis_save_dir):
|
||||
vis_save_dir = vis_save_dir.resolve()
|
||||
print ('{:} start to visualize {:} information'.format(time_string(), dataset))
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_file_path = vis_save_dir / '{:}-cache-tss-info.pth'.format(dataset)
|
||||
if not cache_file_path.exists():
|
||||
print ('Do not find cache file : {:}'.format(cache_file_path))
|
||||
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
|
||||
for index in range(len(api)):
|
||||
info = api.get_cost_info(index, dataset)
|
||||
params.append(info['params'])
|
||||
flops.append(info['flops'])
|
||||
# accuracy
|
||||
info = api.get_more_info(index, dataset, hp='200')
|
||||
train_accs.append(info['train-accuracy'])
|
||||
test_accs.append(info['test-accuracy'])
|
||||
if dataset == 'cifar10':
|
||||
info = api.get_more_info(index, 'cifar10-valid', hp='200')
|
||||
valid_accs.append(info['valid-accuracy'])
|
||||
else:
|
||||
valid_accs.append(info['valid-accuracy'])
|
||||
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs}
|
||||
torch.save(info, cache_file_path)
|
||||
else:
|
||||
print ('Find cache file : {:}'.format(cache_file_path))
|
||||
info = torch.load(cache_file_path)
|
||||
params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs']
|
||||
print ('{:} collect data done.'.format(time_string()))
|
||||
|
||||
resnet = ['|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|']
|
||||
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
|
||||
largest_indexes = [api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|')]
|
||||
|
||||
indexes = list(range(len(params)))
|
||||
dpi, width, height = 250, 8500, 1300
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
LabelSize, LegendFontsize = 24, 24
|
||||
# resnet_scale, resnet_alpha = 120, 0.5
|
||||
xscale, xalpha = 120, 0.8
|
||||
|
||||
fig, axs = plt.subplots(1, 4, figsize=figsize)
|
||||
# ax1, ax2, ax3, ax4, ax5 = axs
|
||||
for ax in axs:
|
||||
for tick in ax.xaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(LabelSize)
|
||||
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))
|
||||
for tick in ax.yaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(LabelSize)
|
||||
ax2, ax3, ax4, ax5 = axs
|
||||
# ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5))
|
||||
# ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||
# ax1.set_xlabel('architecture ID', fontsize=LabelSize)
|
||||
# ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize)
|
||||
|
||||
ax2.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax2.scatter([params[x] for x in resnet_indexes] , [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
|
||||
ax2.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
|
||||
ax2.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||
ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
|
||||
ax2.legend(loc=4, fontsize=LegendFontsize)
|
||||
|
||||
ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax3.scatter([params[x] for x in resnet_indexes] , [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
|
||||
ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
|
||||
ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||
ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize)
|
||||
ax3.legend(loc=4, fontsize=LegendFontsize)
|
||||
|
||||
ax4.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax4.scatter([flops[x] for x in resnet_indexes], [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
|
||||
ax4.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
|
||||
ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
|
||||
ax4.set_ylabel('train accuracy (%)', fontsize=LabelSize)
|
||||
ax4.legend(loc=4, fontsize=LegendFontsize)
|
||||
|
||||
ax5.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax5.scatter([flops[x] for x in resnet_indexes], [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
|
||||
ax5.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
|
||||
ax5.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
|
||||
ax5.set_ylabel('test accuracy (%)', fontsize=LabelSize)
|
||||
ax5.legend(loc=4, fontsize=LegendFontsize)
|
||||
|
||||
save_path = vis_save_dir / 'tss-{:}.png'.format(dataset)
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
plt.close('all')
|
||||
|
||||
|
||||
def test_issue_81_82(api):
|
||||
results = api.query_by_index(0, 'cifar10')
|
||||
results = api.query_by_index(0, 'cifar10-valid', hp='200')
|
||||
print(results.keys())
|
||||
print(results[888].get_eval('x-valid'))
|
||||
result_dict = api.get_more_info(index=0, dataset='cifar10-valid', iepoch=11, hp='200', is_random=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--check_N', type=int, default=32768, help='For safety.')
|
||||
# use for train the model
|
||||
args = parser.parse_args()
|
||||
|
||||
api201 = NASBench201API(os.path.join(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_0-e61699.pth'), verbose=True)
|
||||
test_issue_81_82(api201)
|
||||
test_api(api201, False)
|
||||
api201 = NASBench201API(None, verbose=True)
|
||||
test_issue_81_82(api201)
|
||||
visualize_tss_info(api201, 'cifar10', Path('output/vis-nas-bench'))
|
||||
visualize_tss_info(api201, 'cifar100', Path('output/vis-nas-bench'))
|
||||
visualize_tss_info(api201, 'ImageNet16-120', Path('output/vis-nas-bench'))
|
||||
test_api(api201, False)
|
||||
|
||||
api301 = NASBench301API(None, verbose=True)
|
||||
visualize_sss_info(api301, 'cifar10', Path('output/vis-nas-bench'))
|
||||
visualize_sss_info(api301, 'cifar100', Path('output/vis-nas-bench'))
|
||||
visualize_sss_info(api301, 'ImageNet16-120', Path('output/vis-nas-bench'))
|
||||
test_api(api301, True)
|
||||
|
||||
# save_dir = '{:}/visual'.format(args.save_dir)
|
@@ -38,7 +38,6 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||
for idx in range(len(api)):
|
||||
# info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
|
||||
# import pdb; pdb.set_trace()
|
||||
for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']:
|
||||
info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False)
|
||||
if key == 'cifar10-valid':
|
||||
|
242
exps/NAS-Bench-201/xshape-collect.py
Normal file
242
exps/NAS-Bench-201/xshape-collect.py
Normal file
@@ -0,0 +1,242 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
# python exps/NAS-Bench-201/xshape-collect.py
|
||||
#####################################################
|
||||
import os, re, sys, time, argparse, collections
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from collections import defaultdict, OrderedDict
|
||||
from typing import Dict, Any, Text, List
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from config_utils import dict2config
|
||||
# NAS-Bench-201 related module or function
|
||||
from models import CellStructure, get_cell_based_tiny_net
|
||||
from nas_201_api import NASBench301API, ArchResults, ResultsCount
|
||||
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
|
||||
|
||||
|
||||
def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text]) -> ArchResults:
|
||||
information = ArchResults(arch_index, arch_str)
|
||||
|
||||
for checkpoint_path in checkpoints:
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
except:
|
||||
raise ValueError('This checkpoint failed to be loaded : {:}'.format(checkpoint_path))
|
||||
used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
|
||||
ok_dataset = 0
|
||||
for dataset in datasets:
|
||||
if dataset not in checkpoint:
|
||||
print('Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path))
|
||||
continue
|
||||
else:
|
||||
ok_dataset += 1
|
||||
results = checkpoint[dataset]
|
||||
assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path)
|
||||
arch_config = {'name': 'infer.shape.tiny', 'channels': arch_str, 'arch_str': arch_str,
|
||||
'genotype': results['arch_config']['genotype'],
|
||||
'class_num': results['arch_config']['num_classes']}
|
||||
xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'],
|
||||
results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None)
|
||||
xresult.update_train_info(results['train_acc1es'], results['train_acc5es'], results['train_losses'], results['train_times'])
|
||||
xresult.update_eval(results['valid_acc1es'], results['valid_losses'], results['valid_times'])
|
||||
information.update(dataset, int(used_seed), xresult)
|
||||
if ok_dataset < len(datasets): raise ValueError('{:} does find enought data : {:} vs {:}'.format(checkpoint_path, ok_dataset, len(datasets)))
|
||||
return information
|
||||
|
||||
|
||||
def correct_time_related_info(hp2info: Dict[Text, ArchResults]):
|
||||
# calibrate the latency based on the number of epochs = 01, since they are trained on the same machine.
|
||||
x1 = hp2info['01'].get_metrics('cifar10-valid', 'x-valid')['all_time'] / 98
|
||||
x2 = hp2info['01'].get_metrics('cifar10-valid', 'ori-test')['all_time'] / 40
|
||||
cifar010_latency = (x1 + x2) / 2
|
||||
for hp, arch_info in hp2info.items():
|
||||
arch_info.reset_latency('cifar10-valid', None, cifar010_latency)
|
||||
arch_info.reset_latency('cifar10', None, cifar010_latency)
|
||||
# hp2info['01'].get_latency('cifar10')
|
||||
|
||||
x1 = hp2info['01'].get_metrics('cifar100', 'ori-test')['all_time'] / 40
|
||||
x2 = hp2info['01'].get_metrics('cifar100', 'x-test')['all_time'] / 20
|
||||
x3 = hp2info['01'].get_metrics('cifar100', 'x-valid')['all_time'] / 20
|
||||
cifar100_latency = (x1 + x2 + x3) / 3
|
||||
for hp, arch_info in hp2info.items():
|
||||
arch_info.reset_latency('cifar100', None, cifar100_latency)
|
||||
|
||||
x1 = hp2info['01'].get_metrics('ImageNet16-120', 'ori-test')['all_time'] / 24
|
||||
x2 = hp2info['01'].get_metrics('ImageNet16-120', 'x-test')['all_time'] / 12
|
||||
x3 = hp2info['01'].get_metrics('ImageNet16-120', 'x-valid')['all_time'] / 12
|
||||
image_latency = (x1 + x2 + x3) / 3
|
||||
for hp, arch_info in hp2info.items():
|
||||
arch_info.reset_latency('ImageNet16-120', None, image_latency)
|
||||
|
||||
# CIFAR10 VALID
|
||||
train_per_epoch_time = list(hp2info['01'].query('cifar10-valid', 777).train_times.values())
|
||||
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
|
||||
eval_ori_test_time, eval_x_valid_time = [], []
|
||||
for key, value in hp2info['01'].query('cifar10-valid', 777).eval_times.items():
|
||||
if key.startswith('ori-test@'):
|
||||
eval_ori_test_time.append(value)
|
||||
elif key.startswith('x-valid@'):
|
||||
eval_x_valid_time.append(value)
|
||||
else: raise ValueError('-- {:} --'.format(key))
|
||||
eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time)
|
||||
eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time)
|
||||
for hp, arch_info in hp2info.items():
|
||||
arch_info.reset_pseudo_train_times('cifar10-valid', None, train_per_epoch_time)
|
||||
arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'x-valid', eval_x_valid_time)
|
||||
arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'ori-test', eval_ori_test_time)
|
||||
|
||||
# CIFAR10
|
||||
train_per_epoch_time = list(hp2info['01'].query('cifar10', 777).train_times.values())
|
||||
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
|
||||
eval_ori_test_time = []
|
||||
for key, value in hp2info['01'].query('cifar10', 777).eval_times.items():
|
||||
if key.startswith('ori-test@'):
|
||||
eval_ori_test_time.append(value)
|
||||
else: raise ValueError('-- {:} --'.format(key))
|
||||
eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time)
|
||||
for hp, arch_info in hp2info.items():
|
||||
arch_info.reset_pseudo_train_times('cifar10', None, train_per_epoch_time)
|
||||
arch_info.reset_pseudo_eval_times('cifar10', None, 'ori-test', eval_ori_test_time)
|
||||
|
||||
# CIFAR100
|
||||
train_per_epoch_time = list(hp2info['01'].query('cifar100', 777).train_times.values())
|
||||
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
|
||||
eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], []
|
||||
for key, value in hp2info['01'].query('cifar100', 777).eval_times.items():
|
||||
if key.startswith('ori-test@'):
|
||||
eval_ori_test_time.append(value)
|
||||
elif key.startswith('x-valid@'):
|
||||
eval_x_valid_time.append(value)
|
||||
elif key.startswith('x-test@'):
|
||||
eval_x_test_time.append(value)
|
||||
else: raise ValueError('-- {:} --'.format(key))
|
||||
eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time)
|
||||
eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time)
|
||||
eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time)
|
||||
for hp, arch_info in hp2info.items():
|
||||
arch_info.reset_pseudo_train_times('cifar100', None, train_per_epoch_time)
|
||||
arch_info.reset_pseudo_eval_times('cifar100', None, 'x-valid', eval_x_valid_time)
|
||||
arch_info.reset_pseudo_eval_times('cifar100', None, 'x-test', eval_x_test_time)
|
||||
arch_info.reset_pseudo_eval_times('cifar100', None, 'ori-test', eval_ori_test_time)
|
||||
|
||||
# ImageNet16-120
|
||||
train_per_epoch_time = list(hp2info['01'].query('ImageNet16-120', 777).train_times.values())
|
||||
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
|
||||
eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], []
|
||||
for key, value in hp2info['01'].query('ImageNet16-120', 777).eval_times.items():
|
||||
if key.startswith('ori-test@'):
|
||||
eval_ori_test_time.append(value)
|
||||
elif key.startswith('x-valid@'):
|
||||
eval_x_valid_time.append(value)
|
||||
elif key.startswith('x-test@'):
|
||||
eval_x_test_time.append(value)
|
||||
else: raise ValueError('-- {:} --'.format(key))
|
||||
eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time)
|
||||
eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time)
|
||||
eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time)
|
||||
for hp, arch_info in hp2info.items():
|
||||
arch_info.reset_pseudo_train_times('ImageNet16-120', None, train_per_epoch_time)
|
||||
arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-valid', eval_x_valid_time)
|
||||
arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-test', eval_x_test_time)
|
||||
arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'ori-test', eval_ori_test_time)
|
||||
return hp2info
|
||||
|
||||
|
||||
def simplify(save_dir, save_name, nets, total):
|
||||
|
||||
hps, seeds = ['01', '12', '90'], set()
|
||||
for hp in hps:
|
||||
sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
|
||||
ckps = sorted(list(sub_save_dir.glob('arch-*-seed-*.pth')))
|
||||
seed2names = defaultdict(list)
|
||||
for ckp in ckps:
|
||||
parts = re.split('-|\.', ckp.name)
|
||||
seed2names[parts[3]].append(ckp.name)
|
||||
print('DIR : {:}'.format(sub_save_dir))
|
||||
nums = []
|
||||
for seed, xlist in seed2names.items():
|
||||
seeds.add(seed)
|
||||
nums.append(len(xlist))
|
||||
print(' seed={:}, num={:}'.format(seed, len(xlist)))
|
||||
# assert len(nets) == total == max(nums), 'there are some missed files : {:} vs {:}'.format(max(nums), total)
|
||||
print('{:} start simplify the checkpoint.'.format(time_string()))
|
||||
|
||||
datasets = ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120')
|
||||
|
||||
simplify_save_dir, arch2infos, evaluated_indexes = save_dir / save_name, {}, set()
|
||||
simplify_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
end_time, arch_time = time.time(), AverageMeter()
|
||||
# for index, arch_str in enumerate(nets):
|
||||
for index in tqdm(range(total)):
|
||||
arch_str = nets[index]
|
||||
hp2info = OrderedDict()
|
||||
for hp in hps:
|
||||
sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
|
||||
ckps = [sub_save_dir / 'arch-{:06d}-seed-{:}.pth'.format(index, seed) for seed in seeds]
|
||||
ckps = [x for x in ckps if x.exists()]
|
||||
if len(ckps) == 0: raise ValueError('Invalid data : index={:}, hp={:}'.format(index, hp))
|
||||
|
||||
arch_info = account_one_arch(index, arch_str, ckps, datasets)
|
||||
hp2info[hp] = arch_info
|
||||
|
||||
hp2info = correct_time_related_info(hp2info)
|
||||
evaluated_indexes.add(index)
|
||||
|
||||
to_save_data = OrderedDict({'01': hp2info['01'].state_dict(),
|
||||
'12': hp2info['12'].state_dict(),
|
||||
'90': hp2info['90'].state_dict()})
|
||||
torch.save(to_save_data, simplify_save_dir / '{:}-FULL.pth'.format(index))
|
||||
|
||||
for hp in hps: hp2info[hp].clear_params()
|
||||
to_save_data = OrderedDict({'01': hp2info['01'].state_dict(),
|
||||
'12': hp2info['12'].state_dict(),
|
||||
'90': hp2info['90'].state_dict()})
|
||||
torch.save(to_save_data, simplify_save_dir / '{:}-SIMPLE.pth'.format(index))
|
||||
arch2infos[index] = to_save_data
|
||||
# measure elapsed time
|
||||
arch_time.update(time.time() - end_time)
|
||||
end_time = time.time()
|
||||
need_time = '{:}'.format(convert_secs2time(arch_time.avg * (total-index-1), True))
|
||||
# print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
|
||||
print('{:} {:} done.'.format(time_string(), save_name))
|
||||
final_infos = {'meta_archs' : nets,
|
||||
'total_archs': total,
|
||||
'arch2infos' : arch2infos,
|
||||
'evaluated_indexes': evaluated_indexes}
|
||||
save_file_name = save_dir / '{:}.pth'.format(save_name)
|
||||
torch.save(final_infos, save_file_name)
|
||||
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), total, save_file_name))
|
||||
|
||||
|
||||
def traverse_net(candidates: List[int], N: int):
|
||||
nets = ['']
|
||||
for i in range(N):
|
||||
new_nets = []
|
||||
for net in nets:
|
||||
for C in candidates:
|
||||
new_nets.append(str(C) if net == '' else "{:}:{:}".format(net,C))
|
||||
nets = new_nets
|
||||
return nets
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-202', help='The base-name of folder to save checkpoints and log.')
|
||||
parser.add_argument('--candidateC' , type=int, nargs='+', default=[8, 16, 24, 32, 40, 48, 56, 64], help='.')
|
||||
parser.add_argument('--num_layers' , type=int, default=5, help='The number of layers in a network.')
|
||||
parser.add_argument('--check_N' , type=int, default=32768, help='For safety.')
|
||||
parser.add_argument('--save_name' , type=str, default='simplify', help='The save directory.')
|
||||
args = parser.parse_args()
|
||||
|
||||
nets = traverse_net(args.candidateC, args.num_layers)
|
||||
if len(nets) != args.check_N: raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N))
|
||||
|
||||
save_dir = Path(args.base_save_dir)
|
||||
simplify(save_dir, args.save_name, nets, args.check_N)
|
@@ -22,7 +22,7 @@ from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||
|
||||
|
||||
def obtain_valid_ckp(save_dir: Text, total: int):
|
||||
possible_seeds = [777, 888]
|
||||
possible_seeds = [777, 888, 999]
|
||||
seed2ckps = defaultdict(list)
|
||||
miss2ckps = defaultdict(list)
|
||||
for i in range(total):
|
||||
@@ -33,7 +33,7 @@ def obtain_valid_ckp(save_dir: Text, total: int):
|
||||
else:
|
||||
miss2ckps[seed].append(i)
|
||||
for seed, xlist in seed2ckps.items():
|
||||
print('[{:}] [seed={:}] has {:}/{:}'.format(save_dir, seed, len(xlist), total))
|
||||
print('[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}'.format(save_dir, seed, len(xlist), total, total-len(xlist), total))
|
||||
return dict(seed2ckps), dict(miss2ckps)
|
||||
|
||||
|
||||
|
@@ -65,7 +65,7 @@ class MyWorker(Worker):
|
||||
assert len(self.seen_archs) > 0
|
||||
best_index, best_acc = -1, None
|
||||
for arch_index in self.seen_archs:
|
||||
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True)
|
||||
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True)
|
||||
vacc = info['valid-accuracy']
|
||||
if best_acc is None or best_acc < vacc:
|
||||
best_acc = vacc
|
||||
@@ -77,7 +77,7 @@ class MyWorker(Worker):
|
||||
start_time = time.time()
|
||||
structure = self.convert_func( config )
|
||||
arch_index = self._nas_bench.query_index_by_arch( structure )
|
||||
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True)
|
||||
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True)
|
||||
cur_time = info['train-all-time'] + info['valid-per-time']
|
||||
cur_vacc = info['valid-accuracy']
|
||||
self.real_cost_time += (time.time() - start_time)
|
||||
|
@@ -42,7 +42,7 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_01
|
||||
if use_012_epoch_training and nas_bench is not None:
|
||||
arch_index = nas_bench.query_index_by_arch( arch )
|
||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||
info = nas_bench.get_more_info(arch_index, dataname, None, True)
|
||||
info = nas_bench.get_more_info(arch_index, dataname, iepoch=None, hp='12', is_random=True)
|
||||
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
||||
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
|
||||
elif not use_012_epoch_training and nas_bench is not None:
|
||||
@@ -51,10 +51,10 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_01
|
||||
# It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
|
||||
arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25
|
||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
||||
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', False)
|
||||
info = nas_bench.get_more_info(arch_index, dataname, nepoch, False, True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
|
||||
cost = nas_bench.get_cost_info(arch_index, dataname, False)
|
||||
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch=None, hp='12')
|
||||
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', hp='200')
|
||||
info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
|
||||
cost = nas_bench.get_cost_info(arch_index, dataname, hp='200')
|
||||
# The following codes are used to estimate the time cost.
|
||||
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
|
||||
# When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared.
|
||||
|
20
exps/experimental/test-api.py
Normal file
20
exps/experimental/test-api.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#
|
||||
# exps/experimental/test-api.py
|
||||
#
|
||||
import sys, time, random, argparse
|
||||
from copy import deepcopy
|
||||
import torchvision.models as models
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from nas_201_api import NASBench201API as API
|
||||
|
||||
|
||||
def main():
|
||||
api = API(None)
|
||||
info = api.get_more_info(100, 'cifar100', 199, False, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user