Update visualization codes
This commit is contained in:
@@ -94,11 +94,11 @@ def visualize_sss_info(api, dataset, vis_save_dir):
|
||||
params.append(info['params'])
|
||||
flops.append(info['flops'])
|
||||
# accuracy
|
||||
info = api.get_more_info(index, dataset, hp='90')
|
||||
info = api.get_more_info(index, dataset, hp='90', is_random=False)
|
||||
train_accs.append(info['train-accuracy'])
|
||||
test_accs.append(info['test-accuracy'])
|
||||
if dataset == 'cifar10':
|
||||
info = api.get_more_info(index, 'cifar10-valid', hp='90')
|
||||
info = api.get_more_info(index, 'cifar10-valid', hp='90', is_random=False)
|
||||
valid_accs.append(info['valid-accuracy'])
|
||||
else:
|
||||
valid_accs.append(info['valid-accuracy'])
|
||||
@@ -182,11 +182,11 @@ def visualize_tss_info(api, dataset, vis_save_dir):
|
||||
params.append(info['params'])
|
||||
flops.append(info['flops'])
|
||||
# accuracy
|
||||
info = api.get_more_info(index, dataset, hp='200')
|
||||
info = api.get_more_info(index, dataset, hp='200', is_random=False)
|
||||
train_accs.append(info['train-accuracy'])
|
||||
test_accs.append(info['test-accuracy'])
|
||||
if dataset == 'cifar10':
|
||||
info = api.get_more_info(index, 'cifar10-valid', hp='200')
|
||||
info = api.get_more_info(index, 'cifar10-valid', hp='200', is_random=False)
|
||||
valid_accs.append(info['valid-accuracy'])
|
||||
else:
|
||||
valid_accs.append(info['valid-accuracy'])
|
||||
@@ -319,6 +319,68 @@ def visualize_rank_info(api, vis_save_dir, indicator):
|
||||
plt.close('all')
|
||||
|
||||
|
||||
def calculate_correlation(*vectors):
|
||||
matrix = []
|
||||
for i, vectori in enumerate(vectors):
|
||||
x = []
|
||||
for j, vectorj in enumerate(vectors):
|
||||
x.append( np.corrcoef(vectori, vectorj)[0,1] )
|
||||
matrix.append( x )
|
||||
return np.array(matrix)
|
||||
|
||||
|
||||
def visualize_all_rank_info(api, vis_save_dir, indicator):
|
||||
vis_save_dir = vis_save_dir.resolve()
|
||||
# print ('{:} start to visualize {:} information'.format(time_string(), api))
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
|
||||
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
|
||||
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
|
||||
cifar010_info = torch.load(cifar010_cache_path)
|
||||
cifar100_info = torch.load(cifar100_cache_path)
|
||||
imagenet_info = torch.load(imagenet_cache_path)
|
||||
indexes = list(range(len(cifar010_info['params'])))
|
||||
|
||||
print ('{:} start to visualize relative ranking'.format(time_string()))
|
||||
|
||||
|
||||
dpi, width, height = 250, 3200, 1400
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
LabelSize, LegendFontsize = 14, 14
|
||||
|
||||
fig, axs = plt.subplots(1, 2, figsize=figsize)
|
||||
ax1, ax2 = axs
|
||||
|
||||
sns_size = 15
|
||||
CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs'])
|
||||
|
||||
sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5, ax=ax1,
|
||||
xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'],
|
||||
yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'])
|
||||
|
||||
selected_indexes, acc_bar = [], 92
|
||||
for i, acc in enumerate(cifar010_info['test_accs']):
|
||||
if acc > acc_bar: selected_indexes.append( i )
|
||||
cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ]
|
||||
cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ]
|
||||
cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ]
|
||||
cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ]
|
||||
imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ]
|
||||
imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ]
|
||||
CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs)
|
||||
|
||||
sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5, ax=ax2,
|
||||
xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'],
|
||||
yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'])
|
||||
ax1.set_title('Correlation coefficient over ALL candidates')
|
||||
ax2.set_title('Correlation coefficient over candidates with accuracy > {:}%'.format(acc_bar))
|
||||
save_path = (vis_save_dir / '{:}-all-relative-rank.png'.format(indicator)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
plt.close('all')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.')
|
||||
@@ -326,20 +388,19 @@ if __name__ == '__main__':
|
||||
# use for train the model
|
||||
args = parser.parse_args()
|
||||
|
||||
visualize_rank_info(None, Path('output/vis-nas-bench/'), 'tss')
|
||||
visualize_rank_info(None, Path('output/vis-nas-bench/'), 'sss')
|
||||
|
||||
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
|
||||
api201 = NASBench201API(None, verbose=True)
|
||||
visualize_tss_info(api201, 'cifar10', Path('output/vis-nas-bench'))
|
||||
visualize_tss_info(api201, 'cifar100', Path('output/vis-nas-bench'))
|
||||
visualize_tss_info(api201, 'ImageNet16-120', Path('output/vis-nas-bench'))
|
||||
for xdata in datasets:
|
||||
visualize_tss_info(api201, xdata, Path('output/vis-nas-bench'))
|
||||
|
||||
api301 = NASBench301API(None, verbose=True)
|
||||
visualize_sss_info(api301, 'cifar10', Path('output/vis-nas-bench'))
|
||||
visualize_sss_info(api301, 'cifar100', Path('output/vis-nas-bench'))
|
||||
visualize_sss_info(api301, 'ImageNet16-120', Path('output/vis-nas-bench'))
|
||||
for xdata in datasets:
|
||||
visualize_sss_info(api301, xdata, Path('output/vis-nas-bench'))
|
||||
|
||||
visualize_info(None, Path('output/vis-nas-bench/'), 'tss')
|
||||
visualize_info(None, Path('output/vis-nas-bench/'), 'sss')
|
||||
visualize_rank_info(None, Path('output/vis-nas-bench/'), 'tss')
|
||||
visualize_rank_info(None, Path('output/vis-nas-bench/'), 'sss')
|
||||
|
||||
visualize_all_rank_info(None, Path('output/vis-nas-bench/'), 'tss')
|
||||
visualize_all_rank_info(None, Path('output/vis-nas-bench/'), 'sss')
|
||||
|
Reference in New Issue
Block a user