first commit
This commit is contained in:
0
NAS-Bench-201/models/__init__.py
Executable file
0
NAS-Bench-201/models/__init__.py
Executable file
391
NAS-Bench-201/models/cate.py
Normal file
391
NAS-Bench-201/models/cate.py
Normal file
@@ -0,0 +1,391 @@
|
||||
# Most of this code is from https://github.com/AIoT-MLSys-Lab/CATE.git
|
||||
# which was authored by Shen Yan, Kaiqiang Song, Fei Liu, Mi Zhang, 2021
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from . import utils
|
||||
from .transformer import Encoder, SemanticEmbedding
|
||||
from .set_encoder.setenc_models import SetPool
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu):
|
||||
"""
|
||||
num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
|
||||
input_dim: dimensionality of input features
|
||||
hidden_dim: dimensionality of hidden units at ALL layers
|
||||
output_dim: number of classes for prediction
|
||||
num_classes: the number of classes of input, to be treated with different gains and biases,
|
||||
(see the definition of class `ConditionalLayer1d`)
|
||||
"""
|
||||
|
||||
super(MLP, self).__init__()
|
||||
|
||||
self.linear_or_not = True # default is linear model
|
||||
self.num_layers = num_layers
|
||||
self.use_bn = use_bn
|
||||
self.activate_func = activate_func
|
||||
|
||||
if num_layers < 1:
|
||||
raise ValueError("number of layers should be positive!")
|
||||
elif num_layers == 1:
|
||||
# Linear model
|
||||
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||
else:
|
||||
# Multi-layer model
|
||||
self.linear_or_not = False
|
||||
self.linears = torch.nn.ModuleList()
|
||||
|
||||
self.linears.append(torch.nn.Linear(input_dim, hidden_dim))
|
||||
for layer in range(num_layers - 2):
|
||||
self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim))
|
||||
self.linears.append(torch.nn.Linear(hidden_dim, output_dim))
|
||||
|
||||
if self.use_bn:
|
||||
self.batch_norms = torch.nn.ModuleList()
|
||||
for layer in range(num_layers - 1):
|
||||
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x: [num_classes * batch_size, N, F_i], batch of node features
|
||||
note that in self.cond_layers[layer],
|
||||
`x` is splited into `num_classes` groups in dim=0,
|
||||
and then treated with different gains and biases
|
||||
"""
|
||||
if self.linear_or_not:
|
||||
# If linear model
|
||||
return self.linear(x)
|
||||
else:
|
||||
# If MLP
|
||||
h = x
|
||||
for layer in range(self.num_layers - 1):
|
||||
h = self.linears[layer](h)
|
||||
if self.use_bn:
|
||||
h = self.batch_norms[layer](h)
|
||||
h = self.activate_func(h)
|
||||
return self.linears[self.num_layers - 1](h)
|
||||
|
||||
|
||||
""" Transformer Encoder """
|
||||
class GraphEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(GraphEncoder, self).__init__()
|
||||
# Forward Transformers
|
||||
self.encoder_f = Encoder(config)
|
||||
|
||||
def forward(self, x, mask):
|
||||
h_f, hs_f, attns_f = self.encoder_f(x, mask)
|
||||
h = torch.cat(hs_f, dim=-1)
|
||||
return h
|
||||
|
||||
@staticmethod
|
||||
def get_embeddings(h_x):
|
||||
h_x = h_x.cpu()
|
||||
return h_x[:, -1]
|
||||
|
||||
|
||||
class CLSHead(nn.Module):
|
||||
def __init__(self, config, init_weights=None):
|
||||
super(CLSHead, self).__init__()
|
||||
self.layer_1 = nn.Linear(config.d_model, config.d_model)
|
||||
self.dropout = nn.Dropout(p=config.dropout)
|
||||
self.layer_2 = nn.Linear(config.d_model, config.n_vocab)
|
||||
if init_weights is not None:
|
||||
self.layer_2.weight = init_weights
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(torch.tanh(self.layer_1(x)))
|
||||
return F.log_softmax(self.layer_2(x), dim=-1)
|
||||
|
||||
|
||||
@utils.register_model(name='CATE')
|
||||
class CATE(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CATE, self).__init__()
|
||||
# Shared Embedding Layer
|
||||
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
|
||||
self.dropout_op = nn.Dropout(p=config.model.dropout)
|
||||
self.d_model = config.model.graph_encoder.d_model
|
||||
self.act = act = get_act(config)
|
||||
# Time
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
# 2 GraphEncoder for X and Y
|
||||
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
|
||||
|
||||
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
if 'pos_enc_type' in config.model:
|
||||
self.pos_enc_type = config.model.pos_enc_type
|
||||
if self.pos_enc_type == 1:
|
||||
raise NotImplementedError
|
||||
elif self.pos_enc_type == 2:
|
||||
if config.data.name == 'NASBench201':
|
||||
self.pos_encoder = PositionalEncoding_Cell(d_model=self.d_model, max_len=config.data.max_node)
|
||||
else:
|
||||
self.pos_encoder = PositionalEncoding_StageWise(d_model=self.d_model, max_len=config.data.max_node)
|
||||
elif self.pos_enc_type == 3:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.pos_encoder = None
|
||||
else:
|
||||
self.pos_encoder = None
|
||||
|
||||
|
||||
def forward(self, X, time_cond, maskX):
|
||||
|
||||
emb_x = self.dropout_op(self.opEmb(X))
|
||||
|
||||
if self.pos_encoder is not None:
|
||||
emb_p = self.pos_encoder(emb_x)
|
||||
emb_x = emb_x + emb_p
|
||||
|
||||
# Time embedding
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)
|
||||
emb_t = self.timeEmb1(emb_t) # [32, 512]
|
||||
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
|
||||
emb_t = emb_t.unsqueeze(1)
|
||||
emb = emb_x + emb_t
|
||||
|
||||
h_x = self.graph_encoder(emb, maskX)
|
||||
h_x = self.final(h_x)
|
||||
|
||||
return h_x
|
||||
|
||||
|
||||
@utils.register_model(name='PredictorCATE')
|
||||
class PredictorCATE(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PredictorCATE, self).__init__()
|
||||
# Shared Embedding Layer
|
||||
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
|
||||
self.dropout_op = nn.Dropout(p=config.model.dropout)
|
||||
self.d_model = config.model.graph_encoder.d_model
|
||||
self.act = act = get_act(config)
|
||||
# Time
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
# 2 GraphEncoder for X and Y
|
||||
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
|
||||
|
||||
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.rdim = int(config.data.max_node * config.data.n_vocab)
|
||||
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=1,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
def forward(self, X, time_cond, maskX):
|
||||
|
||||
emb_x = self.dropout_op(self.opEmb(X))
|
||||
|
||||
# Time embedding
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)
|
||||
emb_t = self.timeEmb1(emb_t)
|
||||
emb_t = self.timeEmb2(self.act(emb_t))
|
||||
emb_t = emb_t.unsqueeze(1)
|
||||
emb = emb_x + emb_t
|
||||
|
||||
h_x = self.graph_encoder(emb, maskX)
|
||||
h_x = self.final(h_x)
|
||||
h_x = h_x.reshape(h_x.size(0), -1)
|
||||
h_x = self.regeress(h_x)
|
||||
|
||||
return h_x
|
||||
|
||||
|
||||
class PositionalEncoding_StageWise(nn.Module):
|
||||
|
||||
def __init__(self, d_model, max_len):
|
||||
super(PositionalEncoding_StageWise, self).__init__()
|
||||
NUM_STAGE = 5
|
||||
max_len = int(max_len / NUM_STAGE)
|
||||
self.encoding = torch.zeros(max_len, d_model)
|
||||
self.encoding.requires_grad = False
|
||||
pos = torch.arange(0, max_len)
|
||||
pos = pos.float().unsqueeze(dim=1)
|
||||
_2i = torch.arange(0, d_model, step=2).float()
|
||||
self.encoding[:, ::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
|
||||
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
|
||||
self.encoding = torch.cat([self.encoding] * NUM_STAGE, dim=0)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, _ = x.size()
|
||||
return self.encoding[:seq_len, :].to(x.device)
|
||||
|
||||
|
||||
class PositionalEncoding_Cell(nn.Module):
|
||||
|
||||
def __init__(self, d_model, max_len):
|
||||
super(PositionalEncoding_Cell, self).__init__()
|
||||
NUM_STAGE = 1
|
||||
max_len = int(max_len / NUM_STAGE)
|
||||
self.encoding = torch.zeros(max_len, d_model)
|
||||
self.encoding.requires_grad = False
|
||||
pos = torch.arange(0, max_len)
|
||||
pos = pos.float().unsqueeze(dim=1)
|
||||
_2i = torch.arange(0, d_model, step=2).float()
|
||||
self.encoding[:, ::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
|
||||
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
|
||||
self.encoding = torch.cat([self.encoding] * NUM_STAGE, dim=0)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, _ = x.size()
|
||||
return self.encoding[:seq_len, :].to(x.device)
|
||||
|
||||
|
||||
@utils.register_model(name='MetaPredictorCATE')
|
||||
class MetaPredictorCATE(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(MetaPredictorCATE, self).__init__()
|
||||
|
||||
self.input_type= config.model.input_type
|
||||
self.hs = config.model.hs
|
||||
|
||||
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
|
||||
self.dropout_op = nn.Dropout(p=config.model.dropout)
|
||||
self.d_model = config.model.graph_encoder.d_model
|
||||
self.act = act = get_act(config)
|
||||
|
||||
# Time
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
|
||||
|
||||
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.rdim = int(config.data.max_node * config.data.n_vocab)
|
||||
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=2*self.rdim,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
# Set
|
||||
self.nz = config.model.nz
|
||||
self.num_sample = config.model.num_sample
|
||||
self.intra_setpool = SetPool(dim_input=512,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.inter_setpool = SetPool(dim_input=self.nz,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.set_fc = nn.Sequential(
|
||||
nn.Linear(512, self.nz),
|
||||
nn.ReLU())
|
||||
|
||||
input_dim = 0
|
||||
if 'D' in self.input_type:
|
||||
input_dim += self.nz
|
||||
if 'A' in self.input_type:
|
||||
input_dim += 2*self.rdim
|
||||
|
||||
self.pred_fc = nn.Sequential(
|
||||
nn.Linear(input_dim, self.hs),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.hs, 1)
|
||||
)
|
||||
|
||||
self.sample_state = False
|
||||
self.D_mu = None
|
||||
|
||||
|
||||
def arch_encode(self, X, time_cond, maskX):
|
||||
emb_x = self.dropout_op(self.opEmb(X))
|
||||
|
||||
# Time embedding
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
|
||||
emb_t = self.timeEmb1(emb_t) # [32, 512]
|
||||
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
|
||||
emb_t = emb_t.unsqueeze(1)
|
||||
emb = emb_x + emb_t
|
||||
|
||||
h_x = self.graph_encoder(emb, maskX)
|
||||
h_x = self.final(h_x)
|
||||
|
||||
h_x = h_x.reshape(h_x.size(0), -1)
|
||||
h_x = self.regeress(h_x)
|
||||
return h_x
|
||||
|
||||
|
||||
def set_encode(self, task):
|
||||
proto_batch = []
|
||||
for x in task:
|
||||
cls_protos = self.intra_setpool(
|
||||
x.view(-1, self.num_sample, 512)).squeeze(1)
|
||||
proto_batch.append(
|
||||
self.inter_setpool(cls_protos.unsqueeze(0)))
|
||||
v = torch.stack(proto_batch).squeeze()
|
||||
return v
|
||||
|
||||
|
||||
def predict(self, D_mu, A_mu):
|
||||
input_vec = []
|
||||
if 'D' in self.input_type:
|
||||
input_vec.append(D_mu)
|
||||
if 'A' in self.input_type:
|
||||
input_vec.append(A_mu)
|
||||
input_vec = torch.cat(input_vec, dim=1)
|
||||
return self.pred_fc(input_vec)
|
||||
|
||||
|
||||
def forward(self, X, time_cond, maskX, task):
|
||||
if self.sample_state:
|
||||
if self.D_mu is None:
|
||||
self.D_mu = self.set_encode(task)
|
||||
D_mu = self.D_mu
|
||||
else:
|
||||
D_mu = self.set_encode(task)
|
||||
A_mu = self.arch_encode(X, time_cond, maskX)
|
||||
y_pred = self.predict(D_mu, A_mu)
|
||||
return y_pred
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
125
NAS-Bench-201/models/digcn.py
Normal file
125
NAS-Bench-201/models/digcn.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
|
||||
# which was authored by Yuge Zhang, 2020
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
from . import utils
|
||||
from models.cate import PositionalEncoding_StageWise
|
||||
|
||||
|
||||
def normalize_adj(adj):
|
||||
# Row-normalize matrix
|
||||
last_dim = adj.size(-1)
|
||||
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
|
||||
return torch.div(adj, rowsum)
|
||||
|
||||
|
||||
def graph_pooling(inputs, num_vertices):
|
||||
num_vertices = num_vertices.to(inputs.device)
|
||||
out = inputs.sum(1)
|
||||
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
|
||||
|
||||
|
||||
class DirectedGraphConvolution(nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.xavier_uniform_(self.weight1.data)
|
||||
nn.init.xavier_uniform_(self.weight2.data)
|
||||
|
||||
def forward(self, inputs, adj):
|
||||
inputs = inputs.to(self.weight1.device)
|
||||
adj = adj.to(self.weight1.device)
|
||||
norm_adj = normalize_adj(adj)
|
||||
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
|
||||
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
|
||||
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
|
||||
out = (output1 + output2) / 2
|
||||
out = self.dropout(out)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ str(self.in_features) + ' -> ' \
|
||||
+ str(self.out_features) + ')'
|
||||
|
||||
|
||||
@utils.register_model(name='NeuralPredictor')
|
||||
class NeuralPredictor(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
|
||||
config.model.graph_encoder.gcn_hidden)
|
||||
for i in range(config.model.graph_encoder.gcn_layers)]
|
||||
self.gcn = nn.ModuleList(self.gcn)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
|
||||
self.fc2 = nn.Linear(config.model.graph_encoder.linear_hidden, 1, bias=False)
|
||||
# Time
|
||||
self.d_model = config.model.graph_encoder.gcn_hidden
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
self.act = act = get_act(config)
|
||||
|
||||
def forward(self, X, time_cond, maskX):
|
||||
out = X
|
||||
adj = maskX
|
||||
|
||||
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device) # 20
|
||||
gs = adj.size(1) # graph node number
|
||||
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
|
||||
emb_t = self.timeEmb1(emb_t)
|
||||
emb_t = self.timeEmb2(self.act(emb_t)) # (5, 144)
|
||||
|
||||
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device)) # assuming diagonal is not 1
|
||||
for layer in self.gcn:
|
||||
out = layer(out, adj_with_diag)
|
||||
out = graph_pooling(out, numv)
|
||||
# time
|
||||
out = out + emb_t
|
||||
out = self.fc1(out)
|
||||
out = self.dropout(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
190
NAS-Bench-201/models/digcn_meta.py
Normal file
190
NAS-Bench-201/models/digcn_meta.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
|
||||
# which was authored by Yuge Zhang, 2020
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from . import utils
|
||||
from .set_encoder.setenc_models import SetPool
|
||||
|
||||
|
||||
def normalize_adj(adj):
|
||||
# Row-normalize matrix
|
||||
last_dim = adj.size(-1)
|
||||
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
|
||||
return torch.div(adj, rowsum)
|
||||
|
||||
|
||||
def graph_pooling(inputs, num_vertices):
|
||||
num_vertices = num_vertices.to(inputs.device)
|
||||
out = inputs.sum(1)
|
||||
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
|
||||
|
||||
|
||||
class DirectedGraphConvolution(nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.xavier_uniform_(self.weight1.data)
|
||||
nn.init.xavier_uniform_(self.weight2.data)
|
||||
|
||||
def forward(self, inputs, adj):
|
||||
inputs = inputs.to(self.weight1.device)
|
||||
adj = adj.to(self.weight1.device)
|
||||
norm_adj = normalize_adj(adj)
|
||||
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
|
||||
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
|
||||
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
|
||||
out = (output1 + output2) / 2
|
||||
out = self.dropout(out)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ str(self.in_features) + ' -> ' \
|
||||
+ str(self.out_features) + ')'
|
||||
|
||||
|
||||
@utils.register_model(name='MetaNeuralPredictor')
|
||||
class MetaeuralPredictor(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# Arch
|
||||
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
|
||||
config.model.graph_encoder.gcn_hidden)
|
||||
for i in range(config.model.graph_encoder.gcn_layers)]
|
||||
self.gcn = nn.ModuleList(self.gcn)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
|
||||
|
||||
# Time
|
||||
self.d_model = config.model.graph_encoder.gcn_hidden
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
self.act = act = get_act(config)
|
||||
self.input_type = config.model.input_type
|
||||
self.hs = config.model.hs
|
||||
|
||||
# Set
|
||||
self.nz = config.model.nz
|
||||
self.num_sample = config.model.num_sample
|
||||
self.intra_setpool = SetPool(dim_input=512,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.inter_setpool = SetPool(dim_input=self.nz,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.set_fc = nn.Sequential(
|
||||
nn.Linear(512, self.nz),
|
||||
nn.ReLU())
|
||||
|
||||
input_dim = 0
|
||||
if 'D' in self.input_type:
|
||||
input_dim += self.nz
|
||||
if 'A' in self.input_type:
|
||||
input_dim += config.model.graph_encoder.linear_hidden
|
||||
|
||||
self.pred_fc = nn.Sequential(
|
||||
nn.Linear(input_dim, self.hs),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.hs, 1)
|
||||
)
|
||||
|
||||
self.sample_state = False
|
||||
self.D_mu = None
|
||||
|
||||
def arch_encode(self, X, time_cond, maskX):
|
||||
out = X
|
||||
adj = maskX
|
||||
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device)
|
||||
gs = adj.size(1) # graph node number
|
||||
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)
|
||||
emb_t = self.timeEmb1(emb_t)
|
||||
emb_t = self.timeEmb2(self.act(emb_t))
|
||||
|
||||
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device))
|
||||
for layer in self.gcn:
|
||||
out = layer(out, adj_with_diag)
|
||||
out = graph_pooling(out, numv)
|
||||
# time
|
||||
out = out + emb_t
|
||||
out = self.fc1(out)
|
||||
out = self.dropout(out)
|
||||
|
||||
return out
|
||||
|
||||
def set_encode(self, task):
|
||||
proto_batch = []
|
||||
for x in task:
|
||||
cls_protos = self.intra_setpool(
|
||||
x.view(-1, self.num_sample, 512)).squeeze(1)
|
||||
proto_batch.append(
|
||||
self.inter_setpool(cls_protos.unsqueeze(0)))
|
||||
v = torch.stack(proto_batch).squeeze()
|
||||
return v
|
||||
|
||||
def predict(self, D_mu, A_mu):
|
||||
input_vec = []
|
||||
if 'D' in self.input_type:
|
||||
input_vec.append(D_mu)
|
||||
if 'A' in self.input_type:
|
||||
input_vec.append(A_mu)
|
||||
input_vec = torch.cat(input_vec, dim=1)
|
||||
return self.pred_fc(input_vec)
|
||||
|
||||
def forward(self, X, time_cond, maskX, task):
|
||||
if self.sample_state:
|
||||
if self.D_mu is None:
|
||||
self.D_mu = self.set_encode(task)
|
||||
D_mu = self.D_mu
|
||||
else:
|
||||
D_mu = self.set_encode(task)
|
||||
A_mu = self.arch_encode(X, time_cond, maskX)
|
||||
y_pred = self.predict(D_mu, A_mu)
|
||||
return y_pred
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
85
NAS-Bench-201/models/ema.py
Normal file
85
NAS-Bench-201/models/ema.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
|
||||
|
||||
class ExponentialMovingAverage:
|
||||
"""
|
||||
Maintains (exponential) moving average of a set of parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, parameters, decay, use_num_updates=True):
|
||||
"""
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`.
|
||||
decay: The exponential decay.
|
||||
use_num_updates: Whether to use number of updates when computing averages.
|
||||
"""
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
self.decay = decay
|
||||
self.num_updates = 0 if use_num_updates else None
|
||||
self.shadow_params = [p.clone().detach()
|
||||
for p in parameters if p.requires_grad]
|
||||
self.collected_params = []
|
||||
|
||||
def update(self, parameters):
|
||||
"""
|
||||
Update currently maintained parameters.
|
||||
|
||||
Call this every time the parameters are updated, such as the result of the `optimizer.step()` call.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to
|
||||
initialize this object.
|
||||
"""
|
||||
decay = self.decay
|
||||
if self.num_updates is not None:
|
||||
self.num_updates += 1
|
||||
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
one_minus_decay = 1.0 - decay
|
||||
with torch.no_grad():
|
||||
parameters = [p for p in parameters if p.requires_grad]
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
s_param.sub_(one_minus_decay * (s_param - param))
|
||||
|
||||
def copy_to(self, parameters):
|
||||
"""
|
||||
Copy current parameters into given collection of parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored moving averages.
|
||||
"""
|
||||
parameters = [p for p in parameters if p.requires_grad]
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
if param.requires_grad:
|
||||
param.data.copy_(s_param.data)
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the original optimization process.
|
||||
Store the parameters before the `copy_to` method.
|
||||
After validation (or model saving), use this to restore the former parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
|
||||
def state_dict(self):
|
||||
return dict(decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.decay = state_dict['decay']
|
||||
self.num_updates = state_dict['num_updates']
|
||||
self.shadow_params = state_dict['shadow_params']
|
82
NAS-Bench-201/models/gnns.py
Normal file
82
NAS-Bench-201/models/gnns.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from .trans_layers import *
|
||||
|
||||
|
||||
class pos_gnn(nn.Module):
|
||||
def __init__(self, act, x_ch, pos_ch, out_ch, max_node, graph_layer, n_layers=3, edge_dim=None, heads=4,
|
||||
temb_dim=None, dropout=0.1, attn_clamp=False):
|
||||
super().__init__()
|
||||
self.out_ch = out_ch
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.act = act
|
||||
self.max_node = max_node
|
||||
self.n_layers = n_layers
|
||||
|
||||
if temb_dim is not None:
|
||||
self.Dense_node0 = nn.Linear(temb_dim, x_ch)
|
||||
self.Dense_node1 = nn.Linear(temb_dim, pos_ch)
|
||||
self.Dense_edge0 = nn.Linear(temb_dim, edge_dim)
|
||||
self.Dense_edge1 = nn.Linear(temb_dim, edge_dim)
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.edge_convs = nn.ModuleList()
|
||||
self.edge_layer = nn.Linear(edge_dim * 2 + self.out_ch, edge_dim)
|
||||
|
||||
for i in range(n_layers):
|
||||
if i == 0:
|
||||
self.convs.append(eval(graph_layer)(x_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2,
|
||||
act=act, attn_clamp=attn_clamp))
|
||||
else:
|
||||
self.convs.append(eval(graph_layer)
|
||||
(self.out_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2, act=act,
|
||||
attn_clamp=attn_clamp))
|
||||
self.edge_convs.append(nn.Linear(self.out_ch, edge_dim*2))
|
||||
|
||||
def forward(self, x_degree, x_pos, edge_index, dense_ori, dense_spd, dense_index, temb=None):
|
||||
"""
|
||||
Args:
|
||||
x_degree: node degree feature [B*N, x_ch]
|
||||
x_pos: node rwpe feature [B*N, pos_ch]
|
||||
edge_index: [2, edge_length]
|
||||
dense_ori: edge feature [B, N, N, nf//2]
|
||||
dense_spd: edge shortest path distance feature [B, N, N, nf//2] # Do we need this part? # TODO
|
||||
dense_index
|
||||
temb: [B, temb_dim]
|
||||
"""
|
||||
|
||||
B, N, _, _ = dense_ori.shape
|
||||
|
||||
if temb is not None:
|
||||
dense_ori = dense_ori + self.Dense_edge0(self.act(temb))[:, None, None, :]
|
||||
dense_spd = dense_spd + self.Dense_edge1(self.act(temb))[:, None, None, :]
|
||||
|
||||
temb = temb.unsqueeze(1).repeat(1, self.max_node, 1)
|
||||
temb = temb.reshape(-1, temb.shape[-1])
|
||||
x_degree = x_degree + self.Dense_node0(self.act(temb))
|
||||
x_pos = x_pos + self.Dense_node1(self.act(temb))
|
||||
|
||||
dense_edge = torch.cat([dense_ori, dense_spd], dim=-1)
|
||||
|
||||
ori_edge_attr = dense_edge
|
||||
h = x_degree
|
||||
h_pos = x_pos
|
||||
|
||||
for i_layer in range(self.n_layers):
|
||||
h_edge = dense_edge[dense_index]
|
||||
# update node feature
|
||||
h, h_pos = self.convs[i_layer](h, h_pos, edge_index, h_edge)
|
||||
h = self.Dropout_0(h)
|
||||
h_pos = self.Dropout_0(h_pos)
|
||||
|
||||
# update dense edge feature
|
||||
h_dense_node = h.reshape(B, N, -1)
|
||||
cur_edge_attr = h_dense_node.unsqueeze(1) + h_dense_node.unsqueeze(2) # [B, N, N, nf]
|
||||
dense_edge = (dense_edge + self.act(self.edge_convs[i_layer](cur_edge_attr))) / math.sqrt(2.)
|
||||
dense_edge = self.Dropout_0(dense_edge)
|
||||
|
||||
# Concat edge attribute
|
||||
h_dense_edge = torch.cat([ori_edge_attr, dense_edge], dim=-1)
|
||||
h_dense_edge = self.edge_layer(h_dense_edge).permute(0, 3, 1, 2)
|
||||
|
||||
return h_dense_edge
|
44
NAS-Bench-201/models/layers.py
Normal file
44
NAS-Bench-201/models/layers.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Common layers"""
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, padding=0):
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
|
||||
padding=padding)
|
||||
return conv
|
||||
|
||||
|
||||
# from DDPM
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
38
NAS-Bench-201/models/set_encoder/setenc_models.py
Normal file
38
NAS-Bench-201/models/set_encoder/setenc_models.py
Normal file
@@ -0,0 +1,38 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from .setenc_modules import *
|
||||
|
||||
|
||||
class SetPool(nn.Module):
|
||||
def __init__(self, dim_input, num_outputs, dim_output,
|
||||
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
|
||||
super(SetPool, self).__init__()
|
||||
if 'sab' in mode: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
|
||||
else: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
|
||||
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
|
||||
if 'PF' in mode: # [32, 1, 501]
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
elif 'P' in mode:
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
|
||||
else: # torch.Size([32, 1, 501])
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
# "", sm, sab, sabsm
|
||||
|
||||
def forward(self, X):
|
||||
x1 = self.enc(X)
|
||||
x2 = self.dec(x1)
|
||||
return x2
|
67
NAS-Bench-201/models/set_encoder/setenc_modules.py
Normal file
67
NAS-Bench-201/models/set_encoder/setenc_modules.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#####################################################################################
|
||||
# Copyright (c) Juho Lee SetTransformer, ICML 2019 [GitHub set_transformer]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class MAB(nn.Module):
|
||||
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
|
||||
super(MAB, self).__init__()
|
||||
self.dim_V = dim_V
|
||||
self.num_heads = num_heads
|
||||
self.fc_q = nn.Linear(dim_Q, dim_V)
|
||||
self.fc_k = nn.Linear(dim_K, dim_V)
|
||||
self.fc_v = nn.Linear(dim_K, dim_V)
|
||||
if ln:
|
||||
self.ln0 = nn.LayerNorm(dim_V)
|
||||
self.ln1 = nn.LayerNorm(dim_V)
|
||||
self.fc_o = nn.Linear(dim_V, dim_V)
|
||||
|
||||
def forward(self, Q, K):
|
||||
Q = self.fc_q(Q)
|
||||
K, V = self.fc_k(K), self.fc_v(K)
|
||||
|
||||
dim_split = self.dim_V // self.num_heads
|
||||
Q_ = torch.cat(Q.split(dim_split, 2), 0)
|
||||
K_ = torch.cat(K.split(dim_split, 2), 0)
|
||||
V_ = torch.cat(V.split(dim_split, 2), 0)
|
||||
|
||||
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
|
||||
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
|
||||
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
|
||||
O = O + F.relu(self.fc_o(O))
|
||||
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
|
||||
return O
|
||||
|
||||
class SAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, ln=False):
|
||||
super(SAB, self).__init__()
|
||||
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(X, X)
|
||||
|
||||
class ISAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
|
||||
super(ISAB, self).__init__()
|
||||
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
|
||||
nn.init.xavier_uniform_(self.I)
|
||||
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
|
||||
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
|
||||
return self.mab1(X, H)
|
||||
|
||||
class PMA(nn.Module):
|
||||
def __init__(self, dim, num_heads, num_seeds, ln=False):
|
||||
super(PMA, self).__init__()
|
||||
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
|
||||
nn.init.xavier_uniform_(self.S)
|
||||
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(self.S.repeat(X.size(0), 1, 1), X)
|
144
NAS-Bench-201/models/trans_layers.py
Normal file
144
NAS-Bench-201/models/trans_layers.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import math
|
||||
from typing import Union, Tuple, Optional
|
||||
from torch_geometric.typing import PairTensor, Adj, OptTensor
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Linear
|
||||
from torch_scatter import scatter
|
||||
from torch_geometric.nn.conv import MessagePassing
|
||||
from torch_geometric.utils import softmax
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PosTransLayer(MessagePassing):
|
||||
"""Involving the edge feature and updating position feature. Multiply Msg."""
|
||||
|
||||
_alpha: OptTensor
|
||||
|
||||
def __init__(self, x_channels: int, pos_channels: int, out_channels: int,
|
||||
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
|
||||
bias: bool = True, act=None, attn_clamp: bool = False, **kwargs):
|
||||
kwargs.setdefault('aggr', 'add')
|
||||
super(PosTransLayer, self).__init__(node_dim=0, **kwargs)
|
||||
|
||||
self.x_channels = x_channels
|
||||
self.pos_channels = pos_channels
|
||||
self.in_channels = in_channels = x_channels + pos_channels
|
||||
self.out_channels = out_channels
|
||||
self.heads = heads
|
||||
self.dropout = dropout
|
||||
self.edge_dim = edge_dim
|
||||
self.attn_clamp = attn_clamp
|
||||
|
||||
if act is None:
|
||||
self.act = nn.LeakyReLU(negative_slope=0.2)
|
||||
else:
|
||||
self.act = act
|
||||
|
||||
self.lin_key = Linear(in_channels, heads * out_channels)
|
||||
self.lin_query = Linear(in_channels, heads * out_channels)
|
||||
self.lin_value = Linear(in_channels, heads * out_channels)
|
||||
|
||||
self.lin_edge0 = Linear(edge_dim, heads * out_channels, bias=False)
|
||||
self.lin_edge1 = Linear(edge_dim, heads * out_channels, bias=False)
|
||||
|
||||
self.lin_pos = Linear(heads * out_channels, pos_channels, bias=False)
|
||||
|
||||
self.lin_skip = Linear(x_channels, heads * out_channels, bias=bias)
|
||||
self.norm1 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
|
||||
num_channels=heads * out_channels, eps=1e-6)
|
||||
self.norm2 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
|
||||
num_channels=heads * out_channels, eps=1e-6)
|
||||
# FFN
|
||||
self.FFN = nn.Sequential(Linear(heads * out_channels, heads * out_channels),
|
||||
self.act,
|
||||
Linear(heads * out_channels, heads * out_channels))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.lin_key.reset_parameters()
|
||||
self.lin_query.reset_parameters()
|
||||
self.lin_value.reset_parameters()
|
||||
self.lin_skip.reset_parameters()
|
||||
self.lin_edge0.reset_parameters()
|
||||
self.lin_edge1.reset_parameters()
|
||||
self.lin_pos.reset_parameters()
|
||||
|
||||
def forward(self, x: OptTensor,
|
||||
pos: Tensor,
|
||||
edge_index: Adj,
|
||||
edge_attr: OptTensor = None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
""""""
|
||||
|
||||
H, C = self.heads, self.out_channels
|
||||
|
||||
x_feat = torch.cat([x, pos], -1)
|
||||
query = self.lin_query(x_feat).view(-1, H, C)
|
||||
key = self.lin_key(x_feat).view(-1, H, C)
|
||||
value = self.lin_value(x_feat).view(-1, H, C)
|
||||
|
||||
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
|
||||
out_x, out_pos = self.propagate(edge_index, query=query, key=key, value=value, pos=pos, edge_attr=edge_attr,
|
||||
size=None)
|
||||
|
||||
out_x = out_x.view(-1, self.heads * self.out_channels)
|
||||
|
||||
# skip connection for x
|
||||
x_r = self.lin_skip(x)
|
||||
out_x = (out_x + x_r) / math.sqrt(2)
|
||||
out_x = self.norm1(out_x)
|
||||
|
||||
# FFN
|
||||
out_x = (out_x + self.FFN(out_x)) / math.sqrt(2)
|
||||
out_x = self.norm2(out_x)
|
||||
|
||||
# skip connection for pos
|
||||
out_pos = pos + torch.tanh(pos + out_pos)
|
||||
|
||||
return out_x, out_pos
|
||||
|
||||
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
|
||||
pos_j: Tensor,
|
||||
edge_attr: OptTensor,
|
||||
index: Tensor, ptr: OptTensor,
|
||||
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
edge_attn = self.lin_edge0(edge_attr).view(-1, self.heads, self.out_channels)
|
||||
alpha = (query_i * key_j * edge_attn).sum(dim=-1) / math.sqrt(self.out_channels)
|
||||
if self.attn_clamp:
|
||||
alpha = alpha.clamp(min=-5., max=5.)
|
||||
|
||||
alpha = softmax(alpha, index, ptr, size_i)
|
||||
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
|
||||
|
||||
# node feature message
|
||||
msg = value_j
|
||||
msg = msg * self.lin_edge1(edge_attr).view(-1, self.heads, self.out_channels)
|
||||
msg = msg * alpha.view(-1, self.heads, 1)
|
||||
|
||||
# node position message
|
||||
pos_msg = pos_j * self.lin_pos(msg.reshape(-1, self.heads * self.out_channels))
|
||||
|
||||
return msg, pos_msg
|
||||
|
||||
def aggregate(self, inputs: Tuple[Tensor, Tensor], index: Tensor,
|
||||
ptr: Optional[Tensor] = None,
|
||||
dim_size: Optional[int] = None) -> Tuple[Tensor, Tensor]:
|
||||
if ptr is not None:
|
||||
raise NotImplementedError("Not implement Ptr in aggregate")
|
||||
else:
|
||||
return (scatter(inputs[0], index, 0, dim_size=dim_size, reduce=self.aggr),
|
||||
scatter(inputs[1], index, 0, dim_size=dim_size, reduce="mean"))
|
||||
|
||||
def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
|
||||
return inputs
|
||||
|
||||
def __repr__(self):
|
||||
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
|
||||
self.in_channels,
|
||||
self.out_channels, self.heads)
|
255
NAS-Bench-201/models/transformer.py
Executable file
255
NAS-Bench-201/models/transformer.py
Executable file
@@ -0,0 +1,255 @@
|
||||
from copy import deepcopy as cp
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def clones(module, N):
|
||||
return nn.ModuleList([cp(module) for _ in range(N)])
|
||||
|
||||
|
||||
def attention(query, key, value, mask = None, dropout = None):
|
||||
d_k = query.size(-1)
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e9)
|
||||
attn = F.softmax(scores, dim = -1)
|
||||
if dropout is not None:
|
||||
attn = dropout(attn)
|
||||
return torch.matmul(attn, value), attn
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
self.d_model = config.d_model
|
||||
self.n_head = config.n_head
|
||||
self.d_k = config.d_model // config.n_head
|
||||
|
||||
self.linears = clones(nn.Linear(self.d_model, self.d_model), 4)
|
||||
self.dropout = nn.Dropout(p=config.dropout)
|
||||
|
||||
def forward(self, query, key, value, mask = None):
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1)
|
||||
batch_size = query.size(0)
|
||||
|
||||
query, key , value = [l(x).view(batch_size, -1, self.n_head, self.d_k).transpose(1,2) for l, x in zip(self.linears, (query, key, value))]
|
||||
x, attn = attention(query, key, value, mask = mask, dropout = self.dropout)
|
||||
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)
|
||||
return self.linears[3](x), attn
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
|
||||
self.w_1 = nn.Linear(config.d_model, config.d_ff)
|
||||
self.w_2 = nn.Linear(config.d_ff, config.d_model)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
||||
|
||||
|
||||
class PositionwiseFeedForwardLast(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PositionwiseFeedForwardLast, self).__init__()
|
||||
|
||||
self.w_1 = nn.Linear(config.d_model, config.d_ff)
|
||||
self.w_2 = nn.Linear(config.d_ff, config.n_vocab)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
||||
|
||||
|
||||
class SelfAttentionBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(SelfAttentionBlock, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.attn = MultiHeadAttention(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x, mask):
|
||||
x_ = self.norm(x)
|
||||
x_ , attn = self.attn(x_, x_, x_, mask)
|
||||
return self.dropout(x_) + x, attn
|
||||
|
||||
|
||||
class SourceAttentionBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(SourceAttentionBlock, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.attn = MultiHeadAttention(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x, m, mask):
|
||||
x_ = self.norm(x)
|
||||
x_, attn = self.attn(x_, m, m, mask)
|
||||
return self.dropout(x_) + x, attn
|
||||
|
||||
|
||||
class FeedForwardBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(FeedForwardBlock, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.feed_forward = PositionwiseFeedForward(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x_ = self.norm(x)
|
||||
x_ = self.feed_forward(x_)
|
||||
return self.dropout(x_) + x
|
||||
|
||||
|
||||
class FeedForwardBlockLast(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(FeedForwardBlockLast, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.feed_forward = PositionwiseFeedForwardLast(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
# Only for the last layer
|
||||
self.proj_fc = nn.Linear(config.d_model, config.n_vocab)
|
||||
|
||||
def forward(self, x):
|
||||
x_ = self.norm(x)
|
||||
x_ = self.feed_forward(x_)
|
||||
return self.dropout(x_) + self.proj_fc(x)
|
||||
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(EncoderBlock, self).__init__()
|
||||
self.self_attn = SelfAttentionBlock(config)
|
||||
self.feed_forward = FeedForwardBlock(config)
|
||||
|
||||
def forward(self, x, mask):
|
||||
x, attn = self.self_attn(x, mask)
|
||||
x = self.feed_forward(x)
|
||||
return x, attn
|
||||
|
||||
|
||||
class EncoderBlockLast(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(EncoderBlockLast, self).__init__()
|
||||
self.self_attn = SelfAttentionBlock(config)
|
||||
self.feed_forward = FeedForwardBlockLast(config)
|
||||
|
||||
def forward(self, x, mask):
|
||||
x, attn = self.self_attn(x, mask)
|
||||
x = self.feed_forward(x)
|
||||
return x, attn
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(DecoderBlock, self).__init__()
|
||||
|
||||
self.self_attn = SelfAttentionBlock(config)
|
||||
self.src_attn = SourceAttentionBlock(config)
|
||||
self.feed_forward = FeedForwardBlock(config)
|
||||
|
||||
def forward(self, x, m, src_mask, tgt_mask):
|
||||
x, attn_tgt = self.self_attn(x, tgt_mask)
|
||||
x, attn_src = self.src_attn(x, m, src_mask)
|
||||
x = self.feed_forward(x)
|
||||
return x, attn_src, attn_tgt
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Encoder, self).__init__()
|
||||
self.layers = clones(EncoderBlock(config), config.n_layers)
|
||||
self.norms = clones(nn.LayerNorm(config.d_model), config.n_layers)
|
||||
|
||||
def forward(self, x, mask):
|
||||
outputs = []
|
||||
attns = []
|
||||
for layer, norm in zip(self.layers, self.norms):
|
||||
x, attn = layer(x, mask)
|
||||
outputs.append(norm(x))
|
||||
attns.append(attn)
|
||||
return outputs[-1], outputs, attns
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
|
||||
p2e = torch.zeros(config.max_len, config.d_model)
|
||||
position = torch.arange(0.0, config.max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0.0, config.d_model, 2) * (- math.log(10000.0) / config.d_model))
|
||||
p2e[:, 0::2] = torch.sin(position * div_term)
|
||||
p2e[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
self.register_buffer('p2e', p2e)
|
||||
|
||||
def forward(self, x):
|
||||
shp = x.size()
|
||||
with torch.no_grad():
|
||||
emb = torch.index_select(self.p2e, 0, x.view(-1)).view(shp + (-1,))
|
||||
return emb
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Transformer, self).__init__()
|
||||
self.p2e = PositionalEmbedding(config)
|
||||
self.encoder = Encoder(config)
|
||||
|
||||
def forward(self, input_emb, position_ids, attention_mask):
|
||||
# position embedding projection
|
||||
projection = self.p2e(position_ids) + input_emb
|
||||
return self.encoder(projection, attention_mask)
|
||||
|
||||
|
||||
class TokenTypeEmbedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(TokenTypeEmbedding, self).__init__()
|
||||
self.t2e = nn.Embedding(config.n_token_type, config.d_model)
|
||||
self.d_model = config.d_model
|
||||
|
||||
def forward(self, x):
|
||||
return self.t2e(x) * math.sqrt(self.d_model)
|
||||
|
||||
|
||||
class SemanticEmbedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(SemanticEmbedding, self).__init__()
|
||||
self.d_model = config.d_model
|
||||
self.fc = nn.Linear(config.n_vocab, config.d_model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(x) * math.sqrt(self.d_model)
|
||||
|
||||
|
||||
class Embeddings(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Embeddings, self).__init__()
|
||||
|
||||
self.w2e = SemanticEmbedding(config)
|
||||
self.p2e = PositionalEmbedding(config)
|
||||
self.t2e = TokenTypeEmbedding(config)
|
||||
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, input_ids, position_ids = None, token_type_ids = None):
|
||||
if position_ids is None:
|
||||
batch_size, length = input_ids.size()
|
||||
with torch.no_grad():
|
||||
position_ids = torch.arange(0, length).repeat(batch_size, 1)
|
||||
if torch.cuda.is_available():
|
||||
position_ids = position_ids.cuda(device=input_ids.device)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
embeddings = self.w2e(input_ids) + self.p2e(position_ids) + self.t2e(token_type_ids)
|
||||
return self.dropout(embeddings)
|
289
NAS-Bench-201/models/utils.py
Normal file
289
NAS-Bench-201/models/utils.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import sde_lib
|
||||
|
||||
_MODELS = {}
|
||||
|
||||
|
||||
def register_model(cls=None, *, name=None):
|
||||
"""A decorator for registering model classes."""
|
||||
|
||||
def _register(cls):
|
||||
if name is None:
|
||||
local_name = cls.__name__
|
||||
else:
|
||||
local_name = name
|
||||
if local_name in _MODELS:
|
||||
raise ValueError(
|
||||
f'Already registered model with name: {local_name}')
|
||||
_MODELS[local_name] = cls
|
||||
return cls
|
||||
|
||||
if cls is None:
|
||||
return _register
|
||||
else:
|
||||
return _register(cls)
|
||||
|
||||
|
||||
def get_model(name):
|
||||
return _MODELS[name]
|
||||
|
||||
|
||||
def create_model(config):
|
||||
"""Create the model."""
|
||||
model_name = config.model.name
|
||||
model = get_model(model_name)(config)
|
||||
model = model.to(config.device)
|
||||
return model
|
||||
|
||||
|
||||
def get_model_fn(model, train=False):
|
||||
"""Create a function to give the output of the score-based model.
|
||||
|
||||
Args:
|
||||
model: The score model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
|
||||
Returns:
|
||||
A model function.
|
||||
"""
|
||||
|
||||
def model_fn(x, labels, *args, **kwargs):
|
||||
"""Compute the output of the score-based model.
|
||||
|
||||
Args:
|
||||
x: A mini-batch of input data (Adjacency matrices).
|
||||
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
||||
for different models.
|
||||
mask: Mask for adjacency matrices.
|
||||
|
||||
Returns:
|
||||
A tuple of (model output, new mutable states)
|
||||
"""
|
||||
if not train:
|
||||
model.eval()
|
||||
return model(x, labels, *args, **kwargs)
|
||||
else:
|
||||
model.train()
|
||||
return model(x, labels, *args, **kwargs)
|
||||
|
||||
return model_fn
|
||||
|
||||
|
||||
def get_score_fn(sde, model, train=False, continuous=False):
|
||||
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
model: A score model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
||||
|
||||
Returns:
|
||||
A score function.
|
||||
"""
|
||||
model_fn = get_model_fn(model, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def score_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
||||
labels.long()]
|
||||
|
||||
score = -score / std[:, None, None]
|
||||
return score
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def score_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
return score
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
|
||||
return score_fn
|
||||
|
||||
|
||||
def get_classifier_grad_fn(sde, classifier, train=False, continuous=False,
|
||||
regress=True, labels='max'):
|
||||
logit_fn = get_logit_fn(sde, classifier, train, continuous)
|
||||
|
||||
def classifier_grad_fn(x, t, *args, **kwargs):
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
if regress:
|
||||
assert labels in ['max', 'min']
|
||||
logit = logit_fn(x_in, t, *args, **kwargs)
|
||||
if labels == 'max':
|
||||
prob = logit.sum()
|
||||
elif labels == 'min':
|
||||
prob = -logit.sum()
|
||||
else:
|
||||
logit = logit_fn(x_in, t, *args, **kwargs)
|
||||
log_prob = F.log_softmax(logit, dim=-1)
|
||||
prob = log_prob[range(len(logit)), labels.view(-1)].sum()
|
||||
classifier_grad = torch.autograd.grad(prob, x_in)[0]
|
||||
return classifier_grad
|
||||
|
||||
return classifier_grad_fn
|
||||
|
||||
|
||||
def get_logit_fn(sde, classifier, train=False, continuous=False):
|
||||
classifier_fn = get_model_fn(classifier, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def logit_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
return logit
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def logit_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
return logit
|
||||
|
||||
return logit_fn
|
||||
|
||||
|
||||
def get_predictor_fn(sde, model, train=False, continuous=False):
|
||||
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
model: A predictor model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
||||
|
||||
Returns:
|
||||
A score function.
|
||||
"""
|
||||
model_fn = get_model_fn(model, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def predictor_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
||||
labels.long()]
|
||||
|
||||
return pred
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def predictor_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
return pred
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
|
||||
return predictor_fn
|
||||
|
||||
|
||||
def to_flattened_numpy(x):
|
||||
"""Flatten a torch tensor `x` and convert it to numpy."""
|
||||
return x.detach().cpu().numpy().reshape((-1,))
|
||||
|
||||
|
||||
def from_flattened_numpy(x, shape):
|
||||
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
||||
return torch.from_numpy(x.reshape(shape))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mask_adj2node(adj_mask):
|
||||
"""Convert batched adjacency mask matrices to batched node mask matrices.
|
||||
|
||||
Args:
|
||||
adj_mask: [B, N, N] Batched adjacency mask matrices without self-loop edge.
|
||||
|
||||
Output:
|
||||
node_mask: [B, N] Batched node mask matrices indicating the valid nodes.
|
||||
"""
|
||||
|
||||
batch_size, max_num_nodes, _ = adj_mask.shape
|
||||
|
||||
node_mask = adj_mask[:, 0, :].clone()
|
||||
node_mask[:, 0] = 1
|
||||
|
||||
return node_mask
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_rw_feat(k_step, dense_adj):
|
||||
"""Compute k_step Random Walk for given dense adjacency matrix."""
|
||||
|
||||
rw_list = []
|
||||
deg = dense_adj.sum(-1, keepdims=True)
|
||||
AD = dense_adj / (deg + 1e-8)
|
||||
rw_list.append(AD)
|
||||
|
||||
for _ in range(k_step):
|
||||
rw = torch.bmm(rw_list[-1], AD)
|
||||
rw_list.append(rw)
|
||||
rw_map = torch.stack(rw_list[1:], dim=1) # [B, k_step, N, N]
|
||||
|
||||
rw_landing = torch.diagonal(
|
||||
rw_map, offset=0, dim1=2, dim2=3) # [B, k_step, N]
|
||||
rw_landing = rw_landing.permute(0, 2, 1) # [B, N, rw_depth]
|
||||
|
||||
# get the shortest path distance indices
|
||||
tmp_rw = rw_map.sort(dim=1)[0]
|
||||
spd_ind = (tmp_rw <= 0).sum(dim=1) # [B, N, N]
|
||||
|
||||
spd_onehot = torch.nn.functional.one_hot(
|
||||
spd_ind, num_classes=k_step+1).to(torch.float)
|
||||
spd_onehot = spd_onehot.permute(0, 3, 1, 2) # [B, kstep, N, N]
|
||||
|
||||
return rw_landing, spd_onehot
|
Reference in New Issue
Block a user