Update visualization codes for NATS-Bench
This commit is contained in:
85
exps/NATS-Bench/draw-table.py
Normal file
85
exps/NATS-Bench/draw-table.py
Normal file
@@ -0,0 +1,85 @@
|
||||
###############################################################
|
||||
# NATS-Bench (https://arxiv.org/pdf/2009.00437.pdf) #
|
||||
# The code to draw some results in Table 4 in our paper. #
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
||||
###############################################################
|
||||
# Usage: python exps/NATS-Bench/draw-table.py #
|
||||
###############################################################
|
||||
import os, gc, sys, time, torch, argparse
|
||||
import numpy as np
|
||||
from typing import List, Text, Dict, Any
|
||||
from shutil import copyfile
|
||||
from collections import defaultdict, OrderedDict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import matplotlib
|
||||
import seaborn as sns
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import dict2config, load_config
|
||||
from nats_bench import create
|
||||
from log_utils import time_string
|
||||
|
||||
|
||||
def get_valid_test_acc(api, arch, dataset):
|
||||
is_size_space = api.search_space_name == 'size'
|
||||
if dataset == 'cifar10':
|
||||
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
|
||||
test_acc = xinfo['test-accuracy']
|
||||
xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False)
|
||||
valid_acc = xinfo['valid-accuracy']
|
||||
else:
|
||||
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
|
||||
valid_acc = xinfo['valid-accuracy']
|
||||
test_acc = xinfo['test-accuracy']
|
||||
return valid_acc, test_acc, 'validation = {:.2f}, test = {:.2f}\n'.format(valid_acc, test_acc)
|
||||
|
||||
|
||||
def show_valid_test(api, arch):
|
||||
is_size_space = api.search_space_name == 'size'
|
||||
final_str = ''
|
||||
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
|
||||
valid_acc, test_acc, perf_str = get_valid_test_acc(api, arch, dataset)
|
||||
final_str += '{:} : {:}\n'.format(dataset, perf_str)
|
||||
return final_str
|
||||
|
||||
|
||||
def find_best_valid(api, dataset):
|
||||
all_valid_accs, all_test_accs = [], []
|
||||
for index, arch in enumerate(api):
|
||||
# import pdb; pdb.set_trace()
|
||||
valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset)
|
||||
all_valid_accs.append((index, valid_acc))
|
||||
all_test_accs.append((index, test_acc))
|
||||
best_valid_index = sorted(all_valid_accs, key=lambda x: -x[1])[0][0]
|
||||
best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0]
|
||||
|
||||
print('-' * 50 + '{:10s}'.format(dataset) + '-' * 50)
|
||||
print('Best ({:}) architecture on validation: {:}'.format(best_valid_index, api[best_valid_index]))
|
||||
print('Best ({:}) architecture on test: {:}'.format(best_test_index, api[best_test_index]))
|
||||
_, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset)
|
||||
print('using validation ::: {:}'.format(perf_str))
|
||||
_, _, perf_str = get_valid_test_acc(api, best_test_index, dataset)
|
||||
print('using test ::: {:}'.format(perf_str))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
api_tss = create(None, 'tss', fast_mode=False, verbose=False)
|
||||
resnet = '|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|'
|
||||
resnet_index = api_tss.query_index_by_arch(resnet)
|
||||
print(show_valid_test(api_tss, resnet_index))
|
||||
|
||||
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
|
||||
find_best_valid(api_tss, dataset)
|
||||
|
||||
largest = '64:64:64:64:64'
|
||||
largest_index = api_sss.query_index_by_arch(largest)
|
||||
print(show_valid_test(api_sss, largest_index))
|
||||
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
|
||||
find_best_valid(api_sss, dataset)
|
Reference in New Issue
Block a user