first commit

This commit is contained in:
CownowAn
2024-03-15 14:38:51 +00:00
commit bc2ed1304f
321 changed files with 44802 additions and 0 deletions

View File

@@ -0,0 +1,117 @@
import math
import torch
from torch.nn import Parameter
import torch.nn.functional as F
from models.GDSS.layers import DenseGCNConv, MLP
# from ..utils.graph_utils import mask_adjs, mask_x
from .graph_utils import mask_x, mask_adjs
# -------- Graph Multi-Head Attention (GMH) --------
# -------- From Baek et al. (2021) --------
class Attention(torch.nn.Module):
def __init__(self, in_dim, attn_dim, out_dim, num_heads=4, conv='GCN'):
super(Attention, self).__init__()
self.num_heads = num_heads
self.attn_dim = attn_dim
self.out_dim = out_dim
self.conv = conv
self.gnn_q, self.gnn_k, self.gnn_v = self.get_gnn(in_dim, attn_dim, out_dim, conv)
self.activation = torch.tanh
self.softmax_dim = 2
def forward(self, x, adj, flags, attention_mask=None):
if self.conv == 'GCN':
Q = self.gnn_q(x, adj)
K = self.gnn_k(x, adj)
else:
Q = self.gnn_q(x)
K = self.gnn_k(x)
V = self.gnn_v(x, adj)
dim_split = self.attn_dim // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), 0)
K_ = torch.cat(K.split(dim_split, 2), 0)
if attention_mask is not None:
attention_mask = torch.cat([attention_mask for _ in range(self.num_heads)], 0)
attention_score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim)
A = self.activation( attention_mask + attention_score )
else:
A = self.activation( Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim) ) # (B x num_heads) x N x N
# -------- (B x num_heads) x N x N --------
A = A.view(-1, *adj.shape)
A = A.mean(dim=0)
A = (A + A.transpose(-1,-2))/2
return V, A
def get_gnn(self, in_dim, attn_dim, out_dim, conv='GCN'):
if conv == 'GCN':
gnn_q = DenseGCNConv(in_dim, attn_dim)
gnn_k = DenseGCNConv(in_dim, attn_dim)
gnn_v = DenseGCNConv(in_dim, out_dim)
return gnn_q, gnn_k, gnn_v
elif conv == 'MLP':
num_layers=2
gnn_q = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh)
gnn_k = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh)
gnn_v = DenseGCNConv(in_dim, out_dim)
return gnn_q, gnn_k, gnn_v
else:
raise NotImplementedError(f'{conv} not implemented.')
# -------- Layer of ScoreNetworkA --------
class AttentionLayer(torch.nn.Module):
def __init__(self, num_linears, conv_input_dim, attn_dim, conv_output_dim, input_dim, output_dim,
num_heads=4, conv='GCN'):
super(AttentionLayer, self).__init__()
self.attn = torch.nn.ModuleList()
for _ in range(input_dim):
self.attn_dim = attn_dim
self.attn.append(Attention(conv_input_dim, self.attn_dim, conv_output_dim,
num_heads=num_heads, conv=conv))
self.hidden_dim = 2*max(input_dim, output_dim)
self.mlp = MLP(num_linears, 2*input_dim, self.hidden_dim, output_dim, use_bn=False, activate_func=F.elu)
self.multi_channel = MLP(2, input_dim*conv_output_dim, self.hidden_dim, conv_output_dim,
use_bn=False, activate_func=F.elu)
def forward(self, x, adj, flags):
"""
:param x: B x N x F_i
:param adj: B x C_i x N x N
:return: x_out: B x N x F_o, adj_out: B x C_o x N x N
"""
mask_list = []
x_list = []
for _ in range(len(self.attn)):
_x, mask = self.attn[_](x, adj[:,_,:,:], flags)
mask_list.append(mask.unsqueeze(-1))
x_list.append(_x)
x_out = mask_x(self.multi_channel(torch.cat(x_list, dim=-1)), flags)
x_out = torch.tanh(x_out)
mlp_in = torch.cat([torch.cat(mask_list, dim=-1), adj.permute(0,2,3,1)], dim=-1)
shape = mlp_in.shape
mlp_out = self.mlp(mlp_in.view(-1, shape[-1]))
_adj = mlp_out.view(shape[0], shape[1], shape[2], -1).permute(0,3,1,2)
_adj = _adj + _adj.transpose(-1,-2)
adj_out = mask_adjs(_adj, flags)
return x_out, adj_out

View File

