Fix minor bugs in test-ww.py
This commit is contained in:
@@ -37,7 +37,8 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||
for idx in range(len(api)):
|
||||
info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
|
||||
# info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
|
||||
# import pdb; pdb.set_trace()
|
||||
for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']:
|
||||
info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False)
|
||||
if key == 'cifar10-valid':
|
||||
@@ -50,7 +51,7 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
config = api.get_net_config(idx, data)
|
||||
net = get_cell_based_tiny_net(config)
|
||||
api.reload(weight_dir, idx)
|
||||
params = api.get_net_param(idx, data, None)
|
||||
params = api.get_net_param(idx, data, None, use_12epochs_result=use_12epochs_result)
|
||||
cur_norms = []
|
||||
for seed, param in params.items():
|
||||
with torch.no_grad():
|
||||
|
Reference in New Issue
Block a user