update_name
This commit is contained in:
0
graph_dit/models/__init__.py
Normal file
0
graph_dit/models/__init__.py
Normal file
119
graph_dit/models/conditions.py
Normal file
119
graph_dit/models/conditions.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t = t.view(-1)
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class CategoricalEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds categorical conditions such as data sources into vector representations.
|
||||
Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, num_classes, hidden_size, dropout_prob):
|
||||
super().__init__()
|
||||
use_cfg_embedding = dropout_prob > 0
|
||||
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def token_drop(self, labels, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None, t=None):
|
||||
labels = labels.long().view(-1)
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
if True and train:
|
||||
noise = torch.randn_like(embeddings)
|
||||
embeddings = embeddings + noise
|
||||
return embeddings
|
||||
|
||||
class ClusterContinuousEmbedder(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, dropout_prob):
|
||||
super().__init__()
|
||||
use_cfg_embedding = dropout_prob > 0
|
||||
|
||||
if use_cfg_embedding:
|
||||
self.embedding_drop = nn.Embedding(1, hidden_size)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_size, hidden_size, bias=True),
|
||||
nn.Softmax(dim=1),
|
||||
nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None, timestep=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if force_drop_ids is not None:
|
||||
drop_ids = force_drop_ids == 1
|
||||
else:
|
||||
drop_ids = None
|
||||
|
||||
if (train and use_dropout):
|
||||
drop_ids_rand = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
||||
if force_drop_ids is not None:
|
||||
drop_ids = torch.logical_or(drop_ids, drop_ids_rand)
|
||||
else:
|
||||
drop_ids = drop_ids_rand
|
||||
|
||||
if drop_ids is not None:
|
||||
embeddings = torch.zeros((labels.shape[0], self.hidden_size), device=labels.device)
|
||||
embeddings[~drop_ids] = self.mlp(labels[~drop_ids])
|
||||
embeddings[drop_ids] += self.embedding_drop.weight[0]
|
||||
else:
|
||||
embeddings = self.mlp(labels)
|
||||
|
||||
if train:
|
||||
noise = torch.randn_like(embeddings)
|
||||
embeddings = embeddings + noise
|
||||
return embeddings
|
114
graph_dit/models/layers.py
Normal file
114
graph_dit/models/layers.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from torch.jit import Final
|
||||
import torch.nn.functional as F
|
||||
from itertools import repeat
|
||||
import collections.abc
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Attention(nn.Module):
|
||||
fast_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_norm=False,
|
||||
attn_drop=0,
|
||||
proj_drop=0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||
assert self.fast_attn, "scaled_dot_product_attention Not implemented"
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def dot_product_attention(self, q, k, v):
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn_sfmx = attn.softmax(dim=-1)
|
||||
attn_sfmx = self.attn_drop(attn_sfmx)
|
||||
x = attn_sfmx @ v
|
||||
return x, attn
|
||||
|
||||
def forward(self, x, node_mask):
|
||||
B, N, D = x.shape
|
||||
|
||||
# B, head, N, head_dim
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # B, head, N, head_dim
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
attn_mask = (node_mask[:, None, :, None] & node_mask[:, None, None, :]).expand(-1, self.num_heads, N, N)
|
||||
attn_mask[attn_mask.sum(-1) == 0] = True
|
||||
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, -1)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
bias=True,
|
||||
drop=0.,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
bias = to_2tuple(bias)
|
||||
drop_probs = to_2tuple(drop)
|
||||
linear_layer = nn.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
return parse
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
184
graph_dit/models/transformer.py
Normal file
184
graph_dit/models/transformer.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import utils
|
||||
from models.layers import Attention, Mlp
|
||||
from models.conditions import TimestepEmbedder, CategoricalEmbedder, ClusterContinuousEmbedder
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
class Denoiser(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
max_n_nodes,
|
||||
hidden_size=384,
|
||||
depth=12,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
drop_condition=0.1,
|
||||
Xdim=118,
|
||||
Edim=5,
|
||||
ydim=3,
|
||||
task_type='regression',
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.ydim = ydim
|
||||
self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedding_list = torch.nn.ModuleList()
|
||||
|
||||
self.y_embedding_list.append(ClusterContinuousEmbedder(2, hidden_size, drop_condition))
|
||||
for i in range(ydim - 2):
|
||||
if task_type == 'regression':
|
||||
self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition))
|
||||
else:
|
||||
self.y_embedding_list.append(CategoricalEmbedder(2, hidden_size, drop_condition))
|
||||
|
||||
self.encoders = nn.ModuleList(
|
||||
[
|
||||
SELayer(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.out_layer = OutLayer(
|
||||
max_n_nodes=max_n_nodes,
|
||||
hidden_size=hidden_size,
|
||||
atom_type=Xdim,
|
||||
bond_type=Edim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
def _constant_init(module, i):
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.constant_(module.weight, i)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, i)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
for block in self.encoders :
|
||||
_constant_init(block.adaLN_modulation[0], 0)
|
||||
_constant_init(self.out_layer.adaLN_modulation[0], 0)
|
||||
|
||||
def forward(self, x, e, node_mask, y, t, unconditioned):
|
||||
|
||||
force_drop_id = torch.zeros_like(y.sum(-1))
|
||||
force_drop_id[torch.isnan(y.sum(-1))] = 1
|
||||
if unconditioned:
|
||||
force_drop_id = torch.ones_like(y[:, 0])
|
||||
|
||||
x_in, e_in, y_in = x, e, y
|
||||
bs, n, _ = x.size()
|
||||
x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1)
|
||||
x = self.x_embedder(x)
|
||||
|
||||
c1 = self.t_embedder(t)
|
||||
for i in range(1, self.ydim):
|
||||
if i == 1:
|
||||
c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t)
|
||||
else:
|
||||
c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t)
|
||||
c = c1 + c2
|
||||
|
||||
for i, block in enumerate(self.encoders):
|
||||
x = block(x, c, node_mask)
|
||||
|
||||
# X: B * N * dx, E: B * N * N * de
|
||||
X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask)
|
||||
return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask)
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
||||
super().__init__()
|
||||
self.dropout = 0.
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
|
||||
|
||||
self.attn = Attention(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=True, **block_kwargs
|
||||
)
|
||||
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=int(hidden_size * mlp_ratio),
|
||||
drop=self.dropout,
|
||||
)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, node_mask):
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
) = self.adaLN_modulation(c).chunk(6, dim=1)
|
||||
x = x + gate_msa.unsqueeze(1) * modulate(self.norm1(self.attn(x, node_mask=node_mask)), shift_msa, scale_msa)
|
||||
x = x + gate_mlp.unsqueeze(1) * modulate(self.norm2(self.mlp(x)), shift_mlp, scale_mlp)
|
||||
return x
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
# Structure Output Layer
|
||||
def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None):
|
||||
super().__init__()
|
||||
self.atom_type = atom_type
|
||||
self.bond_type = bond_type
|
||||
final_size = atom_type + max_n_nodes * bond_type
|
||||
self.xedecoder = Mlp(in_features=hidden_size,
|
||||
out_features=final_size, drop=0)
|
||||
|
||||
self.norm_final = nn.LayerNorm(final_size, elementwise_affine=False)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 2 * final_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, x_in, e_in, c, t, node_mask):
|
||||
x_all = self.xedecoder(x)
|
||||
B, N, D = x_all.size()
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x_all = modulate(self.norm_final(x_all), shift, scale)
|
||||
|
||||
atom_out = x_all[:, :, :self.atom_type]
|
||||
atom_out = x_in + atom_out
|
||||
|
||||
bond_out = x_all[:, :, self.atom_type:].reshape(B, N, N, self.bond_type)
|
||||
bond_out = e_in + bond_out
|
||||
|
||||
##### standardize adj_out
|
||||
edge_mask = (~node_mask)[:, :, None] & (~node_mask)[:, None, :]
|
||||
diag_mask = (
|
||||
torch.eye(N, dtype=torch.bool)
|
||||
.unsqueeze(0)
|
||||
.expand(B, -1, -1)
|
||||
.type_as(edge_mask)
|
||||
)
|
||||
bond_out.masked_fill_(edge_mask[:, :, :, None], 0)
|
||||
bond_out.masked_fill_(diag_mask[:, :, :, None], 0)
|
||||
bond_out = 1 / 2 * (bond_out + torch.transpose(bond_out, 1, 2))
|
||||
|
||||
return atom_out, bond_out, None
|
Reference in New Issue
Block a user