@@ -0,0 +1,209 @@
import torch
import torch.nn.functional as F
import networkx as nx
import numpy as np
# -------- Mask batch of node features with 0-1 flags tensor --------
def mask_x(x, flags):
if flags is None:
flags = torch.ones((x.shape[0], x.shape[1]), device=x.device)
return x * flags[:,:,None]
# -------- Mask batch of adjacency matrices with 0-1 flags tensor --------
def mask_adjs(adjs, flags):
"""
:param adjs: B x N x N or B x C x N x N
:param flags: B x N
:return:
"""
if flags is None:
flags = torch.ones((adjs.shape[0], adjs.shape[-1]), device=adjs.device)
if len(adjs.shape) == 4:
flags = flags.unsqueeze(1) # B x 1 x N
adjs = adjs * flags.unsqueeze(-1)
adjs = adjs * flags.unsqueeze(-2)
return adjs
# -------- Create flags tensor from graph dataset --------
def node_flags(adj, eps=1e-5):
flags = torch.abs(adj).sum(-1).gt(eps).to(dtype=torch.float32)
if len(flags.shape)==3:
flags = flags[:,0,:]
return flags
# -------- Create initial node features --------
def init_features(init, adjs=None, nfeat=10):
if init=='zeros':
feature = torch.zeros((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
elif init=='ones':
feature = torch.ones((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
elif init=='deg':
feature = adjs.sum(dim=-1).to(torch.long)
num_classes = nfeat
try:
feature = F.one_hot(feature, num_classes=num_classes).to(torch.float32)
except:
print(feature.max().item())
raise NotImplementedError(f'max_feat_num mismatch')
else:
raise NotImplementedError(f'{init} not implemented')
flags = node_flags(adjs)
return mask_x(feature, flags)
# -------- Sample initial flags tensor from the training graph set --------
def init_flags(graph_list, config, batch_size=None):
if batch_size is None:
batch_size = config.data.batch_size
max_node_num = config.data.max_node_num
graph_tensor = graphs_to_tensor(graph_list, max_node_num)
idx = np.random.randint(0, len(graph_list), batch_size)
flags = node_flags(graph_tensor[idx])
return flags
# -------- Generate noise --------
def gen_noise(x, flags, sym=True):
z = torch.randn_like(x)
if sym:
z = z.triu(1)
z = z + z.transpose(-1,-2)
z = mask_adjs(z, flags)
else:
z = mask_x(z, flags)
return z
# -------- Quantize generated graphs --------
def quantize(adjs, thr=0.5):
adjs_ = torch.where(adjs < thr, torch.zeros_like(adjs), torch.ones_like(adjs))
return adjs_
# -------- Quantize generated molecules --------
# adjs: 32 x 9 x 9
def quantize_mol(adjs):
if type(adjs).__name__ == 'Tensor':
adjs = adjs.detach().cpu()
else:
adjs = torch.tensor(adjs)
adjs[adjs >= 2.5] = 3
adjs[torch.bitwise_and(adjs >= 1.5, adjs < 2.5)] = 2
adjs[torch.bitwise_and(adjs >= 0.5, adjs < 1.5)] = 1
adjs[adjs < 0.5] = 0
return np.array(adjs.to(torch.int64))
def adjs_to_graphs(adjs, is_cuda=False):
graph_list = []
for adj in adjs:
if is_cuda:
adj = adj.detach().cpu().numpy()
G = nx.from_numpy_matrix(adj)
G.remove_edges_from(nx.selfloop_edges(G))
G.remove_nodes_from(list(nx.isolates(G)))
if G.number_of_nodes() < 1:
G.add_node(1)
graph_list.append(G)
return graph_list
# -------- Check if the adjacency matrices are symmetric --------
def check_sym(adjs, print_val=False):
sym_error = (adjs-adjs.transpose(-1,-2)).abs().sum([0,1,2])
if not sym_error < 1e-2:
raise ValueError(f'Not symmetric: {sym_error:.4e}')
if print_val:
print(f'{sym_error:.4e}')
# -------- Create higher order adjacency matrices --------
def pow_tensor(x, cnum):
# x : B x N x N
x_ = x.clone()
xc = [x.unsqueeze(1)]
for _ in range(cnum-1):
x_ = torch.bmm(x_, x)
xc.append(x_.unsqueeze(1))
xc = torch.cat(xc, dim=1)
return xc
# -------- Create padded adjacency matrices --------
def pad_adjs(ori_adj, node_number):
a = ori_adj
ori_len = a.shape[-1]
if ori_len == node_number:
return a
if ori_len > node_number:
raise ValueError(f'ori_len {ori_len} > node_number {node_number}')
a = np.concatenate([a, np.zeros([ori_len, node_number - ori_len])], axis=-1)
a = np.concatenate([a, np.zeros([node_number - ori_len, node_number])], axis=0)
return a
def graphs_to_tensor(graph_list, max_node_num):
adjs_list = []
max_node_num = max_node_num
for g in graph_list:
assert isinstance(g, nx.Graph)
node_list = []
for v, feature in g.nodes.data('feature'):
node_list.append(v)
adj = nx.to_numpy_matrix(g, nodelist=node_list)
padded_adj = pad_adjs(adj, node_number=max_node_num)
adjs_list.append(padded_adj)
del graph_list
adjs_np = np.asarray(adjs_list)
del adjs_list
adjs_tensor = torch.tensor(adjs_np, dtype=torch.float32)
del adjs_np
return adjs_tensor
def graphs_to_adj(graph, max_node_num):
max_node_num = max_node_num
assert isinstance(graph, nx.Graph)
node_list = []
for v, feature in graph.nodes.data('feature'):
node_list.append(v)
adj = nx.to_numpy_matrix(graph, nodelist=node_list)
padded_adj = pad_adjs(adj, node_number=max_node_num)
adj = torch.tensor(padded_adj, dtype=torch.float32)
del padded_adj
return adj
def node_feature_to_matrix(x):
"""
:param x: BS x N x F
:return:
x_pair: BS x N x N x 2F
"""
x_b = x.unsqueeze(-2).expand(x.size(0), x.size(1), x.size(1), -1) # BS x N x N x F
x_pair = torch.cat([x_b, x_b.transpose(1, 2)], dim=-1) # BS x N x N x 2F
return x_pair

View File

@@ -0,0 +1,153 @@
import torch
from torch.nn import Parameter
import torch.nn.functional as F
import math
from typing import Any
def glorot(tensor):
if tensor is not None:
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
tensor.data.uniform_(-stdv, stdv)
def zeros(tensor):
if tensor is not None:
tensor.data.fill_(0)
def reset(value: Any):
if hasattr(value, 'reset_parameters'):
value.reset_parameters()
else:
for child in value.children() if hasattr(value, 'children') else []:
reset(child)
# -------- GCN layer --------
class DenseGCNConv(torch.nn.Module):
r"""See :class:`torch_geometric.nn.conv.GCNConv`.
"""
def __init__(self, in_channels, out_channels, improved=False, bias=True):
super(DenseGCNConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.improved = improved
self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
def forward(self, x, adj, mask=None, add_loop=True):
r"""
Args:
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
\times N \times F}`, with batch-size :math:`B`, (maximum)
number of nodes :math:`N` for each graph, and feature
dimension :math:`F`.
adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B
\times N \times N}`. The adjacency tensor is broadcastable in
the batch dimension, resulting in a shared adjacency matrix for
the complete batch.
mask (BoolTensor, optional): Mask matrix
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
the valid nodes for each graph. (default: :obj:`None`)
add_loop (bool, optional): If set to :obj:`False`, the layer will
not automatically add self-loops to the adjacency matrices.
(default: :obj:`True`)
"""
x = x.unsqueeze(0) if x.dim() == 2 else x
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
B, N, _ = adj.size()
if add_loop:
adj = adj.clone()
idx = torch.arange(N, dtype=torch.long, device=adj.device)
adj[:, idx, idx] = 1 if not self.improved else 2
out = torch.matmul(x, self.weight)
deg_inv_sqrt = adj.sum(dim=-1).clamp(min=1).pow(-0.5)
adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
out = torch.matmul(adj, out)
if self.bias is not None:
out = out + self.bias
if mask is not None:
out = out * mask.view(B, N, 1).to(x.dtype)
return out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
# -------- MLP layer --------
class MLP(torch.nn.Module):
def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu):
"""
num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
input_dim: dimensionality of input features
hidden_dim: dimensionality of hidden units at ALL layers
output_dim: number of classes for prediction
num_classes: the number of classes of input, to be treated with different gains and biases,
(see the definition of class `ConditionalLayer1d`)
"""
super(MLP, self).__init__()
self.linear_or_not = True # default is linear model
self.num_layers = num_layers
self.use_bn = use_bn
self.activate_func = activate_func
if num_layers < 1:
raise ValueError("number of layers should be positive!")
elif num_layers == 1:
# Linear model
self.linear = torch.nn.Linear(input_dim, output_dim)
else:
# Multi-layer model
self.linear_or_not = False
self.linears = torch.nn.ModuleList()
self.linears.append(torch.nn.Linear(input_dim, hidden_dim))
for layer in range(num_layers - 2):
self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim))
self.linears.append(torch.nn.Linear(hidden_dim, output_dim))
if self.use_bn:
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layers - 1):
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
def forward(self, x):
"""
:param x: [num_classes * batch_size, N, F_i], batch of node features
note that in self.cond_layers[layer],
`x` is splited into `num_classes` groups in dim=0,
and then treated with different gains and biases
"""
if self.linear_or_not:
# If linear model
return self.linear(x)
else:
# If MLP
h = x
for layer in range(self.num_layers - 1):
h = self.linears[layer](h)
if self.use_bn:
h = self.batch_norms[layer](h)
h = self.activate_func(h)
return self.linears[self.num_layers - 1](h)

View File

@@ -0,0 +1,103 @@
import torch
import torch.nn.functional as F
from models.GDSS.layers import DenseGCNConv, MLP
from .graph_utils import mask_x, pow_tensor
from .attention import AttentionLayer
from .. import utils
@utils.register_model(name='ScoreNetworkX')
class ScoreNetworkX(torch.nn.Module):
# def __init__(self, max_feat_num, depth, nhid):
def __init__(self, config):
super(ScoreNetworkX, self).__init__()
self.nfeat = config.data.n_vocab
self.depth = config.model.depth
self.nhid = config.model.nhid
self.layers = torch.nn.ModuleList()
for _ in range(self.depth):
if _ == 0:
self.layers.append(DenseGCNConv(self.nfeat, self.nhid))
else:
self.layers.append(DenseGCNConv(self.nhid, self.nhid))
self.fdim = self.nfeat + self.depth * self.nhid
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=self.nfeat,
use_bn=False, activate_func=F.elu)
self.activation = torch.tanh
def forward(self, x, time_cond, maskX, flags=None):
x_list = [x]
for _ in range(self.depth):
x = self.layers[_](x, maskX)
x = self.activation(x)
x_list.append(x)
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
out_shape = (x.shape[0], x.shape[1], -1)
x = self.final(xs).view(*out_shape)
x = mask_x(x, flags)
return x
@utils.register_model(name='ScoreNetworkX_GMH')
class ScoreNetworkX_GMH(torch.nn.Module):
# def __init__(self, max_feat_num, depth, nhid, num_linears,
# c_init, c_hid, c_final, adim, num_heads=4, conv='GCN'):
def __init__(self, config):
super().__init__()
self.max_feat_num = config.data.n_vocab
self.depth = config.model.depth
self.nhid = config.model.nhid
self.c_init = config.model.c_init
self.c_hid = config.model.c_hid
self.c_final = config.model.c_final
self.num_linears = config.model.num_linears
self.num_heads = config.model.num_heads
self.conv = config.model.conv
self.adim = config.model.adim
self.layers = torch.nn.ModuleList()
for _ in range(self.depth):
if _ == 0:
self.layers.append(AttentionLayer(self.num_linears, self.max_feat_num,
self.nhid, self.nhid, self.c_init,
self.c_hid, self.num_heads, self.conv))
elif _ == self.depth - 1:
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
self.nhid, self.c_hid,
self.c_final, self.num_heads, self.conv))
else:
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
self.nhid, self.c_hid,
self.c_hid, self.num_heads, self.conv))
fdim = self.max_feat_num + self.depth * self.nhid
self.final = MLP(num_layers=3, input_dim=fdim, hidden_dim=2*fdim, output_dim=self.max_feat_num,
use_bn=False, activate_func=F.elu)
self.activation = torch.tanh
def forward(self, x, time_cond, maskX, flags=None):
adjc = pow_tensor(maskX, self.c_init)
x_list = [x]
for _ in range(self.depth):
x, adjc = self.layers[_](x, adjc, flags)
x = self.activation(x)
x_list.append(x)
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
out_shape = (x.shape[0], x.shape[1], -1)
x = self.final(xs).view(*out_shape)
x = mask_x(x, flags)
return x

0
MobileNetV3/models/__init__.py Executable file
View File

352
MobileNetV3/models/cate.py Normal file
View File

