Update Q workflow
This commit is contained in:
@@ -41,8 +41,8 @@ class QuantTransformer(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
hidden_size=48,
|
||||
depth=5,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
@@ -62,7 +62,7 @@ class QuantTransformer(Model):
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.depth = depth
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
@@ -79,7 +79,7 @@ class QuantTransformer(Model):
|
||||
"Transformer parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndepth : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
@@ -93,7 +93,7 @@ class QuantTransformer(Model):
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
depth,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
@@ -112,7 +112,9 @@ class QuantTransformer(Model):
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.model = TransformerModel(d_feat=self.d_feat)
|
||||
self.model = TransformerModel(d_feat=self.d_feat,
|
||||
embed_dim=self.hidden_size,
|
||||
depth=self.depth)
|
||||
self.logger.info('model: {:}'.format(self.model))
|
||||
self.logger.info('model size: {:.3f} MB'.format(count_parameters_in_MB(self.model)))
|
||||
|
||||
|
Reference in New Issue
Block a user