Add int search space
This commit is contained in:
@@ -69,7 +69,13 @@ def plot(filename):
|
||||
for xin in range(i):
|
||||
op_i = random.randint(0, len(OPS) - 1)
|
||||
# g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
|
||||
g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
|
||||
g.edge(
|
||||
str(xin),
|
||||
str(i),
|
||||
label=OPS[op_i],
|
||||
color=COLORS[op_i],
|
||||
fillcolor=COLORS[op_i],
|
||||
)
|
||||
# import pdb; pdb.set_trace()
|
||||
g.render(filename, cleanup=True, view=False)
|
||||
|
||||
@@ -88,7 +94,9 @@ def test_auto_grad():
|
||||
net = Net(10)
|
||||
inputs = torch.rand(256, 10)
|
||||
loss = net(inputs)
|
||||
first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
|
||||
first_order_grads = torch.autograd.grad(
|
||||
loss, net.parameters(), retain_graph=True, create_graph=True
|
||||
)
|
||||
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
|
||||
second_order_grads = []
|
||||
for grads in first_order_grads:
|
||||
@@ -108,9 +116,15 @@ def test_one_shot_model(ckpath, use_train):
|
||||
print("ckpath : {:}".format(ckpath))
|
||||
ckp = torch.load(ckpath)
|
||||
xargs = ckp["args"]
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
train_data, valid_data, xshape, class_num = get_datasets(
|
||||
xargs.dataset, xargs.data_path, -1
|
||||
)
|
||||
# config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
|
||||
config = load_config("./configs/nas-benchmark/algos/DARTS.config", {"class_num": class_num, "xshape": xshape}, None)
|
||||
config = load_config(
|
||||
"./configs/nas-benchmark/algos/DARTS.config",
|
||||
{"class_num": class_num, "xshape": xshape},
|
||||
None,
|
||||
)
|
||||
if xargs.dataset == "cifar10":
|
||||
cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
|
||||
xvalid_data = deepcopy(train_data)
|
||||
@@ -142,7 +156,9 @@ def test_one_shot_model(ckpath, use_train):
|
||||
search_model.load_state_dict(ckp["search_model"])
|
||||
search_model = search_model.cuda()
|
||||
api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth")
|
||||
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
|
||||
archs, probs, accuracies = evaluate_one_shot(
|
||||
search_model, valid_loader, api, use_train
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user