add autodl
This commit is contained in:
168
AutoDL-Projects/exps/experimental/test-nas-plot.py
Normal file
168
AutoDL-Projects/exps/experimental/test-nas-plot.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# python ./exps/vis/test.py
|
||||
import os, sys, random
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from nas_201_api import NASBench201API as API
|
||||
|
||||
|
||||
def test_nas_api():
|
||||
from nas_201_api import ArchResults
|
||||
|
||||
xdata = torch.load(
|
||||
"/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth"
|
||||
)
|
||||
for key in ["full", "less"]:
|
||||
print("\n------------------------- {:} -------------------------".format(key))
|
||||
archRes = ArchResults.create_from_state_dict(xdata[key])
|
||||
print(archRes)
|
||||
print(archRes.arch_idx_str())
|
||||
print(archRes.get_dataset_names())
|
||||
print(archRes.get_comput_costs("cifar10-valid"))
|
||||
# get the metrics
|
||||
print(archRes.get_metrics("cifar10-valid", "x-valid", None, False))
|
||||
print(archRes.get_metrics("cifar10-valid", "x-valid", None, True))
|
||||
print(archRes.query("cifar10-valid", 777))
|
||||
|
||||
|
||||
OPS = ["skip-connect", "conv-1x1", "conv-3x3", "pool-3x3"]
|
||||
COLORS = ["chartreuse", "cyan", "navyblue", "chocolate1"]
|
||||
|
||||
|
||||
def plot(filename):
|
||||
from graphviz import Digraph
|
||||
|
||||
g = Digraph(
|
||||
format="png",
|
||||
edge_attr=dict(fontsize="20", fontname="times"),
|
||||
node_attr=dict(
|
||||
style="filled",
|
||||
shape="rect",
|
||||
align="center",
|
||||
fontsize="20",
|
||||
height="0.5",
|
||||
width="0.5",
|
||||
penwidth="2",
|
||||
fontname="times",
|
||||
),
|
||||
engine="dot",
|
||||
)
|
||||
g.body.extend(["rankdir=LR"])
|
||||
|
||||
steps = 5
|
||||
for i in range(0, steps):
|
||||
if i == 0:
|
||||
g.node(str(i), fillcolor="darkseagreen2")
|
||||
elif i + 1 == steps:
|
||||
g.node(str(i), fillcolor="palegoldenrod")
|
||||
else:
|
||||
g.node(str(i), fillcolor="lightblue")
|
||||
|
||||
for i in range(1, steps):
|
||||
for xin in range(i):
|
||||
op_i = random.randint(0, len(OPS) - 1)
|
||||
# g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
|
||||
g.edge(
|
||||
str(xin),
|
||||
str(i),
|
||||
label=OPS[op_i],
|
||||
color=COLORS[op_i],
|
||||
fillcolor=COLORS[op_i],
|
||||
)
|
||||
# import pdb; pdb.set_trace()
|
||||
g.render(filename, cleanup=True, view=False)
|
||||
|
||||
|
||||
def test_auto_grad():
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self, iS):
|
||||
super(Net, self).__init__()
|
||||
self.layer = torch.nn.Linear(iS, 1)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.layer(inputs)
|
||||
outputs = torch.exp(outputs)
|
||||
return outputs.mean()
|
||||
|
||||
net = Net(10)
|
||||
inputs = torch.rand(256, 10)
|
||||
loss = net(inputs)
|
||||
first_order_grads = torch.autograd.grad(
|
||||
loss, net.parameters(), retain_graph=True, create_graph=True
|
||||
)
|
||||
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
|
||||
second_order_grads = []
|
||||
for grads in first_order_grads:
|
||||
s_grads = torch.autograd.grad(grads, net.parameters())
|
||||
second_order_grads.append(s_grads)
|
||||
|
||||
|
||||
def test_one_shot_model(ckpath, use_train):
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from config_utils import load_config, dict2config
|
||||
from utils.nas_utils import evaluate_one_shot
|
||||
|
||||
use_train = int(use_train) > 0
|
||||
# ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
|
||||
# ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
|
||||
print("ckpath : {:}".format(ckpath))
|
||||
ckp = torch.load(ckpath)
|
||||
xargs = ckp["args"]
|
||||
train_data, valid_data, xshape, class_num = get_datasets(
|
||||
xargs.dataset, xargs.data_path, -1
|
||||
)
|
||||
# config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
|
||||
config = load_config(
|
||||
"./configs/nas-benchmark/algos/DARTS.config",
|
||||
{"class_num": class_num, "xshape": xshape},
|
||||
None,
|
||||
)
|
||||
if xargs.dataset == "cifar10":
|
||||
cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
|
||||
xvalid_data = deepcopy(train_data)
|
||||
xvalid_data.transform = valid_data.transform
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
xvalid_data,
|
||||
batch_size=2048,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid),
|
||||
num_workers=12,
|
||||
pin_memory=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid dataset : {:}".format(xargs.dataseet))
|
||||
search_space = get_search_spaces("cell", xargs.search_space_name)
|
||||
model_config = dict2config(
|
||||
{
|
||||
"name": "SETN",
|
||||
"C": xargs.channel,
|
||||
"N": xargs.num_cells,
|
||||
"max_nodes": xargs.max_nodes,
|
||||
"num_classes": class_num,
|
||||
"space": search_space,
|
||||
"affine": False,
|
||||
"track_running_stats": True,
|
||||
},
|
||||
None,
|
||||
)
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
search_model.load_state_dict(ckp["search_model"])
|
||||
search_model = search_model.cuda()
|
||||
api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth")
|
||||
archs, probs, accuracies = evaluate_one_shot(
|
||||
search_model, valid_loader, api, use_train
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_nas_api()
|
||||
# for i in range(200): plot('{:04d}'.format(i))
|
||||
# test_auto_grad()
|
||||
test_one_shot_model(sys.argv[1], sys.argv[2])
|
Reference in New Issue
Block a user