first commit
This commit is contained in:
103
MobileNetV3/models/GDSS/scorenetx.py
Normal file
103
MobileNetV3/models/GDSS/scorenetx.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.GDSS.layers import DenseGCNConv, MLP
|
||||
from .graph_utils import mask_x, pow_tensor
|
||||
from .attention import AttentionLayer
|
||||
from .. import utils
|
||||
|
||||
@utils.register_model(name='ScoreNetworkX')
|
||||
class ScoreNetworkX(torch.nn.Module):
|
||||
|
||||
# def __init__(self, max_feat_num, depth, nhid):
|
||||
def __init__(self, config):
|
||||
|
||||
super(ScoreNetworkX, self).__init__()
|
||||
|
||||
self.nfeat = config.data.n_vocab
|
||||
self.depth = config.model.depth
|
||||
self.nhid = config.model.nhid
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for _ in range(self.depth):
|
||||
if _ == 0:
|
||||
self.layers.append(DenseGCNConv(self.nfeat, self.nhid))
|
||||
else:
|
||||
self.layers.append(DenseGCNConv(self.nhid, self.nhid))
|
||||
|
||||
self.fdim = self.nfeat + self.depth * self.nhid
|
||||
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=self.nfeat,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.activation = torch.tanh
|
||||
|
||||
def forward(self, x, time_cond, maskX, flags=None):
|
||||
|
||||
x_list = [x]
|
||||
for _ in range(self.depth):
|
||||
x = self.layers[_](x, maskX)
|
||||
x = self.activation(x)
|
||||
x_list.append(x)
|
||||
|
||||
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
|
||||
out_shape = (x.shape[0], x.shape[1], -1)
|
||||
x = self.final(xs).view(*out_shape)
|
||||
|
||||
x = mask_x(x, flags)
|
||||
return x
|
||||
|
||||
|
||||
@utils.register_model(name='ScoreNetworkX_GMH')
|
||||
class ScoreNetworkX_GMH(torch.nn.Module):
|
||||
# def __init__(self, max_feat_num, depth, nhid, num_linears,
|
||||
# c_init, c_hid, c_final, adim, num_heads=4, conv='GCN'):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.max_feat_num = config.data.n_vocab
|
||||
self.depth = config.model.depth
|
||||
self.nhid = config.model.nhid
|
||||
self.c_init = config.model.c_init
|
||||
self.c_hid = config.model.c_hid
|
||||
self.c_final = config.model.c_final
|
||||
self.num_linears = config.model.num_linears
|
||||
self.num_heads = config.model.num_heads
|
||||
self.conv = config.model.conv
|
||||
self.adim = config.model.adim
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for _ in range(self.depth):
|
||||
if _ == 0:
|
||||
self.layers.append(AttentionLayer(self.num_linears, self.max_feat_num,
|
||||
self.nhid, self.nhid, self.c_init,
|
||||
self.c_hid, self.num_heads, self.conv))
|
||||
elif _ == self.depth - 1:
|
||||
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
|
||||
self.nhid, self.c_hid,
|
||||
self.c_final, self.num_heads, self.conv))
|
||||
else:
|
||||
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
|
||||
self.nhid, self.c_hid,
|
||||
self.c_hid, self.num_heads, self.conv))
|
||||
|
||||
fdim = self.max_feat_num + self.depth * self.nhid
|
||||
self.final = MLP(num_layers=3, input_dim=fdim, hidden_dim=2*fdim, output_dim=self.max_feat_num,
|
||||
use_bn=False, activate_func=F.elu)
|
||||
|
||||
self.activation = torch.tanh
|
||||
|
||||
def forward(self, x, time_cond, maskX, flags=None):
|
||||
adjc = pow_tensor(maskX, self.c_init)
|
||||
|
||||
x_list = [x]
|
||||
for _ in range(self.depth):
|
||||
x, adjc = self.layers[_](x, adjc, flags)
|
||||
x = self.activation(x)
|
||||
x_list.append(x)
|
||||
|
||||
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
|
||||
out_shape = (x.shape[0], x.shape[1], -1)
|
||||
x = self.final(xs).view(*out_shape)
|
||||
x = mask_x(x, flags)
|
||||
|
||||
return x
|
Reference in New Issue
Block a user