add the results
This commit is contained in:
@@ -3,9 +3,9 @@ import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
import time
|
||||
import os
|
||||
from naswot.score_networks import get_nasbench201_nodes_score
|
||||
from naswot import nasspace
|
||||
from naswot import datasets
|
||||
# from naswot.score_networks import get_nasbench201_nodes_score
|
||||
# from naswot import nasspace
|
||||
# from naswot import datasets
|
||||
from models.transformer import Denoiser
|
||||
from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalTransition
|
||||
|
||||
@@ -41,7 +41,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.args.batch_size = 128
|
||||
self.args.GPU = '0'
|
||||
self.args.dataset = 'cifar10-valid'
|
||||
self.args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||
self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||
self.args.data_loc = '../cifardata/'
|
||||
self.args.seed = 777
|
||||
self.args.init = ''
|
||||
@@ -59,10 +59,10 @@ class Graph_DiT(pl.LightningModule):
|
||||
if 'valid' in self.args.dataset:
|
||||
self.args.dataset = self.args.dataset.replace('-valid', '')
|
||||
print('graph_dit starts to get searchspace of nasbench201')
|
||||
self.searchspace = nasspace.get_search_space(self.args)
|
||||
# self.searchspace = nasspace.get_search_space(self.args)
|
||||
print('searchspace of nasbench201 is obtained')
|
||||
print('graphdit starts to get train_loader')
|
||||
self.train_loader = datasets.get_data(self.args.dataset, self.args.data_loc, self.args.trainval, self.args.batch_size, self.args.augtype, self.args.repeat, self.args)
|
||||
# self.train_loader = datasets.get_data(self.args.dataset, self.args.data_loc, self.args.trainval, self.args.batch_size, self.args.augtype, self.args.repeat, self.args)
|
||||
print('train_loader is obtained')
|
||||
|
||||
self.cfg = cfg
|
||||
@@ -162,7 +162,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
return pred
|
||||
|
||||
def training_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
@@ -220,9 +220,9 @@ class Graph_DiT(pl.LightningModule):
|
||||
# self.sampling_metrics.reset()
|
||||
self.val_y_collection = []
|
||||
|
||||
# @torch.no_grad()
|
||||
@torch.no_grad()
|
||||
def validation_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
dense_data = dense_data.mask(node_mask, collapse=False)
|
||||
@@ -313,9 +313,9 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.test_E_logp.reset()
|
||||
self.test_y_collection = []
|
||||
|
||||
# @torch.no_grad()
|
||||
@torch.no_grad()
|
||||
def test_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index]
|
||||
data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
@@ -573,7 +573,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
return nll
|
||||
|
||||
# @torch.no_grad()
|
||||
@torch.no_grad()
|
||||
def sample_batch(self, batch_id, batch_size, y, keep_chain, number_chain_steps, save_final, num_nodes=None):
|
||||
"""
|
||||
:param batch_id: int
|
||||
@@ -686,130 +686,120 @@ class Graph_DiT(pl.LightningModule):
|
||||
assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
|
||||
assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()
|
||||
|
||||
# sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
|
||||
# sample multiple times and get the best score arch...
|
||||
|
||||
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
||||
op_type = {
|
||||
'input': 0,
|
||||
'nor_conv_1x1': 1,
|
||||
'nor_conv_3x3': 2,
|
||||
'avg_pool_3x3': 3,
|
||||
'skip_connect': 4,
|
||||
'none': 5,
|
||||
'output': 6,
|
||||
}
|
||||
def check_valid_graph(nodes, edges):
|
||||
nodes = [num_to_op[i] for i in nodes]
|
||||
if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
|
||||
return False
|
||||
if nodes[0] != 'input' or nodes[-1] != 'output':
|
||||
return False
|
||||
for i in range(0, len(nodes)):
|
||||
if edges[i][i] == 1:
|
||||
return False
|
||||
for i in range(1, len(nodes) - 1):
|
||||
if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output':
|
||||
return False
|
||||
for i in range(0, len(nodes)):
|
||||
for j in range(i, len(nodes)):
|
||||
if edges[i, j] == 1 and nodes[j] == 'input':
|
||||
return False
|
||||
for i in range(0, len(nodes)):
|
||||
for j in range(i, len(nodes)):
|
||||
if edges[i, j] == 1 and nodes[i] == 'output':
|
||||
return False
|
||||
flag = 0
|
||||
for i in range(0,len(nodes)):
|
||||
if edges[i,-1] == 1:
|
||||
flag = 1
|
||||
break
|
||||
if flag == 0: return False
|
||||
return True
|
||||
# num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
||||
# op_type = {
|
||||
# 'input': 0,
|
||||
# 'nor_conv_1x1': 1,
|
||||
# 'nor_conv_3x3': 2,
|
||||
# 'avg_pool_3x3': 3,
|
||||
# 'skip_connect': 4,
|
||||
# 'none': 5,
|
||||
# 'output': 6,
|
||||
# }
|
||||
# def check_valid_graph(nodes, edges):
|
||||
# nodes = [num_to_op[i] for i in nodes]
|
||||
# if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
|
||||
# return False
|
||||
# if nodes[0] != 'input' or nodes[-1] != 'output':
|
||||
# return False
|
||||
# for i in range(0, len(nodes)):
|
||||
# if edges[i][i] == 1:
|
||||
# return False
|
||||
# for i in range(1, len(nodes) - 1):
|
||||
# if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output':
|
||||
# return False
|
||||
# for i in range(0, len(nodes)):
|
||||
# for j in range(i, len(nodes)):
|
||||
# if edges[i, j] == 1 and nodes[j] == 'input':
|
||||
# return False
|
||||
# for i in range(0, len(nodes)):
|
||||
# for j in range(i, len(nodes)):
|
||||
# if edges[i, j] == 1 and nodes[i] == 'output':
|
||||
# return False
|
||||
# flag = 0
|
||||
# for i in range(0,len(nodes)):
|
||||
# if edges[i,-1] == 1:
|
||||
# flag = 1
|
||||
# break
|
||||
# if flag == 0: return False
|
||||
# return True
|
||||
|
||||
class Args:
|
||||
pass
|
||||
# class Args:
|
||||
# pass
|
||||
|
||||
def get_score(sampled_s):
|
||||
x_list = sampled_s.X.unbind(dim=0)
|
||||
e_list = sampled_s.E.unbind(dim=0)
|
||||
valid_rlt = [check_valid_graph(x_list[i].cpu().numpy(), e_list[i].cpu().numpy()) for i in range(len(x_list))]
|
||||
from graph_dit.naswot.naswot.score_networks import get_nasbench201_nodes_score
|
||||
score = []
|
||||
# def get_score(sampled_s):
|
||||
# x_list = sampled_s.X.unbind(dim=0)
|
||||
# e_list = sampled_s.E.unbind(dim=0)
|
||||
# valid_rlt = [check_valid_graph(x_list[i].cpu().numpy(), e_list[i].cpu().numpy()) for i in range(len(x_list))]
|
||||
# from graph_dit.naswot.naswot.score_networks import get_nasbench201_nodes_score
|
||||
# score = []
|
||||
|
||||
for i in range(len(x_list)):
|
||||
if valid_rlt[i]:
|
||||
nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()]
|
||||
# edges = e_list[i].cpu().numpy()
|
||||
score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=sampled_s.X.device , args=self.args))
|
||||
else:
|
||||
score.append(-1)
|
||||
# return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
|
||||
target_score = torch.ones(100, dtype=torch.float32, device=sampled_s.X.device, requires_grad=True) * 2000.0
|
||||
# target_score_list = [2000 for i in range(100)]
|
||||
# return torch.tensor(score, device=sampled_s.X.device ,dtype=torch.float32, requires_grad=True), torch.tensor(target_score_list, device=sampled_s.X.device, dtype=torch.float32, requires_grad=True)
|
||||
return torch.tensor(score, device=sampled_s.X.device ,dtype=torch.float32, requires_grad=True), target_score
|
||||
# for i in range(len(x_list)):
|
||||
# if valid_rlt[i]:
|
||||
# nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()]
|
||||
# # edges = e_list[i].cpu().numpy()
|
||||
# score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args))
|
||||
# else:
|
||||
# score.append(-1)
|
||||
# return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
|
||||
|
||||
sample_num = 10
|
||||
best_arch = None
|
||||
best_score_int = -1e8
|
||||
score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8
|
||||
print(f'score.requires_grad: {score.requires_grad}')
|
||||
# sample_num = 10
|
||||
# best_arch = None
|
||||
# best_score_int = -1e8
|
||||
# score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8
|
||||
|
||||
for i in range(sample_num):
|
||||
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
score, target_score = get_score(sampled_s)
|
||||
print(f'score: {score}')
|
||||
print(f'score.shape: {score.shape}')
|
||||
print(f'torch.sum(score): {torch.sum(score)}')
|
||||
sum_score = torch.sum(score)
|
||||
print(f'sum_score: {sum_score}')
|
||||
if sum_score > best_score_int:
|
||||
best_score_int = sum_score
|
||||
best_score = score
|
||||
best_arch = sampled_s
|
||||
# for i in range(sample_num):
|
||||
# sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
# score = get_score(sampled_s)
|
||||
# print(f'score: {score}')
|
||||
# print(f'score.shape: {score.shape}')
|
||||
# print(f'torch.sum(score): {torch.sum(score)}')
|
||||
# sum_score = torch.sum(score)
|
||||
# print(f'sum_score: {sum_score}')
|
||||
# if sum_score > best_score_int:
|
||||
# best_score_int = sum_score
|
||||
# best_score = score
|
||||
# best_arch = sampled_s
|
||||
|
||||
# print(f'prob_X: {prob_X.shape}, prob_E: {prob_E.shape}')
|
||||
|
||||
# best_arch = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
|
||||
# X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
|
||||
# E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
|
||||
print(f'best_arch.X: {best_arch.X.shape}, best_arch.E: {best_arch.E.shape}') # 100 8 8, bs n n, 100 8 8 2, bs n n 2
|
||||
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
|
||||
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
|
||||
# print(f'best_arch.X: {best_arch.X.shape}, best_arch.E: {best_arch.E.shape}') # 100 8 8, bs n n, 100 8 8 2, bs n n 2
|
||||
|
||||
print(f'best_arch.X: {best_arch.X}, best_arch.E: {best_arch.E}')
|
||||
X_s = F.one_hot(best_arch.X, num_classes=self.Xdim_output).float()
|
||||
E_s = F.one_hot(best_arch.E, num_classes=self.Edim_output).float()
|
||||
print(f'X_s: {X_s}, E_s: {E_s}')
|
||||
# print(f'best_arch.X: {best_arch.X}, best_arch.E: {best_arch.E}')
|
||||
# X_s = F.one_hot(best_arch.X, num_classes=self.Xdim_output).float()
|
||||
# E_s = F.one_hot(best_arch.E, num_classes=self.Edim_output).float()
|
||||
# print(f'X_s: {X_s}, E_s: {E_s}')
|
||||
|
||||
# NASWOT score
|
||||
# target_score = torch.ones(100, requires_grad=True, device=X_s.device) * 2000.0
|
||||
# # NASWOT score
|
||||
# target_score = torch.ones(100, requires_grad=True) * 2000.0
|
||||
print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
|
||||
print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
|
||||
print(f'best_score.device: {best_score.device}, target_score.device: {target_score.device}')
|
||||
# target_score = target_score.to(X_s.device)
|
||||
|
||||
# # compute loss mse(cur_score - target_score)
|
||||
# mse_loss = torch.nn.MSELoss()
|
||||
# print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
|
||||
# print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
|
||||
|
||||
# compute loss mse(cur_score - target_score)
|
||||
mse_loss = torch.nn.MSELoss()
|
||||
loss = mse_loss(target_score, best_score)
|
||||
print(f'loss: {loss.requires_grad}')
|
||||
loss.backward(retain_graph=True)
|
||||
# loss = mse_loss(best_score, target_score)
|
||||
# loss.backward(retain_graph=True)
|
||||
|
||||
# loss backward = gradient
|
||||
|
||||
# get prob.X, prob_E gradient
|
||||
x_grad = pred.X.grad
|
||||
e_grad = pred.E.grad
|
||||
# x_grad = pred.X.grad
|
||||
# e_grad = pred.E.grad
|
||||
|
||||
beta_ratio = 0.5
|
||||
# x_current = pred.X - beta_ratio * x_grad
|
||||
# e_current = pred.E - beta_ratio * e_grad
|
||||
X_s = pred.X - beta_ratio * x_grad
|
||||
E_s = pred.E - beta_ratio * e_grad
|
||||
# beta_ratio = 0.5
|
||||
# # x_current = pred.X - beta_ratio * x_grad
|
||||
# # e_current = pred.E - beta_ratio * e_grad
|
||||
# E_s = pred.X - beta_ratio * x_grad
|
||||
# X_s = pred.E - beta_ratio * e_grad
|
||||
|
||||
# update prob.X prob_E with using gradient
|
||||
|
||||
|
Reference in New Issue
Block a user