@@ -0,0 +1,352 @@
import torch.nn as nn
import torch
import functools
from torch_geometric.utils import dense_to_sparse
from . import utils, layers, gnns
import torch
import torch.nn as nn
import torch.nn.functional as F
from .transformer import Encoder, SemanticEmbedding
from models.GDSS.layers import MLP
from .set_encoder.setenc_models import SetPool
""" Transformer Encoder """
class GraphEncoder(nn.Module):
def __init__(self, config):
super(GraphEncoder, self).__init__()
# Forward Transformers
self.encoder_f = Encoder(config)
def forward(self, x, mask):
h_f, hs_f, attns_f = self.encoder_f(x, mask)
h = torch.cat(hs_f, dim=-1)
return h
@staticmethod
def get_embeddings(h_x):
h_x = h_x.cpu()
return h_x[:, -1]
class CLSHead(nn.Module):
def __init__(self, config, init_weights=None):
super(CLSHead, self).__init__()
self.layer_1 = nn.Linear(config.d_model, config.d_model)
self.dropout = nn.Dropout(p=config.dropout)
self.layer_2 = nn.Linear(config.d_model, config.n_vocab)
if init_weights is not None:
self.layer_2.weight = init_weights
def forward(self, x):
x = self.dropout(torch.tanh(self.layer_1(x)))
return F.log_softmax(self.layer_2(x), dim=-1)
@utils.register_model(name='CATE')
class CATE(nn.Module):
def __init__(self, config):
super(CATE, self).__init__()
# Shared Embedding Layer
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
self.dropout_op = nn.Dropout(p=config.model.dropout)
self.d_model = config.model.graph_encoder.d_model
self.act = act = get_act(config)
# Time
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
# 2 GraphEncoder for X and Y
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
use_bn=False, activate_func=F.elu)
self.pos_enc_type = config.model.pos_enc_type
self.pos_encoder = PositionalEncoding_StageWise(d_model=self.d_model, max_len=config.data.max_node)
def forward(self, X, time_cond, maskX):
# Shared Embeddings
emb_x = self.dropout_op(self.opEmb(X))
if self.pos_encoder is not None:
emb_p = self.pos_encoder(emb_x) # [20, 64]
emb_x = emb_x + emb_p
# Time embedding
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
emb_t = self.timeEmb1(emb_t) # [32, 512]
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
emb_t = emb_t.unsqueeze(1)
emb = emb_x + emb_t
h_x = self.graph_encoder(emb, maskX)
h_x = self.final(h_x)
"""
Shape: Batch Size, Length (with Pad), Feature Dim (forward) + Feature Dim (backward)
*HINT: X1 X2 X3 [PAD] [PAD]
"""
return h_x
@utils.register_model(name='PredictorCATE')
class PredictorCATE(nn.Module):
def __init__(self, config):
super(PredictorCATE, self).__init__()
# Shared Embedding Layer
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
self.dropout_op = nn.Dropout(p=config.model.dropout)
self.d_model = config.model.graph_encoder.d_model
self.act = act = get_act(config)
# Time
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
# 2 GraphEncoder for X and Y
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
use_bn=False, activate_func=F.elu)
self.rdim = int(config.data.max_node * config.data.n_vocab)
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=1,
use_bn=False, activate_func=F.elu)
def forward(self, X, time_cond, maskX):
# Shared Embeddings
emb_x = self.dropout_op(self.opEmb(X))
# Time embedding
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
emb_t = self.timeEmb1(emb_t) # [32, 512]
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
emb_t = emb_t.unsqueeze(1)
emb = emb_x + emb_t
# h_x = self.graph_encoder(emb_x, maskX)
h_x = self.graph_encoder(emb, maskX)
h_x = self.final(h_x)
"""
Shape: Batch Size, Length (with Pad), Feature Dim (forward) + Feature Dim (backward)
*HINT: X1 X2 X3 [PAD] [PAD]
"""
h_x = h_x.reshape(h_x.size(0), -1)
h_x = self.regeress(h_x)
return h_x
class PositionalEncoding_StageWise(nn.Module):
def __init__(self, d_model, max_len):
super(PositionalEncoding_StageWise, self).__init__()
NUM_STAGE = 5
max_len = int(max_len / NUM_STAGE)
self.encoding = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len)
pos = pos.float().unsqueeze(dim=1)
_2i = torch.arange(0, d_model, step=2).float()
# (max_len, 1) / (d_model/2 ) -> (max_len, d_model/2)
self.encoding[:, ::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model))) # (4, 64)
self.encoding = torch.cat([self.encoding] * NUM_STAGE, dim=0)
def forward(self, x):
batch_size, seq_len, _ = x.size()
return self.encoding[:seq_len, :].to(x.device)
@utils.register_model(name='MetaPredictorCATE')
class MetaPredictorCATE(nn.Module):
def __init__(self, config):
super(MetaPredictorCATE, self).__init__()
self.input_type= config.model.input_type
self.hs = config.model.hs
# Shared Embedding Layer
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
self.dropout_op = nn.Dropout(p=config.model.dropout)
self.d_model = config.model.graph_encoder.d_model
self.act = act = get_act(config)
# Time
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
# 2 GraphEncoder for X and Y
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
use_bn=False, activate_func=F.elu)
self.rdim = int(config.data.max_node * config.data.n_vocab)
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=2*self.rdim,
use_bn=False, activate_func=F.elu)
# Set
self.nz = config.model.nz
self.num_sample = config.model.num_sample
self.intra_setpool = SetPool(dim_input=512,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.inter_setpool = SetPool(dim_input=self.nz,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.set_fc = nn.Sequential(
nn.Linear(512, self.nz),
nn.ReLU())
input_dim = 0
if 'D' in self.input_type:
input_dim += self.nz
if 'A' in self.input_type:
input_dim += 2*self.rdim
self.pred_fc = nn.Sequential(
nn.Linear(input_dim, self.hs),
nn.Tanh(),
nn.Linear(self.hs, 1)
)
self.sample_state = False
self.D_mu = None
def arch_encode(self, X, time_cond, maskX):
# Shared Embeddings
emb_x = self.dropout_op(self.opEmb(X))
# Time embedding
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
emb_t = self.timeEmb1(emb_t) # [32, 512]
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
emb_t = emb_t.unsqueeze(1)
emb = emb_x + emb_t
h_x = self.graph_encoder(emb, maskX)
h_x = self.final(h_x)
h_x = h_x.reshape(h_x.size(0), -1)
h_x = self.regeress(h_x)
return h_x
def set_encode(self, task):
proto_batch = []
for x in task:
cls_protos = self.intra_setpool(
x.view(-1, self.num_sample, 512)).squeeze(1)
proto_batch.append(
self.inter_setpool(cls_protos.unsqueeze(0)))
v = torch.stack(proto_batch).squeeze()
return v
def predict(self, D_mu, A_mu):
input_vec = []
if 'D' in self.input_type:
input_vec.append(D_mu)
if 'A' in self.input_type:
input_vec.append(A_mu)
input_vec = torch.cat(input_vec, dim=1)
return self.pred_fc(input_vec)
def forward(self, X, time_cond, maskX, task):
if self.sample_state:
if self.D_mu is None:
self.D_mu = self.set_encode(task)
D_mu = self.D_mu
else:
D_mu = self.set_encode(task)
A_mu = self.arch_encode(X, time_cond, maskX)
y_pred = self.predict(D_mu, A_mu)
return y_pred
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
import math
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')

View File

@@ -0,0 +1,355 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops.layers.torch import Rearrange
from einops import rearrange
import numpy as np
from . import utils
class SinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions, embedding_dim):
super().__init__(num_positions, embedding_dim) # torch.nn.Embedding(num_embeddings, embedding_dim)
self.weight = self._init_weight(self.weight) # self.weight => nn.Embedding(num_positions, embedding_dim).weight
@staticmethod
def _init_weight(out: nn.Parameter):
n_pos, embed_dim = out.shape
pe = nn.Parameter(torch.zeros(out.shape))
for pos in range(n_pos):
for i in range(0, embed_dim, 2):
pe[pos, i].data.copy_( torch.tensor( np.sin(pos / (10000 ** ( i / embed_dim)))) )
pe[pos, i + 1].data.copy_( torch.tensor( np.cos(pos / (10000 ** ((i + 1) / embed_dim)))) )
pe.detach_()
return pe
@torch.no_grad()
def forward(self, input_ids):
bsz, seq_len = input_ids.shape[:2] # for x, seq_len = max_node_num
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
return super().forward(positions)
class MLP(nn.Module):
def __init__(
self,
dim_in,
dim_out,
*,
expansion_factor = 2.,
depth = 2,
norm = False,
):
super().__init__()
hidden_dim = int(expansion_factor * dim_out)
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
layers = [nn.Sequential(
nn.Linear(dim_in, hidden_dim),
nn.SiLU(),
norm_fn()
)]
for _ in range(depth - 1):
layers.append(nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
norm_fn()
))
layers.append(nn.Linear(hidden_dim, dim_out))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x.float())
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
dtype, device = x.dtype, x.device
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
class PositionWiseFeedForward(nn.Module):
def __init__(self, emb_dim: int, d_ff: int, dropout: float = 0.1):
super(PositionWiseFeedForward, self).__init__()
self.activation = nn.ReLU()
self.w_1 = nn.Linear(emb_dim, d_ff)
self.w_2 = nn.Linear(d_ff, emb_dim)
self.dropout = dropout
def forward(self, x):
residual = x
x = self.activation(self.w_1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.w_2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
return x + residual # residual connection for preventing gradient vanishing
class MultiHeadAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
emb_dim,
num_heads,
dropout=0.0,
bias=False,
encoder_decoder_attention=False, # otherwise self_attention
causal = True
):
super().__init__()
self.emb_dim = emb_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = emb_dim // num_heads
assert self.head_dim * num_heads == self.emb_dim, "emb_dim must be divisible by num_heads"
self.encoder_decoder_attention = encoder_decoder_attention
self.causal = causal
self.q_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
self.k_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
self.v_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
self.out_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_heads,
self.head_dim,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
# This is equivalent to
# return x.transpose(1,2)
def scaled_dot_product(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.BoolTensor):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.emb_dim) # QK^T/sqrt(d)
if attention_mask is not None:
attn_weights = attn_weights.masked_fill(attention_mask.unsqueeze(1), float("-inf"))
attn_weights = F.softmax(attn_weights, dim=-1) # softmax(QK^T/sqrt(d))
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_probs, value) # softmax(QK^T/sqrt(d))V
return attn_output, attn_probs
def MultiHead_scaled_dot_product(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.BoolTensor):
attention_mask = attention_mask.bool()
attn_weights = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.head_dim) # QK^T/sqrt(d) # [6, 6]
# Attention mask
if attention_mask is not None:
if self.causal:
# (seq_len x seq_len)
attn_weights = attn_weights.masked_fill(attention_mask.unsqueeze(0).unsqueeze(1), float("-inf"))
else:
# (batch_size x seq_len)
attn_weights = attn_weights.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
attn_weights = F.softmax(attn_weights, dim=-1) # softmax(QK^T/sqrt(d))
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_probs, value) # softmax(QK^T/sqrt(d))V
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
concat_attn_output_shape = attn_output.size()[:-2] + (self.emb_dim,)
attn_output = attn_output.view(*concat_attn_output_shape)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
attention_mask: torch.Tensor = None,
):
q = self.q_proj(query)
# Enc-Dec attention
if self.encoder_decoder_attention:
k = self.k_proj(key)
v = self.v_proj(key)
# Self attention
else:
k = self.k_proj(query)
v = self.v_proj(query)
q = self.transpose_for_scores(q)
k = self.transpose_for_scores(k)
v = self.transpose_for_scores(v)
attn_output, attn_weights = self.MultiHead_scaled_dot_product(q,k,v,attention_mask)
return attn_output, attn_weights
class EncoderLayer(nn.Module):
def __init__(self, emb_dim, ffn_dim, attention_heads,
attention_dropout, dropout):
super().__init__()
self.emb_dim = emb_dim
self.ffn_dim = ffn_dim
self.self_attn = MultiHeadAttention(
emb_dim=self.emb_dim,
num_heads=attention_heads,
dropout=attention_dropout)
self.self_attn_layer_norm = nn.LayerNorm(self.emb_dim)
self.dropout = dropout
self.activation_fn = nn.ReLU()
self.PositionWiseFeedForward = PositionWiseFeedForward(self.emb_dim, self.ffn_dim, dropout)
self.final_layer_norm = nn.LayerNorm(self.emb_dim)
def forward(self, x, encoder_padding_mask):
residual = x
x, attn_weights = self.self_attn(query=x, key=x, attention_mask=encoder_padding_mask)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
x = self.PositionWiseFeedForward(x)
x = self.final_layer_norm(x)
if torch.isinf(x).any() or torch.isnan(x).any():
clamp_value = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
return x, attn_weights
@utils.register_model(name='DAGformer')
class DAGformer(torch.nn.Module):
def __init__(self, config):
# max_feat_num,
# max_node_num,
# emb_dim,
# ffn_dim,
# encoder_layers,
# attention_heads,
# attention_dropout,
# dropout,
# hs,
# time_dep=True,
# num_timesteps=None,
# return_attn=False,
# except_inout=False,
# connect_prev=True
# ):
super().__init__()
self.dropout = config.model.dropout
self.time_dep = config.model.time_dep
self.return_attn = config.model.return_attn
max_feat_num = config.data.n_vocab
max_node_num = config.data.max_node
emb_dim = config.model.emb_dim
# num_timesteps = config.model.num_scales
num_timesteps = None
self.x_embedding = MLP(max_feat_num, emb_dim)
# position embedding with topological order
self.position_embedding = SinusoidalPositionalEmbedding(max_node_num, emb_dim)
if self.time_dep:
self.time_embedding = nn.Sequential(
nn.Embedding(num_timesteps, emb_dim) if num_timesteps is not None
else nn.Sequential(SinusoidalPosEmb(emb_dim), MLP(emb_dim, emb_dim)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n=1)
)
self.layers = nn.ModuleList([EncoderLayer(emb_dim,
config.model.ffn_dim,
config.model.attention_heads,
config.model.attention_dropout,
config.model.dropout)
for _ in range(config.model.encoder_layers)])
self.pred_fc = nn.Sequential(
nn.Linear(emb_dim, config.model.hs),
nn.Tanh(),
nn.Linear(config.model.hs, 1),
# nn.Sigmoid()
)
# -------- Load Constant Adj Matrix (START) --------- #
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# from utils.graph_utils import get_const_adj
# mat = get_const_adj(
# except_inout=except_inout,
# shape_adj=(1, max_node_num, max_node_num),
# device=torch.device('cpu'),
# connect_prev=connect_prev)[0].cpu()
# is_triu_ = is_triu(mat)
# if is_triu_:
# self.adj_ = mat.T.to(self.device)
# else:
# self.adj_ = mat.to(self.device)
# -------- Load Constant Adj Matrix (END) --------- #
def forward(self, x, t, adj, flags=None):
"""
:param x: B x N x F_i
:param adjs: B x C_i x N x N
:return: x_o: B x N x F_o, new_adjs: B x C_o x N x N
"""
assert len(x.shape) == 3
self_attention_mask = torch.eye(adj.size(1)).to(self.device)
# attention_mask = 1. - (self_attention_mask + self.adj_)
attention_mask = 1. - (self_attention_mask + adj[0])
# -------- Generate input for DAGformer ------- #
x_embed = self.x_embedding(x)
# x_embed = x
x_pos = self.position_embedding(x).unsqueeze(0)
if self.time_dep:
time_embed = self.time_embedding(t)
x = x_embed + x_pos
if self.time_dep:
x = x + time_embed
x = F.dropout(x, p=self.dropout, training=self.training)
self_attn_scores = []
for encoder_layer in self.layers:
x, attn = encoder_layer(x, attention_mask)
self_attn_scores.append(attn.detach())
x = self.pred_fc(x[:, -1, :]) # [256, 16]
if self.return_attn:
return x, self_attn_scores
else:
return x

