Update visualization codes for NATS-Bench

This commit is contained in:
D-X-Y
2020-11-30 00:48:10 +08:00
parent 550d24ec07
commit 29428bf5a3
6 changed files with 802 additions and 10 deletions

View File

@@ -33,7 +33,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
alg2name['REA'] = 'R-EA-SS3'
alg2name['REINFORCE'] = 'REINFORCE-0.01'
alg2name['RANDOM'] = 'RANDOM'
# alg2name['BOHB'] = 'BOHB'
alg2name['BOHB'] = 'BOHB'
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth')
assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg])
@@ -59,7 +59,26 @@ def query_performance(api, data, dataset, ticket):
accuracy_a, accuracy_b = info_a['test-accuracy'], info_b['test-accuracy']
interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b
results.append(interplate)
return sum(results) / len(results)
# return sum(results) / len(results)
return np.mean(results), np.std(results)
def show_valid_test(api, data, dataset):
valid_accs, test_accs, is_size_space = [], [], api.search_space_name == 'size'
for i, info in data.items():
time, arch = info['time_w_arch'][-1]
if dataset == 'cifar10':
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
test_accs.append(xinfo['test-accuracy'])
xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False)
valid_accs.append(xinfo['valid-accuracy'])
else:
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
valid_accs.append(xinfo['valid-accuracy'])
test_accs.append(xinfo['test-accuracy'])
valid_str = '{:.2f}$\pm${:.2f}'.format(np.mean(valid_accs), np.std(valid_accs))
test_str = '{:.2f}$\pm${:.2f}'.format(np.mean(test_accs), np.std(test_accs))
return valid_str, test_str
y_min_s = {('cifar10', 'tss'): 90,
@@ -69,11 +88,11 @@ y_min_s = {('cifar10', 'tss'): 90,
('ImageNet16-120', 'tss'): 36,
('ImageNet16-120', 'sss'): 40}
y_max_s = {('cifar10', 'tss'): 94.5,
y_max_s = {('cifar10', 'tss'): 94.3,
('cifar10', 'sss'): 93.3,
('cifar100', 'tss'): 72,
('cifar100', 'sss'): 70,
('ImageNet16-120', 'tss'): 44,
('cifar100', 'tss'): 72.5,
('cifar100', 'sss'): 70.5,
('ImageNet16-120', 'tss'): 46,
('ImageNet16-120', 'sss'): 46}
x_axis_s = {('cifar10', 'tss'): 200,
@@ -87,6 +106,7 @@ name2label = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet-16-120'}
def visualize_curve(api, vis_save_dir, search_space):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
@@ -106,11 +126,13 @@ def visualize_curve(api, vis_save_dir, search_space):
ax.set_ylim(y_min_s[(xdataset, search_space)],
y_max_s[(xdataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()):
print('{:} plot alg : {:}'.format(time_string(), alg))
accuracies = []
for ticket in time_tickets:
accuracy = query_performance(api, data, xdataset, ticket)
accuracy, accuracy_std = query_performance(api, data, xdataset, ticket)
accuracies.append(accuracy)
valid_str, test_str = show_valid_test(api, data, xdataset)
# print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std))
print('{:} plot alg : {:10s} | validation = {:} | test = {:}'.format(time_string(), alg, valid_str, test_str))
alg2accuracies[alg] = accuracies
ax.plot([x/100 for x in time_tickets], accuracies, c=colors[idx], label='{:}'.format(alg))
ax.set_xlabel('Estimated wall-clock time (1e2 seconds)', fontsize=LabelSize)