first commit

This commit is contained in:
CownowAn
2024-03-15 14:38:51 +00:00
commit bc2ed1304f
321 changed files with 44802 additions and 0 deletions

View File

View File

@@ -0,0 +1,391 @@
# Most of this code is from https://github.com/AIoT-MLSys-Lab/CATE.git
# which was authored by Shen Yan, Kaiqiang Song, Fei Liu, Mi Zhang, 2021
import torch.nn as nn
import torch
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import utils
from .transformer import Encoder, SemanticEmbedding
from .set_encoder.setenc_models import SetPool
class MLP(torch.nn.Module):
def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu):
"""
num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
input_dim: dimensionality of input features
hidden_dim: dimensionality of hidden units at ALL layers
output_dim: number of classes for prediction
num_classes: the number of classes of input, to be treated with different gains and biases,
(see the definition of class `ConditionalLayer1d`)
"""
super(MLP, self).__init__()
self.linear_or_not = True # default is linear model
self.num_layers = num_layers
self.use_bn = use_bn
self.activate_func = activate_func
if num_layers < 1:
raise ValueError("number of layers should be positive!")
elif num_layers == 1:
# Linear model
self.linear = torch.nn.Linear(input_dim, output_dim)
else:
# Multi-layer model
self.linear_or_not = False
self.linears = torch.nn.ModuleList()
self.linears.append(torch.nn.Linear(input_dim, hidden_dim))
for layer in range(num_layers - 2):
self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim))
self.linears.append(torch.nn.Linear(hidden_dim, output_dim))
if self.use_bn:
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layers - 1):
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
def forward(self, x):
"""
:param x: [num_classes * batch_size, N, F_i], batch of node features
note that in self.cond_layers[layer],
`x` is splited into `num_classes` groups in dim=0,
and then treated with different gains and biases
"""
if self.linear_or_not:
# If linear model
return self.linear(x)
else:
# If MLP
h = x
for layer in range(self.num_layers - 1):
h = self.linears[layer](h)
if self.use_bn:
h = self.batch_norms[layer](h)
h = self.activate_func(h)
return self.linears[self.num_layers - 1](h)
""" Transformer Encoder """
class GraphEncoder(nn.Module):
def __init__(self, config):
super(GraphEncoder, self).__init__()
# Forward Transformers
self.encoder_f = Encoder(config)
def forward(self, x, mask):
h_f, hs_f, attns_f = self.encoder_f(x, mask)
h = torch.cat(hs_f, dim=-1)
return h
@staticmethod
def get_embeddings(h_x):
h_x = h_x.cpu()
return h_x[:, -1]
class CLSHead(nn.Module):
def __init__(self, config, init_weights=None):
super(CLSHead, self).__init__()
self.layer_1 = nn.Linear(config.d_model, config.d_model)
self.dropout = nn.Dropout(p=config.dropout)
self.layer_2 = nn.Linear(config.d_model, config.n_vocab)
if init_weights is not None:
self.layer_2.weight = init_weights
def forward(self, x):
x = self.dropout(torch.tanh(self.layer_1(x)))
return F.log_softmax(self.layer_2(x), dim=-1)
@utils.register_model(name='CATE')
class CATE(nn.Module):
def __init__(self, config):
super(CATE, self).__init__()
# Shared Embedding Layer
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
self.dropout_op = nn.Dropout(p=config.model.dropout)
self.d_model = config.model.graph_encoder.d_model
self.act = act = get_act(config)
# Time
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
# 2 GraphEncoder for X and Y
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
use_bn=False, activate_func=F.elu)
if 'pos_enc_type' in config.model:
self.pos_enc_type = config.model.pos_enc_type
if self.pos_enc_type == 1:
raise NotImplementedError
elif self.pos_enc_type == 2:
if config.data.name == 'NASBench201':
self.pos_encoder = PositionalEncoding_Cell(d_model=self.d_model, max_len=config.data.max_node)
else:
self.pos_encoder = PositionalEncoding_StageWise(d_model=self.d_model, max_len=config.data.max_node)
elif self.pos_enc_type == 3:
raise NotImplementedError
else:
self.pos_encoder = None
else:
self.pos_encoder = None
def forward(self, X, time_cond, maskX):
emb_x = self.dropout_op(self.opEmb(X))
if self.pos_encoder is not None:
emb_p = self.pos_encoder(emb_x)
emb_x = emb_x + emb_p
# Time embedding
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)
emb_t = self.timeEmb1(emb_t) # [32, 512]
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
emb_t = emb_t.unsqueeze(1)
emb = emb_x + emb_t
h_x = self.graph_encoder(emb, maskX)
h_x = self.final(h_x)
return h_x
@utils.register_model(name='PredictorCATE')
class PredictorCATE(nn.Module):
def __init__(self, config):
super(PredictorCATE, self).__init__()
# Shared Embedding Layer
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
self.dropout_op = nn.Dropout(p=config.model.dropout)
self.d_model = config.model.graph_encoder.d_model
self.act = act = get_act(config)
# Time
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
# 2 GraphEncoder for X and Y
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
use_bn=False, activate_func=F.elu)
self.rdim = int(config.data.max_node * config.data.n_vocab)
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=1,
use_bn=False, activate_func=F.elu)
def forward(self, X, time_cond, maskX):
emb_x = self.dropout_op(self.opEmb(X))
# Time embedding
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)
emb_t = self.timeEmb1(emb_t)
emb_t = self.timeEmb2(self.act(emb_t))
emb_t = emb_t.unsqueeze(1)
emb = emb_x + emb_t
h_x = self.graph_encoder(emb, maskX)
h_x = self.final(h_x)
h_x = h_x.reshape(h_x.size(0), -1)
h_x = self.regeress(h_x)
return h_x
class PositionalEncoding_StageWise(nn.Module):
def __init__(self, d_model, max_len):
super(PositionalEncoding_StageWise, self).__init__()
NUM_STAGE = 5
max_len = int(max_len / NUM_STAGE)
self.encoding = torch.zeros(max_len, d_model)
self.encoding.requires_grad = False
pos = torch.arange(0, max_len)
pos = pos.float().unsqueeze(dim=1)
_2i = torch.arange(0, d_model, step=2).float()
self.encoding[:, ::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
self.encoding = torch.cat([self.encoding] * NUM_STAGE, dim=0)
def forward(self, x):
batch_size, seq_len, _ = x.size()
return self.encoding[:seq_len, :].to(x.device)
class PositionalEncoding_Cell(nn.Module):
def __init__(self, d_model, max_len):
super(PositionalEncoding_Cell, self).__init__()
NUM_STAGE = 1
max_len = int(max_len / NUM_STAGE)
self.encoding = torch.zeros(max_len, d_model)
self.encoding.requires_grad = False
pos = torch.arange(0, max_len)
pos = pos.float().unsqueeze(dim=1)
_2i = torch.arange(0, d_model, step=2).float()
self.encoding[:, ::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
self.encoding = torch.cat([self.encoding] * NUM_STAGE, dim=0)
def forward(self, x):
batch_size, seq_len, _ = x.size()
return self.encoding[:seq_len, :].to(x.device)
@utils.register_model(name='MetaPredictorCATE')
class MetaPredictorCATE(nn.Module):
def __init__(self, config):
super(MetaPredictorCATE, self).__init__()
self.input_type= config.model.input_type
self.hs = config.model.hs
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
self.dropout_op = nn.Dropout(p=config.model.dropout)
self.d_model = config.model.graph_encoder.d_model
self.act = act = get_act(config)
# Time
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
use_bn=False, activate_func=F.elu)
self.rdim = int(config.data.max_node * config.data.n_vocab)
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=2*self.rdim,
use_bn=False, activate_func=F.elu)
# Set
self.nz = config.model.nz
self.num_sample = config.model.num_sample
self.intra_setpool = SetPool(dim_input=512,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.inter_setpool = SetPool(dim_input=self.nz,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.set_fc = nn.Sequential(
nn.Linear(512, self.nz),
nn.ReLU())
input_dim = 0
if 'D' in self.input_type:
input_dim += self.nz
if 'A' in self.input_type:
input_dim += 2*self.rdim
self.pred_fc = nn.Sequential(
nn.Linear(input_dim, self.hs),
nn.Tanh(),
nn.Linear(self.hs, 1)
)
self.sample_state = False
self.D_mu = None
def arch_encode(self, X, time_cond, maskX):
emb_x = self.dropout_op(self.opEmb(X))
# Time embedding
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
emb_t = self.timeEmb1(emb_t) # [32, 512]
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
emb_t = emb_t.unsqueeze(1)
emb = emb_x + emb_t
h_x = self.graph_encoder(emb, maskX)
h_x = self.final(h_x)
h_x = h_x.reshape(h_x.size(0), -1)
h_x = self.regeress(h_x)
return h_x
def set_encode(self, task):
proto_batch = []
for x in task:
cls_protos = self.intra_setpool(
x.view(-1, self.num_sample, 512)).squeeze(1)
proto_batch.append(
self.inter_setpool(cls_protos.unsqueeze(0)))
v = torch.stack(proto_batch).squeeze()
return v
def predict(self, D_mu, A_mu):
input_vec = []
if 'D' in self.input_type:
input_vec.append(D_mu)
if 'A' in self.input_type:
input_vec.append(A_mu)
input_vec = torch.cat(input_vec, dim=1)
return self.pred_fc(input_vec)
def forward(self, X, time_cond, maskX, task):
if self.sample_state:
if self.D_mu is None:
self.D_mu = self.set_encode(task)
D_mu = self.D_mu
else:
D_mu = self.set_encode(task)
A_mu = self.arch_encode(X, time_cond, maskX)
y_pred = self.predict(D_mu, A_mu)
return y_pred
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')

View File

@@ -0,0 +1,125 @@
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
# which was authored by Yuge Zhang, 2020
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from . import utils
from models.cate import PositionalEncoding_StageWise
def normalize_adj(adj):
# Row-normalize matrix
last_dim = adj.size(-1)
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
return torch.div(adj, rowsum)
def graph_pooling(inputs, num_vertices):
num_vertices = num_vertices.to(inputs.device)
out = inputs.sum(1)
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
class DirectedGraphConvolution(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
self.dropout = nn.Dropout(0.1)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight1.data)
nn.init.xavier_uniform_(self.weight2.data)
def forward(self, inputs, adj):
inputs = inputs.to(self.weight1.device)
adj = adj.to(self.weight1.device)
norm_adj = normalize_adj(adj)
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
out = (output1 + output2) / 2
out = self.dropout(out)
return out
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
@utils.register_model(name='NeuralPredictor')
class NeuralPredictor(nn.Module):
def __init__(self, config):
super().__init__()
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
config.model.graph_encoder.gcn_hidden)
for i in range(config.model.graph_encoder.gcn_layers)]
self.gcn = nn.ModuleList(self.gcn)
self.dropout = nn.Dropout(0.1)
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
self.fc2 = nn.Linear(config.model.graph_encoder.linear_hidden, 1, bias=False)
# Time
self.d_model = config.model.graph_encoder.gcn_hidden
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
self.act = act = get_act(config)
def forward(self, X, time_cond, maskX):
out = X
adj = maskX
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device) # 20
gs = adj.size(1) # graph node number
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
emb_t = self.timeEmb1(emb_t)
emb_t = self.timeEmb2(self.act(emb_t)) # (5, 144)
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device)) # assuming diagonal is not 1
for layer in self.gcn:
out = layer(out, adj_with_diag)
out = graph_pooling(out, numv)
# time
out = out + emb_t
out = self.fc1(out)
out = self.dropout(out)
out = self.fc2(out)
return out
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')

