Add int search space

This commit is contained in:
D-X-Y
2021-03-18 16:02:55 +08:00
parent ece6ac5f41
commit 63c8bb9bc8
67 changed files with 5150 additions and 1474 deletions

View File

@@ -69,7 +69,13 @@ def plot(filename):
for xin in range(i):
op_i = random.randint(0, len(OPS) - 1)
# g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
g.edge(
str(xin),
str(i),
label=OPS[op_i],
color=COLORS[op_i],
fillcolor=COLORS[op_i],
)
# import pdb; pdb.set_trace()
g.render(filename, cleanup=True, view=False)
@@ -88,7 +94,9 @@ def test_auto_grad():
net = Net(10)
inputs = torch.rand(256, 10)
loss = net(inputs)
first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
first_order_grads = torch.autograd.grad(
loss, net.parameters(), retain_graph=True, create_graph=True
)
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
second_order_grads = []
for grads in first_order_grads:
@@ -108,9 +116,15 @@ def test_one_shot_model(ckpath, use_train):
print("ckpath : {:}".format(ckpath))
ckp = torch.load(ckpath)
xargs = ckp["args"]
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
# config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
config = load_config("./configs/nas-benchmark/algos/DARTS.config", {"class_num": class_num, "xshape": xshape}, None)
config = load_config(
"./configs/nas-benchmark/algos/DARTS.config",
{"class_num": class_num, "xshape": xshape},
None,
)
if xargs.dataset == "cifar10":
cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
xvalid_data = deepcopy(train_data)
@@ -142,7 +156,9 @@ def test_one_shot_model(ckpath, use_train):
search_model.load_state_dict(ckp["search_model"])
search_model = search_model.cuda()
api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth")
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
archs, probs, accuracies = evaluate_one_shot(
search_model, valid_loader, api, use_train
)
if __name__ == "__main__":