Update weight watcher codes

This commit is contained in:
D-X-Y
2020-07-05 22:29:26 +00:00
parent 9659f132be
commit 6facc39a42
9 changed files with 167 additions and 161 deletions

View File

@@ -90,9 +90,9 @@ def visualize_sss_info(api, dataset, vis_save_dir):
print ('Do not find cache file : {:}'.format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
info = api.get_cost_info(index, dataset)
params.append(info['params'])
flops.append(info['flops'])
cost_info = api.get_cost_info(index, dataset, hp='90')
params.append(cost_info['params'])
flops.append(cost_info['flops'])
# accuracy
info = api.get_more_info(index, dataset, hp='90', is_random=False)
train_accs.append(info['train-accuracy'])
@@ -178,9 +178,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
print ('Do not find cache file : {:}'.format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
info = api.get_cost_info(index, dataset)
params.append(info['params'])
flops.append(info['flops'])
cost_info = api.get_cost_info(index, dataset, hp='12')
params.append(cost_info['params'])
flops.append(cost_info['flops'])
# accuracy
info = api.get_more_info(index, dataset, hp='200', is_random=False)
train_accs.append(info['train-accuracy'])
@@ -190,6 +190,7 @@ def visualize_tss_info(api, dataset, vis_save_dir):
valid_accs.append(info['valid-accuracy'])
else:
valid_accs.append(info['valid-accuracy'])
print('')
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs}
torch.save(info, cache_file_path)
else: