Update visualization codes for NATS-Bench
This commit is contained in:
@@ -33,7 +33,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
|
||||
alg2name['REA'] = 'R-EA-SS3'
|
||||
alg2name['REINFORCE'] = 'REINFORCE-0.01'
|
||||
alg2name['RANDOM'] = 'RANDOM'
|
||||
# alg2name['BOHB'] = 'BOHB'
|
||||
alg2name['BOHB'] = 'BOHB'
|
||||
for alg, name in alg2name.items():
|
||||
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth')
|
||||
assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg])
|
||||
@@ -59,7 +59,26 @@ def query_performance(api, data, dataset, ticket):
|
||||
accuracy_a, accuracy_b = info_a['test-accuracy'], info_b['test-accuracy']
|
||||
interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b
|
||||
results.append(interplate)
|
||||
return sum(results) / len(results)
|
||||
# return sum(results) / len(results)
|
||||
return np.mean(results), np.std(results)
|
||||
|
||||
|
||||
def show_valid_test(api, data, dataset):
|
||||
valid_accs, test_accs, is_size_space = [], [], api.search_space_name == 'size'
|
||||
for i, info in data.items():
|
||||
time, arch = info['time_w_arch'][-1]
|
||||
if dataset == 'cifar10':
|
||||
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
|
||||
test_accs.append(xinfo['test-accuracy'])
|
||||
xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False)
|
||||
valid_accs.append(xinfo['valid-accuracy'])
|
||||
else:
|
||||
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
|
||||
valid_accs.append(xinfo['valid-accuracy'])
|
||||
test_accs.append(xinfo['test-accuracy'])
|
||||
valid_str = '{:.2f}$\pm${:.2f}'.format(np.mean(valid_accs), np.std(valid_accs))
|
||||
test_str = '{:.2f}$\pm${:.2f}'.format(np.mean(test_accs), np.std(test_accs))
|
||||
return valid_str, test_str
|
||||
|
||||
|
||||
y_min_s = {('cifar10', 'tss'): 90,
|
||||
@@ -69,11 +88,11 @@ y_min_s = {('cifar10', 'tss'): 90,
|
||||
('ImageNet16-120', 'tss'): 36,
|
||||
('ImageNet16-120', 'sss'): 40}
|
||||
|
||||
y_max_s = {('cifar10', 'tss'): 94.5,
|
||||
y_max_s = {('cifar10', 'tss'): 94.3,
|
||||
('cifar10', 'sss'): 93.3,
|
||||
('cifar100', 'tss'): 72,
|
||||
('cifar100', 'sss'): 70,
|
||||
('ImageNet16-120', 'tss'): 44,
|
||||
('cifar100', 'tss'): 72.5,
|
||||
('cifar100', 'sss'): 70.5,
|
||||
('ImageNet16-120', 'tss'): 46,
|
||||
('ImageNet16-120', 'sss'): 46}
|
||||
|
||||
x_axis_s = {('cifar10', 'tss'): 200,
|
||||
@@ -87,6 +106,7 @@ name2label = {'cifar10': 'CIFAR-10',
|
||||
'cifar100': 'CIFAR-100',
|
||||
'ImageNet16-120': 'ImageNet-16-120'}
|
||||
|
||||
|
||||
def visualize_curve(api, vis_save_dir, search_space):
|
||||
vis_save_dir = vis_save_dir.resolve()
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -106,11 +126,13 @@ def visualize_curve(api, vis_save_dir, search_space):
|
||||
ax.set_ylim(y_min_s[(xdataset, search_space)],
|
||||
y_max_s[(xdataset, search_space)])
|
||||
for idx, (alg, data) in enumerate(alg2data.items()):
|
||||
print('{:} plot alg : {:}'.format(time_string(), alg))
|
||||
accuracies = []
|
||||
for ticket in time_tickets:
|
||||
accuracy = query_performance(api, data, xdataset, ticket)
|
||||
accuracy, accuracy_std = query_performance(api, data, xdataset, ticket)
|
||||
accuracies.append(accuracy)
|
||||
valid_str, test_str = show_valid_test(api, data, xdataset)
|
||||
# print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std))
|
||||
print('{:} plot alg : {:10s} | validation = {:} | test = {:}'.format(time_string(), alg, valid_str, test_str))
|
||||
alg2accuracies[alg] = accuracies
|
||||
ax.plot([x/100 for x in time_tickets], accuracies, c=colors[idx], label='{:}'.format(alg))
|
||||
ax.set_xlabel('Estimated wall-clock time (1e2 seconds)', fontsize=LabelSize)
|
||||
|
Reference in New Issue
Block a user