This commit is contained in:
D-X-Y
2021-03-23 09:43:45 +00:00
parent dd27f15e27
commit 01397660de
5 changed files with 78 additions and 77 deletions

View File

@@ -1,75 +0,0 @@
import os
import sys
import ruamel.yaml as yaml
import pprint
import numpy as np
import pandas as pd
from pathlib import Path
import qlib
from qlib import config as qconfig
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha360",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi300",
"infer_processors": [
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": True}},
{"class": "Fillna", "kwargs": {"fields_group": "feature"}},
],
"learn_processors": [
{"class": "DropnaLabel"},
{"class": "CSRankNorm", "kwargs": {"fields_group": "label"}},
],
"label": ["Ref($close, -2) / Ref($close, -1) - 1"]
},
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
}
if __name__ == "__main__":
qlib_root_dir = (Path(__file__).parent / '..' / '..' / '.latent-data' / 'qlib').resolve()
demo_yaml_path = qlib_root_dir / 'examples' / 'benchmarks' / 'GRU' / 'workflow_config_gru_Alpha360.yaml'
print('Demo-workflow-yaml: {:}'.format(demo_yaml_path))
with open(demo_yaml_path, 'r') as fp:
config = yaml.safe_load(fp)
pprint.pprint(config['task']['dataset'])
dataset = init_instance_by_config(dataset_config)
pprint.pprint(dataset_config)
pprint.pprint(dataset)
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
import pdb
pdb.set_trace()
print("Complete")

75
notebooks/spaces/test.py Normal file
View File

@@ -0,0 +1,75 @@
import os
import sys
import qlib
import pprint
import numpy as np
import pandas as pd
from pathlib import Path
import torch
__file__ = os.path.dirname(os.path.realpath("__file__"))
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
print("library path: {:}".format(lib_dir))
assert lib_dir.exists(), "{:} does not exist".format(lib_dir)
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from trade_models import get_transformer
from qlib import config as qconfig
from qlib.utils import init_instance_by_config
from qlib.model.base import Model
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha360",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
},
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
}
pprint.pprint(dataset_config)
dataset = init_instance_by_config(dataset_config)
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
model = get_transformer(None)
print(model)
features = torch.from_numpy(df_train["feature"].values).float()
labels = torch.from_numpy(df_train["label"].values).squeeze().float()
batch = list(range(2000))
predicts = model(features[batch])
mask = ~torch.isnan(labels[batch])
pred = predicts[mask]
label = labels[batch][mask]
loss = torch.nn.functional.mse_loss(pred, label)
from sklearn.metrics import mean_squared_error
mse_loss = mean_squared_error(pred.numpy(), label.numpy())