first commit

This commit is contained in:
CownowAn
2024-03-15 14:38:51 +00:00
commit bc2ed1304f
321 changed files with 44802 additions and 0 deletions

View File

@@ -0,0 +1,117 @@
import math
import torch
from torch.nn import Parameter
import torch.nn.functional as F
from models.GDSS.layers import DenseGCNConv, MLP
# from ..utils.graph_utils import mask_adjs, mask_x
from .graph_utils import mask_x, mask_adjs
# -------- Graph Multi-Head Attention (GMH) --------
# -------- From Baek et al. (2021) --------
class Attention(torch.nn.Module):
def __init__(self, in_dim, attn_dim, out_dim, num_heads=4, conv='GCN'):
super(Attention, self).__init__()
self.num_heads = num_heads
self.attn_dim = attn_dim
self.out_dim = out_dim
self.conv = conv
self.gnn_q, self.gnn_k, self.gnn_v = self.get_gnn(in_dim, attn_dim, out_dim, conv)
self.activation = torch.tanh
self.softmax_dim = 2
def forward(self, x, adj, flags, attention_mask=None):
if self.conv == 'GCN':
Q = self.gnn_q(x, adj)
K = self.gnn_k(x, adj)
else:
Q = self.gnn_q(x)
K = self.gnn_k(x)
V = self.gnn_v(x, adj)
dim_split = self.attn_dim // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), 0)
K_ = torch.cat(K.split(dim_split, 2), 0)
if attention_mask is not None:
attention_mask = torch.cat([attention_mask for _ in range(self.num_heads)], 0)
attention_score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim)
A = self.activation( attention_mask + attention_score )
else:
A = self.activation( Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim) ) # (B x num_heads) x N x N
# -------- (B x num_heads) x N x N --------
A = A.view(-1, *adj.shape)
A = A.mean(dim=0)
A = (A + A.transpose(-1,-2))/2
return V, A
def get_gnn(self, in_dim, attn_dim, out_dim, conv='GCN'):
if conv == 'GCN':
gnn_q = DenseGCNConv(in_dim, attn_dim)
gnn_k = DenseGCNConv(in_dim, attn_dim)
gnn_v = DenseGCNConv(in_dim, out_dim)
return gnn_q, gnn_k, gnn_v
elif conv == 'MLP':
num_layers=2
gnn_q = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh)
gnn_k = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh)
gnn_v = DenseGCNConv(in_dim, out_dim)
return gnn_q, gnn_k, gnn_v
else:
raise NotImplementedError(f'{conv} not implemented.')
# -------- Layer of ScoreNetworkA --------
class AttentionLayer(torch.nn.Module):
def __init__(self, num_linears, conv_input_dim, attn_dim, conv_output_dim, input_dim, output_dim,
num_heads=4, conv='GCN'):
super(AttentionLayer, self).__init__()
self.attn = torch.nn.ModuleList()
for _ in range(input_dim):
self.attn_dim = attn_dim
self.attn.append(Attention(conv_input_dim, self.attn_dim, conv_output_dim,
num_heads=num_heads, conv=conv))
self.hidden_dim = 2*max(input_dim, output_dim)
self.mlp = MLP(num_linears, 2*input_dim, self.hidden_dim, output_dim, use_bn=False, activate_func=F.elu)
self.multi_channel = MLP(2, input_dim*conv_output_dim, self.hidden_dim, conv_output_dim,
use_bn=False, activate_func=F.elu)
def forward(self, x, adj, flags):
"""
:param x: B x N x F_i
:param adj: B x C_i x N x N
:return: x_out: B x N x F_o, adj_out: B x C_o x N x N
"""
mask_list = []
x_list = []
for _ in range(len(self.attn)):
_x, mask = self.attn[_](x, adj[:,_,:,:], flags)
mask_list.append(mask.unsqueeze(-1))
x_list.append(_x)
x_out = mask_x(self.multi_channel(torch.cat(x_list, dim=-1)), flags)
x_out = torch.tanh(x_out)
mlp_in = torch.cat([torch.cat(mask_list, dim=-1), adj.permute(0,2,3,1)], dim=-1)
shape = mlp_in.shape
mlp_out = self.mlp(mlp_in.view(-1, shape[-1]))
_adj = mlp_out.view(shape[0], shape[1], shape[2], -1).permute(0,3,1,2)
_adj = _adj + _adj.transpose(-1,-2)
adj_out = mask_adjs(_adj, flags)
return x_out, adj_out