142
MobileNetV3/models/digcn.py Normal file
View File

@@ -0,0 +1,142 @@
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
# which was authored by Yuge Zhang, 2020
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import utils
from models.cate import PositionalEncoding_StageWise
def normalize_adj(adj):
# Row-normalize matrix
last_dim = adj.size(-1)
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
return torch.div(adj, rowsum)
def graph_pooling(inputs, num_vertices):
num_vertices = num_vertices.to(inputs.device)
out = inputs.sum(1)
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
class DirectedGraphConvolution(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
self.dropout = nn.Dropout(0.1)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight1.data)
nn.init.xavier_uniform_(self.weight2.data)
def forward(self, inputs, adj):
inputs = inputs.to(self.weight1.device)
adj = adj.to(self.weight1.device)
norm_adj = normalize_adj(adj)
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
out = (output1 + output2) / 2
out = self.dropout(out)
return out
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
# if nasbench-101: initial_hidden=5. if nasbench-201: initial_hidden=7
@utils.register_model(name='NeuralPredictor')
class NeuralPredictor(nn.Module):
# def __init__(self, initial_hidden=5, gcn_hidden=144, gcn_layers=4, linear_hidden=128):
def __init__(self, config):
super().__init__()
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
config.model.graph_encoder.gcn_hidden)
for i in range(config.model.graph_encoder.gcn_layers)]
self.gcn = nn.ModuleList(self.gcn)
self.dropout = nn.Dropout(0.1)
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
self.fc2 = nn.Linear(config.model.graph_encoder.linear_hidden, 1, bias=False)
# Time
self.d_model = config.model.graph_encoder.gcn_hidden
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
self.act = act = get_act(config)
# self.pos_enc_type = config.model.pos_enc_type
# if self.pos_enc_type == 1:
# raise NotImplementedError
# elif self.pos_enc_type == 2:
# self.pos_encoder = PositionalEncoding_StageWise(d_model=config.model.graph_encoder.gcn_hidden, max_len=config.data.max_node)
# elif self.pos_enc_type == 3:
# raise NotImplementedError
# else:
# self.pos_encoder = None
# def forward(self, inputs):
def forward(self, X, time_cond, maskX):
# numv, adj, out = inputs["num_vertices"], inputs["adjacency"], inputs["operations"]
out = X # (5, 20, 10)
adj = maskX # (1, 20, 20)
# # pos embedding
# if self.pos_encoder is not None:
# emb_p = self.pos_encoder(out) # [20, 64]
# out = out + emb_p
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device) # 20
gs = adj.size(1) # graph node number
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
emb_t = self.timeEmb1(emb_t)
emb_t = self.timeEmb2(self.act(emb_t)) # (5, 144)
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device)) # assuming diagonal is not 1
for layer in self.gcn:
out = layer(out, adj_with_diag)
out = graph_pooling(out, numv) # out: 5, 20, 144
# time
out = out + emb_t
out = self.fc1(out) # (5, 128)
out = self.dropout(out)
# out = self.fc2(out).view(-1)
out = self.fc2(out)
return out
import math
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')

