Update Warmup
This commit is contained in:
@@ -26,7 +26,8 @@ from nats_bench import create
|
||||
from log_utils import time_string
|
||||
|
||||
|
||||
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-AWD0.0-WARMNone'):
|
||||
# def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'):
|
||||
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARM0.3'):
|
||||
ss_dir = '{:}-{:}'.format(root_dir, search_space)
|
||||
alg2name, alg2path = OrderedDict(), OrderedDict()
|
||||
seeds = [777, 888, 999]
|
||||
@@ -39,9 +40,12 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf
|
||||
alg2name['ENAS'] = 'enas-affine0_BN0-None'
|
||||
alg2name['SETN'] = 'setn-affine0_BN0-None'
|
||||
else:
|
||||
alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
|
||||
alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
|
||||
alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
|
||||
# alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
|
||||
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
|
||||
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
|
||||
alg2name['channel-wise interpaltion'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
|
||||
alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix)
|
||||
alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix)
|
||||
for alg, name in alg2name.items():
|
||||
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
|
||||
alg2data = OrderedDict()
|
||||
@@ -98,8 +102,11 @@ def visualize_curve(api, vis_save_dir, search_space):
|
||||
for idx, (alg, data) in enumerate(alg2data.items()):
|
||||
print('plot alg : {:}'.format(alg))
|
||||
xs, accuracies = [], []
|
||||
for iepoch in range(epochs+1):
|
||||
structures, accs = [_[iepoch-1] for _ in data], []
|
||||
for iepoch in range(epochs + 1):
|
||||
try:
|
||||
structures, accs = [_[iepoch-1] for _ in data], []
|
||||
except:
|
||||
raise ValueError('This alg {:} on {:} has invalid checkpoints.'.format(alg, dataset))
|
||||
for structure in structures:
|
||||
info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False)
|
||||
accs.append(info['test-accuracy'])
|
||||
@@ -131,5 +138,5 @@ if __name__ == '__main__':
|
||||
|
||||
save_dir = Path(args.save_dir)
|
||||
|
||||
api = create(None, args.search_space, verbose=False)
|
||||
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||
visualize_curve(api, save_dir, args.search_space)
|
||||
|
Reference in New Issue
Block a user