Update REA, REINFORCE, and RANDOM

This commit is contained in:
D-X-Y
2020-07-13 11:35:13 +00:00
parent 6dc494be08
commit ebad9197f7
5 changed files with 38 additions and 26 deletions

View File

@@ -3,7 +3,7 @@
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/experimental/vis-bench-algos.py
# Usage: python exps/experimental/vis-bench-algos.py #
###############################################################
import os, sys, time, torch, argparse
import numpy as np
@@ -30,6 +30,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
alg2name, alg2path = OrderedDict(), OrderedDict()
alg2name['REA'] = 'R-EA-SS3'
alg2name['REINFORCE'] = 'REINFORCE-0.001'
# alg2name['RANDOM'] = 'RANDOM'
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth')
assert os.path.isfile(alg2path[alg])
@@ -62,14 +63,15 @@ def visualize_curve(api, vis_save_dir, search_space, max_time):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 4700, 1500
dpi, width, height = 250, 5100, 1500
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
def sub_plot_fn(ax, dataset):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
time_tickets = [float(i) / 100 * max_time for i in range(100)]
total_tickets = 150
time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)]
colors = ['b', 'g', 'c', 'm', 'y']
for idx, (alg, data) in enumerate(alg2data.items()):
print('plot alg : {:}'.format(alg))
@@ -78,7 +80,10 @@ def visualize_curve(api, vis_save_dir, search_space, max_time):
accuracy = query_performance(api, data, dataset, ticket)
accuracies.append(accuracy)
alg2accuracies[alg] = accuracies
ax.plot(time_tickets, accuracies, c=colors[idx], label='{:}'.format(alg))
ax.plot([x/100 for x in time_tickets], accuracies, c=colors[idx], label='{:}'.format(alg))
ax.set_xlabel('Estimated wall-clock time (1e2 seconds)', fontsize=LabelSize)
ax.set_ylabel('Test accuracy on {:}'.format(dataset), fontsize=LabelSize)
ax.set_title('Searching results on {:}'.format(dataset), fontsize=LabelSize+4)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
@@ -104,4 +109,3 @@ if __name__ == '__main__':
visualize_curve(api201, save_dir, 'tss', args.max_time)
api301 = NASBench301API(verbose=False)
visualize_curve(api301, save_dir, 'sss', args.max_time)