View File

@@ -0,0 +1,194 @@
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
# which was authored by Yuge Zhang, 2020
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import utils
from .set_encoder.setenc_models import SetPool
def normalize_adj(adj):
# Row-normalize matrix
last_dim = adj.size(-1)
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
return torch.div(adj, rowsum)
def graph_pooling(inputs, num_vertices):
num_vertices = num_vertices.to(inputs.device)
out = inputs.sum(1)
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
class DirectedGraphConvolution(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
self.dropout = nn.Dropout(0.1)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight1.data)
nn.init.xavier_uniform_(self.weight2.data)
def forward(self, inputs, adj):
inputs = inputs.to(self.weight1.device)
adj = adj.to(self.weight1.device)
norm_adj = normalize_adj(adj)
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
out = (output1 + output2) / 2
out = self.dropout(out)
return out
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
# if nasbench-101: initial_hidden=5. if nasbench-201: initial_hidden=7
@utils.register_model(name='MetaNeuralPredictor')
class MetaeuralPredictor(nn.Module):
# def __init__(self, initial_hidden=5, gcn_hidden=144, gcn_layers=4, linear_hidden=128):
def __init__(self, config):
super().__init__()
# Arch
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
config.model.graph_encoder.gcn_hidden)
for i in range(config.model.graph_encoder.gcn_layers)]
self.gcn = nn.ModuleList(self.gcn)
self.dropout = nn.Dropout(0.1)
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
# self.fc2 = nn.Linear(config.model.graph_encoder.linear_hidden, 1, bias=False)
# Time
self.d_model = config.model.graph_encoder.gcn_hidden
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
self.act = act = get_act(config)
self.input_type = config.model.input_type
self.hs = config.model.hs
# Set
self.nz = config.model.nz
self.num_sample = config.model.num_sample
self.intra_setpool = SetPool(dim_input=512,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.inter_setpool = SetPool(dim_input=self.nz,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.set_fc = nn.Sequential(
nn.Linear(512, self.nz),
nn.ReLU())
input_dim = 0
if 'D' in self.input_type:
input_dim += self.nz
if 'A' in self.input_type:
input_dim += config.model.graph_encoder.linear_hidden
self.pred_fc = nn.Sequential(
nn.Linear(input_dim, self.hs),
nn.Tanh(),
nn.Linear(self.hs, 1)
)
self.sample_state = False
self.D_mu = None
def arch_encode(self, X, time_cond, maskX):
# numv, adj, out = inputs["num_vertices"], inputs["adjacency"], inputs["operations"]
out = X
adj = maskX
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device)
gs = adj.size(1) # graph node number
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
emb_t = self.timeEmb1(emb_t)
emb_t = self.timeEmb2(self.act(emb_t))
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device)) # assuming diagonal is not 1
for layer in self.gcn:
out = layer(out, adj_with_diag)
out = graph_pooling(out, numv)
# time
out = out + emb_t
out = self.fc1(out)
out = self.dropout(out)
# out = self.fc2(out).view(-1)
# out = self.fc2(out)
return out
def set_encode(self, task):
proto_batch = []
for x in task:
cls_protos = self.intra_setpool(
x.view(-1, self.num_sample, 512)).squeeze(1)
proto_batch.append(
self.inter_setpool(cls_protos.unsqueeze(0)))
v = torch.stack(proto_batch).squeeze()
return v
def predict(self, D_mu, A_mu):
input_vec = []
if 'D' in self.input_type:
input_vec.append(D_mu)
if 'A' in self.input_type:
input_vec.append(A_mu)
input_vec = torch.cat(input_vec, dim=1)
return self.pred_fc(input_vec)
def forward(self, X, time_cond, maskX, task):
if self.sample_state:
if self.D_mu is None:
self.D_mu = self.set_encode(task)
D_mu = self.D_mu
else:
D_mu = self.set_encode(task)
A_mu = self.arch_encode(X, time_cond, maskX)
y_pred = self.predict(D_mu, A_mu)
return y_pred
import math
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')

85
MobileNetV3/models/ema.py Normal file
View File

@@ -0,0 +1,85 @@
import torch
class ExponentialMovingAverage:
"""
Maintains (exponential) moving average of a set of parameters.
"""
def __init__(self, parameters, decay, use_num_updates=True):
"""
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`.
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing averages.
"""
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.num_updates = 0 if use_num_updates else None
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
self.collected_params = []
def update(self, parameters):
"""
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of the `optimizer.step()` call.
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to
initialize this object.
"""
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
s_param.sub_(one_minus_decay * (s_param - param))
def copy_to(self, parameters):
"""
Copy current parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages.
"""
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
param.data.copy_(s_param.data)
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the original optimization process.
Store the parameters before the `copy_to` method.
After validation (or model saving), use this to restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
def state_dict(self):
return dict(decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params)
def load_state_dict(self, state_dict):
self.decay = state_dict['decay']
self.num_updates = state_dict['num_updates']
self.shadow_params = state_dict['shadow_params']

View File

@@ -0,0 +1,82 @@
import torch.nn as nn
import torch
from .trans_layers import *
class pos_gnn(nn.Module):
def __init__(self, act, x_ch, pos_ch, out_ch, max_node, graph_layer, n_layers=3, edge_dim=None, heads=4,
temb_dim=None, dropout=0.1, attn_clamp=False):
super().__init__()
self.out_ch = out_ch
self.Dropout_0 = nn.Dropout(dropout)
self.act = act
self.max_node = max_node
self.n_layers = n_layers
if temb_dim is not None:
self.Dense_node0 = nn.Linear(temb_dim, x_ch)
self.Dense_node1 = nn.Linear(temb_dim, pos_ch)
self.Dense_edge0 = nn.Linear(temb_dim, edge_dim)
self.Dense_edge1 = nn.Linear(temb_dim, edge_dim)
self.convs = nn.ModuleList()
self.edge_convs = nn.ModuleList()
self.edge_layer = nn.Linear(edge_dim * 2 + self.out_ch, edge_dim)
for i in range(n_layers):
if i == 0:
self.convs.append(eval(graph_layer)(x_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2,
act=act, attn_clamp=attn_clamp))
else:
self.convs.append(eval(graph_layer)
(self.out_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2, act=act,
attn_clamp=attn_clamp))
self.edge_convs.append(nn.Linear(self.out_ch, edge_dim*2))
def forward(self, x_degree, x_pos, edge_index, dense_ori, dense_spd, dense_index, temb=None):
"""
Args:
x_degree: node degree feature [B*N, x_ch]
x_pos: node rwpe feature [B*N, pos_ch]
edge_index: [2, edge_length]
dense_ori: edge feature [B, N, N, nf//2]
dense_spd: edge shortest path distance feature [B, N, N, nf//2] # Do we need this part? # TODO
dense_index
temb: [B, temb_dim]
"""
B, N, _, _ = dense_ori.shape
if temb is not None:
dense_ori = dense_ori + self.Dense_edge0(self.act(temb))[:, None, None, :]
dense_spd = dense_spd + self.Dense_edge1(self.act(temb))[:, None, None, :]
temb = temb.unsqueeze(1).repeat(1, self.max_node, 1)
temb = temb.reshape(-1, temb.shape[-1])
x_degree = x_degree + self.Dense_node0(self.act(temb))
x_pos = x_pos + self.Dense_node1(self.act(temb))
dense_edge = torch.cat([dense_ori, dense_spd], dim=-1)
ori_edge_attr = dense_edge
h = x_degree
h_pos = x_pos
for i_layer in range(self.n_layers):
h_edge = dense_edge[dense_index]
# update node feature
h, h_pos = self.convs[i_layer](h, h_pos, edge_index, h_edge)
h = self.Dropout_0(h)
h_pos = self.Dropout_0(h_pos)
# update dense edge feature
h_dense_node = h.reshape(B, N, -1)
cur_edge_attr = h_dense_node.unsqueeze(1) + h_dense_node.unsqueeze(2) # [B, N, N, nf]
dense_edge = (dense_edge + self.act(self.edge_convs[i_layer](cur_edge_attr))) / math.sqrt(2.)
dense_edge = self.Dropout_0(dense_edge)
# Concat edge attribute
h_dense_edge = torch.cat([ori_edge_attr, dense_edge], dim=-1)
h_dense_edge = self.edge_layer(h_dense_edge).permute(0, 3, 1, 2)
return h_dense_edge

