Add get_torch_home func for NATS-Bench
This commit is contained in:
96
exps/NATS-Bench/draw-ranks.py
Normal file
96
exps/NATS-Bench/draw-ranks.py
Normal file
@@ -0,0 +1,96 @@
|
||||
###############################################################
|
||||
# NATS-Bench (https://arxiv.org/pdf/2009.00437.pdf) #
|
||||
# The code to draw Figure 2 / 3 / 4 / 5 in our paper. #
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
||||
###############################################################
|
||||
# Usage: python exps/NATS-Bench/draw-ranks.py #
|
||||
###############################################################
|
||||
import os, sys, time, torch, argparse
|
||||
import scipy
|
||||
import numpy as np
|
||||
from typing import List, Text, Dict, Any
|
||||
from shutil import copyfile
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import matplotlib
|
||||
import seaborn as sns
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import dict2config, load_config
|
||||
from log_utils import time_string
|
||||
from models import get_cell_based_tiny_net
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
def visualize_relative_info(api, vis_save_dir, indicator):
|
||||
vis_save_dir = vis_save_dir.resolve()
|
||||
# print ('{:} start to visualize {:} information'.format(time_string(), api))
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
|
||||
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
|
||||
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
|
||||
cifar010_info = torch.load(cifar010_cache_path)
|
||||
cifar100_info = torch.load(cifar100_cache_path)
|
||||
imagenet_info = torch.load(imagenet_cache_path)
|
||||
indexes = list(range(len(cifar010_info['params'])))
|
||||
|
||||
print ('{:} start to visualize relative ranking'.format(time_string()))
|
||||
|
||||
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
|
||||
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
|
||||
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
|
||||
|
||||
cifar100_labels, imagenet_labels = [], []
|
||||
for idx in cifar010_ord_indexes:
|
||||
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
|
||||
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
|
||||
print ('{:} prepare data done.'.format(time_string()))
|
||||
|
||||
dpi, width, height = 200, 1400, 800
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
LabelSize, LegendFontsize = 18, 12
|
||||
resnet_scale, resnet_alpha = 120, 0.5
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xlim(min(indexes), max(indexes))
|
||||
plt.ylim(min(indexes), max(indexes))
|
||||
# plt.ylabel('y').set_rotation(30)
|
||||
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical')
|
||||
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
|
||||
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
|
||||
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
|
||||
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
|
||||
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
|
||||
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
|
||||
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
|
||||
plt.grid(zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
|
||||
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/rank-stability', help='Folder to save checkpoints and log.')
|
||||
# use for train the model
|
||||
args = parser.parse_args()
|
||||
|
||||
to_save_dir = Path(args.save_dir)
|
||||
|
||||
# Figure 2
|
||||
visualize_relative_info(None, to_save_dir, 'tss')
|
||||
visualize_relative_info(None, to_save_dir, 'sss')
|
Reference in New Issue
Block a user