View File

@@ -0,0 +1,209 @@
import torch
import torch.nn.functional as F
import networkx as nx
import numpy as np
# -------- Mask batch of node features with 0-1 flags tensor --------
def mask_x(x, flags):
if flags is None:
flags = torch.ones((x.shape[0], x.shape[1]), device=x.device)
return x * flags[:,:,None]
# -------- Mask batch of adjacency matrices with 0-1 flags tensor --------
def mask_adjs(adjs, flags):
"""
:param adjs: B x N x N or B x C x N x N
:param flags: B x N
:return:
"""
if flags is None:
flags = torch.ones((adjs.shape[0], adjs.shape[-1]), device=adjs.device)
if len(adjs.shape) == 4:
flags = flags.unsqueeze(1) # B x 1 x N
adjs = adjs * flags.unsqueeze(-1)
adjs = adjs * flags.unsqueeze(-2)
return adjs
# -------- Create flags tensor from graph dataset --------
def node_flags(adj, eps=1e-5):
flags = torch.abs(adj).sum(-1).gt(eps).to(dtype=torch.float32)
if len(flags.shape)==3:
flags = flags[:,0,:]
return flags
# -------- Create initial node features --------
def init_features(init, adjs=None, nfeat=10):
if init=='zeros':
feature = torch.zeros((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
elif init=='ones':
feature = torch.ones((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
elif init=='deg':
feature = adjs.sum(dim=-1).to(torch.long)
num_classes = nfeat
try:
feature = F.one_hot(feature, num_classes=num_classes).to(torch.float32)
except:
print(feature.max().item())
raise NotImplementedError(f'max_feat_num mismatch')
else:
raise NotImplementedError(f'{init} not implemented')
flags = node_flags(adjs)
return mask_x(feature, flags)
# -------- Sample initial flags tensor from the training graph set --------
def init_flags(graph_list, config, batch_size=None):
if batch_size is None:
batch_size = config.data.batch_size
max_node_num = config.data.max_node_num
graph_tensor = graphs_to_tensor(graph_list, max_node_num)
idx = np.random.randint(0, len(graph_list), batch_size)
flags = node_flags(graph_tensor[idx])
return flags
# -------- Generate noise --------
def gen_noise(x, flags, sym=True):
z = torch.randn_like(x)
if sym:
z = z.triu(1)
z = z + z.transpose(-1,-2)
z = mask_adjs(z, flags)
else:
z = mask_x(z, flags)
return z
# -------- Quantize generated graphs --------
def quantize(adjs, thr=0.5):
adjs_ = torch.where(adjs < thr, torch.zeros_like(adjs), torch.ones_like(adjs))
return adjs_
# -------- Quantize generated molecules --------
# adjs: 32 x 9 x 9
def quantize_mol(adjs):
if type(adjs).__name__ == 'Tensor':
adjs = adjs.detach().cpu()
else:
adjs = torch.tensor(adjs)
adjs[adjs >= 2.5] = 3
adjs[torch.bitwise_and(adjs >= 1.5, adjs < 2.5)] = 2
adjs[torch.bitwise_and(adjs >= 0.5, adjs < 1.5)] = 1
adjs[adjs < 0.5] = 0
return np.array(adjs.to(torch.int64))
def adjs_to_graphs(adjs, is_cuda=False):
graph_list = []
for adj in adjs:
if is_cuda:
adj = adj.detach().cpu().numpy()
G = nx.from_numpy_matrix(adj)
G.remove_edges_from(nx.selfloop_edges(G))
G.remove_nodes_from(list(nx.isolates(G)))
if G.number_of_nodes() < 1:
G.add_node(1)
graph_list.append(G)
return graph_list
# -------- Check if the adjacency matrices are symmetric --------
def check_sym(adjs, print_val=False):
sym_error = (adjs-adjs.transpose(-1,-2)).abs().sum([0,1,2])
if not sym_error < 1e-2:
raise ValueError(f'Not symmetric: {sym_error:.4e}')
if print_val:
print(f'{sym_error:.4e}')
# -------- Create higher order adjacency matrices --------
def pow_tensor(x, cnum):
# x : B x N x N
x_ = x.clone()
xc = [x.unsqueeze(1)]
for _ in range(cnum-1):
x_ = torch.bmm(x_, x)
xc.append(x_.unsqueeze(1))
xc = torch.cat(xc, dim=1)
return xc
# -------- Create padded adjacency matrices --------
def pad_adjs(ori_adj, node_number):
a = ori_adj
ori_len = a.shape[-1]
if ori_len == node_number:
return a
if ori_len > node_number:
raise ValueError(f'ori_len {ori_len} > node_number {node_number}')
a = np.concatenate([a, np.zeros([ori_len, node_number - ori_len])], axis=-1)
a = np.concatenate([a, np.zeros([node_number - ori_len, node_number])], axis=0)
return a
def graphs_to_tensor(graph_list, max_node_num):
adjs_list = []
max_node_num = max_node_num
for g in graph_list:
assert isinstance(g, nx.Graph)
node_list = []
for v, feature in g.nodes.data('feature'):
node_list.append(v)
adj = nx.to_numpy_matrix(g, nodelist=node_list)
padded_adj = pad_adjs(adj, node_number=max_node_num)
adjs_list.append(padded_adj)
del graph_list
adjs_np = np.asarray(adjs_list)
del adjs_list
adjs_tensor = torch.tensor(adjs_np, dtype=torch.float32)
del adjs_np
return adjs_tensor
def graphs_to_adj(graph, max_node_num):
max_node_num = max_node_num
assert isinstance(graph, nx.Graph)
node_list = []
for v, feature in graph.nodes.data('feature'):
node_list.append(v)
adj = nx.to_numpy_matrix(graph, nodelist=node_list)
padded_adj = pad_adjs(adj, node_number=max_node_num)
adj = torch.tensor(padded_adj, dtype=torch.float32)
del padded_adj
return adj
def node_feature_to_matrix(x):
"""
:param x: BS x N x F
:return:
x_pair: BS x N x N x 2F
"""
x_b = x.unsqueeze(-2).expand(x.size(0), x.size(1), x.size(1), -1) # BS x N x N x F
x_pair = torch.cat([x_b, x_b.transpose(1, 2)], dim=-1) # BS x N x N x 2F
return x_pair

View File

@@ -0,0 +1,153 @@
import torch
from torch.nn import Parameter
import torch.nn.functional as F
import math
from typing import Any
def glorot(tensor):
if tensor is not None:
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
tensor.data.uniform_(-stdv, stdv)
def zeros(tensor):
if tensor is not None:
tensor.data.fill_(0)
def reset(value: Any):
if hasattr(value, 'reset_parameters'):
value.reset_parameters()
else:
for child in value.children() if hasattr(value, 'children') else []:
reset(child)
# -------- GCN layer --------
class DenseGCNConv(torch.nn.Module):
r"""See :class:`torch_geometric.nn.conv.GCNConv`.
"""
def __init__(self, in_channels, out_channels, improved=False, bias=True):
super(DenseGCNConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.improved = improved
self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
def forward(self, x, adj, mask=None, add_loop=True):
r"""
Args:
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
\times N \times F}`, with batch-size :math:`B`, (maximum)
number of nodes :math:`N` for each graph, and feature
dimension :math:`F`.
adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B
\times N \times N}`. The adjacency tensor is broadcastable in
the batch dimension, resulting in a shared adjacency matrix for
the complete batch.
mask (BoolTensor, optional): Mask matrix
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
the valid nodes for each graph. (default: :obj:`None`)
add_loop (bool, optional): If set to :obj:`False`, the layer will
not automatically add self-loops to the adjacency matrices.
(default: :obj:`True`)
"""
x = x.unsqueeze(0) if x.dim() == 2 else x
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
B, N, _ = adj.size()
if add_loop:
adj = adj.clone()
idx = torch.arange(N, dtype=torch.long, device=adj.device)
adj[:, idx, idx] = 1 if not self.improved else 2
out = torch.matmul(x, self.weight)
deg_inv_sqrt = adj.sum(dim=-1).clamp(min=1).pow(-0.5)
adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
out = torch.matmul(adj, out)
if self.bias is not None:
out = out + self.bias
if mask is not None:
out = out * mask.view(B, N, 1).to(x.dtype)
return out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
# -------- MLP layer --------
class MLP(torch.nn.Module):
def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu):
"""
num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
input_dim: dimensionality of input features
hidden_dim: dimensionality of hidden units at ALL layers
output_dim: number of classes for prediction
num_classes: the number of classes of input, to be treated with different gains and biases,
(see the definition of class `ConditionalLayer1d`)
"""
super(MLP, self).__init__()
self.linear_or_not = True # default is linear model
self.num_layers = num_layers
self.use_bn = use_bn
self.activate_func = activate_func
if num_layers < 1:
raise ValueError("number of layers should be positive!")
elif num_layers == 1:
# Linear model
self.linear = torch.nn.Linear(input_dim, output_dim)
else:
# Multi-layer model
self.linear_or_not = False
self.linears = torch.nn.ModuleList()
self.linears.append(torch.nn.Linear(input_dim, hidden_dim))
for layer in range(num_layers - 2):
self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim))
self.linears.append(torch.nn.Linear(hidden_dim, output_dim))
if self.use_bn:
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layers - 1):
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
def forward(self, x):
"""
:param x: [num_classes * batch_size, N, F_i], batch of node features
note that in self.cond_layers[layer],
`x` is splited into `num_classes` groups in dim=0,
and then treated with different gains and biases
"""
if self.linear_or_not:
# If linear model
return self.linear(x)
else:
# If MLP
h = x
for layer in range(self.num_layers - 1):
h = self.linears[layer](h)
if self.use_bn:
h = self.batch_norms[layer](h)
h = self.activate_func(h)
return self.linears[self.num_layers - 1](h)

View File

@@ -0,0 +1,103 @@
import torch
import torch.nn.functional as F
from models.GDSS.layers import DenseGCNConv, MLP
from .graph_utils import mask_x, pow_tensor
from .attention import AttentionLayer
from .. import utils
@utils.register_model(name='ScoreNetworkX')
class ScoreNetworkX(torch.nn.Module):
# def __init__(self, max_feat_num, depth, nhid):
def __init__(self, config):
super(ScoreNetworkX, self).__init__()
self.nfeat = config.data.n_vocab
self.depth = config.model.depth
self.nhid = config.model.nhid
self.layers = torch.nn.ModuleList()
for _ in range(self.depth):
if _ == 0:
self.layers.append(DenseGCNConv(self.nfeat, self.nhid))
else:
self.layers.append(DenseGCNConv(self.nhid, self.nhid))
self.fdim = self.nfeat + self.depth * self.nhid
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=self.nfeat,
use_bn=False, activate_func=F.elu)
self.activation = torch.tanh
def forward(self, x, time_cond, maskX, flags=None):
x_list = [x]
for _ in range(self.depth):
x = self.layers[_](x, maskX)
x = self.activation(x)
x_list.append(x)
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
out_shape = (x.shape[0], x.shape[1], -1)
x = self.final(xs).view(*out_shape)
x = mask_x(x, flags)
return x
@utils.register_model(name='ScoreNetworkX_GMH')
class ScoreNetworkX_GMH(torch.nn.Module):
# def __init__(self, max_feat_num, depth, nhid, num_linears,
# c_init, c_hid, c_final, adim, num_heads=4, conv='GCN'):
def __init__(self, config):
super().__init__()
self.max_feat_num = config.data.n_vocab
self.depth = config.model.depth
self.nhid = config.model.nhid
self.c_init = config.model.c_init
self.c_hid = config.model.c_hid
self.c_final = config.model.c_final
self.num_linears = config.model.num_linears
self.num_heads = config.model.num_heads
self.conv = config.model.conv
self.adim = config.model.adim
self.layers = torch.nn.ModuleList()
for _ in range(self.depth):
if _ == 0:
self.layers.append(AttentionLayer(self.num_linears, self.max_feat_num,
self.nhid, self.nhid, self.c_init,
self.c_hid, self.num_heads, self.conv))
elif _ == self.depth - 1:
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
self.nhid, self.c_hid,
self.c_final, self.num_heads, self.conv))
else:
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
self.nhid, self.c_hid,
self.c_hid, self.num_heads, self.conv))
fdim = self.max_feat_num + self.depth * self.nhid
self.final = MLP(num_layers=3, input_dim=fdim, hidden_dim=2*fdim, output_dim=self.max_feat_num,
use_bn=False, activate_func=F.elu)
self.activation = torch.tanh
def forward(self, x, time_cond, maskX, flags=None):
adjc = pow_tensor(maskX, self.c_init)
x_list = [x]
for _ in range(self.depth):
x, adjc = self.layers[_](x, adjc, flags)
x = self.activation(x)
x_list.append(x)
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
out_shape = (x.shape[0], x.shape[1], -1)
x = self.final(xs).view(*out_shape)
x = mask_x(x, flags)
return x