View File

@@ -0,0 +1,44 @@
"""Common layers"""
import torch.nn as nn
import torch
import torch.nn.functional as F
import math
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')
def conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, padding=0):
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
padding=padding)
return conv
# from DDPM
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb

171
MobileNetV3/models/pgsn.py Normal file
View File

@@ -0,0 +1,171 @@
import torch.nn as nn
import torch
import functools
from torch_geometric.utils import dense_to_sparse
from . import utils, layers, gnns
get_act = layers.get_act
conv1x1 = layers.conv1x1
@utils.register_model(name='PGSN')
class PGSN(nn.Module):
"""Position enhanced graph score network."""
def __init__(self, config):
super().__init__()
self.config = config
self.act = act = get_act(config)
# get model construction paras
self.nf = nf = config.model.nf
self.num_gnn_layers = num_gnn_layers = config.model.num_gnn_layers
dropout = config.model.dropout
self.embedding_type = embedding_type = config.model.embedding_type.lower()
self.rw_depth = rw_depth = config.model.rw_depth
self.edge_th = config.model.edge_th
modules = []
# timestep/noise_level embedding; only for continuous training
if embedding_type == 'positional':
embed_dim = nf
else:
raise ValueError(f'embedding type {embedding_type} unknown.')
# timestep embedding layers
modules.append(nn.Linear(embed_dim, nf * 4))
modules.append(nn.Linear(nf * 4, nf * 4))
# graph size condition embedding
self.size_cond = size_cond = config.model.size_cond
if size_cond:
self.size_onehot = functools.partial(nn.functional.one_hot, num_classes=config.data.max_node + 1)
modules.append(nn.Linear(config.data.max_node + 1, nf * 4))
modules.append(nn.Linear(nf * 4, nf * 4))
channels = config.data.num_channels
assert channels == 1, "Without edge features."
# degree onehot
self.degree_max = self.config.data.max_node // 2
self.degree_onehot = functools.partial(
nn.functional.one_hot,
num_classes=self.degree_max + 1)
# project edge features
modules.append(conv1x1(channels, nf // 2))
modules.append(conv1x1(rw_depth + 1, nf // 2))
# project node features
self.x_ch = nf
self.pos_ch = nf // 2
modules.append(nn.Linear(self.degree_max + 1, self.x_ch))
modules.append(nn.Linear(rw_depth, self.pos_ch))
# GNN
modules.append(gnns.pos_gnn(act, self.x_ch, self.pos_ch, nf, config.data.max_node,
config.model.graph_layer, num_gnn_layers,
heads=config.model.heads, edge_dim=nf//2, temb_dim=nf * 4,
dropout=dropout, attn_clamp=config.model.attn_clamp))
# output
modules.append(conv1x1(nf // 2, nf // 2))
modules.append(conv1x1(nf // 2, channels))
self.all_modules = nn.ModuleList(modules)
def forward(self, x, time_cond, *args, **kwargs):
mask = kwargs['mask']
modules = self.all_modules
m_idx = 0
# Sinusoidal positional embeddings
timesteps = time_cond
temb = layers.get_timestep_embedding(timesteps, self.nf)
# time embedding
temb = modules[m_idx](temb) # [32, 512]
m_idx += 1
temb = modules[m_idx](self.act(temb)) # [32, 512]
m_idx += 1
if self.size_cond:
with torch.no_grad():
node_mask = utils.mask_adj2node(mask.squeeze(1)) # [B, N]
num_node = torch.sum(node_mask, dim=-1) # [B]
num_node = self.size_onehot(num_node.to(torch.long)).to(torch.float)
num_node_emb = modules[m_idx](num_node)
m_idx += 1
num_node_emb = modules[m_idx](self.act(num_node_emb))
m_idx += 1
temb = temb + num_node_emb
if not self.config.data.centered:
# rescale the input data to [-1, 1]
x = x * 2. - 1.
with torch.no_grad():
# continuous-valued graph adjacency matrices
cont_adj = ((x + 1.) / 2.).clone()
cont_adj = (cont_adj * mask).squeeze(1) # [B, N, N]
cont_adj = cont_adj.clamp(min=0., max=1.)
if self.edge_th > 0.:
cont_adj[cont_adj < self.edge_th] = 0.
# discretized graph adjacency matrices
adj = x.squeeze(1).clone() # [B, N, N]
adj[adj >= 0.] = 1.
adj[adj < 0.] = 0.
adj = adj * mask.squeeze(1)
# extract RWSE and Shortest-Path Distance
x_pos, spd_onehot = utils.get_rw_feat(self.rw_depth, adj)
# x_pos: [32, 20, 16], spd_onehot: [32, 17, 20, 20]
# edge [B, N, N, F]
dense_edge_ori = modules[m_idx](x).permute(0, 2, 3, 1) # [32, 20, 20, 64]
m_idx += 1
dense_edge_spd = modules[m_idx](spd_onehot).permute(0, 2, 3, 1) # [32, 20, 20, 64]
m_idx += 1
# Use Degree as node feature
x_degree = torch.sum(cont_adj, dim=-1) # [B, N] # [32, 20]
x_degree = x_degree.clamp(max=float(self.degree_max)) # [B, N] # [32, 20]
x_degree = self.degree_onehot(x_degree.to(torch.long)).to(torch.float) # [B, N, max_node] # [32, 20, 11]
x_degree = modules[m_idx](x_degree) # projection layer [B, N, nf] # [32, 20, 128]
m_idx += 1
import pdb; pdb.set_trace()
# pos encoding
# x_pos: [32, 20, 16]
x_pos = modules[m_idx](x_pos) # [32, 20, 64]
m_idx += 1
# Dense to sparse node [BxN, -1]
x_degree = x_degree.reshape(-1, self.x_ch) # [640, 128]
x_pos = x_pos.reshape(-1, self.pos_ch) # [640, 64]
dense_index = cont_adj.nonzero(as_tuple=True)
edge_index, _ = dense_to_sparse(cont_adj) # [2, 5386]
# Run GNN layers
h_dense_edge = modules[m_idx](x_degree, x_pos, edge_index, dense_edge_ori, dense_edge_spd, dense_index, temb)
m_idx += 1
import pdb; pdb.set_trace()
# Output
h = self.act(modules[m_idx](self.act(h_dense_edge)))
m_idx += 1
import pdb; pdb.set_trace()
h = modules[m_idx](h)
m_idx += 1
import pdb; pdb.set_trace()
# make edge estimation symmetric
h = (h + h.transpose(2, 3)) / 2. * mask
import pdb; pdb.set_trace()
assert m_idx == len(modules)
return h

View File

@@ -0,0 +1,27 @@
import torch
import torch.nn as nn
import torch.optim as optim
from . import utils
@utils.register_model(name='MLPRegressor')
class MLPRegressor(nn.Module):
# def __init__(self, input_size, hidden_size, output_size):
def __init__(self, config):
super().__init__()
input_size = int(config.data.max_node * config.data.n_vocab)
hidden_size = config.model.hidden_size
output_size = config.model.output_size
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, hidden_size)
self.fc4 = nn.Linear(hidden_size, output_size)
self.activation = nn.ReLU()
def forward(self, X, time_cond, maskX):
x = X.view(X.size(0), -1)
x = self.activation(self.fc1(x))
x= self.activation(self.fc2(x))
x= self.activation(self.fc3(x))
x= self.fc4(x)
return x

View File

@@ -0,0 +1,38 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from .setenc_modules import *
class SetPool(nn.Module):
def __init__(self, dim_input, num_outputs, dim_output,
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
super(SetPool, self).__init__()
if 'sab' in mode: # [32, 400, 128]
self.enc = nn.Sequential(
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
else: # [32, 400, 128]
self.enc = nn.Sequential(
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
if 'PF' in mode: # [32, 1, 501]
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
nn.Linear(dim_hidden, dim_output))
elif 'P' in mode:
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
else: # torch.Size([32, 1, 501])
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
nn.Linear(dim_hidden, dim_output))
# "", sm, sab, sabsm
def forward(self, X):
x1 = self.enc(X)
x2 = self.dec(x1)
return x2

View File

@@ -0,0 +1,67 @@
#####################################################################################
# Copyright (c) Juho Lee SetTransformer, ICML 2019 [GitHub set_transformer]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MAB(nn.Module):
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
super(MAB, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads
self.fc_q = nn.Linear(dim_Q, dim_V)
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V)
def forward(self, Q, K):
Q = self.fc_q(Q)
K, V = self.fc_k(K), self.fc_v(K)
dim_split = self.dim_V // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), 0)
K_ = torch.cat(K.split(dim_split, 2), 0)
V_ = torch.cat(V.split(dim_split, 2), 0)
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
return O
class SAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, ln=False):
super(SAB, self).__init__()
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
def forward(self, X):
return self.mab(X, X)
class ISAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
super(ISAB, self).__init__()
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
nn.init.xavier_uniform_(self.I)
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
def forward(self, X):
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
return self.mab1(X, H)
class PMA(nn.Module):
def __init__(self, dim, num_heads, num_seeds, ln=False):
super(PMA, self).__init__()
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
nn.init.xavier_uniform_(self.S)
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
def forward(self, X):
return self.mab(self.S.repeat(X.size(0), 1, 1), X)

View File

@@ -0,0 +1,144 @@
import math
from typing import Union, Tuple, Optional
from torch_geometric.typing import PairTensor, Adj, OptTensor
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Linear
from torch_scatter import scatter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
import numpy as np
class PosTransLayer(MessagePassing):
"""Involving the edge feature and updating position feature. Multiply Msg."""
_alpha: OptTensor
def __init__(self, x_channels: int, pos_channels: int, out_channels: int,
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
bias: bool = True, act=None, attn_clamp: bool = False, **kwargs):
kwargs.setdefault('aggr', 'add')
super(PosTransLayer, self).__init__(node_dim=0, **kwargs)
self.x_channels = x_channels
self.pos_channels = pos_channels
self.in_channels = in_channels = x_channels + pos_channels
self.out_channels = out_channels
self.heads = heads
self.dropout = dropout
self.edge_dim = edge_dim
self.attn_clamp = attn_clamp
if act is None:
self.act = nn.LeakyReLU(negative_slope=0.2)
else:
self.act = act
self.lin_key = Linear(in_channels, heads * out_channels)
self.lin_query = Linear(in_channels, heads * out_channels)
self.lin_value = Linear(in_channels, heads * out_channels)
self.lin_edge0 = Linear(edge_dim, heads * out_channels, bias=False)
self.lin_edge1 = Linear(edge_dim, heads * out_channels, bias=False)
self.lin_pos = Linear(heads * out_channels, pos_channels, bias=False)
self.lin_skip = Linear(x_channels, heads * out_channels, bias=bias)
self.norm1 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
num_channels=heads * out_channels, eps=1e-6)
self.norm2 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
num_channels=heads * out_channels, eps=1e-6)
# FFN
self.FFN = nn.Sequential(Linear(heads * out_channels, heads * out_channels),
self.act,
Linear(heads * out_channels, heads * out_channels))
self.reset_parameters()
def reset_parameters(self):
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
self.lin_skip.reset_parameters()
self.lin_edge0.reset_parameters()
self.lin_edge1.reset_parameters()
self.lin_pos.reset_parameters()
def forward(self, x: OptTensor,
pos: Tensor,
edge_index: Adj,
edge_attr: OptTensor = None
) -> Tuple[Tensor, Tensor]:
""""""
H, C = self.heads, self.out_channels
x_feat = torch.cat([x, pos], -1)
query = self.lin_query(x_feat).view(-1, H, C)
key = self.lin_key(x_feat).view(-1, H, C)
value = self.lin_value(x_feat).view(-1, H, C)
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out_x, out_pos = self.propagate(edge_index, query=query, key=key, value=value, pos=pos, edge_attr=edge_attr,
size=None)
out_x = out_x.view(-1, self.heads * self.out_channels)
# skip connection for x
x_r = self.lin_skip(x)
out_x = (out_x + x_r) / math.sqrt(2)
out_x = self.norm1(out_x)
# FFN
out_x = (out_x + self.FFN(out_x)) / math.sqrt(2)
out_x = self.norm2(out_x)
# skip connection for pos
out_pos = pos + torch.tanh(pos + out_pos)
return out_x, out_pos
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
pos_j: Tensor,
edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
edge_attn = self.lin_edge0(edge_attr).view(-1, self.heads, self.out_channels)
alpha = (query_i * key_j * edge_attn).sum(dim=-1) / math.sqrt(self.out_channels)
if self.attn_clamp:
alpha = alpha.clamp(min=-5., max=5.)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# node feature message
msg = value_j
msg = msg * self.lin_edge1(edge_attr).view(-1, self.heads, self.out_channels)
msg = msg * alpha.view(-1, self.heads, 1)
# node position message
pos_msg = pos_j * self.lin_pos(msg.reshape(-1, self.heads * self.out_channels))
return msg, pos_msg
def aggregate(self, inputs: Tuple[Tensor, Tensor], index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[Tensor, Tensor]:
if ptr is not None:
raise NotImplementedError("Not implement Ptr in aggregate")
else:
return (scatter(inputs[0], index, 0, dim_size=dim_size, reduce=self.aggr),
scatter(inputs[1], index, 0, dim_size=dim_size, reduce="mean"))
def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
return inputs
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)

248
MobileNetV3/models/transformer.py Executable file
View File

@@ -0,0 +1,248 @@
from copy import deepcopy as cp
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def clones(module, N):
return nn.ModuleList([cp(module) for _ in range(N)])
def attention(query, key, value, mask = None, dropout = None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim = -1)
if dropout is not None:
attn = dropout(attn)
return torch.matmul(attn, value), attn
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super(MultiHeadAttention, self).__init__()
self.d_model = config.d_model
self.n_head = config.n_head
self.d_k = config.d_model // config.n_head
self.linears = clones(nn.Linear(self.d_model, self.d_model), 4)
self.dropout = nn.Dropout(p=config.dropout)
def forward(self, query, key, value, mask = None):
if mask is not None:
mask = mask.unsqueeze(1)
batch_size = query.size(0)
query, key , value = [l(x).view(batch_size, -1, self.n_head, self.d_k).transpose(1,2) for l, x in zip(self.linears, (query, key, value))]
x, attn = attention(query, key, value, mask = mask, dropout = self.dropout)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)
return self.linears[3](x), attn
class PositionwiseFeedForward(nn.Module):
def __init__(self, config):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(config.d_model, config.d_ff)
self.w_2 = nn.Linear(config.d_ff, config.d_model)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class PositionwiseFeedForwardLast(nn.Module):
def __init__(self, config):
super(PositionwiseFeedForwardLast, self).__init__()
self.w_1 = nn.Linear(config.d_model, config.d_ff)
self.w_2 = nn.Linear(config.d_ff, config.n_vocab)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class SelfAttentionBlock(nn.Module):
def __init__(self, config):
super(SelfAttentionBlock, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.attn = MultiHeadAttention(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x, mask):
x_ = self.norm(x)
x_ , attn = self.attn(x_, x_, x_, mask)
return self.dropout(x_) + x, attn
class SourceAttentionBlock(nn.Module):
def __init__(self, config):
super(SourceAttentionBlock, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.attn = MultiHeadAttention(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x, m, mask):
x_ = self.norm(x)
x_, attn = self.attn(x_, m, m, mask)
return self.dropout(x_) + x, attn
class FeedForwardBlock(nn.Module):
def __init__(self, config):
super(FeedForwardBlock, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.feed_forward = PositionwiseFeedForward(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x):
x_ = self.norm(x)
x_ = self.feed_forward(x_)
return self.dropout(x_) + x
class FeedForwardBlockLast(nn.Module):
def __init__(self, config):
super(FeedForwardBlockLast, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.feed_forward = PositionwiseFeedForwardLast(config)
self.dropout = nn.Dropout(p = config.dropout)
# Only for the last layer
self.proj_fc = nn.Linear(config.d_model, config.n_vocab)
def forward(self, x):
x_ = self.norm(x)
x_ = self.feed_forward(x_)
# return self.dropout(x_) + x
return self.dropout(x_) + self.proj_fc(x)
class EncoderBlock(nn.Module):
def __init__(self, config):
super(EncoderBlock, self).__init__()
self.self_attn = SelfAttentionBlock(config)
self.feed_forward = FeedForwardBlock(config)
def forward(self, x, mask):
x, attn = self.self_attn(x, mask)
x = self.feed_forward(x)
return x, attn
class EncoderBlockLast(nn.Module):
def __init__(self, config):
super(EncoderBlockLast, self).__init__()
self.self_attn = SelfAttentionBlock(config)
self.feed_forward = FeedForwardBlockLast(config)
def forward(self, x, mask):
x, attn = self.self_attn(x, mask)
x = self.feed_forward(x)
return x, attn
class DecoderBlock(nn.Module):
def __init__(self, config):
super(DecoderBlock, self).__init__()
self.self_attn = SelfAttentionBlock(config)
self.src_attn = SourceAttentionBlock(config)
self.feed_forward = FeedForwardBlock(config)
def forward(self, x, m, src_mask, tgt_mask):
x, attn_tgt = self.self_attn(x, tgt_mask)
x, attn_src = self.src_attn(x, m, src_mask)
x = self.feed_forward(x)
return x, attn_src, attn_tgt
class Encoder(nn.Module):
def __init__(self, config):
super(Encoder, self).__init__()
# self.layers = clones(EncoderBlock(config), config.n_layers - 1)
# self.layers.append(EncoderBlockLast(config))
# self.norms = clones(nn.LayerNorm(config.d_model), config.n_layers - 1)
# self.norms.append(nn.LayerNorm(config.n_vocab))
self.layers = clones(EncoderBlock(config), config.n_layers)
self.norms = clones(nn.LayerNorm(config.d_model), config.n_layers)
def forward(self, x, mask):
outputs = []
attns = []
for layer, norm in zip(self.layers, self.norms):
x, attn = layer(x, mask)
outputs.append(norm(x))
attns.append(attn)
return outputs[-1], outputs, attns
class PositionalEmbedding(nn.Module):
def __init__(self, config):
super(PositionalEmbedding, self).__init__()
p2e = torch.zeros(config.max_len, config.d_model)
position = torch.arange(0.0, config.max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0.0, config.d_model, 2) * (- math.log(10000.0) / config.d_model))
p2e[:, 0::2] = torch.sin(position * div_term)
p2e[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('p2e', p2e)
def forward(self, x):
shp = x.size()
with torch.no_grad():
emb = torch.index_select(self.p2e, 0, x.view(-1)).view(shp + (-1,))
return emb
class Transformer(nn.Module):
def __init__(self, config):
super(Transformer, self).__init__()
self.p2e = PositionalEmbedding(config)
self.encoder = Encoder(config)
def forward(self, input_emb, position_ids, attention_mask):
# position embedding projection
projection = self.p2e(position_ids) + input_emb
return self.encoder(projection, attention_mask)
class TokenTypeEmbedding(nn.Module):
def __init__(self, config):
super(TokenTypeEmbedding, self).__init__()
self.t2e = nn.Embedding(config.n_token_type, config.d_model)
self.d_model = config.d_model
def forward(self, x):
return self.t2e(x) * math.sqrt(self.d_model)
class SemanticEmbedding(nn.Module):
def __init__(self, config):
super(SemanticEmbedding, self).__init__()
# self.w2e = nn.Embedding(config.n_vocab, config.d_model)
self.d_model = config.d_model
self.fc = nn.Linear(config.n_vocab, config.d_model)
def forward(self, x):
# return self.w2e(x) * math.sqrt(self.d_model)
return self.fc(x) * math.sqrt(self.d_model)
class Embeddings(nn.Module):
def __init__(self, config):
super(Embeddings, self).__init__()
self.w2e = SemanticEmbedding(config)
self.p2e = PositionalEmbedding(config)
self.t2e = TokenTypeEmbedding(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, input_ids, position_ids = None, token_type_ids = None):
if position_ids is None:
batch_size, length = input_ids.size()
with torch.no_grad():
position_ids = torch.arange(0, length).repeat(batch_size, 1)
if torch.cuda.is_available():
position_ids = position_ids.cuda(device=input_ids.device)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
embeddings = self.w2e(input_ids) + self.p2e(position_ids) + self.t2e(token_type_ids)
return self.dropout(embeddings)

301
MobileNetV3/models/utils.py Normal file
View File

@@ -0,0 +1,301 @@
import torch
import torch.nn.functional as F
import sde_lib
import numpy as np
_MODELS = {}
def register_model(cls=None, *, name=None):
"""A decorator for registering model classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _MODELS:
raise ValueError(
f'Already registered model with name: {local_name}')
_MODELS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def get_model(name):
return _MODELS[name]
def create_model(config):
"""Create the score model."""
model_name = config.model.name
score_model = get_model(model_name)(config)
score_model = score_model.to(config.device)
if 'load_pretrained' in config['training'].keys() and config.training.load_pretrained:
from utils import restore_checkpoint_partial
score_model = restore_checkpoint_partial(score_model, torch.load(config.training.pretrained_model_path, map_location=config.device)['model'])
# score_model = torch.nn.DataParallel(score_model)
return score_model
def get_model_fn(model, train=False):
"""Create a function to give the output of the score-based model.
Args:
model: The score model.
train: `True` for training and `False` for evaluation.
Returns:
A model function.
"""
def model_fn(x, labels, *args, **kwargs):
"""Compute the output of the score-based model.
Args:
x: A mini-batch of input data (Adjacency matrices).
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
for different models.
mask: Mask for adjacency matrices.
Returns:
A tuple of (model output, new mutable states)
"""
if not train:
model.eval()
return model(x, labels, *args, **kwargs)
else:
model.train()
return model(x, labels, *args, **kwargs)
return model_fn
def get_score_fn(sde, model, train=False, continuous=False):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A score model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
Returns:
A score function.
"""
model_fn = get_model_fn(model, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def score_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
score = model_fn(x, labels, *args, **kwargs)
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
score = model_fn(x, labels, *args, **kwargs)
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
labels.long()]
score = -score / std[:, None, None]
return score
elif isinstance(sde, sde_lib.VESDE):
def score_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
score = model_fn(x, labels, *args, **kwargs)
return score
else:
raise NotImplementedError(
f"SDE class {sde.__class__.__name__} not yet supported.")
return score_fn
def get_classifier_grad_fn(sde, classifier, train=False, continuous=False,
regress=True, labels='max'):
logit_fn = get_logit_fn(sde, classifier, train, continuous)
def classifier_grad_fn(x, t, *args, **kwargs):
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
if regress:
assert labels in ['max', 'min']
logit = logit_fn(x_in, t, *args, **kwargs)
prob = logit.sum()
else:
logit = logit_fn(x_in, t, *args, **kwargs)
# prob = torch.nn.functional.log_softmax(logit, dim=-1)[torch.arange(labels.shape[0]), labels].sum()
log_prob = F.log_softmax(logit, dim=-1)
prob = log_prob[range(len(logit)), labels.view(-1)].sum()
# prob.backward()
# classifier_grad = x_in.grad
classifier_grad = torch.autograd.grad(prob, x_in)[0]
return classifier_grad
return classifier_grad_fn
def get_logit_fn(sde, classifier, train=False, continuous=False):
classifier_fn = get_model_fn(classifier, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def logit_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
logit = classifier_fn(x, labels, *args, **kwargs)
# std = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
logit = classifier_fn(x, labels, *args, **kwargs)
# std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
# labels.long()]
# score = -score / std[:, None, None]
return logit
elif isinstance(sde, sde_lib.VESDE):
def logit_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
logit = classifier_fn(x, labels, *args, **kwargs)
return logit
return logit_fn
def get_predictor_fn(sde, model, train=False, continuous=False):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A predictor model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
Returns:
A score function.
"""
model_fn = get_model_fn(model, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def predictor_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
pred = model_fn(x, labels, *args, **kwargs)
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
pred = model_fn(x, labels, *args, **kwargs)
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
labels.long()]
# score = -score / std[:, None, None]
return pred
elif isinstance(sde, sde_lib.VESDE):
def predictor_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
pred = model_fn(x, labels, *args, **kwargs)
return pred
else:
raise NotImplementedError(
f"SDE class {sde.__class__.__name__} not yet supported.")
return predictor_fn
def to_flattened_numpy(x):
"""Flatten a torch tensor `x` and convert it to numpy."""
return x.detach().cpu().numpy().reshape((-1,))
def from_flattened_numpy(x, shape):
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
return torch.from_numpy(x.reshape(shape))
@torch.no_grad()
def mask_adj2node(adj_mask):
"""Convert batched adjacency mask matrices to batched node mask matrices.
Args:
adj_mask: [B, N, N] Batched adjacency mask matrices without self-loop edge.
Output:
node_mask: [B, N] Batched node mask matrices indicating the valid nodes.
"""
batch_size, max_num_nodes, _ = adj_mask.shape
node_mask = adj_mask[:, 0, :].clone()
node_mask[:, 0] = 1
return node_mask
@torch.no_grad()
def get_rw_feat(k_step, dense_adj):
"""Compute k_step Random Walk for given dense adjacency matrix."""
rw_list = []
deg = dense_adj.sum(-1, keepdims=True)
AD = dense_adj / (deg + 1e-8)
rw_list.append(AD)
for _ in range(k_step):
rw = torch.bmm(rw_list[-1], AD)
rw_list.append(rw)
rw_map = torch.stack(rw_list[1:], dim=1) # [B, k_step, N, N]
rw_landing = torch.diagonal(
rw_map, offset=0, dim1=2, dim2=3) # [B, k_step, N]
rw_landing = rw_landing.permute(0, 2, 1) # [B, N, rw_depth]
# get the shortest path distance indices
tmp_rw = rw_map.sort(dim=1)[0]
spd_ind = (tmp_rw <= 0).sum(dim=1) # [B, N, N]
spd_onehot = torch.nn.functional.one_hot(
spd_ind, num_classes=k_step+1).to(torch.float)
spd_onehot = spd_onehot.permute(0, 3, 1, 2) # [B, kstep, N, N]
return rw_landing, spd_onehot