View File

@@ -0,0 +1,190 @@
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
# which was authored by Yuge Zhang, 2020
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from . import utils
from .set_encoder.setenc_models import SetPool
def normalize_adj(adj):
# Row-normalize matrix
last_dim = adj.size(-1)
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
return torch.div(adj, rowsum)
def graph_pooling(inputs, num_vertices):
num_vertices = num_vertices.to(inputs.device)
out = inputs.sum(1)
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
class DirectedGraphConvolution(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
self.dropout = nn.Dropout(0.1)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight1.data)
nn.init.xavier_uniform_(self.weight2.data)
def forward(self, inputs, adj):
inputs = inputs.to(self.weight1.device)
adj = adj.to(self.weight1.device)
norm_adj = normalize_adj(adj)
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
out = (output1 + output2) / 2
out = self.dropout(out)
return out
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
@utils.register_model(name='MetaNeuralPredictor')
class MetaeuralPredictor(nn.Module):
def __init__(self, config):
super().__init__()
# Arch
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
config.model.graph_encoder.gcn_hidden)
for i in range(config.model.graph_encoder.gcn_layers)]
self.gcn = nn.ModuleList(self.gcn)
self.dropout = nn.Dropout(0.1)
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
# Time
self.d_model = config.model.graph_encoder.gcn_hidden
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
self.act = act = get_act(config)
self.input_type = config.model.input_type
self.hs = config.model.hs
# Set
self.nz = config.model.nz
self.num_sample = config.model.num_sample
self.intra_setpool = SetPool(dim_input=512,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.inter_setpool = SetPool(dim_input=self.nz,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.set_fc = nn.Sequential(
nn.Linear(512, self.nz),
nn.ReLU())
input_dim = 0
if 'D' in self.input_type:
input_dim += self.nz
if 'A' in self.input_type:
input_dim += config.model.graph_encoder.linear_hidden
self.pred_fc = nn.Sequential(
nn.Linear(input_dim, self.hs),
nn.Tanh(),
nn.Linear(self.hs, 1)
)
self.sample_state = False
self.D_mu = None
def arch_encode(self, X, time_cond, maskX):
out = X
adj = maskX
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device)
gs = adj.size(1) # graph node number
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)
emb_t = self.timeEmb1(emb_t)
emb_t = self.timeEmb2(self.act(emb_t))
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device))
for layer in self.gcn:
out = layer(out, adj_with_diag)
out = graph_pooling(out, numv)
# time
out = out + emb_t
out = self.fc1(out)
out = self.dropout(out)
return out
def set_encode(self, task):
proto_batch = []
for x in task:
cls_protos = self.intra_setpool(
x.view(-1, self.num_sample, 512)).squeeze(1)
proto_batch.append(
self.inter_setpool(cls_protos.unsqueeze(0)))
v = torch.stack(proto_batch).squeeze()
return v
def predict(self, D_mu, A_mu):
input_vec = []
if 'D' in self.input_type:
input_vec.append(D_mu)
if 'A' in self.input_type:
input_vec.append(A_mu)
input_vec = torch.cat(input_vec, dim=1)
return self.pred_fc(input_vec)
def forward(self, X, time_cond, maskX, task):
if self.sample_state:
if self.D_mu is None:
self.D_mu = self.set_encode(task)
D_mu = self.D_mu
else:
D_mu = self.set_encode(task)
A_mu = self.arch_encode(X, time_cond, maskX)
y_pred = self.predict(D_mu, A_mu)
return y_pred
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')

View File

@@ -0,0 +1,85 @@
import torch
class ExponentialMovingAverage:
"""
Maintains (exponential) moving average of a set of parameters.
"""
def __init__(self, parameters, decay, use_num_updates=True):
"""
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`.
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing averages.
"""
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.num_updates = 0 if use_num_updates else None
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
self.collected_params = []
def update(self, parameters):
"""
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of the `optimizer.step()` call.
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to
initialize this object.
"""
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
s_param.sub_(one_minus_decay * (s_param - param))
def copy_to(self, parameters):
"""
Copy current parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages.
"""
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
param.data.copy_(s_param.data)
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the original optimization process.
Store the parameters before the `copy_to` method.
After validation (or model saving), use this to restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
def state_dict(self):
return dict(decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params)
def load_state_dict(self, state_dict):
self.decay = state_dict['decay']
self.num_updates = state_dict['num_updates']
self.shadow_params = state_dict['shadow_params']

View File

@@ -0,0 +1,82 @@
import torch.nn as nn
import torch
from .trans_layers import *
class pos_gnn(nn.Module):
def __init__(self, act, x_ch, pos_ch, out_ch, max_node, graph_layer, n_layers=3, edge_dim=None, heads=4,
temb_dim=None, dropout=0.1, attn_clamp=False):
super().__init__()
self.out_ch = out_ch
self.Dropout_0 = nn.Dropout(dropout)
self.act = act
self.max_node = max_node
self.n_layers = n_layers
if temb_dim is not None:
self.Dense_node0 = nn.Linear(temb_dim, x_ch)
self.Dense_node1 = nn.Linear(temb_dim, pos_ch)
self.Dense_edge0 = nn.Linear(temb_dim, edge_dim)
self.Dense_edge1 = nn.Linear(temb_dim, edge_dim)
self.convs = nn.ModuleList()
self.edge_convs = nn.ModuleList()
self.edge_layer = nn.Linear(edge_dim * 2 + self.out_ch, edge_dim)
for i in range(n_layers):
if i == 0:
self.convs.append(eval(graph_layer)(x_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2,
act=act, attn_clamp=attn_clamp))
else:
self.convs.append(eval(graph_layer)
(self.out_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2, act=act,
attn_clamp=attn_clamp))
self.edge_convs.append(nn.Linear(self.out_ch, edge_dim*2))
def forward(self, x_degree, x_pos, edge_index, dense_ori, dense_spd, dense_index, temb=None):
"""
Args:
x_degree: node degree feature [B*N, x_ch]
x_pos: node rwpe feature [B*N, pos_ch]
edge_index: [2, edge_length]
dense_ori: edge feature [B, N, N, nf//2]
dense_spd: edge shortest path distance feature [B, N, N, nf//2] # Do we need this part? # TODO
dense_index
temb: [B, temb_dim]
"""
B, N, _, _ = dense_ori.shape
if temb is not None:
dense_ori = dense_ori + self.Dense_edge0(self.act(temb))[:, None, None, :]
dense_spd = dense_spd + self.Dense_edge1(self.act(temb))[:, None, None, :]
temb = temb.unsqueeze(1).repeat(1, self.max_node, 1)
temb = temb.reshape(-1, temb.shape[-1])
x_degree = x_degree + self.Dense_node0(self.act(temb))
x_pos = x_pos + self.Dense_node1(self.act(temb))
dense_edge = torch.cat([dense_ori, dense_spd], dim=-1)
ori_edge_attr = dense_edge
h = x_degree
h_pos = x_pos
for i_layer in range(self.n_layers):
h_edge = dense_edge[dense_index]
# update node feature
h, h_pos = self.convs[i_layer](h, h_pos, edge_index, h_edge)
h = self.Dropout_0(h)
h_pos = self.Dropout_0(h_pos)
# update dense edge feature
h_dense_node = h.reshape(B, N, -1)
cur_edge_attr = h_dense_node.unsqueeze(1) + h_dense_node.unsqueeze(2) # [B, N, N, nf]
dense_edge = (dense_edge + self.act(self.edge_convs[i_layer](cur_edge_attr))) / math.sqrt(2.)
dense_edge = self.Dropout_0(dense_edge)
# Concat edge attribute
h_dense_edge = torch.cat([ori_edge_attr, dense_edge], dim=-1)
h_dense_edge = self.edge_layer(h_dense_edge).permute(0, 3, 1, 2)
return h_dense_edge

View File

@@ -0,0 +1,44 @@
"""Common layers"""
import torch.nn as nn
import torch
import torch.nn.functional as F
import math
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')
def conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, padding=0):
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
padding=padding)
return conv
# from DDPM
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb

View File

@@ -0,0 +1,38 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from .setenc_modules import *
class SetPool(nn.Module):
def __init__(self, dim_input, num_outputs, dim_output,
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
super(SetPool, self).__init__()
if 'sab' in mode: # [32, 400, 128]
self.enc = nn.Sequential(
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
else: # [32, 400, 128]
self.enc = nn.Sequential(
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
if 'PF' in mode: # [32, 1, 501]
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
nn.Linear(dim_hidden, dim_output))
elif 'P' in mode:
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
else: # torch.Size([32, 1, 501])
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
nn.Linear(dim_hidden, dim_output))
# "", sm, sab, sabsm
def forward(self, X):
x1 = self.enc(X)
x2 = self.dec(x1)
return x2

View File

@@ -0,0 +1,67 @@
#####################################################################################
# Copyright (c) Juho Lee SetTransformer, ICML 2019 [GitHub set_transformer]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MAB(nn.Module):
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
super(MAB, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads
self.fc_q = nn.Linear(dim_Q, dim_V)
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V)
def forward(self, Q, K):
Q = self.fc_q(Q)
K, V = self.fc_k(K), self.fc_v(K)
dim_split = self.dim_V // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), 0)
K_ = torch.cat(K.split(dim_split, 2), 0)
V_ = torch.cat(V.split(dim_split, 2), 0)
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
return O
class SAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, ln=False):
super(SAB, self).__init__()
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
def forward(self, X):
return self.mab(X, X)
class ISAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
super(ISAB, self).__init__()
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
nn.init.xavier_uniform_(self.I)
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
def forward(self, X):
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
return self.mab1(X, H)
class PMA(nn.Module):
def __init__(self, dim, num_heads, num_seeds, ln=False):
super(PMA, self).__init__()
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
nn.init.xavier_uniform_(self.S)
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
def forward(self, X):
return self.mab(self.S.repeat(X.size(0), 1, 1), X)

View File

@@ -0,0 +1,144 @@
import math
from typing import Union, Tuple, Optional
from torch_geometric.typing import PairTensor, Adj, OptTensor
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Linear
from torch_scatter import scatter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
import numpy as np
class PosTransLayer(MessagePassing):
"""Involving the edge feature and updating position feature. Multiply Msg."""
_alpha: OptTensor
def __init__(self, x_channels: int, pos_channels: int, out_channels: int,
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
bias: bool = True, act=None, attn_clamp: bool = False, **kwargs):
kwargs.setdefault('aggr', 'add')
super(PosTransLayer, self).__init__(node_dim=0, **kwargs)
self.x_channels = x_channels
self.pos_channels = pos_channels
self.in_channels = in_channels = x_channels + pos_channels
self.out_channels = out_channels
self.heads = heads
self.dropout = dropout
self.edge_dim = edge_dim
self.attn_clamp = attn_clamp
if act is None:
self.act = nn.LeakyReLU(negative_slope=0.2)
else:
self.act = act
self.lin_key = Linear(in_channels, heads * out_channels)
self.lin_query = Linear(in_channels, heads * out_channels)
self.lin_value = Linear(in_channels, heads * out_channels)
self.lin_edge0 = Linear(edge_dim, heads * out_channels, bias=False)
self.lin_edge1 = Linear(edge_dim, heads * out_channels, bias=False)
self.lin_pos = Linear(heads * out_channels, pos_channels, bias=False)
self.lin_skip = Linear(x_channels, heads * out_channels, bias=bias)
self.norm1 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
num_channels=heads * out_channels, eps=1e-6)
self.norm2 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
num_channels=heads * out_channels, eps=1e-6)
# FFN
self.FFN = nn.Sequential(Linear(heads * out_channels, heads * out_channels),
self.act,
Linear(heads * out_channels, heads * out_channels))
self.reset_parameters()
def reset_parameters(self):
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
self.lin_skip.reset_parameters()
self.lin_edge0.reset_parameters()
self.lin_edge1.reset_parameters()
self.lin_pos.reset_parameters()
def forward(self, x: OptTensor,
pos: Tensor,
edge_index: Adj,
edge_attr: OptTensor = None
) -> Tuple[Tensor, Tensor]:
""""""
H, C = self.heads, self.out_channels
x_feat = torch.cat([x, pos], -1)
query = self.lin_query(x_feat).view(-1, H, C)
key = self.lin_key(x_feat).view(-1, H, C)
value = self.lin_value(x_feat).view(-1, H, C)
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out_x, out_pos = self.propagate(edge_index, query=query, key=key, value=value, pos=pos, edge_attr=edge_attr,
size=None)
out_x = out_x.view(-1, self.heads * self.out_channels)
# skip connection for x
x_r = self.lin_skip(x)
out_x = (out_x + x_r) / math.sqrt(2)
out_x = self.norm1(out_x)
# FFN
out_x = (out_x + self.FFN(out_x)) / math.sqrt(2)
out_x = self.norm2(out_x)
# skip connection for pos
out_pos = pos + torch.tanh(pos + out_pos)
return out_x, out_pos
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
pos_j: Tensor,
edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
edge_attn = self.lin_edge0(edge_attr).view(-1, self.heads, self.out_channels)
alpha = (query_i * key_j * edge_attn).sum(dim=-1) / math.sqrt(self.out_channels)
if self.attn_clamp:
alpha = alpha.clamp(min=-5., max=5.)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# node feature message
msg = value_j
msg = msg * self.lin_edge1(edge_attr).view(-1, self.heads, self.out_channels)
msg = msg * alpha.view(-1, self.heads, 1)
# node position message
pos_msg = pos_j * self.lin_pos(msg.reshape(-1, self.heads * self.out_channels))
return msg, pos_msg
def aggregate(self, inputs: Tuple[Tensor, Tensor], index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[Tensor, Tensor]:
if ptr is not None:
raise NotImplementedError("Not implement Ptr in aggregate")
else:
return (scatter(inputs[0], index, 0, dim_size=dim_size, reduce=self.aggr),
scatter(inputs[1], index, 0, dim_size=dim_size, reduce="mean"))
def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
return inputs
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)

View File

@@ -0,0 +1,255 @@
from copy import deepcopy as cp
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def clones(module, N):
return nn.ModuleList([cp(module) for _ in range(N)])
def attention(query, key, value, mask = None, dropout = None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim = -1)
if dropout is not None:
attn = dropout(attn)
return torch.matmul(attn, value), attn
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super(MultiHeadAttention, self).__init__()
self.d_model = config.d_model
self.n_head = config.n_head
self.d_k = config.d_model // config.n_head
self.linears = clones(nn.Linear(self.d_model, self.d_model), 4)
self.dropout = nn.Dropout(p=config.dropout)
def forward(self, query, key, value, mask = None):
if mask is not None:
mask = mask.unsqueeze(1)
batch_size = query.size(0)
query, key , value = [l(x).view(batch_size, -1, self.n_head, self.d_k).transpose(1,2) for l, x in zip(self.linears, (query, key, value))]
x, attn = attention(query, key, value, mask = mask, dropout = self.dropout)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)
return self.linears[3](x), attn
class PositionwiseFeedForward(nn.Module):
def __init__(self, config):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(config.d_model, config.d_ff)
self.w_2 = nn.Linear(config.d_ff, config.d_model)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class PositionwiseFeedForwardLast(nn.Module):
def __init__(self, config):
super(PositionwiseFeedForwardLast, self).__init__()
self.w_1 = nn.Linear(config.d_model, config.d_ff)
self.w_2 = nn.Linear(config.d_ff, config.n_vocab)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class SelfAttentionBlock(nn.Module):
def __init__(self, config):
super(SelfAttentionBlock, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.attn = MultiHeadAttention(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x, mask):
x_ = self.norm(x)
x_ , attn = self.attn(x_, x_, x_, mask)
return self.dropout(x_) + x, attn
class SourceAttentionBlock(nn.Module):
def __init__(self, config):
super(SourceAttentionBlock, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.attn = MultiHeadAttention(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x, m, mask):
x_ = self.norm(x)
x_, attn = self.attn(x_, m, m, mask)
return self.dropout(x_) + x, attn
class FeedForwardBlock(nn.Module):
def __init__(self, config):
super(FeedForwardBlock, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.feed_forward = PositionwiseFeedForward(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x):
x_ = self.norm(x)
x_ = self.feed_forward(x_)
return self.dropout(x_) + x
class FeedForwardBlockLast(nn.Module):
def __init__(self, config):
super(FeedForwardBlockLast, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.feed_forward = PositionwiseFeedForwardLast(config)
self.dropout = nn.Dropout(p = config.dropout)
# Only for the last layer
self.proj_fc = nn.Linear(config.d_model, config.n_vocab)
def forward(self, x):
x_ = self.norm(x)
x_ = self.feed_forward(x_)
return self.dropout(x_) + self.proj_fc(x)
class EncoderBlock(nn.Module):
def __init__(self, config):
super(EncoderBlock, self).__init__()
self.self_attn = SelfAttentionBlock(config)
self.feed_forward = FeedForwardBlock(config)
def forward(self, x, mask):
x, attn = self.self_attn(x, mask)
x = self.feed_forward(x)
return x, attn
class EncoderBlockLast(nn.Module):
def __init__(self, config):
super(EncoderBlockLast, self).__init__()
self.self_attn = SelfAttentionBlock(config)
self.feed_forward = FeedForwardBlockLast(config)
def forward(self, x, mask):
x, attn = self.self_attn(x, mask)
x = self.feed_forward(x)
return x, attn
class DecoderBlock(nn.Module):
def __init__(self, config):
super(DecoderBlock, self).__init__()
self.self_attn = SelfAttentionBlock(config)
self.src_attn = SourceAttentionBlock(config)
self.feed_forward = FeedForwardBlock(config)
def forward(self, x, m, src_mask, tgt_mask):
x, attn_tgt = self.self_attn(x, tgt_mask)
x, attn_src = self.src_attn(x, m, src_mask)
x = self.feed_forward(x)
return x, attn_src, attn_tgt
class Encoder(nn.Module):
def __init__(self, config):
super(Encoder, self).__init__()
self.layers = clones(EncoderBlock(config), config.n_layers)
self.norms = clones(nn.LayerNorm(config.d_model), config.n_layers)
def forward(self, x, mask):
outputs = []
attns = []
for layer, norm in zip(self.layers, self.norms):
x, attn = layer(x, mask)
outputs.append(norm(x))
attns.append(attn)
return outputs[-1], outputs, attns
class PositionalEmbedding(nn.Module):
def __init__(self, config):
super(PositionalEmbedding, self).__init__()
p2e = torch.zeros(config.max_len, config.d_model)
position = torch.arange(0.0, config.max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0.0, config.d_model, 2) * (- math.log(10000.0) / config.d_model))
p2e[:, 0::2] = torch.sin(position * div_term)
p2e[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('p2e', p2e)
def forward(self, x):
shp = x.size()
with torch.no_grad():
emb = torch.index_select(self.p2e, 0, x.view(-1)).view(shp + (-1,))
return emb
class Transformer(nn.Module):
def __init__(self, config):
super(Transformer, self).__init__()
self.p2e = PositionalEmbedding(config)
self.encoder = Encoder(config)
def forward(self, input_emb, position_ids, attention_mask):
# position embedding projection
projection = self.p2e(position_ids) + input_emb
return self.encoder(projection, attention_mask)
class TokenTypeEmbedding(nn.Module):
def __init__(self, config):
super(TokenTypeEmbedding, self).__init__()
self.t2e = nn.Embedding(config.n_token_type, config.d_model)
self.d_model = config.d_model
def forward(self, x):
return self.t2e(x) * math.sqrt(self.d_model)
class SemanticEmbedding(nn.Module):
def __init__(self, config):
super(SemanticEmbedding, self).__init__()
self.d_model = config.d_model
self.fc = nn.Linear(config.n_vocab, config.d_model)
def forward(self, x):
return self.fc(x) * math.sqrt(self.d_model)
class Embeddings(nn.Module):
def __init__(self, config):
super(Embeddings, self).__init__()
self.w2e = SemanticEmbedding(config)
self.p2e = PositionalEmbedding(config)
self.t2e = TokenTypeEmbedding(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, input_ids, position_ids = None, token_type_ids = None):
if position_ids is None:
batch_size, length = input_ids.size()
with torch.no_grad():
position_ids = torch.arange(0, length).repeat(batch_size, 1)
if torch.cuda.is_available():
position_ids = position_ids.cuda(device=input_ids.device)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
embeddings = self.w2e(input_ids) + self.p2e(position_ids) + self.t2e(token_type_ids)
return self.dropout(embeddings)

View File

@@ -0,0 +1,289 @@
import torch
import torch.nn.functional as F
import sde_lib
_MODELS = {}
def register_model(cls=None, *, name=None):
"""A decorator for registering model classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _MODELS:
raise ValueError(
f'Already registered model with name: {local_name}')
_MODELS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def get_model(name):
return _MODELS[name]
def create_model(config):
"""Create the model."""
model_name = config.model.name
model = get_model(model_name)(config)
model = model.to(config.device)
return model
def get_model_fn(model, train=False):
"""Create a function to give the output of the score-based model.
Args:
model: The score model.
train: `True` for training and `False` for evaluation.
Returns:
A model function.
"""
def model_fn(x, labels, *args, **kwargs):
"""Compute the output of the score-based model.
Args:
x: A mini-batch of input data (Adjacency matrices).
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
for different models.
mask: Mask for adjacency matrices.
Returns:
A tuple of (model output, new mutable states)
"""
if not train:
model.eval()
return model(x, labels, *args, **kwargs)
else:
model.train()
return model(x, labels, *args, **kwargs)
return model_fn
def get_score_fn(sde, model, train=False, continuous=False):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A score model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
Returns:
A score function.
"""
model_fn = get_model_fn(model, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def score_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
score = model_fn(x, labels, *args, **kwargs)
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
score = model_fn(x, labels, *args, **kwargs)
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
labels.long()]
score = -score / std[:, None, None]
return score
elif isinstance(sde, sde_lib.VESDE):
def score_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
score = model_fn(x, labels, *args, **kwargs)
return score
else:
raise NotImplementedError(
f"SDE class {sde.__class__.__name__} not yet supported.")
return score_fn
def get_classifier_grad_fn(sde, classifier, train=False, continuous=False,
regress=True, labels='max'):
logit_fn = get_logit_fn(sde, classifier, train, continuous)
def classifier_grad_fn(x, t, *args, **kwargs):
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
if regress:
assert labels in ['max', 'min']
logit = logit_fn(x_in, t, *args, **kwargs)
if labels == 'max':
prob = logit.sum()
elif labels == 'min':
prob = -logit.sum()
else:
logit = logit_fn(x_in, t, *args, **kwargs)
log_prob = F.log_softmax(logit, dim=-1)
prob = log_prob[range(len(logit)), labels.view(-1)].sum()
classifier_grad = torch.autograd.grad(prob, x_in)[0]
return classifier_grad
return classifier_grad_fn
def get_logit_fn(sde, classifier, train=False, continuous=False):
classifier_fn = get_model_fn(classifier, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def logit_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
logit = classifier_fn(x, labels, *args, **kwargs)
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
logit = classifier_fn(x, labels, *args, **kwargs)
return logit
elif isinstance(sde, sde_lib.VESDE):
def logit_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
logit = classifier_fn(x, labels, *args, **kwargs)
return logit
return logit_fn
def get_predictor_fn(sde, model, train=False, continuous=False):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A predictor model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
Returns:
A score function.
"""
model_fn = get_model_fn(model, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def predictor_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
pred = model_fn(x, labels, *args, **kwargs)
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
pred = model_fn(x, labels, *args, **kwargs)
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
labels.long()]
return pred
elif isinstance(sde, sde_lib.VESDE):
def predictor_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
pred = model_fn(x, labels, *args, **kwargs)
return pred
else:
raise NotImplementedError(
f"SDE class {sde.__class__.__name__} not yet supported.")
return predictor_fn
def to_flattened_numpy(x):
"""Flatten a torch tensor `x` and convert it to numpy."""
return x.detach().cpu().numpy().reshape((-1,))
def from_flattened_numpy(x, shape):
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
return torch.from_numpy(x.reshape(shape))
@torch.no_grad()
def mask_adj2node(adj_mask):
"""Convert batched adjacency mask matrices to batched node mask matrices.
Args:
adj_mask: [B, N, N] Batched adjacency mask matrices without self-loop edge.
Output:
node_mask: [B, N] Batched node mask matrices indicating the valid nodes.
"""
batch_size, max_num_nodes, _ = adj_mask.shape
node_mask = adj_mask[:, 0, :].clone()
node_mask[:, 0] = 1
return node_mask
@torch.no_grad()
def get_rw_feat(k_step, dense_adj):
"""Compute k_step Random Walk for given dense adjacency matrix."""
rw_list = []
deg = dense_adj.sum(-1, keepdims=True)
AD = dense_adj / (deg + 1e-8)
rw_list.append(AD)
for _ in range(k_step):
rw = torch.bmm(rw_list[-1], AD)
rw_list.append(rw)
rw_map = torch.stack(rw_list[1:], dim=1) # [B, k_step, N, N]
rw_landing = torch.diagonal(
rw_map, offset=0, dim1=2, dim2=3) # [B, k_step, N]
rw_landing = rw_landing.permute(0, 2, 1) # [B, N, rw_depth]
# get the shortest path distance indices
tmp_rw = rw_map.sort(dim=1)[0]
spd_ind = (tmp_rw <= 0).sum(dim=1) # [B, N, N]
spd_onehot = torch.nn.functional.one_hot(
spd_ind, num_classes=k_step+1).to(torch.float)
spd_onehot = spd_onehot.permute(0, 3, 1, 2) # [B, kstep, N, N]
return rw_landing, spd_onehot