Update Q workflow

This commit is contained in:
D-X-Y
2021-03-04 13:55:48 +00:00
parent e329b78cf4
commit 192c25eb42
2 changed files with 65 additions and 28 deletions

View File

@@ -1,12 +1,12 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
#####################################################
# Refer to:
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py
# python exps/trading/workflow_tt.py
#####################################################
import sys, site, argparse
import sys, argparse
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
@@ -15,19 +15,11 @@ if str(lib_dir) not in sys.path:
import qlib
from qlib.config import C
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.utils import flatten_dict
from qlib.log import set_log_basic_config
def main(xargs):
@@ -73,13 +65,51 @@ def main(xargs):
},
}
task = {"model": model_config, "dataset": dataset_config}
port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.strategy",
"kwargs": {
"topk": 50,
"n_drop": 5,
},
},
"backtest": {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": "SH000300",
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
},
}
record_config = [
{"class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict()},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
"kwargs": dict(ana_long_short=False, ann_scaler=252),
},
{
"class": "PortAnaRecord",
"module_path": "qlib.workflow.record_temp",
"kwargs": dict(config=port_analysis_config),
},
]
task = dict(model=model_config, dataset=dataset_config, record=record_config)
model = init_instance_by_config(model_config)
dataset = init_instance_by_config(dataset_config)
# start exp to train model
with R.start(experiment_name="train_tt_model"):
set_log_basic_config(R.get_recorder().root_uri / 'log.log')
model = init_instance_by_config(model_config)
dataset = init_instance_by_config(dataset_config)
R.log_params(**flatten_dict(task))
model.fit(dataset)
R.save_objects(trained_model=model)
@@ -87,14 +117,19 @@ def main(xargs):
# prediction
recorder = R.get_recorder()
print(recorder)
sr = SignalRecord(model, dataset, recorder)
sr.generate()
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()
for record in task["record"]:
record = record.copy()
if record["class"] == "SignalRecord":
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
record["kwargs"].update(srconf)
sr = init_instance_by_config(record)
sr.generate()
else:
rconf = {"recorder": recorder}
record["kwargs"].update(rconf)
ar = init_instance_by_config(record)
ar.generate()
if __name__ == "__main__":