From 13f77963d0d366c43299dca38320bc425534d093 Mon Sep 17 00:00:00 2001 From: mhz Date: Sun, 28 Jul 2024 21:18:55 +0200 Subject: [PATCH] wrote the get_nasbench201_idx_score --- graph_dit/naswot/score_networks.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/graph_dit/naswot/score_networks.py b/graph_dit/naswot/score_networks.py index e5a9c76..90552c9 100644 --- a/graph_dit/naswot/score_networks.py +++ b/graph_dit/naswot/score_networks.py @@ -58,15 +58,9 @@ def get_batch_jacobian(net, x, target, device, args=None): return jacob, target.detach(), y.detach(), out.detach() def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): - op_type = { - 'input': 0, - 'nor_conv_1x1': 1, - 'nor_conv_3x3': 2, - 'avg_pool_3x3': 3, - 'skip_connect': 4, - 'none': 5, - 'output': 6, - } + num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] + + def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device): # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")