update code styles
This commit is contained in:
@@ -8,7 +8,6 @@ from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import matplotlib
|
||||
@@ -498,6 +497,8 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
|
||||
|
||||
def get_accs(xdata):
|
||||
epochs, xresults = xdata['epoch'], []
|
||||
metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
|
||||
xresults.append( metrics['accuracy'] )
|
||||
for iepoch in range(epochs):
|
||||
genotype = xdata['genotypes'][iepoch]
|
||||
index = api.query_index_by_arch(genotype)
|
||||
@@ -547,7 +548,6 @@ if __name__ == '__main__':
|
||||
#visualize_relative_ranking(vis_save_dir)
|
||||
|
||||
api = API(args.api_path)
|
||||
"""
|
||||
for x_maxs in [50, 250]:
|
||||
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
@@ -555,11 +555,12 @@ if __name__ == '__main__':
|
||||
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
just_show(api)
|
||||
"""
|
||||
just_show(api)
|
||||
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||
"""
|
||||
|
Reference in New Issue
Block a user