first commit
This commit is contained in:
117
MobileNetV3/models/GDSS/attention.py
Normal file
117
MobileNetV3/models/GDSS/attention.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import math
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.GDSS.layers import DenseGCNConv, MLP
|
||||
# from ..utils.graph_utils import mask_adjs, mask_x
|
||||
from .graph_utils import mask_x, mask_adjs
|
||||
|
||||
|
||||
# -------- Graph Multi-Head Attention (GMH) --------
|
||||
# -------- From Baek et al. (2021) --------
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_dim, attn_dim, out_dim, num_heads=4, conv='GCN'):
|
||||
super(Attention, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
self.attn_dim = attn_dim
|
||||
self.out_dim = out_dim
|
||||
self.conv = conv
|
||||
|
||||
self.gnn_q, self.gnn_k, self.gnn_v = self.get_gnn(in_dim, attn_dim, out_dim, conv)
|
||||
self.activation = torch.tanh
|
||||
self.softmax_dim = 2
|
||||
|
||||
def forward(self, x, adj, flags, attention_mask=None):
|
||||
|
||||
if self.conv == 'GCN':
|
||||
Q = self.gnn_q(x, adj)
|
||||
K = self.gnn_k(x, adj)
|
||||
else:
|
||||
Q = self.gnn_q(x)
|
||||
K = self.gnn_k(x)
|
||||
|
||||
V = self.gnn_v(x, adj)
|
||||
dim_split = self.attn_dim // self.num_heads
|
||||
Q_ = torch.cat(Q.split(dim_split, 2), 0)
|
||||
K_ = torch.cat(K.split(dim_split, 2), 0)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.cat([attention_mask for _ in range(self.num_heads)], 0)
|
||||
attention_score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim)
|
||||
A = self.activation( attention_mask + attention_score )
|
||||
else:
|
||||
A = self.activation( Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim) ) # (B x num_heads) x N x N
|
||||
|
||||
# -------- (B x num_heads) x N x N --------
|
||||
A = A.view(-1, *adj.shape)
|
||||
A = A.mean(dim=0)
|
||||
A = (A + A.transpose(-1,-2))/2
|
||||
|
||||
return V, A
|
||||
|
||||
def get_gnn(self, in_dim, attn_dim, out_dim, conv='GCN'):
|
||||
|
||||
if conv == 'GCN':
|
||||
gnn_q = DenseGCNConv(in_dim, attn_dim)
|
||||
gnn_k = DenseGCNConv(in_dim, attn_dim)
|
||||
gnn_v = DenseGCNConv(in_dim, out_dim)
|
||||
|
||||
return gnn_q, gnn_k, gnn_v
|
||||
|
||||
elif conv == 'MLP':
|
||||
num_layers=2
|
||||
gnn_q = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh)
|
||||
gnn_k = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh)
|
||||
gnn_v = DenseGCNConv(in_dim, out_dim)
|
||||
|
||||
return gnn_q, gnn_k, gnn_v
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'{conv} not implemented.')
|
||||
|
||||
|
||||
# -------- Layer of ScoreNetworkA --------
|
||||
class AttentionLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_linears, conv_input_dim, attn_dim, conv_output_dim, input_dim, output_dim,
|
||||
num_heads=4, conv='GCN'):
|
||||
|
||||
super(AttentionLayer, self).__init__()
|
||||
|
||||
self.attn = torch.nn.ModuleList()
|
||||
for _ in range(input_dim):
|
||||
self.attn_dim = attn_dim
|
||||
self.attn.append(Attention(conv_input_dim, self.attn_dim, conv_output_dim,
|
||||
num_heads=num_heads, conv=conv))
|
||||
|
||||
self.hidden_dim = 2*max(input_dim, output_dim)
|
||||
self.mlp = MLP(num_linears, 2*input_dim, self.hidden_dim, output_dim, use_bn=False, activate_func=F.elu)
|
||||
self.multi_channel = MLP(2, input_dim*conv_output_dim, self.hidden_dim, conv_output_dim,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
def forward(self, x, adj, flags):
|
||||
"""
|
||||
|
||||
:param x: B x N x F_i
|
||||
:param adj: B x C_i x N x N
|
||||
:return: x_out: B x N x F_o, adj_out: B x C_o x N x N
|
||||
"""
|
||||
mask_list = []
|
||||
x_list = []
|
||||
for _ in range(len(self.attn)):
|
||||
_x, mask = self.attn[_](x, adj[:,_,:,:], flags)
|
||||
mask_list.append(mask.unsqueeze(-1))
|
||||
x_list.append(_x)
|
||||
x_out = mask_x(self.multi_channel(torch.cat(x_list, dim=-1)), flags)
|
||||
x_out = torch.tanh(x_out)
|
||||
|
||||
mlp_in = torch.cat([torch.cat(mask_list, dim=-1), adj.permute(0,2,3,1)], dim=-1)
|
||||
shape = mlp_in.shape
|
||||
mlp_out = self.mlp(mlp_in.view(-1, shape[-1]))
|
||||
_adj = mlp_out.view(shape[0], shape[1], shape[2], -1).permute(0,3,1,2)
|
||||
_adj = _adj + _adj.transpose(-1,-2)
|
||||
adj_out = mask_adjs(_adj, flags)
|
||||
|
||||
return x_out, adj_out
|
209
MobileNetV3/models/GDSS/graph_utils.py
Normal file
209
MobileNetV3/models/GDSS/graph_utils.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
# -------- Mask batch of node features with 0-1 flags tensor --------
|
||||
def mask_x(x, flags):
|
||||
|
||||
if flags is None:
|
||||
flags = torch.ones((x.shape[0], x.shape[1]), device=x.device)
|
||||
return x * flags[:,:,None]
|
||||
|
||||
|
||||
# -------- Mask batch of adjacency matrices with 0-1 flags tensor --------
|
||||
def mask_adjs(adjs, flags):
|
||||
"""
|
||||
:param adjs: B x N x N or B x C x N x N
|
||||
:param flags: B x N
|
||||
:return:
|
||||
"""
|
||||
if flags is None:
|
||||
flags = torch.ones((adjs.shape[0], adjs.shape[-1]), device=adjs.device)
|
||||
|
||||
if len(adjs.shape) == 4:
|
||||
flags = flags.unsqueeze(1) # B x 1 x N
|
||||
adjs = adjs * flags.unsqueeze(-1)
|
||||
adjs = adjs * flags.unsqueeze(-2)
|
||||
return adjs
|
||||
|
||||
|
||||
# -------- Create flags tensor from graph dataset --------
|
||||
def node_flags(adj, eps=1e-5):
|
||||
|
||||
flags = torch.abs(adj).sum(-1).gt(eps).to(dtype=torch.float32)
|
||||
|
||||
if len(flags.shape)==3:
|
||||
flags = flags[:,0,:]
|
||||
return flags
|
||||
|
||||
|
||||
# -------- Create initial node features --------
|
||||
def init_features(init, adjs=None, nfeat=10):
|
||||
|
||||
if init=='zeros':
|
||||
feature = torch.zeros((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
|
||||
elif init=='ones':
|
||||
feature = torch.ones((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
|
||||
elif init=='deg':
|
||||
feature = adjs.sum(dim=-1).to(torch.long)
|
||||
num_classes = nfeat
|
||||
try:
|
||||
feature = F.one_hot(feature, num_classes=num_classes).to(torch.float32)
|
||||
except:
|
||||
print(feature.max().item())
|
||||
raise NotImplementedError(f'max_feat_num mismatch')
|
||||
else:
|
||||
raise NotImplementedError(f'{init} not implemented')
|
||||
|
||||
flags = node_flags(adjs)
|
||||
|
||||
return mask_x(feature, flags)
|
||||
|
||||
|
||||
# -------- Sample initial flags tensor from the training graph set --------
|
||||
def init_flags(graph_list, config, batch_size=None):
|
||||
if batch_size is None:
|
||||
batch_size = config.data.batch_size
|
||||
max_node_num = config.data.max_node_num
|
||||
graph_tensor = graphs_to_tensor(graph_list, max_node_num)
|
||||
idx = np.random.randint(0, len(graph_list), batch_size)
|
||||
flags = node_flags(graph_tensor[idx])
|
||||
|
||||
return flags
|
||||
|
||||
|
||||
# -------- Generate noise --------
|
||||
def gen_noise(x, flags, sym=True):
|
||||
z = torch.randn_like(x)
|
||||
if sym:
|
||||
z = z.triu(1)
|
||||
z = z + z.transpose(-1,-2)
|
||||
z = mask_adjs(z, flags)
|
||||
else:
|
||||
z = mask_x(z, flags)
|
||||
return z
|
||||
|
||||
|
||||
# -------- Quantize generated graphs --------
|
||||
def quantize(adjs, thr=0.5):
|
||||
adjs_ = torch.where(adjs < thr, torch.zeros_like(adjs), torch.ones_like(adjs))
|
||||
return adjs_
|
||||
|
||||
|
||||
# -------- Quantize generated molecules --------
|
||||
# adjs: 32 x 9 x 9
|
||||
def quantize_mol(adjs):
|
||||
if type(adjs).__name__ == 'Tensor':
|
||||
adjs = adjs.detach().cpu()
|
||||
else:
|
||||
adjs = torch.tensor(adjs)
|
||||
adjs[adjs >= 2.5] = 3
|
||||
adjs[torch.bitwise_and(adjs >= 1.5, adjs < 2.5)] = 2
|
||||
adjs[torch.bitwise_and(adjs >= 0.5, adjs < 1.5)] = 1
|
||||
adjs[adjs < 0.5] = 0
|
||||
return np.array(adjs.to(torch.int64))
|
||||
|
||||
|
||||
def adjs_to_graphs(adjs, is_cuda=False):
|
||||
graph_list = []
|
||||
for adj in adjs:
|
||||
if is_cuda:
|
||||
adj = adj.detach().cpu().numpy()
|
||||
G = nx.from_numpy_matrix(adj)
|
||||
G.remove_edges_from(nx.selfloop_edges(G))
|
||||
G.remove_nodes_from(list(nx.isolates(G)))
|
||||
if G.number_of_nodes() < 1:
|
||||
G.add_node(1)
|
||||
graph_list.append(G)
|
||||
return graph_list
|
||||
|
||||
|
||||
# -------- Check if the adjacency matrices are symmetric --------
|
||||
def check_sym(adjs, print_val=False):
|
||||
sym_error = (adjs-adjs.transpose(-1,-2)).abs().sum([0,1,2])
|
||||
if not sym_error < 1e-2:
|
||||
raise ValueError(f'Not symmetric: {sym_error:.4e}')
|
||||
if print_val:
|
||||
print(f'{sym_error:.4e}')
|
||||
|
||||
|
||||
# -------- Create higher order adjacency matrices --------
|
||||
def pow_tensor(x, cnum):
|
||||
# x : B x N x N
|
||||
x_ = x.clone()
|
||||
xc = [x.unsqueeze(1)]
|
||||
for _ in range(cnum-1):
|
||||
x_ = torch.bmm(x_, x)
|
||||
xc.append(x_.unsqueeze(1))
|
||||
xc = torch.cat(xc, dim=1)
|
||||
|
||||
return xc
|
||||
|
||||
|
||||
# -------- Create padded adjacency matrices --------
|
||||
def pad_adjs(ori_adj, node_number):
|
||||
a = ori_adj
|
||||
ori_len = a.shape[-1]
|
||||
if ori_len == node_number:
|
||||
return a
|
||||
if ori_len > node_number:
|
||||
raise ValueError(f'ori_len {ori_len} > node_number {node_number}')
|
||||
a = np.concatenate([a, np.zeros([ori_len, node_number - ori_len])], axis=-1)
|
||||
a = np.concatenate([a, np.zeros([node_number - ori_len, node_number])], axis=0)
|
||||
return a
|
||||
|
||||
|
||||
def graphs_to_tensor(graph_list, max_node_num):
|
||||
adjs_list = []
|
||||
max_node_num = max_node_num
|
||||
|
||||
for g in graph_list:
|
||||
assert isinstance(g, nx.Graph)
|
||||
node_list = []
|
||||
for v, feature in g.nodes.data('feature'):
|
||||
node_list.append(v)
|
||||
|
||||
adj = nx.to_numpy_matrix(g, nodelist=node_list)
|
||||
padded_adj = pad_adjs(adj, node_number=max_node_num)
|
||||
adjs_list.append(padded_adj)
|
||||
|
||||
del graph_list
|
||||
|
||||
adjs_np = np.asarray(adjs_list)
|
||||
del adjs_list
|
||||
|
||||
adjs_tensor = torch.tensor(adjs_np, dtype=torch.float32)
|
||||
del adjs_np
|
||||
|
||||
return adjs_tensor
|
||||
|
||||
|
||||
def graphs_to_adj(graph, max_node_num):
|
||||
max_node_num = max_node_num
|
||||
|
||||
assert isinstance(graph, nx.Graph)
|
||||
node_list = []
|
||||
for v, feature in graph.nodes.data('feature'):
|
||||
node_list.append(v)
|
||||
|
||||
adj = nx.to_numpy_matrix(graph, nodelist=node_list)
|
||||
padded_adj = pad_adjs(adj, node_number=max_node_num)
|
||||
|
||||
adj = torch.tensor(padded_adj, dtype=torch.float32)
|
||||
del padded_adj
|
||||
|
||||
return adj
|
||||
|
||||
|
||||
def node_feature_to_matrix(x):
|
||||
"""
|
||||
:param x: BS x N x F
|
||||
:return:
|
||||
x_pair: BS x N x N x 2F
|
||||
"""
|
||||
x_b = x.unsqueeze(-2).expand(x.size(0), x.size(1), x.size(1), -1) # BS x N x N x F
|
||||
x_pair = torch.cat([x_b, x_b.transpose(1, 2)], dim=-1) # BS x N x N x 2F
|
||||
|
||||
return x_pair
|
153
MobileNetV3/models/GDSS/layers.py
Normal file
153
MobileNetV3/models/GDSS/layers.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
|
||||
def glorot(tensor):
|
||||
if tensor is not None:
|
||||
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
|
||||
tensor.data.uniform_(-stdv, stdv)
|
||||
|
||||
def zeros(tensor):
|
||||
if tensor is not None:
|
||||
tensor.data.fill_(0)
|
||||
|
||||
def reset(value: Any):
|
||||
if hasattr(value, 'reset_parameters'):
|
||||
value.reset_parameters()
|
||||
else:
|
||||
for child in value.children() if hasattr(value, 'children') else []:
|
||||
reset(child)
|
||||
|
||||
# -------- GCN layer --------
|
||||
class DenseGCNConv(torch.nn.Module):
|
||||
r"""See :class:`torch_geometric.nn.conv.GCNConv`.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, improved=False, bias=True):
|
||||
super(DenseGCNConv, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.improved = improved
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
glorot(self.weight)
|
||||
zeros(self.bias)
|
||||
|
||||
|
||||
def forward(self, x, adj, mask=None, add_loop=True):
|
||||
r"""
|
||||
Args:
|
||||
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
|
||||
\times N \times F}`, with batch-size :math:`B`, (maximum)
|
||||
number of nodes :math:`N` for each graph, and feature
|
||||
dimension :math:`F`.
|
||||
adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B
|
||||
\times N \times N}`. The adjacency tensor is broadcastable in
|
||||
the batch dimension, resulting in a shared adjacency matrix for
|
||||
the complete batch.
|
||||
mask (BoolTensor, optional): Mask matrix
|
||||
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
||||
the valid nodes for each graph. (default: :obj:`None`)
|
||||
add_loop (bool, optional): If set to :obj:`False`, the layer will
|
||||
not automatically add self-loops to the adjacency matrices.
|
||||
(default: :obj:`True`)
|
||||
"""
|
||||
x = x.unsqueeze(0) if x.dim() == 2 else x
|
||||
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
|
||||
B, N, _ = adj.size()
|
||||
|
||||
if add_loop:
|
||||
adj = adj.clone()
|
||||
idx = torch.arange(N, dtype=torch.long, device=adj.device)
|
||||
adj[:, idx, idx] = 1 if not self.improved else 2
|
||||
|
||||
out = torch.matmul(x, self.weight)
|
||||
deg_inv_sqrt = adj.sum(dim=-1).clamp(min=1).pow(-0.5)
|
||||
|
||||
adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
|
||||
out = torch.matmul(adj, out)
|
||||
|
||||
if self.bias is not None:
|
||||
out = out + self.bias
|
||||
|
||||
if mask is not None:
|
||||
out = out * mask.view(B, N, 1).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
|
||||
self.out_channels)
|
||||
|
||||
# -------- MLP layer --------
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu):
|
||||
"""
|
||||
num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
|
||||
input_dim: dimensionality of input features
|
||||
hidden_dim: dimensionality of hidden units at ALL layers
|
||||
output_dim: number of classes for prediction
|
||||
num_classes: the number of classes of input, to be treated with different gains and biases,
|
||||
(see the definition of class `ConditionalLayer1d`)
|
||||
"""
|
||||
|
||||
super(MLP, self).__init__()
|
||||
|
||||
self.linear_or_not = True # default is linear model
|
||||
self.num_layers = num_layers
|
||||
self.use_bn = use_bn
|
||||
self.activate_func = activate_func
|
||||
|
||||
if num_layers < 1:
|
||||
raise ValueError("number of layers should be positive!")
|
||||
elif num_layers == 1:
|
||||
# Linear model
|
||||
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||
else:
|
||||
# Multi-layer model
|
||||
self.linear_or_not = False
|
||||
self.linears = torch.nn.ModuleList()
|
||||
|
||||
self.linears.append(torch.nn.Linear(input_dim, hidden_dim))
|
||||
for layer in range(num_layers - 2):
|
||||
self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim))
|
||||
self.linears.append(torch.nn.Linear(hidden_dim, output_dim))
|
||||
|
||||
if self.use_bn:
|
||||
self.batch_norms = torch.nn.ModuleList()
|
||||
for layer in range(num_layers - 1):
|
||||
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x: [num_classes * batch_size, N, F_i], batch of node features
|
||||
note that in self.cond_layers[layer],
|
||||
`x` is splited into `num_classes` groups in dim=0,
|
||||
and then treated with different gains and biases
|
||||
"""
|
||||
if self.linear_or_not:
|
||||
# If linear model
|
||||
return self.linear(x)
|
||||
else:
|
||||
# If MLP
|
||||
h = x
|
||||
for layer in range(self.num_layers - 1):
|
||||
h = self.linears[layer](h)
|
||||
if self.use_bn:
|
||||
h = self.batch_norms[layer](h)
|
||||
h = self.activate_func(h)
|
||||
return self.linears[self.num_layers - 1](h)
|
103
MobileNetV3/models/GDSS/scorenetx.py
Normal file
103
MobileNetV3/models/GDSS/scorenetx.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.GDSS.layers import DenseGCNConv, MLP
|
||||
from .graph_utils import mask_x, pow_tensor
|
||||
from .attention import AttentionLayer
|
||||
from .. import utils
|
||||
|
||||
@utils.register_model(name='ScoreNetworkX')
|
||||
class ScoreNetworkX(torch.nn.Module):
|
||||
|
||||
# def __init__(self, max_feat_num, depth, nhid):
|
||||
def __init__(self, config):
|
||||
|
||||
super(ScoreNetworkX, self).__init__()
|
||||
|
||||
self.nfeat = config.data.n_vocab
|
||||
self.depth = config.model.depth
|
||||
self.nhid = config.model.nhid
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for _ in range(self.depth):
|
||||
if _ == 0:
|
||||
self.layers.append(DenseGCNConv(self.nfeat, self.nhid))
|
||||
else:
|
||||
self.layers.append(DenseGCNConv(self.nhid, self.nhid))
|
||||
|
||||
self.fdim = self.nfeat + self.depth * self.nhid
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=self.nfeat,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.activation = torch.tanh
|
||||
|
||||
def forward(self, x, time_cond, maskX, flags=None):
|
||||
|
||||
x_list = [x]
|
||||
for _ in range(self.depth):
|
||||
x = self.layers[_](x, maskX)
|
||||
x = self.activation(x)
|
||||
x_list.append(x)
|
||||
|
||||
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
|
||||
out_shape = (x.shape[0], x.shape[1], -1)
|
||||
x = self.final(xs).view(*out_shape)
|
||||
|
||||
x = mask_x(x, flags)
|
||||
return x
|
||||
|
||||
|
||||
@utils.register_model(name='ScoreNetworkX_GMH')
|
||||
class ScoreNetworkX_GMH(torch.nn.Module):
|
||||
# def __init__(self, max_feat_num, depth, nhid, num_linears,
|
||||
# c_init, c_hid, c_final, adim, num_heads=4, conv='GCN'):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.max_feat_num = config.data.n_vocab
|
||||
self.depth = config.model.depth
|
||||
self.nhid = config.model.nhid
|
||||
self.c_init = config.model.c_init
|
||||
self.c_hid = config.model.c_hid
|
||||
self.c_final = config.model.c_final
|
||||
self.num_linears = config.model.num_linears
|
||||
self.num_heads = config.model.num_heads
|
||||
self.conv = config.model.conv
|
||||
self.adim = config.model.adim
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for _ in range(self.depth):
|
||||
if _ == 0:
|
||||
self.layers.append(AttentionLayer(self.num_linears, self.max_feat_num,
|
||||
self.nhid, self.nhid, self.c_init,
|
||||
self.c_hid, self.num_heads, self.conv))
|
||||
elif _ == self.depth - 1:
|
||||
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
|
||||
self.nhid, self.c_hid,
|
||||
self.c_final, self.num_heads, self.conv))
|
||||
else:
|
||||
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
|
||||
self.nhid, self.c_hid,
|
||||
self.c_hid, self.num_heads, self.conv))
|
||||
|
||||
fdim = self.max_feat_num + self.depth * self.nhid
|
||||
self.final = MLP(num_layers=3, input_dim=fdim, hidden_dim=2*fdim, output_dim=self.max_feat_num,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.activation = torch.tanh
|
||||
|
||||
def forward(self, x, time_cond, maskX, flags=None):
|
||||
adjc = pow_tensor(maskX, self.c_init)
|
||||
|
||||
x_list = [x]
|
||||
for _ in range(self.depth):
|
||||
x, adjc = self.layers[_](x, adjc, flags)
|
||||
x = self.activation(x)
|
||||
x_list.append(x)
|
||||
|
||||
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
|
||||
out_shape = (x.shape[0], x.shape[1], -1)
|
||||
x = self.final(xs).view(*out_shape)
|
||||
x = mask_x(x, flags)
|
||||
|
||||
return x
|
Reference in New Issue
Block a user