Fix minor bugs in test-ww.py

This commit is contained in:
D-X-Y
2020-03-21 12:13:13 -07:00
parent 22025887f1
commit 87545c4477
2 changed files with 7 additions and 4 deletions

View File

@@ -37,7 +37,8 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
for idx in range(len(api)):
info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
# info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
# import pdb; pdb.set_trace()
for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']:
info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False)
if key == 'cifar10-valid':
@@ -50,7 +51,7 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
config = api.get_net_config(idx, data)
net = get_cell_based_tiny_net(config)
api.reload(weight_dir, idx)
params = api.get_net_param(idx, data, None)
params = api.get_net_param(idx, data, None, use_12epochs_result=use_12epochs_result)
cur_norms = []
for seed, param in params.items():
with torch.no_grad():