first commit
This commit is contained in:
117
MobileNetV3/models/GDSS/attention.py
Normal file
117
MobileNetV3/models/GDSS/attention.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import math
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.GDSS.layers import DenseGCNConv, MLP
|
||||
# from ..utils.graph_utils import mask_adjs, mask_x
|
||||
from .graph_utils import mask_x, mask_adjs
|
||||
|
||||
|
||||
# -------- Graph Multi-Head Attention (GMH) --------
|
||||
# -------- From Baek et al. (2021) --------
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_dim, attn_dim, out_dim, num_heads=4, conv='GCN'):
|
||||
super(Attention, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
self.attn_dim = attn_dim
|
||||
self.out_dim = out_dim
|
||||
self.conv = conv
|
||||
|
||||
self.gnn_q, self.gnn_k, self.gnn_v = self.get_gnn(in_dim, attn_dim, out_dim, conv)
|
||||
self.activation = torch.tanh
|
||||
self.softmax_dim = 2
|
||||
|
||||
def forward(self, x, adj, flags, attention_mask=None):
|
||||
|
||||
if self.conv == 'GCN':
|
||||
Q = self.gnn_q(x, adj)
|
||||
K = self.gnn_k(x, adj)
|
||||
else:
|
||||
Q = self.gnn_q(x)
|
||||
K = self.gnn_k(x)
|
||||
|
||||
V = self.gnn_v(x, adj)
|
||||
dim_split = self.attn_dim // self.num_heads
|
||||
Q_ = torch.cat(Q.split(dim_split, 2), 0)
|
||||
K_ = torch.cat(K.split(dim_split, 2), 0)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.cat([attention_mask for _ in range(self.num_heads)], 0)
|
||||
attention_score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim)
|
||||
A = self.activation( attention_mask + attention_score )
|
||||
else:
|
||||
A = self.activation( Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim) ) # (B x num_heads) x N x N
|
||||
|
||||
# -------- (B x num_heads) x N x N --------
|
||||
A = A.view(-1, *adj.shape)
|
||||
A = A.mean(dim=0)
|
||||
A = (A + A.transpose(-1,-2))/2
|
||||
|
||||
return V, A
|
||||
|
||||
def get_gnn(self, in_dim, attn_dim, out_dim, conv='GCN'):
|
||||
|
||||
if conv == 'GCN':
|
||||
gnn_q = DenseGCNConv(in_dim, attn_dim)
|
||||
gnn_k = DenseGCNConv(in_dim, attn_dim)
|
||||
gnn_v = DenseGCNConv(in_dim, out_dim)
|
||||
|
||||
return gnn_q, gnn_k, gnn_v
|
||||
|
||||
elif conv == 'MLP':
|
||||
num_layers=2
|
||||
gnn_q = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh)
|
||||
gnn_k = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh)
|
||||
gnn_v = DenseGCNConv(in_dim, out_dim)
|
||||
|
||||
return gnn_q, gnn_k, gnn_v
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'{conv} not implemented.')
|
||||
|
||||
|
||||
# -------- Layer of ScoreNetworkA --------
|
||||
class AttentionLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_linears, conv_input_dim, attn_dim, conv_output_dim, input_dim, output_dim,
|
||||
num_heads=4, conv='GCN'):
|
||||
|
||||
super(AttentionLayer, self).__init__()
|
||||
|
||||
self.attn = torch.nn.ModuleList()
|
||||
for _ in range(input_dim):
|
||||
self.attn_dim = attn_dim
|
||||
self.attn.append(Attention(conv_input_dim, self.attn_dim, conv_output_dim,
|
||||
num_heads=num_heads, conv=conv))
|
||||
|
||||
self.hidden_dim = 2*max(input_dim, output_dim)
|
||||
self.mlp = MLP(num_linears, 2*input_dim, self.hidden_dim, output_dim, use_bn=False, activate_func=F.elu)
|
||||
self.multi_channel = MLP(2, input_dim*conv_output_dim, self.hidden_dim, conv_output_dim,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
def forward(self, x, adj, flags):
|
||||
"""
|
||||
|
||||
:param x: B x N x F_i
|
||||
:param adj: B x C_i x N x N
|
||||
:return: x_out: B x N x F_o, adj_out: B x C_o x N x N
|
||||
"""
|
||||
mask_list = []
|
||||
x_list = []
|
||||
for _ in range(len(self.attn)):
|
||||
_x, mask = self.attn[_](x, adj[:,_,:,:], flags)
|
||||
mask_list.append(mask.unsqueeze(-1))
|
||||
x_list.append(_x)
|
||||
x_out = mask_x(self.multi_channel(torch.cat(x_list, dim=-1)), flags)
|
||||
x_out = torch.tanh(x_out)
|
||||
|
||||
mlp_in = torch.cat([torch.cat(mask_list, dim=-1), adj.permute(0,2,3,1)], dim=-1)
|
||||
shape = mlp_in.shape
|
||||
mlp_out = self.mlp(mlp_in.view(-1, shape[-1]))
|
||||
_adj = mlp_out.view(shape[0], shape[1], shape[2], -1).permute(0,3,1,2)
|
||||
_adj = _adj + _adj.transpose(-1,-2)
|
||||
adj_out = mask_adjs(_adj, flags)
|
||||
|
||||
return x_out, adj_out
|
209
MobileNetV3/models/GDSS/graph_utils.py
Normal file
209
MobileNetV3/models/GDSS/graph_utils.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
# -------- Mask batch of node features with 0-1 flags tensor --------
|
||||
def mask_x(x, flags):
|
||||
|
||||
if flags is None:
|
||||
flags = torch.ones((x.shape[0], x.shape[1]), device=x.device)
|
||||
return x * flags[:,:,None]
|
||||
|
||||
|
||||
# -------- Mask batch of adjacency matrices with 0-1 flags tensor --------
|
||||
def mask_adjs(adjs, flags):
|
||||
"""
|
||||
:param adjs: B x N x N or B x C x N x N
|
||||
:param flags: B x N
|
||||
:return:
|
||||
"""
|
||||
if flags is None:
|
||||
flags = torch.ones((adjs.shape[0], adjs.shape[-1]), device=adjs.device)
|
||||
|
||||
if len(adjs.shape) == 4:
|
||||
flags = flags.unsqueeze(1) # B x 1 x N
|
||||
adjs = adjs * flags.unsqueeze(-1)
|
||||
adjs = adjs * flags.unsqueeze(-2)
|
||||
return adjs
|
||||
|
||||
|
||||
# -------- Create flags tensor from graph dataset --------
|
||||
def node_flags(adj, eps=1e-5):
|
||||
|
||||
flags = torch.abs(adj).sum(-1).gt(eps).to(dtype=torch.float32)
|
||||
|
||||
if len(flags.shape)==3:
|
||||
flags = flags[:,0,:]
|
||||
return flags
|
||||
|
||||
|
||||
# -------- Create initial node features --------
|
||||
def init_features(init, adjs=None, nfeat=10):
|
||||
|
||||
if init=='zeros':
|
||||
feature = torch.zeros((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
|
||||
elif init=='ones':
|
||||
feature = torch.ones((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
|
||||
elif init=='deg':
|
||||
feature = adjs.sum(dim=-1).to(torch.long)
|
||||
num_classes = nfeat
|
||||
try:
|
||||
feature = F.one_hot(feature, num_classes=num_classes).to(torch.float32)
|
||||
except:
|
||||
print(feature.max().item())
|
||||
raise NotImplementedError(f'max_feat_num mismatch')
|
||||
else:
|
||||
raise NotImplementedError(f'{init} not implemented')
|
||||
|
||||
flags = node_flags(adjs)
|
||||
|
||||
return mask_x(feature, flags)
|
||||
|
||||
|
||||
# -------- Sample initial flags tensor from the training graph set --------
|
||||
def init_flags(graph_list, config, batch_size=None):
|
||||
if batch_size is None:
|
||||
batch_size = config.data.batch_size
|
||||
max_node_num = config.data.max_node_num
|
||||
graph_tensor = graphs_to_tensor(graph_list, max_node_num)
|
||||
idx = np.random.randint(0, len(graph_list), batch_size)
|
||||
flags = node_flags(graph_tensor[idx])
|
||||
|
||||
return flags
|
||||
|
||||
|
||||
# -------- Generate noise --------
|
||||
def gen_noise(x, flags, sym=True):
|
||||
z = torch.randn_like(x)
|
||||
if sym:
|
||||
z = z.triu(1)
|
||||
z = z + z.transpose(-1,-2)
|
||||
z = mask_adjs(z, flags)
|
||||
else:
|
||||
z = mask_x(z, flags)
|
||||
return z
|
||||
|
||||
|
||||
# -------- Quantize generated graphs --------
|
||||
def quantize(adjs, thr=0.5):
|
||||
adjs_ = torch.where(adjs < thr, torch.zeros_like(adjs), torch.ones_like(adjs))
|
||||
return adjs_
|
||||
|
||||
|
||||
# -------- Quantize generated molecules --------
|
||||
# adjs: 32 x 9 x 9
|
||||
def quantize_mol(adjs):
|
||||
if type(adjs).__name__ == 'Tensor':
|
||||
adjs = adjs.detach().cpu()
|
||||
else:
|
||||
adjs = torch.tensor(adjs)
|
||||
adjs[adjs >= 2.5] = 3
|
||||
adjs[torch.bitwise_and(adjs >= 1.5, adjs < 2.5)] = 2
|
||||
adjs[torch.bitwise_and(adjs >= 0.5, adjs < 1.5)] = 1
|
||||
adjs[adjs < 0.5] = 0
|
||||
return np.array(adjs.to(torch.int64))
|
||||
|
||||
|
||||
def adjs_to_graphs(adjs, is_cuda=False):
|
||||
graph_list = []
|
||||
for adj in adjs:
|
||||
if is_cuda:
|
||||
adj = adj.detach().cpu().numpy()
|
||||
G = nx.from_numpy_matrix(adj)
|
||||
G.remove_edges_from(nx.selfloop_edges(G))
|
||||
G.remove_nodes_from(list(nx.isolates(G)))
|
||||
if G.number_of_nodes() < 1:
|
||||
G.add_node(1)
|
||||
graph_list.append(G)
|
||||
return graph_list
|
||||
|
||||
|
||||
# -------- Check if the adjacency matrices are symmetric --------
|
||||
def check_sym(adjs, print_val=False):
|
||||
sym_error = (adjs-adjs.transpose(-1,-2)).abs().sum([0,1,2])
|
||||
if not sym_error < 1e-2:
|
||||
raise ValueError(f'Not symmetric: {sym_error:.4e}')
|
||||
if print_val:
|
||||
print(f'{sym_error:.4e}')
|
||||
|
||||
|
||||
# -------- Create higher order adjacency matrices --------
|
||||
def pow_tensor(x, cnum):
|
||||
# x : B x N x N
|
||||
x_ = x.clone()
|
||||
xc = [x.unsqueeze(1)]
|
||||
for _ in range(cnum-1):
|
||||
x_ = torch.bmm(x_, x)
|
||||
xc.append(x_.unsqueeze(1))
|
||||
xc = torch.cat(xc, dim=1)
|
||||
|
||||
return xc
|
||||
|
||||
|
||||
# -------- Create padded adjacency matrices --------
|
||||
def pad_adjs(ori_adj, node_number):
|
||||
a = ori_adj
|
||||
ori_len = a.shape[-1]
|
||||
if ori_len == node_number:
|
||||
return a
|
||||
if ori_len > node_number:
|
||||
raise ValueError(f'ori_len {ori_len} > node_number {node_number}')
|
||||
a = np.concatenate([a, np.zeros([ori_len, node_number - ori_len])], axis=-1)
|
||||
a = np.concatenate([a, np.zeros([node_number - ori_len, node_number])], axis=0)
|
||||
return a
|
||||
|
||||
|
||||
def graphs_to_tensor(graph_list, max_node_num):
|
||||
adjs_list = []
|
||||
max_node_num = max_node_num
|
||||
|
||||
for g in graph_list:
|
||||
assert isinstance(g, nx.Graph)
|
||||
node_list = []
|
||||
for v, feature in g.nodes.data('feature'):
|
||||
node_list.append(v)
|
||||
|
||||
adj = nx.to_numpy_matrix(g, nodelist=node_list)
|
||||
padded_adj = pad_adjs(adj, node_number=max_node_num)
|
||||
adjs_list.append(padded_adj)
|
||||
|
||||
del graph_list
|
||||
|
||||
adjs_np = np.asarray(adjs_list)
|
||||
del adjs_list
|
||||
|
||||
adjs_tensor = torch.tensor(adjs_np, dtype=torch.float32)
|
||||
del adjs_np
|
||||
|
||||
return adjs_tensor
|
||||
|
||||
|
||||
def graphs_to_adj(graph, max_node_num):
|
||||
max_node_num = max_node_num
|
||||
|
||||
assert isinstance(graph, nx.Graph)
|
||||
node_list = []
|
||||
for v, feature in graph.nodes.data('feature'):
|
||||
node_list.append(v)
|
||||
|
||||
adj = nx.to_numpy_matrix(graph, nodelist=node_list)
|
||||
padded_adj = pad_adjs(adj, node_number=max_node_num)
|
||||
|
||||
adj = torch.tensor(padded_adj, dtype=torch.float32)
|
||||
del padded_adj
|
||||
|
||||
return adj
|
||||
|
||||
|
||||
def node_feature_to_matrix(x):
|
||||
"""
|
||||
:param x: BS x N x F
|
||||
:return:
|
||||
x_pair: BS x N x N x 2F
|
||||
"""
|
||||
x_b = x.unsqueeze(-2).expand(x.size(0), x.size(1), x.size(1), -1) # BS x N x N x F
|
||||
x_pair = torch.cat([x_b, x_b.transpose(1, 2)], dim=-1) # BS x N x N x 2F
|
||||
|
||||
return x_pair
|
153
MobileNetV3/models/GDSS/layers.py
Normal file
153
MobileNetV3/models/GDSS/layers.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
|
||||
def glorot(tensor):
|
||||
if tensor is not None:
|
||||
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
|
||||
tensor.data.uniform_(-stdv, stdv)
|
||||
|
||||
def zeros(tensor):
|
||||
if tensor is not None:
|
||||
tensor.data.fill_(0)
|
||||
|
||||
def reset(value: Any):
|
||||
if hasattr(value, 'reset_parameters'):
|
||||
value.reset_parameters()
|
||||
else:
|
||||
for child in value.children() if hasattr(value, 'children') else []:
|
||||
reset(child)
|
||||
|
||||
# -------- GCN layer --------
|
||||
class DenseGCNConv(torch.nn.Module):
|
||||
r"""See :class:`torch_geometric.nn.conv.GCNConv`.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, improved=False, bias=True):
|
||||
super(DenseGCNConv, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.improved = improved
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
glorot(self.weight)
|
||||
zeros(self.bias)
|
||||
|
||||
|
||||
def forward(self, x, adj, mask=None, add_loop=True):
|
||||
r"""
|
||||
Args:
|
||||
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
|
||||
\times N \times F}`, with batch-size :math:`B`, (maximum)
|
||||
number of nodes :math:`N` for each graph, and feature
|
||||
dimension :math:`F`.
|
||||
adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B
|
||||
\times N \times N}`. The adjacency tensor is broadcastable in
|
||||
the batch dimension, resulting in a shared adjacency matrix for
|
||||
the complete batch.
|
||||
mask (BoolTensor, optional): Mask matrix
|
||||
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
||||
the valid nodes for each graph. (default: :obj:`None`)
|
||||
add_loop (bool, optional): If set to :obj:`False`, the layer will
|
||||
not automatically add self-loops to the adjacency matrices.
|
||||
(default: :obj:`True`)
|
||||
"""
|
||||
x = x.unsqueeze(0) if x.dim() == 2 else x
|
||||
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
|
||||
B, N, _ = adj.size()
|
||||
|
||||
if add_loop:
|
||||
adj = adj.clone()
|
||||
idx = torch.arange(N, dtype=torch.long, device=adj.device)
|
||||
adj[:, idx, idx] = 1 if not self.improved else 2
|
||||
|
||||
out = torch.matmul(x, self.weight)
|
||||
deg_inv_sqrt = adj.sum(dim=-1).clamp(min=1).pow(-0.5)
|
||||
|
||||
adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
|
||||
out = torch.matmul(adj, out)
|
||||
|
||||
if self.bias is not None:
|
||||
out = out + self.bias
|
||||
|
||||
if mask is not None:
|
||||
out = out * mask.view(B, N, 1).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
|
||||
self.out_channels)
|
||||
|
||||
# -------- MLP layer --------
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu):
|
||||
"""
|
||||
num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
|
||||
input_dim: dimensionality of input features
|
||||
hidden_dim: dimensionality of hidden units at ALL layers
|
||||
output_dim: number of classes for prediction
|
||||
num_classes: the number of classes of input, to be treated with different gains and biases,
|
||||
(see the definition of class `ConditionalLayer1d`)
|
||||
"""
|
||||
|
||||
super(MLP, self).__init__()
|
||||
|
||||
self.linear_or_not = True # default is linear model
|
||||
self.num_layers = num_layers
|
||||
self.use_bn = use_bn
|
||||
self.activate_func = activate_func
|
||||
|
||||
if num_layers < 1:
|
||||
raise ValueError("number of layers should be positive!")
|
||||
elif num_layers == 1:
|
||||
# Linear model
|
||||
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||
else:
|
||||
# Multi-layer model
|
||||
self.linear_or_not = False
|
||||
self.linears = torch.nn.ModuleList()
|
||||
|
||||
self.linears.append(torch.nn.Linear(input_dim, hidden_dim))
|
||||
for layer in range(num_layers - 2):
|
||||
self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim))
|
||||
self.linears.append(torch.nn.Linear(hidden_dim, output_dim))
|
||||
|
||||
if self.use_bn:
|
||||
self.batch_norms = torch.nn.ModuleList()
|
||||
for layer in range(num_layers - 1):
|
||||
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x: [num_classes * batch_size, N, F_i], batch of node features
|
||||
note that in self.cond_layers[layer],
|
||||
`x` is splited into `num_classes` groups in dim=0,
|
||||
and then treated with different gains and biases
|
||||
"""
|
||||
if self.linear_or_not:
|
||||
# If linear model
|
||||
return self.linear(x)
|
||||
else:
|
||||
# If MLP
|
||||
h = x
|
||||
for layer in range(self.num_layers - 1):
|
||||
h = self.linears[layer](h)
|
||||
if self.use_bn:
|
||||
h = self.batch_norms[layer](h)
|
||||
h = self.activate_func(h)
|
||||
return self.linears[self.num_layers - 1](h)
|
103
MobileNetV3/models/GDSS/scorenetx.py
Normal file
103
MobileNetV3/models/GDSS/scorenetx.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.GDSS.layers import DenseGCNConv, MLP
|
||||
from .graph_utils import mask_x, pow_tensor
|
||||
from .attention import AttentionLayer
|
||||
from .. import utils
|
||||
|
||||
@utils.register_model(name='ScoreNetworkX')
|
||||
class ScoreNetworkX(torch.nn.Module):
|
||||
|
||||
# def __init__(self, max_feat_num, depth, nhid):
|
||||
def __init__(self, config):
|
||||
|
||||
super(ScoreNetworkX, self).__init__()
|
||||
|
||||
self.nfeat = config.data.n_vocab
|
||||
self.depth = config.model.depth
|
||||
self.nhid = config.model.nhid
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for _ in range(self.depth):
|
||||
if _ == 0:
|
||||
self.layers.append(DenseGCNConv(self.nfeat, self.nhid))
|
||||
else:
|
||||
self.layers.append(DenseGCNConv(self.nhid, self.nhid))
|
||||
|
||||
self.fdim = self.nfeat + self.depth * self.nhid
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=self.nfeat,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.activation = torch.tanh
|
||||
|
||||
def forward(self, x, time_cond, maskX, flags=None):
|
||||
|
||||
x_list = [x]
|
||||
for _ in range(self.depth):
|
||||
x = self.layers[_](x, maskX)
|
||||
x = self.activation(x)
|
||||
x_list.append(x)
|
||||
|
||||
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
|
||||
out_shape = (x.shape[0], x.shape[1], -1)
|
||||
x = self.final(xs).view(*out_shape)
|
||||
|
||||
x = mask_x(x, flags)
|
||||
return x
|
||||
|
||||
|
||||
@utils.register_model(name='ScoreNetworkX_GMH')
|
||||
class ScoreNetworkX_GMH(torch.nn.Module):
|
||||
# def __init__(self, max_feat_num, depth, nhid, num_linears,
|
||||
# c_init, c_hid, c_final, adim, num_heads=4, conv='GCN'):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.max_feat_num = config.data.n_vocab
|
||||
self.depth = config.model.depth
|
||||
self.nhid = config.model.nhid
|
||||
self.c_init = config.model.c_init
|
||||
self.c_hid = config.model.c_hid
|
||||
self.c_final = config.model.c_final
|
||||
self.num_linears = config.model.num_linears
|
||||
self.num_heads = config.model.num_heads
|
||||
self.conv = config.model.conv
|
||||
self.adim = config.model.adim
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for _ in range(self.depth):
|
||||
if _ == 0:
|
||||
self.layers.append(AttentionLayer(self.num_linears, self.max_feat_num,
|
||||
self.nhid, self.nhid, self.c_init,
|
||||
self.c_hid, self.num_heads, self.conv))
|
||||
elif _ == self.depth - 1:
|
||||
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
|
||||
self.nhid, self.c_hid,
|
||||
self.c_final, self.num_heads, self.conv))
|
||||
else:
|
||||
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
|
||||
self.nhid, self.c_hid,
|
||||
self.c_hid, self.num_heads, self.conv))
|
||||
|
||||
fdim = self.max_feat_num + self.depth * self.nhid
|
||||
self.final = MLP(num_layers=3, input_dim=fdim, hidden_dim=2*fdim, output_dim=self.max_feat_num,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.activation = torch.tanh
|
||||
|
||||
def forward(self, x, time_cond, maskX, flags=None):
|
||||
adjc = pow_tensor(maskX, self.c_init)
|
||||
|
||||
x_list = [x]
|
||||
for _ in range(self.depth):
|
||||
x, adjc = self.layers[_](x, adjc, flags)
|
||||
x = self.activation(x)
|
||||
x_list.append(x)
|
||||
|
||||
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
|
||||
out_shape = (x.shape[0], x.shape[1], -1)
|
||||
x = self.final(xs).view(*out_shape)
|
||||
x = mask_x(x, flags)
|
||||
|
||||
return x
|
0
MobileNetV3/models/__init__.py
Executable file
0
MobileNetV3/models/__init__.py
Executable file
352
MobileNetV3/models/cate.py
Normal file
352
MobileNetV3/models/cate.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import functools
|
||||
from torch_geometric.utils import dense_to_sparse
|
||||
|
||||
from . import utils, layers, gnns
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .transformer import Encoder, SemanticEmbedding
|
||||
from models.GDSS.layers import MLP
|
||||
from .set_encoder.setenc_models import SetPool
|
||||
|
||||
""" Transformer Encoder """
|
||||
class GraphEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(GraphEncoder, self).__init__()
|
||||
# Forward Transformers
|
||||
self.encoder_f = Encoder(config)
|
||||
|
||||
def forward(self, x, mask):
|
||||
h_f, hs_f, attns_f = self.encoder_f(x, mask)
|
||||
h = torch.cat(hs_f, dim=-1)
|
||||
return h
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_embeddings(h_x):
|
||||
h_x = h_x.cpu()
|
||||
return h_x[:, -1]
|
||||
|
||||
class CLSHead(nn.Module):
|
||||
def __init__(self, config, init_weights=None):
|
||||
super(CLSHead, self).__init__()
|
||||
self.layer_1 = nn.Linear(config.d_model, config.d_model)
|
||||
self.dropout = nn.Dropout(p=config.dropout)
|
||||
self.layer_2 = nn.Linear(config.d_model, config.n_vocab)
|
||||
if init_weights is not None:
|
||||
self.layer_2.weight = init_weights
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(torch.tanh(self.layer_1(x)))
|
||||
return F.log_softmax(self.layer_2(x), dim=-1)
|
||||
|
||||
|
||||
@utils.register_model(name='CATE')
|
||||
class CATE(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CATE, self).__init__()
|
||||
# Shared Embedding Layer
|
||||
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
|
||||
self.dropout_op = nn.Dropout(p=config.model.dropout)
|
||||
self.d_model = config.model.graph_encoder.d_model
|
||||
self.act = act = get_act(config)
|
||||
# Time
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
# 2 GraphEncoder for X and Y
|
||||
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
|
||||
|
||||
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.pos_enc_type = config.model.pos_enc_type
|
||||
self.pos_encoder = PositionalEncoding_StageWise(d_model=self.d_model, max_len=config.data.max_node)
|
||||
|
||||
def forward(self, X, time_cond, maskX):
|
||||
|
||||
# Shared Embeddings
|
||||
emb_x = self.dropout_op(self.opEmb(X))
|
||||
|
||||
if self.pos_encoder is not None:
|
||||
emb_p = self.pos_encoder(emb_x) # [20, 64]
|
||||
emb_x = emb_x + emb_p
|
||||
# Time embedding
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
|
||||
emb_t = self.timeEmb1(emb_t) # [32, 512]
|
||||
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
|
||||
emb_t = emb_t.unsqueeze(1)
|
||||
emb = emb_x + emb_t
|
||||
|
||||
h_x = self.graph_encoder(emb, maskX)
|
||||
h_x = self.final(h_x)
|
||||
|
||||
"""
|
||||
Shape: Batch Size, Length (with Pad), Feature Dim (forward) + Feature Dim (backward)
|
||||
*HINT: X1 X2 X3 [PAD] [PAD]
|
||||
"""
|
||||
return h_x
|
||||
|
||||
|
||||
|
||||
@utils.register_model(name='PredictorCATE')
|
||||
class PredictorCATE(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PredictorCATE, self).__init__()
|
||||
# Shared Embedding Layer
|
||||
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
|
||||
self.dropout_op = nn.Dropout(p=config.model.dropout)
|
||||
self.d_model = config.model.graph_encoder.d_model
|
||||
self.act = act = get_act(config)
|
||||
# Time
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
# 2 GraphEncoder for X and Y
|
||||
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
|
||||
|
||||
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.rdim = int(config.data.max_node * config.data.n_vocab)
|
||||
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=1,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
|
||||
|
||||
def forward(self, X, time_cond, maskX):
|
||||
|
||||
# Shared Embeddings
|
||||
emb_x = self.dropout_op(self.opEmb(X))
|
||||
|
||||
# Time embedding
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
|
||||
emb_t = self.timeEmb1(emb_t) # [32, 512]
|
||||
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
|
||||
emb_t = emb_t.unsqueeze(1)
|
||||
|
||||
emb = emb_x + emb_t
|
||||
|
||||
# h_x = self.graph_encoder(emb_x, maskX)
|
||||
h_x = self.graph_encoder(emb, maskX)
|
||||
h_x = self.final(h_x)
|
||||
|
||||
"""
|
||||
Shape: Batch Size, Length (with Pad), Feature Dim (forward) + Feature Dim (backward)
|
||||
*HINT: X1 X2 X3 [PAD] [PAD]
|
||||
"""
|
||||
h_x = h_x.reshape(h_x.size(0), -1)
|
||||
h_x = self.regeress(h_x)
|
||||
|
||||
return h_x
|
||||
|
||||
|
||||
class PositionalEncoding_StageWise(nn.Module):
|
||||
|
||||
def __init__(self, d_model, max_len):
|
||||
|
||||
super(PositionalEncoding_StageWise, self).__init__()
|
||||
|
||||
NUM_STAGE = 5
|
||||
max_len = int(max_len / NUM_STAGE)
|
||||
self.encoding = torch.zeros(max_len, d_model)
|
||||
|
||||
pos = torch.arange(0, max_len)
|
||||
|
||||
|
||||
pos = pos.float().unsqueeze(dim=1)
|
||||
|
||||
|
||||
_2i = torch.arange(0, d_model, step=2).float()
|
||||
|
||||
# (max_len, 1) / (d_model/2 ) -> (max_len, d_model/2)
|
||||
self.encoding[:, ::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
|
||||
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model))) # (4, 64)
|
||||
self.encoding = torch.cat([self.encoding] * NUM_STAGE, dim=0)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
return self.encoding[:seq_len, :].to(x.device)
|
||||
|
||||
|
||||
@utils.register_model(name='MetaPredictorCATE')
|
||||
class MetaPredictorCATE(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(MetaPredictorCATE, self).__init__()
|
||||
|
||||
self.input_type= config.model.input_type
|
||||
self.hs = config.model.hs
|
||||
|
||||
# Shared Embedding Layer
|
||||
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
|
||||
self.dropout_op = nn.Dropout(p=config.model.dropout)
|
||||
self.d_model = config.model.graph_encoder.d_model
|
||||
self.act = act = get_act(config)
|
||||
# Time
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
# 2 GraphEncoder for X and Y
|
||||
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
|
||||
|
||||
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.rdim = int(config.data.max_node * config.data.n_vocab)
|
||||
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=2*self.rdim,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
# Set
|
||||
self.nz = config.model.nz
|
||||
self.num_sample = config.model.num_sample
|
||||
self.intra_setpool = SetPool(dim_input=512,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.inter_setpool = SetPool(dim_input=self.nz,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.set_fc = nn.Sequential(
|
||||
nn.Linear(512, self.nz),
|
||||
nn.ReLU())
|
||||
|
||||
input_dim = 0
|
||||
if 'D' in self.input_type:
|
||||
input_dim += self.nz
|
||||
if 'A' in self.input_type:
|
||||
input_dim += 2*self.rdim
|
||||
|
||||
self.pred_fc = nn.Sequential(
|
||||
nn.Linear(input_dim, self.hs),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.hs, 1)
|
||||
)
|
||||
|
||||
self.sample_state = False
|
||||
self.D_mu = None
|
||||
|
||||
|
||||
def arch_encode(self, X, time_cond, maskX):
|
||||
# Shared Embeddings
|
||||
emb_x = self.dropout_op(self.opEmb(X))
|
||||
|
||||
# Time embedding
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
|
||||
emb_t = self.timeEmb1(emb_t) # [32, 512]
|
||||
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
|
||||
emb_t = emb_t.unsqueeze(1)
|
||||
emb = emb_x + emb_t
|
||||
|
||||
h_x = self.graph_encoder(emb, maskX)
|
||||
h_x = self.final(h_x)
|
||||
|
||||
h_x = h_x.reshape(h_x.size(0), -1)
|
||||
h_x = self.regeress(h_x)
|
||||
return h_x
|
||||
|
||||
def set_encode(self, task):
|
||||
proto_batch = []
|
||||
for x in task:
|
||||
cls_protos = self.intra_setpool(
|
||||
x.view(-1, self.num_sample, 512)).squeeze(1)
|
||||
proto_batch.append(
|
||||
self.inter_setpool(cls_protos.unsqueeze(0)))
|
||||
v = torch.stack(proto_batch).squeeze()
|
||||
return v
|
||||
|
||||
def predict(self, D_mu, A_mu):
|
||||
input_vec = []
|
||||
if 'D' in self.input_type:
|
||||
input_vec.append(D_mu)
|
||||
if 'A' in self.input_type:
|
||||
input_vec.append(A_mu)
|
||||
input_vec = torch.cat(input_vec, dim=1)
|
||||
return self.pred_fc(input_vec)
|
||||
|
||||
def forward(self, X, time_cond, maskX, task):
|
||||
if self.sample_state:
|
||||
if self.D_mu is None:
|
||||
self.D_mu = self.set_encode(task)
|
||||
D_mu = self.D_mu
|
||||
else:
|
||||
D_mu = self.set_encode(task)
|
||||
A_mu = self.arch_encode(X, time_cond, maskX)
|
||||
y_pred = self.predict(D_mu, A_mu)
|
||||
return y_pred
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
|
||||
)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
||||
|
||||
import math
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
355
MobileNetV3/models/dagformer.py
Normal file
355
MobileNetV3/models/dagformer.py
Normal file
@@ -0,0 +1,355 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from einops.layers.torch import Rearrange
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
class SinusoidalPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_positions, embedding_dim):
|
||||
super().__init__(num_positions, embedding_dim) # torch.nn.Embedding(num_embeddings, embedding_dim)
|
||||
self.weight = self._init_weight(self.weight) # self.weight => nn.Embedding(num_positions, embedding_dim).weight
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter):
|
||||
n_pos, embed_dim = out.shape
|
||||
pe = nn.Parameter(torch.zeros(out.shape))
|
||||
for pos in range(n_pos):
|
||||
for i in range(0, embed_dim, 2):
|
||||
pe[pos, i].data.copy_( torch.tensor( np.sin(pos / (10000 ** ( i / embed_dim)))) )
|
||||
pe[pos, i + 1].data.copy_( torch.tensor( np.cos(pos / (10000 ** ((i + 1) / embed_dim)))) )
|
||||
pe.detach_()
|
||||
|
||||
return pe
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, input_ids):
|
||||
bsz, seq_len = input_ids.shape[:2] # for x, seq_len = max_node_num
|
||||
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
|
||||
return super().forward(positions)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
*,
|
||||
expansion_factor = 2.,
|
||||
depth = 2,
|
||||
norm = False,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(expansion_factor * dim_out)
|
||||
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
|
||||
|
||||
layers = [nn.Sequential(
|
||||
nn.Linear(dim_in, hidden_dim),
|
||||
nn.SiLU(),
|
||||
norm_fn()
|
||||
)]
|
||||
|
||||
for _ in range(depth - 1):
|
||||
layers.append(nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
norm_fn()
|
||||
))
|
||||
|
||||
layers.append(nn.Linear(hidden_dim, dim_out))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x.float())
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
dtype, device = x.dtype, x.device
|
||||
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
|
||||
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
||||
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
|
||||
|
||||
def is_float_dtype(dtype):
|
||||
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
|
||||
|
||||
|
||||
class PositionWiseFeedForward(nn.Module):
|
||||
|
||||
def __init__(self, emb_dim: int, d_ff: int, dropout: float = 0.1):
|
||||
super(PositionWiseFeedForward, self).__init__()
|
||||
|
||||
self.activation = nn.ReLU()
|
||||
self.w_1 = nn.Linear(emb_dim, d_ff)
|
||||
self.w_2 = nn.Linear(d_ff, emb_dim)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.activation(self.w_1(x))
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
x = self.w_2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
return x + residual # residual connection for preventing gradient vanishing
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
emb_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
bias=False,
|
||||
encoder_decoder_attention=False, # otherwise self_attention
|
||||
causal = True
|
||||
):
|
||||
super().__init__()
|
||||
self.emb_dim = emb_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = emb_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.emb_dim, "emb_dim must be divisible by num_heads"
|
||||
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
self.causal = causal
|
||||
self.q_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
|
||||
self.k_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
|
||||
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
# This is equivalent to
|
||||
# return x.transpose(1,2)
|
||||
|
||||
|
||||
def scaled_dot_product(self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.BoolTensor):
|
||||
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.emb_dim) # QK^T/sqrt(d)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights.masked_fill(attention_mask.unsqueeze(1), float("-inf"))
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1) # softmax(QK^T/sqrt(d))
|
||||
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_probs, value) # softmax(QK^T/sqrt(d))V
|
||||
|
||||
return attn_output, attn_probs
|
||||
|
||||
|
||||
def MultiHead_scaled_dot_product(self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.BoolTensor):
|
||||
attention_mask = attention_mask.bool()
|
||||
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.head_dim) # QK^T/sqrt(d) # [6, 6]
|
||||
|
||||
# Attention mask
|
||||
if attention_mask is not None:
|
||||
if self.causal:
|
||||
# (seq_len x seq_len)
|
||||
attn_weights = attn_weights.masked_fill(attention_mask.unsqueeze(0).unsqueeze(1), float("-inf"))
|
||||
else:
|
||||
# (batch_size x seq_len)
|
||||
attn_weights = attn_weights.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1) # softmax(QK^T/sqrt(d))
|
||||
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.matmul(attn_probs, value) # softmax(QK^T/sqrt(d))V
|
||||
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
|
||||
concat_attn_output_shape = attn_output.size()[:-2] + (self.emb_dim,)
|
||||
attn_output = attn_output.view(*concat_attn_output_shape)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
attention_mask: torch.Tensor = None,
|
||||
):
|
||||
|
||||
q = self.q_proj(query)
|
||||
# Enc-Dec attention
|
||||
if self.encoder_decoder_attention:
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(key)
|
||||
# Self attention
|
||||
else:
|
||||
k = self.k_proj(query)
|
||||
v = self.v_proj(query)
|
||||
|
||||
q = self.transpose_for_scores(q)
|
||||
k = self.transpose_for_scores(k)
|
||||
v = self.transpose_for_scores(v)
|
||||
|
||||
attn_output, attn_weights = self.MultiHead_scaled_dot_product(q,k,v,attention_mask)
|
||||
return attn_output, attn_weights
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, emb_dim, ffn_dim, attention_heads,
|
||||
attention_dropout, dropout):
|
||||
super().__init__()
|
||||
self.emb_dim = emb_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.self_attn = MultiHeadAttention(
|
||||
emb_dim=self.emb_dim,
|
||||
num_heads=attention_heads,
|
||||
dropout=attention_dropout)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.emb_dim)
|
||||
self.dropout = dropout
|
||||
self.activation_fn = nn.ReLU()
|
||||
self.PositionWiseFeedForward = PositionWiseFeedForward(self.emb_dim, self.ffn_dim, dropout)
|
||||
self.final_layer_norm = nn.LayerNorm(self.emb_dim)
|
||||
|
||||
def forward(self, x, encoder_padding_mask):
|
||||
|
||||
residual = x
|
||||
x, attn_weights = self.self_attn(query=x, key=x, attention_mask=encoder_padding_mask)
|
||||
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x = self.PositionWiseFeedForward(x)
|
||||
x = self.final_layer_norm(x)
|
||||
if torch.isinf(x).any() or torch.isnan(x).any():
|
||||
clamp_value = torch.finfo(x.dtype).max - 1000
|
||||
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
|
||||
return x, attn_weights
|
||||
|
||||
|
||||
@utils.register_model(name='DAGformer')
|
||||
class DAGformer(torch.nn.Module):
|
||||
def __init__(self, config):
|
||||
# max_feat_num,
|
||||
# max_node_num,
|
||||
# emb_dim,
|
||||
# ffn_dim,
|
||||
# encoder_layers,
|
||||
# attention_heads,
|
||||
# attention_dropout,
|
||||
# dropout,
|
||||
# hs,
|
||||
# time_dep=True,
|
||||
# num_timesteps=None,
|
||||
# return_attn=False,
|
||||
# except_inout=False,
|
||||
# connect_prev=True
|
||||
# ):
|
||||
super().__init__()
|
||||
|
||||
self.dropout = config.model.dropout
|
||||
self.time_dep = config.model.time_dep
|
||||
self.return_attn = config.model.return_attn
|
||||
max_feat_num = config.data.n_vocab
|
||||
max_node_num = config.data.max_node
|
||||
emb_dim = config.model.emb_dim
|
||||
# num_timesteps = config.model.num_scales
|
||||
num_timesteps = None
|
||||
|
||||
self.x_embedding = MLP(max_feat_num, emb_dim)
|
||||
# position embedding with topological order
|
||||
self.position_embedding = SinusoidalPositionalEmbedding(max_node_num, emb_dim)
|
||||
|
||||
if self.time_dep:
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Embedding(num_timesteps, emb_dim) if num_timesteps is not None
|
||||
else nn.Sequential(SinusoidalPosEmb(emb_dim), MLP(emb_dim, emb_dim)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
Rearrange('b (n d) -> b n d', n=1)
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList([EncoderLayer(emb_dim,
|
||||
config.model.ffn_dim,
|
||||
config.model.attention_heads,
|
||||
config.model.attention_dropout,
|
||||
config.model.dropout)
|
||||
for _ in range(config.model.encoder_layers)])
|
||||
|
||||
self.pred_fc = nn.Sequential(
|
||||
nn.Linear(emb_dim, config.model.hs),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.model.hs, 1),
|
||||
# nn.Sigmoid()
|
||||
)
|
||||
|
||||
# -------- Load Constant Adj Matrix (START) --------- #
|
||||
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
# from utils.graph_utils import get_const_adj
|
||||
# mat = get_const_adj(
|
||||
# except_inout=except_inout,
|
||||
# shape_adj=(1, max_node_num, max_node_num),
|
||||
# device=torch.device('cpu'),
|
||||
# connect_prev=connect_prev)[0].cpu()
|
||||
# is_triu_ = is_triu(mat)
|
||||
# if is_triu_:
|
||||
# self.adj_ = mat.T.to(self.device)
|
||||
# else:
|
||||
# self.adj_ = mat.to(self.device)
|
||||
# -------- Load Constant Adj Matrix (END) --------- #
|
||||
|
||||
def forward(self, x, t, adj, flags=None):
|
||||
"""
|
||||
:param x: B x N x F_i
|
||||
:param adjs: B x C_i x N x N
|
||||
:return: x_o: B x N x F_o, new_adjs: B x C_o x N x N
|
||||
"""
|
||||
|
||||
assert len(x.shape) == 3
|
||||
|
||||
self_attention_mask = torch.eye(adj.size(1)).to(self.device)
|
||||
# attention_mask = 1. - (self_attention_mask + self.adj_)
|
||||
attention_mask = 1. - (self_attention_mask + adj[0])
|
||||
|
||||
# -------- Generate input for DAGformer ------- #
|
||||
x_embed = self.x_embedding(x)
|
||||
# x_embed = x
|
||||
x_pos = self.position_embedding(x).unsqueeze(0)
|
||||
if self.time_dep:
|
||||
time_embed = self.time_embedding(t)
|
||||
|
||||
x = x_embed + x_pos
|
||||
if self.time_dep:
|
||||
x = x + time_embed
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
self_attn_scores = []
|
||||
for encoder_layer in self.layers:
|
||||
x, attn = encoder_layer(x, attention_mask)
|
||||
self_attn_scores.append(attn.detach())
|
||||
|
||||
x = self.pred_fc(x[:, -1, :]) # [256, 16]
|
||||
|
||||
if self.return_attn:
|
||||
return x, self_attn_scores
|
||||
else:
|
||||
return x
|
142
MobileNetV3/models/digcn.py
Normal file
142
MobileNetV3/models/digcn.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
|
||||
# which was authored by Yuge Zhang, 2020
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from . import utils
|
||||
from models.cate import PositionalEncoding_StageWise
|
||||
|
||||
|
||||
def normalize_adj(adj):
|
||||
# Row-normalize matrix
|
||||
last_dim = adj.size(-1)
|
||||
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
|
||||
return torch.div(adj, rowsum)
|
||||
|
||||
|
||||
def graph_pooling(inputs, num_vertices):
|
||||
num_vertices = num_vertices.to(inputs.device)
|
||||
out = inputs.sum(1)
|
||||
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
|
||||
|
||||
|
||||
class DirectedGraphConvolution(nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.xavier_uniform_(self.weight1.data)
|
||||
nn.init.xavier_uniform_(self.weight2.data)
|
||||
|
||||
def forward(self, inputs, adj):
|
||||
inputs = inputs.to(self.weight1.device)
|
||||
adj = adj.to(self.weight1.device)
|
||||
norm_adj = normalize_adj(adj)
|
||||
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
|
||||
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
|
||||
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
|
||||
out = (output1 + output2) / 2
|
||||
out = self.dropout(out)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ str(self.in_features) + ' -> ' \
|
||||
+ str(self.out_features) + ')'
|
||||
|
||||
# if nasbench-101: initial_hidden=5. if nasbench-201: initial_hidden=7
|
||||
@utils.register_model(name='NeuralPredictor')
|
||||
class NeuralPredictor(nn.Module):
|
||||
# def __init__(self, initial_hidden=5, gcn_hidden=144, gcn_layers=4, linear_hidden=128):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
|
||||
config.model.graph_encoder.gcn_hidden)
|
||||
for i in range(config.model.graph_encoder.gcn_layers)]
|
||||
self.gcn = nn.ModuleList(self.gcn)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
|
||||
self.fc2 = nn.Linear(config.model.graph_encoder.linear_hidden, 1, bias=False)
|
||||
# Time
|
||||
self.d_model = config.model.graph_encoder.gcn_hidden
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
self.act = act = get_act(config)
|
||||
|
||||
# self.pos_enc_type = config.model.pos_enc_type
|
||||
# if self.pos_enc_type == 1:
|
||||
# raise NotImplementedError
|
||||
# elif self.pos_enc_type == 2:
|
||||
# self.pos_encoder = PositionalEncoding_StageWise(d_model=config.model.graph_encoder.gcn_hidden, max_len=config.data.max_node)
|
||||
# elif self.pos_enc_type == 3:
|
||||
# raise NotImplementedError
|
||||
# else:
|
||||
# self.pos_encoder = None
|
||||
|
||||
# def forward(self, inputs):
|
||||
def forward(self, X, time_cond, maskX):
|
||||
# numv, adj, out = inputs["num_vertices"], inputs["adjacency"], inputs["operations"]
|
||||
out = X # (5, 20, 10)
|
||||
adj = maskX # (1, 20, 20)
|
||||
|
||||
# # pos embedding
|
||||
# if self.pos_encoder is not None:
|
||||
# emb_p = self.pos_encoder(out) # [20, 64]
|
||||
# out = out + emb_p
|
||||
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device) # 20
|
||||
gs = adj.size(1) # graph node number
|
||||
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
|
||||
emb_t = self.timeEmb1(emb_t)
|
||||
emb_t = self.timeEmb2(self.act(emb_t)) # (5, 144)
|
||||
|
||||
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device)) # assuming diagonal is not 1
|
||||
for layer in self.gcn:
|
||||
out = layer(out, adj_with_diag)
|
||||
out = graph_pooling(out, numv) # out: 5, 20, 144
|
||||
# time
|
||||
out = out + emb_t
|
||||
out = self.fc1(out) # (5, 128)
|
||||
out = self.dropout(out)
|
||||
# out = self.fc2(out).view(-1)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
import math
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
194
MobileNetV3/models/digcn_meta.py
Normal file
194
MobileNetV3/models/digcn_meta.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
|
||||
# which was authored by Yuge Zhang, 2020
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from . import utils
|
||||
from .set_encoder.setenc_models import SetPool
|
||||
|
||||
def normalize_adj(adj):
|
||||
# Row-normalize matrix
|
||||
last_dim = adj.size(-1)
|
||||
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
|
||||
return torch.div(adj, rowsum)
|
||||
|
||||
|
||||
def graph_pooling(inputs, num_vertices):
|
||||
num_vertices = num_vertices.to(inputs.device)
|
||||
out = inputs.sum(1)
|
||||
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
|
||||
|
||||
|
||||
class DirectedGraphConvolution(nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.xavier_uniform_(self.weight1.data)
|
||||
nn.init.xavier_uniform_(self.weight2.data)
|
||||
|
||||
def forward(self, inputs, adj):
|
||||
inputs = inputs.to(self.weight1.device)
|
||||
adj = adj.to(self.weight1.device)
|
||||
norm_adj = normalize_adj(adj)
|
||||
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
|
||||
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
|
||||
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
|
||||
out = (output1 + output2) / 2
|
||||
out = self.dropout(out)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ str(self.in_features) + ' -> ' \
|
||||
+ str(self.out_features) + ')'
|
||||
|
||||
# if nasbench-101: initial_hidden=5. if nasbench-201: initial_hidden=7
|
||||
@utils.register_model(name='MetaNeuralPredictor')
|
||||
class MetaeuralPredictor(nn.Module):
|
||||
# def __init__(self, initial_hidden=5, gcn_hidden=144, gcn_layers=4, linear_hidden=128):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# Arch
|
||||
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
|
||||
config.model.graph_encoder.gcn_hidden)
|
||||
for i in range(config.model.graph_encoder.gcn_layers)]
|
||||
self.gcn = nn.ModuleList(self.gcn)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
|
||||
# self.fc2 = nn.Linear(config.model.graph_encoder.linear_hidden, 1, bias=False)
|
||||
|
||||
# Time
|
||||
self.d_model = config.model.graph_encoder.gcn_hidden
|
||||
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
|
||||
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
|
||||
|
||||
self.act = act = get_act(config)
|
||||
self.input_type = config.model.input_type
|
||||
self.hs = config.model.hs
|
||||
|
||||
# Set
|
||||
self.nz = config.model.nz
|
||||
self.num_sample = config.model.num_sample
|
||||
self.intra_setpool = SetPool(dim_input=512,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.inter_setpool = SetPool(dim_input=self.nz,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.set_fc = nn.Sequential(
|
||||
nn.Linear(512, self.nz),
|
||||
nn.ReLU())
|
||||
|
||||
input_dim = 0
|
||||
if 'D' in self.input_type:
|
||||
input_dim += self.nz
|
||||
if 'A' in self.input_type:
|
||||
input_dim += config.model.graph_encoder.linear_hidden
|
||||
|
||||
self.pred_fc = nn.Sequential(
|
||||
nn.Linear(input_dim, self.hs),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.hs, 1)
|
||||
)
|
||||
|
||||
self.sample_state = False
|
||||
self.D_mu = None
|
||||
|
||||
def arch_encode(self, X, time_cond, maskX):
|
||||
# numv, adj, out = inputs["num_vertices"], inputs["adjacency"], inputs["operations"]
|
||||
out = X
|
||||
adj = maskX
|
||||
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device)
|
||||
gs = adj.size(1) # graph node number
|
||||
|
||||
timesteps = time_cond
|
||||
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
|
||||
emb_t = self.timeEmb1(emb_t)
|
||||
emb_t = self.timeEmb2(self.act(emb_t))
|
||||
|
||||
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device)) # assuming diagonal is not 1
|
||||
for layer in self.gcn:
|
||||
out = layer(out, adj_with_diag)
|
||||
out = graph_pooling(out, numv)
|
||||
# time
|
||||
out = out + emb_t
|
||||
out = self.fc1(out)
|
||||
out = self.dropout(out)
|
||||
|
||||
# out = self.fc2(out).view(-1)
|
||||
# out = self.fc2(out)
|
||||
return out
|
||||
|
||||
def set_encode(self, task):
|
||||
proto_batch = []
|
||||
for x in task:
|
||||
cls_protos = self.intra_setpool(
|
||||
x.view(-1, self.num_sample, 512)).squeeze(1)
|
||||
proto_batch.append(
|
||||
self.inter_setpool(cls_protos.unsqueeze(0)))
|
||||
v = torch.stack(proto_batch).squeeze()
|
||||
return v
|
||||
|
||||
def predict(self, D_mu, A_mu):
|
||||
input_vec = []
|
||||
if 'D' in self.input_type:
|
||||
input_vec.append(D_mu)
|
||||
if 'A' in self.input_type:
|
||||
input_vec.append(A_mu)
|
||||
input_vec = torch.cat(input_vec, dim=1)
|
||||
return self.pred_fc(input_vec)
|
||||
|
||||
def forward(self, X, time_cond, maskX, task):
|
||||
if self.sample_state:
|
||||
if self.D_mu is None:
|
||||
self.D_mu = self.set_encode(task)
|
||||
D_mu = self.D_mu
|
||||
else:
|
||||
D_mu = self.set_encode(task)
|
||||
A_mu = self.arch_encode(X, time_cond, maskX)
|
||||
y_pred = self.predict(D_mu, A_mu)
|
||||
return y_pred
|
||||
|
||||
|
||||
import math
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
85
MobileNetV3/models/ema.py
Normal file
85
MobileNetV3/models/ema.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
|
||||
|
||||
class ExponentialMovingAverage:
|
||||
"""
|
||||
Maintains (exponential) moving average of a set of parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, parameters, decay, use_num_updates=True):
|
||||
"""
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`.
|
||||
decay: The exponential decay.
|
||||
use_num_updates: Whether to use number of updates when computing averages.
|
||||
"""
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
self.decay = decay
|
||||
self.num_updates = 0 if use_num_updates else None
|
||||
self.shadow_params = [p.clone().detach()
|
||||
for p in parameters if p.requires_grad]
|
||||
self.collected_params = []
|
||||
|
||||
def update(self, parameters):
|
||||
"""
|
||||
Update currently maintained parameters.
|
||||
|
||||
Call this every time the parameters are updated, such as the result of the `optimizer.step()` call.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to
|
||||
initialize this object.
|
||||
"""
|
||||
decay = self.decay
|
||||
if self.num_updates is not None:
|
||||
self.num_updates += 1
|
||||
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
one_minus_decay = 1.0 - decay
|
||||
with torch.no_grad():
|
||||
parameters = [p for p in parameters if p.requires_grad]
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
s_param.sub_(one_minus_decay * (s_param - param))
|
||||
|
||||
def copy_to(self, parameters):
|
||||
"""
|
||||
Copy current parameters into given collection of parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored moving averages.
|
||||
"""
|
||||
parameters = [p for p in parameters if p.requires_grad]
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
if param.requires_grad:
|
||||
param.data.copy_(s_param.data)
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the original optimization process.
|
||||
Store the parameters before the `copy_to` method.
|
||||
After validation (or model saving), use this to restore the former parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
|
||||
def state_dict(self):
|
||||
return dict(decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.decay = state_dict['decay']
|
||||
self.num_updates = state_dict['num_updates']
|
||||
self.shadow_params = state_dict['shadow_params']
|
82
MobileNetV3/models/gnns.py
Normal file
82
MobileNetV3/models/gnns.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from .trans_layers import *
|
||||
|
||||
|
||||
class pos_gnn(nn.Module):
|
||||
def __init__(self, act, x_ch, pos_ch, out_ch, max_node, graph_layer, n_layers=3, edge_dim=None, heads=4,
|
||||
temb_dim=None, dropout=0.1, attn_clamp=False):
|
||||
super().__init__()
|
||||
self.out_ch = out_ch
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.act = act
|
||||
self.max_node = max_node
|
||||
self.n_layers = n_layers
|
||||
|
||||
if temb_dim is not None:
|
||||
self.Dense_node0 = nn.Linear(temb_dim, x_ch)
|
||||
self.Dense_node1 = nn.Linear(temb_dim, pos_ch)
|
||||
self.Dense_edge0 = nn.Linear(temb_dim, edge_dim)
|
||||
self.Dense_edge1 = nn.Linear(temb_dim, edge_dim)
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.edge_convs = nn.ModuleList()
|
||||
self.edge_layer = nn.Linear(edge_dim * 2 + self.out_ch, edge_dim)
|
||||
|
||||
for i in range(n_layers):
|
||||
if i == 0:
|
||||
self.convs.append(eval(graph_layer)(x_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2,
|
||||
act=act, attn_clamp=attn_clamp))
|
||||
else:
|
||||
self.convs.append(eval(graph_layer)
|
||||
(self.out_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2, act=act,
|
||||
attn_clamp=attn_clamp))
|
||||
self.edge_convs.append(nn.Linear(self.out_ch, edge_dim*2))
|
||||
|
||||
def forward(self, x_degree, x_pos, edge_index, dense_ori, dense_spd, dense_index, temb=None):
|
||||
"""
|
||||
Args:
|
||||
x_degree: node degree feature [B*N, x_ch]
|
||||
x_pos: node rwpe feature [B*N, pos_ch]
|
||||
edge_index: [2, edge_length]
|
||||
dense_ori: edge feature [B, N, N, nf//2]
|
||||
dense_spd: edge shortest path distance feature [B, N, N, nf//2] # Do we need this part? # TODO
|
||||
dense_index
|
||||
temb: [B, temb_dim]
|
||||
"""
|
||||
|
||||
B, N, _, _ = dense_ori.shape
|
||||
|
||||
if temb is not None:
|
||||
dense_ori = dense_ori + self.Dense_edge0(self.act(temb))[:, None, None, :]
|
||||
dense_spd = dense_spd + self.Dense_edge1(self.act(temb))[:, None, None, :]
|
||||
|
||||
temb = temb.unsqueeze(1).repeat(1, self.max_node, 1)
|
||||
temb = temb.reshape(-1, temb.shape[-1])
|
||||
x_degree = x_degree + self.Dense_node0(self.act(temb))
|
||||
x_pos = x_pos + self.Dense_node1(self.act(temb))
|
||||
|
||||
dense_edge = torch.cat([dense_ori, dense_spd], dim=-1)
|
||||
|
||||
ori_edge_attr = dense_edge
|
||||
h = x_degree
|
||||
h_pos = x_pos
|
||||
|
||||
for i_layer in range(self.n_layers):
|
||||
h_edge = dense_edge[dense_index]
|
||||
# update node feature
|
||||
h, h_pos = self.convs[i_layer](h, h_pos, edge_index, h_edge)
|
||||
h = self.Dropout_0(h)
|
||||
h_pos = self.Dropout_0(h_pos)
|
||||
|
||||
# update dense edge feature
|
||||
h_dense_node = h.reshape(B, N, -1)
|
||||
cur_edge_attr = h_dense_node.unsqueeze(1) + h_dense_node.unsqueeze(2) # [B, N, N, nf]
|
||||
dense_edge = (dense_edge + self.act(self.edge_convs[i_layer](cur_edge_attr))) / math.sqrt(2.)
|
||||
dense_edge = self.Dropout_0(dense_edge)
|
||||
|
||||
# Concat edge attribute
|
||||
h_dense_edge = torch.cat([ori_edge_attr, dense_edge], dim=-1)
|
||||
h_dense_edge = self.edge_layer(h_dense_edge).permute(0, 3, 1, 2)
|
||||
|
||||
return h_dense_edge
|
44
MobileNetV3/models/layers.py
Normal file
44
MobileNetV3/models/layers.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Common layers"""
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
def get_act(config):
|
||||
"""Get actiuvation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
elif config.model.nonlinearity.lower() == 'tanh':
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, padding=0):
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
|
||||
padding=padding)
|
||||
return conv
|
||||
|
||||
|
||||
# from DDPM
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1
|
||||
half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
171
MobileNetV3/models/pgsn.py
Normal file
171
MobileNetV3/models/pgsn.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import functools
|
||||
from torch_geometric.utils import dense_to_sparse
|
||||
|
||||
from . import utils, layers, gnns
|
||||
|
||||
get_act = layers.get_act
|
||||
conv1x1 = layers.conv1x1
|
||||
|
||||
|
||||
@utils.register_model(name='PGSN')
|
||||
class PGSN(nn.Module):
|
||||
"""Position enhanced graph score network."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.act = act = get_act(config)
|
||||
|
||||
# get model construction paras
|
||||
self.nf = nf = config.model.nf
|
||||
self.num_gnn_layers = num_gnn_layers = config.model.num_gnn_layers
|
||||
dropout = config.model.dropout
|
||||
self.embedding_type = embedding_type = config.model.embedding_type.lower()
|
||||
self.rw_depth = rw_depth = config.model.rw_depth
|
||||
self.edge_th = config.model.edge_th
|
||||
|
||||
modules = []
|
||||
# timestep/noise_level embedding; only for continuous training
|
||||
if embedding_type == 'positional':
|
||||
embed_dim = nf
|
||||
else:
|
||||
raise ValueError(f'embedding type {embedding_type} unknown.')
|
||||
|
||||
# timestep embedding layers
|
||||
modules.append(nn.Linear(embed_dim, nf * 4))
|
||||
modules.append(nn.Linear(nf * 4, nf * 4))
|
||||
|
||||
# graph size condition embedding
|
||||
self.size_cond = size_cond = config.model.size_cond
|
||||
if size_cond:
|
||||
self.size_onehot = functools.partial(nn.functional.one_hot, num_classes=config.data.max_node + 1)
|
||||
modules.append(nn.Linear(config.data.max_node + 1, nf * 4))
|
||||
modules.append(nn.Linear(nf * 4, nf * 4))
|
||||
|
||||
channels = config.data.num_channels
|
||||
assert channels == 1, "Without edge features."
|
||||
|
||||
# degree onehot
|
||||
self.degree_max = self.config.data.max_node // 2
|
||||
self.degree_onehot = functools.partial(
|
||||
nn.functional.one_hot,
|
||||
num_classes=self.degree_max + 1)
|
||||
|
||||
# project edge features
|
||||
modules.append(conv1x1(channels, nf // 2))
|
||||
modules.append(conv1x1(rw_depth + 1, nf // 2))
|
||||
|
||||
# project node features
|
||||
self.x_ch = nf
|
||||
self.pos_ch = nf // 2
|
||||
modules.append(nn.Linear(self.degree_max + 1, self.x_ch))
|
||||
modules.append(nn.Linear(rw_depth, self.pos_ch))
|
||||
|
||||
# GNN
|
||||
modules.append(gnns.pos_gnn(act, self.x_ch, self.pos_ch, nf, config.data.max_node,
|
||||
config.model.graph_layer, num_gnn_layers,
|
||||
heads=config.model.heads, edge_dim=nf//2, temb_dim=nf * 4,
|
||||
dropout=dropout, attn_clamp=config.model.attn_clamp))
|
||||
|
||||
# output
|
||||
modules.append(conv1x1(nf // 2, nf // 2))
|
||||
modules.append(conv1x1(nf // 2, channels))
|
||||
|
||||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
def forward(self, x, time_cond, *args, **kwargs):
|
||||
mask = kwargs['mask']
|
||||
modules = self.all_modules
|
||||
m_idx = 0
|
||||
|
||||
# Sinusoidal positional embeddings
|
||||
timesteps = time_cond
|
||||
temb = layers.get_timestep_embedding(timesteps, self.nf)
|
||||
|
||||
# time embedding
|
||||
temb = modules[m_idx](temb) # [32, 512]
|
||||
m_idx += 1
|
||||
temb = modules[m_idx](self.act(temb)) # [32, 512]
|
||||
m_idx += 1
|
||||
|
||||
if self.size_cond:
|
||||
with torch.no_grad():
|
||||
node_mask = utils.mask_adj2node(mask.squeeze(1)) # [B, N]
|
||||
num_node = torch.sum(node_mask, dim=-1) # [B]
|
||||
num_node = self.size_onehot(num_node.to(torch.long)).to(torch.float)
|
||||
num_node_emb = modules[m_idx](num_node)
|
||||
m_idx += 1
|
||||
num_node_emb = modules[m_idx](self.act(num_node_emb))
|
||||
m_idx += 1
|
||||
temb = temb + num_node_emb
|
||||
|
||||
if not self.config.data.centered:
|
||||
# rescale the input data to [-1, 1]
|
||||
x = x * 2. - 1.
|
||||
|
||||
with torch.no_grad():
|
||||
# continuous-valued graph adjacency matrices
|
||||
cont_adj = ((x + 1.) / 2.).clone()
|
||||
cont_adj = (cont_adj * mask).squeeze(1) # [B, N, N]
|
||||
cont_adj = cont_adj.clamp(min=0., max=1.)
|
||||
if self.edge_th > 0.:
|
||||
cont_adj[cont_adj < self.edge_th] = 0.
|
||||
|
||||
# discretized graph adjacency matrices
|
||||
adj = x.squeeze(1).clone() # [B, N, N]
|
||||
adj[adj >= 0.] = 1.
|
||||
adj[adj < 0.] = 0.
|
||||
adj = adj * mask.squeeze(1)
|
||||
|
||||
# extract RWSE and Shortest-Path Distance
|
||||
x_pos, spd_onehot = utils.get_rw_feat(self.rw_depth, adj)
|
||||
# x_pos: [32, 20, 16], spd_onehot: [32, 17, 20, 20]
|
||||
|
||||
# edge [B, N, N, F]
|
||||
dense_edge_ori = modules[m_idx](x).permute(0, 2, 3, 1) # [32, 20, 20, 64]
|
||||
m_idx += 1
|
||||
dense_edge_spd = modules[m_idx](spd_onehot).permute(0, 2, 3, 1) # [32, 20, 20, 64]
|
||||
m_idx += 1
|
||||
|
||||
# Use Degree as node feature
|
||||
x_degree = torch.sum(cont_adj, dim=-1) # [B, N] # [32, 20]
|
||||
x_degree = x_degree.clamp(max=float(self.degree_max)) # [B, N] # [32, 20]
|
||||
x_degree = self.degree_onehot(x_degree.to(torch.long)).to(torch.float) # [B, N, max_node] # [32, 20, 11]
|
||||
x_degree = modules[m_idx](x_degree) # projection layer [B, N, nf] # [32, 20, 128]
|
||||
m_idx += 1
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
# pos encoding
|
||||
# x_pos: [32, 20, 16]
|
||||
x_pos = modules[m_idx](x_pos) # [32, 20, 64]
|
||||
m_idx += 1
|
||||
|
||||
# Dense to sparse node [BxN, -1]
|
||||
x_degree = x_degree.reshape(-1, self.x_ch) # [640, 128]
|
||||
x_pos = x_pos.reshape(-1, self.pos_ch) # [640, 64]
|
||||
dense_index = cont_adj.nonzero(as_tuple=True)
|
||||
edge_index, _ = dense_to_sparse(cont_adj) # [2, 5386]
|
||||
|
||||
# Run GNN layers
|
||||
h_dense_edge = modules[m_idx](x_degree, x_pos, edge_index, dense_edge_ori, dense_edge_spd, dense_index, temb)
|
||||
m_idx += 1
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
# Output
|
||||
h = self.act(modules[m_idx](self.act(h_dense_edge)))
|
||||
m_idx += 1
|
||||
import pdb; pdb.set_trace()
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
# make edge estimation symmetric
|
||||
h = (h + h.transpose(2, 3)) / 2. * mask
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
assert m_idx == len(modules)
|
||||
|
||||
return h
|
27
MobileNetV3/models/regressor.py
Normal file
27
MobileNetV3/models/regressor.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from . import utils
|
||||
|
||||
@utils.register_model(name='MLPRegressor')
|
||||
class MLPRegressor(nn.Module):
|
||||
# def __init__(self, input_size, hidden_size, output_size):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
input_size = int(config.data.max_node * config.data.n_vocab)
|
||||
hidden_size = config.model.hidden_size
|
||||
output_size = config.model.output_size
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.fc3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.fc4 = nn.Linear(hidden_size, output_size)
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def forward(self, X, time_cond, maskX):
|
||||
x = X.view(X.size(0), -1)
|
||||
x = self.activation(self.fc1(x))
|
||||
x= self.activation(self.fc2(x))
|
||||
x= self.activation(self.fc3(x))
|
||||
x= self.fc4(x)
|
||||
return x
|
38
MobileNetV3/models/set_encoder/setenc_models.py
Normal file
38
MobileNetV3/models/set_encoder/setenc_models.py
Normal file
@@ -0,0 +1,38 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from .setenc_modules import *
|
||||
|
||||
|
||||
class SetPool(nn.Module):
|
||||
def __init__(self, dim_input, num_outputs, dim_output,
|
||||
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
|
||||
super(SetPool, self).__init__()
|
||||
if 'sab' in mode: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
|
||||
else: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
|
||||
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
|
||||
if 'PF' in mode: # [32, 1, 501]
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
elif 'P' in mode:
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
|
||||
else: # torch.Size([32, 1, 501])
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
# "", sm, sab, sabsm
|
||||
|
||||
def forward(self, X):
|
||||
x1 = self.enc(X)
|
||||
x2 = self.dec(x1)
|
||||
return x2
|
67
MobileNetV3/models/set_encoder/setenc_modules.py
Normal file
67
MobileNetV3/models/set_encoder/setenc_modules.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#####################################################################################
|
||||
# Copyright (c) Juho Lee SetTransformer, ICML 2019 [GitHub set_transformer]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class MAB(nn.Module):
|
||||
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
|
||||
super(MAB, self).__init__()
|
||||
self.dim_V = dim_V
|
||||
self.num_heads = num_heads
|
||||
self.fc_q = nn.Linear(dim_Q, dim_V)
|
||||
self.fc_k = nn.Linear(dim_K, dim_V)
|
||||
self.fc_v = nn.Linear(dim_K, dim_V)
|
||||
if ln:
|
||||
self.ln0 = nn.LayerNorm(dim_V)
|
||||
self.ln1 = nn.LayerNorm(dim_V)
|
||||
self.fc_o = nn.Linear(dim_V, dim_V)
|
||||
|
||||
def forward(self, Q, K):
|
||||
Q = self.fc_q(Q)
|
||||
K, V = self.fc_k(K), self.fc_v(K)
|
||||
|
||||
dim_split = self.dim_V // self.num_heads
|
||||
Q_ = torch.cat(Q.split(dim_split, 2), 0)
|
||||
K_ = torch.cat(K.split(dim_split, 2), 0)
|
||||
V_ = torch.cat(V.split(dim_split, 2), 0)
|
||||
|
||||
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
|
||||
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
|
||||
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
|
||||
O = O + F.relu(self.fc_o(O))
|
||||
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
|
||||
return O
|
||||
|
||||
class SAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, ln=False):
|
||||
super(SAB, self).__init__()
|
||||
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(X, X)
|
||||
|
||||
class ISAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
|
||||
super(ISAB, self).__init__()
|
||||
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
|
||||
nn.init.xavier_uniform_(self.I)
|
||||
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
|
||||
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
|
||||
return self.mab1(X, H)
|
||||
|
||||
class PMA(nn.Module):
|
||||
def __init__(self, dim, num_heads, num_seeds, ln=False):
|
||||
super(PMA, self).__init__()
|
||||
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
|
||||
nn.init.xavier_uniform_(self.S)
|
||||
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(self.S.repeat(X.size(0), 1, 1), X)
|
144
MobileNetV3/models/trans_layers.py
Normal file
144
MobileNetV3/models/trans_layers.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import math
|
||||
from typing import Union, Tuple, Optional
|
||||
from torch_geometric.typing import PairTensor, Adj, OptTensor
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Linear
|
||||
from torch_scatter import scatter
|
||||
from torch_geometric.nn.conv import MessagePassing
|
||||
from torch_geometric.utils import softmax
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PosTransLayer(MessagePassing):
|
||||
"""Involving the edge feature and updating position feature. Multiply Msg."""
|
||||
|
||||
_alpha: OptTensor
|
||||
|
||||
def __init__(self, x_channels: int, pos_channels: int, out_channels: int,
|
||||
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
|
||||
bias: bool = True, act=None, attn_clamp: bool = False, **kwargs):
|
||||
kwargs.setdefault('aggr', 'add')
|
||||
super(PosTransLayer, self).__init__(node_dim=0, **kwargs)
|
||||
|
||||
self.x_channels = x_channels
|
||||
self.pos_channels = pos_channels
|
||||
self.in_channels = in_channels = x_channels + pos_channels
|
||||
self.out_channels = out_channels
|
||||
self.heads = heads
|
||||
self.dropout = dropout
|
||||
self.edge_dim = edge_dim
|
||||
self.attn_clamp = attn_clamp
|
||||
|
||||
if act is None:
|
||||
self.act = nn.LeakyReLU(negative_slope=0.2)
|
||||
else:
|
||||
self.act = act
|
||||
|
||||
self.lin_key = Linear(in_channels, heads * out_channels)
|
||||
self.lin_query = Linear(in_channels, heads * out_channels)
|
||||
self.lin_value = Linear(in_channels, heads * out_channels)
|
||||
|
||||
self.lin_edge0 = Linear(edge_dim, heads * out_channels, bias=False)
|
||||
self.lin_edge1 = Linear(edge_dim, heads * out_channels, bias=False)
|
||||
|
||||
self.lin_pos = Linear(heads * out_channels, pos_channels, bias=False)
|
||||
|
||||
self.lin_skip = Linear(x_channels, heads * out_channels, bias=bias)
|
||||
self.norm1 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
|
||||
num_channels=heads * out_channels, eps=1e-6)
|
||||
self.norm2 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
|
||||
num_channels=heads * out_channels, eps=1e-6)
|
||||
# FFN
|
||||
self.FFN = nn.Sequential(Linear(heads * out_channels, heads * out_channels),
|
||||
self.act,
|
||||
Linear(heads * out_channels, heads * out_channels))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.lin_key.reset_parameters()
|
||||
self.lin_query.reset_parameters()
|
||||
self.lin_value.reset_parameters()
|
||||
self.lin_skip.reset_parameters()
|
||||
self.lin_edge0.reset_parameters()
|
||||
self.lin_edge1.reset_parameters()
|
||||
self.lin_pos.reset_parameters()
|
||||
|
||||
def forward(self, x: OptTensor,
|
||||
pos: Tensor,
|
||||
edge_index: Adj,
|
||||
edge_attr: OptTensor = None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
""""""
|
||||
|
||||
H, C = self.heads, self.out_channels
|
||||
|
||||
x_feat = torch.cat([x, pos], -1)
|
||||
query = self.lin_query(x_feat).view(-1, H, C)
|
||||
key = self.lin_key(x_feat).view(-1, H, C)
|
||||
value = self.lin_value(x_feat).view(-1, H, C)
|
||||
|
||||
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
|
||||
out_x, out_pos = self.propagate(edge_index, query=query, key=key, value=value, pos=pos, edge_attr=edge_attr,
|
||||
size=None)
|
||||
|
||||
out_x = out_x.view(-1, self.heads * self.out_channels)
|
||||
|
||||
# skip connection for x
|
||||
x_r = self.lin_skip(x)
|
||||
out_x = (out_x + x_r) / math.sqrt(2)
|
||||
out_x = self.norm1(out_x)
|
||||
|
||||
# FFN
|
||||
out_x = (out_x + self.FFN(out_x)) / math.sqrt(2)
|
||||
out_x = self.norm2(out_x)
|
||||
|
||||
# skip connection for pos
|
||||
out_pos = pos + torch.tanh(pos + out_pos)
|
||||
|
||||
return out_x, out_pos
|
||||
|
||||
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
|
||||
pos_j: Tensor,
|
||||
edge_attr: OptTensor,
|
||||
index: Tensor, ptr: OptTensor,
|
||||
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
edge_attn = self.lin_edge0(edge_attr).view(-1, self.heads, self.out_channels)
|
||||
alpha = (query_i * key_j * edge_attn).sum(dim=-1) / math.sqrt(self.out_channels)
|
||||
if self.attn_clamp:
|
||||
alpha = alpha.clamp(min=-5., max=5.)
|
||||
|
||||
alpha = softmax(alpha, index, ptr, size_i)
|
||||
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
|
||||
|
||||
# node feature message
|
||||
msg = value_j
|
||||
msg = msg * self.lin_edge1(edge_attr).view(-1, self.heads, self.out_channels)
|
||||
msg = msg * alpha.view(-1, self.heads, 1)
|
||||
|
||||
# node position message
|
||||
pos_msg = pos_j * self.lin_pos(msg.reshape(-1, self.heads * self.out_channels))
|
||||
|
||||
return msg, pos_msg
|
||||
|
||||
def aggregate(self, inputs: Tuple[Tensor, Tensor], index: Tensor,
|
||||
ptr: Optional[Tensor] = None,
|
||||
dim_size: Optional[int] = None) -> Tuple[Tensor, Tensor]:
|
||||
if ptr is not None:
|
||||
raise NotImplementedError("Not implement Ptr in aggregate")
|
||||
else:
|
||||
return (scatter(inputs[0], index, 0, dim_size=dim_size, reduce=self.aggr),
|
||||
scatter(inputs[1], index, 0, dim_size=dim_size, reduce="mean"))
|
||||
|
||||
def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
|
||||
return inputs
|
||||
|
||||
def __repr__(self):
|
||||
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
|
||||
self.in_channels,
|
||||
self.out_channels, self.heads)
|
248
MobileNetV3/models/transformer.py
Executable file
248
MobileNetV3/models/transformer.py
Executable file
@@ -0,0 +1,248 @@
|
||||
from copy import deepcopy as cp
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
def clones(module, N):
|
||||
return nn.ModuleList([cp(module) for _ in range(N)])
|
||||
|
||||
def attention(query, key, value, mask = None, dropout = None):
|
||||
d_k = query.size(-1)
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e9)
|
||||
attn = F.softmax(scores, dim = -1)
|
||||
if dropout is not None:
|
||||
attn = dropout(attn)
|
||||
return torch.matmul(attn, value), attn
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
self.d_model = config.d_model
|
||||
self.n_head = config.n_head
|
||||
self.d_k = config.d_model // config.n_head
|
||||
|
||||
self.linears = clones(nn.Linear(self.d_model, self.d_model), 4)
|
||||
self.dropout = nn.Dropout(p=config.dropout)
|
||||
|
||||
def forward(self, query, key, value, mask = None):
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1)
|
||||
batch_size = query.size(0)
|
||||
|
||||
query, key , value = [l(x).view(batch_size, -1, self.n_head, self.d_k).transpose(1,2) for l, x in zip(self.linears, (query, key, value))]
|
||||
x, attn = attention(query, key, value, mask = mask, dropout = self.dropout)
|
||||
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)
|
||||
return self.linears[3](x), attn
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
|
||||
self.w_1 = nn.Linear(config.d_model, config.d_ff)
|
||||
self.w_2 = nn.Linear(config.d_ff, config.d_model)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
||||
|
||||
class PositionwiseFeedForwardLast(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PositionwiseFeedForwardLast, self).__init__()
|
||||
|
||||
self.w_1 = nn.Linear(config.d_model, config.d_ff)
|
||||
self.w_2 = nn.Linear(config.d_ff, config.n_vocab)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
||||
|
||||
class SelfAttentionBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(SelfAttentionBlock, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.attn = MultiHeadAttention(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x, mask):
|
||||
x_ = self.norm(x)
|
||||
x_ , attn = self.attn(x_, x_, x_, mask)
|
||||
return self.dropout(x_) + x, attn
|
||||
|
||||
class SourceAttentionBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(SourceAttentionBlock, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.attn = MultiHeadAttention(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x, m, mask):
|
||||
x_ = self.norm(x)
|
||||
x_, attn = self.attn(x_, m, m, mask)
|
||||
return self.dropout(x_) + x, attn
|
||||
|
||||
class FeedForwardBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(FeedForwardBlock, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.feed_forward = PositionwiseFeedForward(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x_ = self.norm(x)
|
||||
x_ = self.feed_forward(x_)
|
||||
return self.dropout(x_) + x
|
||||
|
||||
class FeedForwardBlockLast(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(FeedForwardBlockLast, self).__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(config.d_model)
|
||||
self.feed_forward = PositionwiseFeedForwardLast(config)
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
# Only for the last layer
|
||||
self.proj_fc = nn.Linear(config.d_model, config.n_vocab)
|
||||
|
||||
def forward(self, x):
|
||||
x_ = self.norm(x)
|
||||
x_ = self.feed_forward(x_)
|
||||
# return self.dropout(x_) + x
|
||||
return self.dropout(x_) + self.proj_fc(x)
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(EncoderBlock, self).__init__()
|
||||
self.self_attn = SelfAttentionBlock(config)
|
||||
self.feed_forward = FeedForwardBlock(config)
|
||||
|
||||
def forward(self, x, mask):
|
||||
x, attn = self.self_attn(x, mask)
|
||||
x = self.feed_forward(x)
|
||||
return x, attn
|
||||
|
||||
class EncoderBlockLast(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(EncoderBlockLast, self).__init__()
|
||||
self.self_attn = SelfAttentionBlock(config)
|
||||
self.feed_forward = FeedForwardBlockLast(config)
|
||||
|
||||
def forward(self, x, mask):
|
||||
x, attn = self.self_attn(x, mask)
|
||||
x = self.feed_forward(x)
|
||||
return x, attn
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(DecoderBlock, self).__init__()
|
||||
|
||||
self.self_attn = SelfAttentionBlock(config)
|
||||
self.src_attn = SourceAttentionBlock(config)
|
||||
self.feed_forward = FeedForwardBlock(config)
|
||||
|
||||
def forward(self, x, m, src_mask, tgt_mask):
|
||||
x, attn_tgt = self.self_attn(x, tgt_mask)
|
||||
x, attn_src = self.src_attn(x, m, src_mask)
|
||||
x = self.feed_forward(x)
|
||||
return x, attn_src, attn_tgt
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
# self.layers = clones(EncoderBlock(config), config.n_layers - 1)
|
||||
# self.layers.append(EncoderBlockLast(config))
|
||||
# self.norms = clones(nn.LayerNorm(config.d_model), config.n_layers - 1)
|
||||
# self.norms.append(nn.LayerNorm(config.n_vocab))
|
||||
|
||||
self.layers = clones(EncoderBlock(config), config.n_layers)
|
||||
self.norms = clones(nn.LayerNorm(config.d_model), config.n_layers)
|
||||
|
||||
def forward(self, x, mask):
|
||||
outputs = []
|
||||
attns = []
|
||||
for layer, norm in zip(self.layers, self.norms):
|
||||
x, attn = layer(x, mask)
|
||||
outputs.append(norm(x))
|
||||
attns.append(attn)
|
||||
return outputs[-1], outputs, attns
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
|
||||
p2e = torch.zeros(config.max_len, config.d_model)
|
||||
position = torch.arange(0.0, config.max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0.0, config.d_model, 2) * (- math.log(10000.0) / config.d_model))
|
||||
p2e[:, 0::2] = torch.sin(position * div_term)
|
||||
p2e[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
self.register_buffer('p2e', p2e)
|
||||
|
||||
def forward(self, x):
|
||||
shp = x.size()
|
||||
with torch.no_grad():
|
||||
emb = torch.index_select(self.p2e, 0, x.view(-1)).view(shp + (-1,))
|
||||
return emb
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Transformer, self).__init__()
|
||||
self.p2e = PositionalEmbedding(config)
|
||||
self.encoder = Encoder(config)
|
||||
|
||||
def forward(self, input_emb, position_ids, attention_mask):
|
||||
# position embedding projection
|
||||
projection = self.p2e(position_ids) + input_emb
|
||||
return self.encoder(projection, attention_mask)
|
||||
|
||||
|
||||
class TokenTypeEmbedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(TokenTypeEmbedding, self).__init__()
|
||||
self.t2e = nn.Embedding(config.n_token_type, config.d_model)
|
||||
self.d_model = config.d_model
|
||||
|
||||
def forward(self, x):
|
||||
return self.t2e(x) * math.sqrt(self.d_model)
|
||||
|
||||
class SemanticEmbedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(SemanticEmbedding, self).__init__()
|
||||
# self.w2e = nn.Embedding(config.n_vocab, config.d_model)
|
||||
self.d_model = config.d_model
|
||||
self.fc = nn.Linear(config.n_vocab, config.d_model)
|
||||
|
||||
def forward(self, x):
|
||||
# return self.w2e(x) * math.sqrt(self.d_model)
|
||||
return self.fc(x) * math.sqrt(self.d_model)
|
||||
|
||||
class Embeddings(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Embeddings, self).__init__()
|
||||
|
||||
self.w2e = SemanticEmbedding(config)
|
||||
self.p2e = PositionalEmbedding(config)
|
||||
self.t2e = TokenTypeEmbedding(config)
|
||||
|
||||
self.dropout = nn.Dropout(p = config.dropout)
|
||||
|
||||
def forward(self, input_ids, position_ids = None, token_type_ids = None):
|
||||
if position_ids is None:
|
||||
batch_size, length = input_ids.size()
|
||||
with torch.no_grad():
|
||||
position_ids = torch.arange(0, length).repeat(batch_size, 1)
|
||||
if torch.cuda.is_available():
|
||||
position_ids = position_ids.cuda(device=input_ids.device)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
embeddings = self.w2e(input_ids) + self.p2e(position_ids) + self.t2e(token_type_ids)
|
||||
return self.dropout(embeddings)
|
301
MobileNetV3/models/utils.py
Normal file
301
MobileNetV3/models/utils.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import sde_lib
|
||||
import numpy as np
|
||||
|
||||
_MODELS = {}
|
||||
|
||||
|
||||
def register_model(cls=None, *, name=None):
|
||||
"""A decorator for registering model classes."""
|
||||
|
||||
def _register(cls):
|
||||
if name is None:
|
||||
local_name = cls.__name__
|
||||
else:
|
||||
local_name = name
|
||||
if local_name in _MODELS:
|
||||
raise ValueError(
|
||||
f'Already registered model with name: {local_name}')
|
||||
_MODELS[local_name] = cls
|
||||
return cls
|
||||
|
||||
if cls is None:
|
||||
return _register
|
||||
else:
|
||||
return _register(cls)
|
||||
|
||||
|
||||
def get_model(name):
|
||||
return _MODELS[name]
|
||||
|
||||
|
||||
def create_model(config):
|
||||
"""Create the score model."""
|
||||
model_name = config.model.name
|
||||
score_model = get_model(model_name)(config)
|
||||
score_model = score_model.to(config.device)
|
||||
if 'load_pretrained' in config['training'].keys() and config.training.load_pretrained:
|
||||
from utils import restore_checkpoint_partial
|
||||
score_model = restore_checkpoint_partial(score_model, torch.load(config.training.pretrained_model_path, map_location=config.device)['model'])
|
||||
# score_model = torch.nn.DataParallel(score_model)
|
||||
return score_model
|
||||
|
||||
|
||||
def get_model_fn(model, train=False):
|
||||
"""Create a function to give the output of the score-based model.
|
||||
|
||||
Args:
|
||||
model: The score model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
|
||||
Returns:
|
||||
A model function.
|
||||
"""
|
||||
|
||||
def model_fn(x, labels, *args, **kwargs):
|
||||
"""Compute the output of the score-based model.
|
||||
|
||||
Args:
|
||||
x: A mini-batch of input data (Adjacency matrices).
|
||||
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
||||
for different models.
|
||||
mask: Mask for adjacency matrices.
|
||||
|
||||
Returns:
|
||||
A tuple of (model output, new mutable states)
|
||||
"""
|
||||
if not train:
|
||||
model.eval()
|
||||
return model(x, labels, *args, **kwargs)
|
||||
else:
|
||||
model.train()
|
||||
return model(x, labels, *args, **kwargs)
|
||||
|
||||
return model_fn
|
||||
|
||||
|
||||
def get_score_fn(sde, model, train=False, continuous=False):
|
||||
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
model: A score model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
||||
|
||||
Returns:
|
||||
A score function.
|
||||
"""
|
||||
model_fn = get_model_fn(model, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def score_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
||||
labels.long()]
|
||||
|
||||
score = -score / std[:, None, None]
|
||||
return score
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def score_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
return score
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
|
||||
return score_fn
|
||||
|
||||
|
||||
def get_classifier_grad_fn(sde, classifier, train=False, continuous=False,
|
||||
regress=True, labels='max'):
|
||||
logit_fn = get_logit_fn(sde, classifier, train, continuous)
|
||||
|
||||
def classifier_grad_fn(x, t, *args, **kwargs):
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
if regress:
|
||||
assert labels in ['max', 'min']
|
||||
logit = logit_fn(x_in, t, *args, **kwargs)
|
||||
prob = logit.sum()
|
||||
else:
|
||||
logit = logit_fn(x_in, t, *args, **kwargs)
|
||||
# prob = torch.nn.functional.log_softmax(logit, dim=-1)[torch.arange(labels.shape[0]), labels].sum()
|
||||
log_prob = F.log_softmax(logit, dim=-1)
|
||||
prob = log_prob[range(len(logit)), labels.view(-1)].sum()
|
||||
# prob.backward()
|
||||
# classifier_grad = x_in.grad
|
||||
classifier_grad = torch.autograd.grad(prob, x_in)[0]
|
||||
return classifier_grad
|
||||
|
||||
return classifier_grad_fn
|
||||
|
||||
|
||||
def get_logit_fn(sde, classifier, train=False, continuous=False):
|
||||
classifier_fn = get_model_fn(classifier, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def logit_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
# std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
# std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
||||
# labels.long()]
|
||||
|
||||
# score = -score / std[:, None, None]
|
||||
return logit
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def logit_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
return logit
|
||||
|
||||
return logit_fn
|
||||
|
||||
|
||||
def get_predictor_fn(sde, model, train=False, continuous=False):
|
||||
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
model: A predictor model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
||||
|
||||
Returns:
|
||||
A score function.
|
||||
"""
|
||||
model_fn = get_model_fn(model, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def predictor_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
||||
labels.long()]
|
||||
|
||||
# score = -score / std[:, None, None]
|
||||
return pred
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def predictor_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
return pred
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
|
||||
return predictor_fn
|
||||
|
||||
|
||||
def to_flattened_numpy(x):
|
||||
"""Flatten a torch tensor `x` and convert it to numpy."""
|
||||
return x.detach().cpu().numpy().reshape((-1,))
|
||||
|
||||
|
||||
def from_flattened_numpy(x, shape):
|
||||
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
||||
return torch.from_numpy(x.reshape(shape))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mask_adj2node(adj_mask):
|
||||
"""Convert batched adjacency mask matrices to batched node mask matrices.
|
||||
|
||||
Args:
|
||||
adj_mask: [B, N, N] Batched adjacency mask matrices without self-loop edge.
|
||||
|
||||
Output:
|
||||
node_mask: [B, N] Batched node mask matrices indicating the valid nodes.
|
||||
"""
|
||||
|
||||
batch_size, max_num_nodes, _ = adj_mask.shape
|
||||
|
||||
node_mask = adj_mask[:, 0, :].clone()
|
||||
node_mask[:, 0] = 1
|
||||
|
||||
return node_mask
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_rw_feat(k_step, dense_adj):
|
||||
"""Compute k_step Random Walk for given dense adjacency matrix."""
|
||||
|
||||
rw_list = []
|
||||
deg = dense_adj.sum(-1, keepdims=True)
|
||||
AD = dense_adj / (deg + 1e-8)
|
||||
rw_list.append(AD)
|
||||
|
||||
for _ in range(k_step):
|
||||
rw = torch.bmm(rw_list[-1], AD)
|
||||
rw_list.append(rw)
|
||||
rw_map = torch.stack(rw_list[1:], dim=1) # [B, k_step, N, N]
|
||||
|
||||
rw_landing = torch.diagonal(
|
||||
rw_map, offset=0, dim1=2, dim2=3) # [B, k_step, N]
|
||||
rw_landing = rw_landing.permute(0, 2, 1) # [B, N, rw_depth]
|
||||
|
||||
# get the shortest path distance indices
|
||||
tmp_rw = rw_map.sort(dim=1)[0]
|
||||
spd_ind = (tmp_rw <= 0).sum(dim=1) # [B, N, N]
|
||||
|
||||
spd_onehot = torch.nn.functional.one_hot(
|
||||
spd_ind, num_classes=k_step+1).to(torch.float)
|
||||
spd_onehot = spd_onehot.permute(0, 3, 1, 2) # [B, kstep, N, N]
|
||||
|
||||
return rw_landing, spd_onehot
|
Reference in New Issue
Block a user