Update Warmup

This commit is contained in:
D-X-Y
2020-10-08 10:19:34 +11:00
parent ad5d6e28b9
commit ab801cbf14
7 changed files with 90 additions and 43 deletions

View File

@@ -26,7 +26,8 @@ from nats_bench import create
from log_utils import time_string
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-AWD0.0-WARMNone'):
# def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'):
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARM0.3'):
ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777, 888, 999]
@@ -39,9 +40,12 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf
alg2name['ENAS'] = 'enas-affine0_BN0-None'
alg2name['SETN'] = 'setn-affine0_BN0-None'
else:
alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
# alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
alg2name['channel-wise interpaltion'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix)
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
alg2data = OrderedDict()
@@ -98,8 +102,11 @@ def visualize_curve(api, vis_save_dir, search_space):
for idx, (alg, data) in enumerate(alg2data.items()):
print('plot alg : {:}'.format(alg))
xs, accuracies = [], []
for iepoch in range(epochs+1):
structures, accs = [_[iepoch-1] for _ in data], []
for iepoch in range(epochs + 1):
try:
structures, accs = [_[iepoch-1] for _ in data], []
except:
raise ValueError('This alg {:} on {:} has invalid checkpoints.'.format(alg, dataset))
for structure in structures:
info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False)
accs.append(info['test-accuracy'])
@@ -131,5 +138,5 @@ if __name__ == '__main__':
save_dir = Path(args.save_dir)
api = create(None, args.search_space, verbose=False)
api = create(None, args.search_space, fast_mode=True, verbose=False)
visualize_curve(api, save_dir, args.search_space)