Upgrade spaces and add more tests
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 #
|
||||
##################################################
|
||||
# Use noise as prediction #
|
||||
##################################################
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
@@ -27,7 +29,11 @@ class NAIVE_V1(Model):
|
||||
self.d_feat = d_feat
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info("NAIVE-V1 parameters setting: d_feat={:}, seed={:}".format(self.d_feat, self.seed))
|
||||
self.logger.info(
|
||||
"NAIVE-V1 parameters setting: d_feat={:}, seed={:}".format(
|
||||
self.d_feat, self.seed
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
random.seed(self.seed)
|
||||
@@ -49,7 +55,9 @@ class NAIVE_V1(Model):
|
||||
|
||||
def model(self, x):
|
||||
num = len(x)
|
||||
return np.random.normal(loc=self._mean, scale=self._std, size=num).astype(x.dtype)
|
||||
return np.random.normal(loc=self._mean, scale=self._std, size=num).astype(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
def _prepare_dataset(df_data):
|
||||
@@ -71,9 +79,15 @@ class NAIVE_V1(Model):
|
||||
# df_train['feature']['CLOSE1'].values
|
||||
# train_dataset['features'][:, -1]
|
||||
masks = ~np.isnan(train_dataset["labels"])
|
||||
self._mean, self._std = np.mean(train_dataset["labels"][masks]), np.std(train_dataset["labels"][masks])
|
||||
train_mse_loss = self.mse(self.model(train_dataset["features"]), train_dataset["labels"])
|
||||
valid_mse_loss = self.mse(self.model(valid_dataset["features"]), valid_dataset["labels"])
|
||||
self._mean, self._std = np.mean(train_dataset["labels"][masks]), np.std(
|
||||
train_dataset["labels"][masks]
|
||||
)
|
||||
train_mse_loss = self.mse(
|
||||
self.model(train_dataset["features"]), train_dataset["labels"]
|
||||
)
|
||||
valid_mse_loss = self.mse(
|
||||
self.model(valid_dataset["features"]), valid_dataset["labels"]
|
||||
)
|
||||
self.logger.info("Training MSE loss: {:}".format(train_mse_loss))
|
||||
self.logger.info("Validation MSE loss: {:}".format(valid_mse_loss))
|
||||
self.fitted = True
|
||||
|
Reference in New Issue
Block a user