Refine TT workflow
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
# Refer to:
|
||||
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb
|
||||
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py
|
||||
# python exps/trading/workflow_tt.py --market all
|
||||
# python exps/trading/workflow_tt.py --market all --gpu 1
|
||||
#####################################################
|
||||
import sys, argparse
|
||||
from pathlib import Path
|
||||
@@ -13,6 +13,10 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from procedures.q_exps import update_gpu
|
||||
from procedures.q_exps import update_market
|
||||
from procedures.q_exps import run_exp
|
||||
|
||||
import qlib
|
||||
from qlib.config import C
|
||||
from qlib.config import REG_CN
|
||||
@@ -100,44 +104,23 @@ def main(xargs):
|
||||
},
|
||||
]
|
||||
|
||||
task = dict(model=model_config, dataset=dataset_config, record=record_config)
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
# start exp to train model
|
||||
with R.start(experiment_name="tt_model", uri=xargs.save_dir + "-" + xargs.market):
|
||||
set_log_basic_config(R.get_recorder().root_uri / "log.log")
|
||||
|
||||
model = init_instance_by_config(model_config)
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
R.save_objects(trained_model=model)
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
print(recorder)
|
||||
|
||||
for record in task["record"]:
|
||||
record = record.copy()
|
||||
if record["class"] == "SignalRecord":
|
||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
record["kwargs"].update(srconf)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate()
|
||||
else:
|
||||
rconf = {"recorder": recorder}
|
||||
record["kwargs"].update(rconf)
|
||||
ar = init_instance_by_config(record)
|
||||
ar.generate()
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
for irun in range(xargs.times):
|
||||
xmodel_config = model_config.copy()
|
||||
xmodel_config = update_gpu(xmodel_config, xags.gpu)
|
||||
task = dict(model=xmodel_config, dataset=dataset_config, record=record_config)
|
||||
run_exp(task_config, dataset, "Transformer", "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Vanilla Transformable Transformer")
|
||||
parser.add_argument("--save_dir", type=str, default="./outputs/tt-ml-runs", help="The checkpoint directory.")
|
||||
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
|
||||
parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.")
|
||||
parser.add_argument("--market", type=str, default="csi300", help="The market indicator.")
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user