Update Q workflow

This commit is contained in:
D-X-Y
2021-03-04 13:55:48 +00:00
parent e329b78cf4
commit 192c25eb42
2 changed files with 65 additions and 28 deletions

View File

@@ -41,8 +41,8 @@ class QuantTransformer(Model):
def __init__(
self,
d_feat=6,
hidden_size=64,
num_layers=2,
hidden_size=48,
depth=5,
dropout=0.0,
n_epochs=200,
lr=0.001,
@@ -62,7 +62,7 @@ class QuantTransformer(Model):
# set hyper-parameters.
self.d_feat = d_feat
self.hidden_size = hidden_size
self.num_layers = num_layers
self.depth = depth
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
@@ -79,7 +79,7 @@ class QuantTransformer(Model):
"Transformer parameters setting:"
"\nd_feat : {}"
"\nhidden_size : {}"
"\nnum_layers : {}"
"\ndepth : {}"
"\ndropout : {}"
"\nn_epochs : {}"
"\nlr : {}"
@@ -93,7 +93,7 @@ class QuantTransformer(Model):
"\nseed : {}".format(
d_feat,
hidden_size,
num_layers,
depth,
dropout,
n_epochs,
lr,
@@ -112,7 +112,9 @@ class QuantTransformer(Model):
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.model = TransformerModel(d_feat=self.d_feat)
self.model = TransformerModel(d_feat=self.d_feat,
embed_dim=self.hidden_size,
depth=self.depth)
self.logger.info('model: {:}'.format(self.model))
self.logger.info('model size: {:.3f} MB'.format(count_parameters_in_MB(self.model)))