update vis

This commit is contained in:
D-X-Y
2020-01-01 22:18:42 +11:00
parent 9ec25663f1
commit 28e4b8406f
12 changed files with 153 additions and 40 deletions

View File

@@ -6,6 +6,7 @@
import os, sys, time, glob, random, argparse
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import torch
import torch.nn as nn
from pathlib import Path
@@ -142,15 +143,17 @@ def check_unique_arch(meta_file):
print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num))
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True):
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False):
if isinstance(meta_file, API):
api = meta_file
else:
api = API(str(meta_file))
cifar10_valid = []
cifar10_test = []
cifar100_valid = []
cifar100_test = []
imagenet_test = []
imagenet_valid = []
for idx, arch in enumerate(api):
results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand)
cifar10_valid.append( results['valid-accuracy'] )
@@ -158,14 +161,16 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True):
cifar10_test.append( results['test-accuracy'] )
results = api.get_more_info(idx, 'cifar100' , None, False, is_rand)
cifar100_test.append( results['test-accuracy'] )
cifar100_valid.append( results['valid-accuracy'] )
results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand)
imagenet_test.append( results['test-accuracy'] )
imagenet_valid.append( results['valid-accuracy'] )
def get_cor(A, B):
return float(np.corrcoef(A, B)[0,1])
cors = []
for basestr, xlist in zip(['CIFAR-010', 'CIFAR-100', 'ImageNet16'], [cifar10_test,cifar100_test, imagenet_test]):
for basestr, xlist in zip(['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]):
correlation = get_cor(cifar10_valid, xlist)
print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(less_epoch, '012' if use_less_or_not else '200', basestr, correlation))
if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation))
cors.append( correlation )
#print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
#print('-'*200)
@@ -173,6 +178,19 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True):
return cors
def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
corrs = []
for i in tqdm(range(100)):
x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
corrs.append( x )
xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
correlations = np.array(corrs)
print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200'))
for idx, xstr in enumerate(xstrs):
print ('{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}'.format(xstr, correlations[:,idx].mean(), correlations[:,idx].std(), correlations[:,idx].mean(), correlations[:,idx].std()))
print('')
if __name__ == '__main__':
parser = argparse.ArgumentParser("Analysis of NAS-Bench-102")
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.')
@@ -189,5 +207,11 @@ if __name__ == '__main__':
#for iepoch in [11, 25, 50, 100, 150, 175, 200]:
# check_cor_for_bandit(api, 6, iepoch)
# check_cor_for_bandit(api, 12, iepoch)
correlations = check_cor_for_bandit(api, 6, True, True)
import pdb; pdb.set_trace()
check_cor_for_bandit_v2(api, 6, True, True)
check_cor_for_bandit_v2(api, 12, True, True)
check_cor_for_bandit_v2(api, 12, False, True)
check_cor_for_bandit_v2(api, 24, False, True)
check_cor_for_bandit_v2(api, 100, False, True)
check_cor_for_bandit_v2(api, 150, False, True)
check_cor_for_bandit_v2(api, 200, False, True)
print('----')

View File

@@ -383,4 +383,4 @@ if __name__ == '__main__':
#visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
#visualize_info(str(meta_file), 'cifar100', vis_save_dir)
#visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
visualize_relative_ranking(vis_save_dir)
#visualize_relative_ranking(vis_save_dir)