Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
@@ -53,7 +53,7 @@ def evaluate(api, weight_dir, data: str):
|
||||
config = api.get_net_config(arch_index, data)
|
||||
net = get_cell_based_tiny_net(config)
|
||||
meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90')
|
||||
params = meta_info.get_net_param(data, 777)
|
||||
params = meta_info.get_net_param(data, 888 if isinstance(api, NASBench201API) else 777)
|
||||
with torch.no_grad():
|
||||
net.load_state_dict(params)
|
||||
_, summary = weight_watcher.analyze(net, alphas=False)
|
||||
|
Reference in New Issue
Block a user