Reformulate via black

This commit is contained in:
D-X-Y
2021-03-17 09:25:58 +00:00
parent a9093e41e1
commit f98edea22a
59 changed files with 12289 additions and 8918 deletions

View File

@@ -5,110 +5,148 @@ from copy import deepcopy
import torch
import numpy as np
from collections import OrderedDict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API as API
def test_nas_api():
from nas_201_api import ArchResults
xdata = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth')
for key in ['full', 'less']:
print ('\n------------------------- {:} -------------------------'.format(key))
archRes = ArchResults.create_from_state_dict(xdata[key])
print(archRes)
print(archRes.arch_idx_str())
print(archRes.get_dataset_names())
print(archRes.get_comput_costs('cifar10-valid'))
# get the metrics
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False))
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True))
print(archRes.query('cifar10-valid', 777))
from nas_201_api import ArchResults
xdata = torch.load(
"/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth"
)
for key in ["full", "less"]:
print("\n------------------------- {:} -------------------------".format(key))
archRes = ArchResults.create_from_state_dict(xdata[key])
print(archRes)
print(archRes.arch_idx_str())
print(archRes.get_dataset_names())
print(archRes.get_comput_costs("cifar10-valid"))
# get the metrics
print(archRes.get_metrics("cifar10-valid", "x-valid", None, False))
print(archRes.get_metrics("cifar10-valid", "x-valid", None, True))
print(archRes.query("cifar10-valid", 777))
OPS = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3']
COLORS = ['chartreuse' , 'cyan' , 'navyblue', 'chocolate1']
OPS = ["skip-connect", "conv-1x1", "conv-3x3", "pool-3x3"]
COLORS = ["chartreuse", "cyan", "navyblue", "chocolate1"]
def plot(filename):
from graphviz import Digraph
g = Digraph(
format='png',
edge_attr=dict(fontsize='20', fontname="times"),
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
engine='dot')
g.body.extend(['rankdir=LR'])
from graphviz import Digraph
steps = 5
for i in range(0, steps):
if i == 0:
g.node(str(i), fillcolor='darkseagreen2')
elif i+1 == steps:
g.node(str(i), fillcolor='palegoldenrod')
else: g.node(str(i), fillcolor='lightblue')
g = Digraph(
format="png",
edge_attr=dict(fontsize="20", fontname="times"),
node_attr=dict(
style="filled",
shape="rect",
align="center",
fontsize="20",
height="0.5",
width="0.5",
penwidth="2",
fontname="times",
),
engine="dot",
)
g.body.extend(["rankdir=LR"])
for i in range(1, steps):
for xin in range(i):
op_i = random.randint(0, len(OPS)-1)
#g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
#import pdb; pdb.set_trace()
g.render(filename, cleanup=True, view=False)
steps = 5
for i in range(0, steps):
if i == 0:
g.node(str(i), fillcolor="darkseagreen2")
elif i + 1 == steps:
g.node(str(i), fillcolor="palegoldenrod")
else:
g.node(str(i), fillcolor="lightblue")
for i in range(1, steps):
for xin in range(i):
op_i = random.randint(0, len(OPS) - 1)
# g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
# import pdb; pdb.set_trace()
g.render(filename, cleanup=True, view=False)
def test_auto_grad():
class Net(torch.nn.Module):
def __init__(self, iS):
super(Net, self).__init__()
self.layer = torch.nn.Linear(iS, 1)
def forward(self, inputs):
outputs = self.layer(inputs)
outputs = torch.exp(outputs)
return outputs.mean()
net = Net(10)
inputs = torch.rand(256, 10)
loss = net(inputs)
first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
second_order_grads = []
for grads in first_order_grads:
s_grads = torch.autograd.grad(grads, net.parameters())
second_order_grads.append( s_grads )
class Net(torch.nn.Module):
def __init__(self, iS):
super(Net, self).__init__()
self.layer = torch.nn.Linear(iS, 1)
def forward(self, inputs):
outputs = self.layer(inputs)
outputs = torch.exp(outputs)
return outputs.mean()
net = Net(10)
inputs = torch.rand(256, 10)
loss = net(inputs)
first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
second_order_grads = []
for grads in first_order_grads:
s_grads = torch.autograd.grad(grads, net.parameters())
second_order_grads.append(s_grads)
def test_one_shot_model(ckpath, use_train):
from models import get_cell_based_tiny_net, get_search_spaces
from datasets import get_datasets, SearchDataset
from config_utils import load_config, dict2config
from utils.nas_utils import evaluate_one_shot
use_train = int(use_train) > 0
#ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
#ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
print ('ckpath : {:}'.format(ckpath))
ckp = torch.load(ckpath)
xargs = ckp['args']
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
#config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
config = load_config('./configs/nas-benchmark/algos/DARTS.config', {'class_num': class_num, 'xshape': xshape}, None)
if xargs.dataset == 'cifar10':
cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
xvalid_data = deepcopy(train_data)
xvalid_data.transform = valid_data.transform
valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True)
else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet))
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': True}, None)
search_model = get_cell_based_tiny_net(model_config)
search_model.load_state_dict( ckp['search_model'] )
search_model = search_model.cuda()
api = API('/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth')
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
from models import get_cell_based_tiny_net, get_search_spaces
from datasets import get_datasets, SearchDataset
from config_utils import load_config, dict2config
from utils.nas_utils import evaluate_one_shot
use_train = int(use_train) > 0
# ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
# ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
print("ckpath : {:}".format(ckpath))
ckp = torch.load(ckpath)
xargs = ckp["args"]
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
# config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
config = load_config("./configs/nas-benchmark/algos/DARTS.config", {"class_num": class_num, "xshape": xshape}, None)
if xargs.dataset == "cifar10":
cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
xvalid_data = deepcopy(train_data)
xvalid_data.transform = valid_data.transform
valid_loader = torch.utils.data.DataLoader(
xvalid_data,
batch_size=2048,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid),
num_workers=12,
pin_memory=True,
)
else:
raise ValueError("invalid dataset : {:}".format(xargs.dataseet))
search_space = get_search_spaces("cell", xargs.search_space_name)
model_config = dict2config(
{
"name": "SETN",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": True,
},
None,
)
search_model = get_cell_based_tiny_net(model_config)
search_model.load_state_dict(ckp["search_model"])
search_model = search_model.cuda()
api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth")
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
if __name__ == '__main__':
#test_nas_api()
#for i in range(200): plot('{:04d}'.format(i))
#test_auto_grad()
test_one_shot_model(sys.argv[1], sys.argv[2])
if __name__ == "__main__":
# test_nas_api()
# for i in range(200): plot('{:04d}'.format(i))
# test_auto_grad()
test_one_shot_model(sys.argv[1], sys.argv[2])