Update Warmup

This commit is contained in:
D-X-Y
2020-10-08 10:19:34 +11:00
parent ad5d6e28b9
commit ab801cbf14
7 changed files with 90 additions and 43 deletions

View File

@@ -2,7 +2,13 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
######################################################################################
# In this file, we aims to evaluate three kinds of channel searching strategies:
# -
# - channel-wise interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
# For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links:
# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md
# - FBV2: https://github.com/facebookresearch/mobile-vision
# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas
####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25
####

View File

@@ -26,7 +26,8 @@ from nats_bench import create
from log_utils import time_string
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-AWD0.0-WARMNone'):
# def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'):
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARM0.3'):
ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777, 888, 999]
@@ -39,9 +40,12 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf
alg2name['ENAS'] = 'enas-affine0_BN0-None'
alg2name['SETN'] = 'setn-affine0_BN0-None'
else:
alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
# alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
alg2name['channel-wise interpaltion'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix)
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
alg2data = OrderedDict()
@@ -98,8 +102,11 @@ def visualize_curve(api, vis_save_dir, search_space):
for idx, (alg, data) in enumerate(alg2data.items()):
print('plot alg : {:}'.format(alg))
xs, accuracies = [], []
for iepoch in range(epochs+1):
structures, accs = [_[iepoch-1] for _ in data], []
for iepoch in range(epochs + 1):
try:
structures, accs = [_[iepoch-1] for _ in data], []
except:
raise ValueError('This alg {:} on {:} has invalid checkpoints.'.format(alg, dataset))
for structure in structures:
info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False)
accs.append(info['test-accuracy'])
@@ -131,5 +138,5 @@ if __name__ == '__main__':
save_dir = Path(args.save_dir)
api = create(None, args.search_space, verbose=False)
api = create(None, args.search_space, fast_mode=True, verbose=False)
visualize_curve(api, save_dir, args.search_space)