Temp / 0.5

This commit is contained in:
D-X-Y
2021-03-05 13:50:30 +00:00
parent 2fa358fdf6
commit cc28e1589e
4 changed files with 35 additions and 10 deletions

View File

@@ -1,11 +1,15 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
#####################################################
import torch
import torch.nn as nn
import math
class PositionalEncoder(nn.Module):
# Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
# https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65
def __init__(self, d_model, max_seq_len):
def __init__(self, d_model, max_seq_len, dropout=0.1):
super(PositionalEncoder, self).__init__()
self.d_model = d_model
# create constant 'pe' matrix with values dependant on
@@ -26,4 +30,6 @@ class PositionalEncoder(nn.Module):
def forward(self, x):
batch, seq, fdim = x.shape[:3]
embeddings = self.pe[:, :seq, :fdim]
import pdb; pdb.set_trace()
outs = self.dropout(x + embeddings)
return x + embeddings