Reformulate Q-Transformer

This commit is contained in:
D-X-Y
2021-03-06 21:35:26 -08:00
parent 53e1441c8d
commit c0481a2357
2 changed files with 370 additions and 375 deletions

View File

@@ -4,7 +4,7 @@
# Refer to:
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py
# python exps/trading/workflow_tt.py --market all --gpu 1
# python exps/trading/workflow_tt.py --gpu 1 --market csi300
#####################################################
import sys, argparse
from pathlib import Path
@@ -63,7 +63,8 @@ def main(xargs):
"class": "QuantTransformer",
"module_path": "trade_models",
"kwargs": {
"loss": "mse",
"net_config": None,
"opt_config": None,
"GPU": "0",
"metric": "loss",
},
@@ -107,20 +108,23 @@ def main(xargs):
provider_uri = "~/.qlib/qlib_data/cn_data"
qlib.init(provider_uri=provider_uri, region=REG_CN)
save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market)
dataset = init_instance_by_config(dataset_config)
for irun in range(xargs.times):
xmodel_config = model_config.copy()
xmodel_config = update_gpu(xmodel_config, xags.gpu)
task = dict(model=xmodel_config, dataset=dataset_config, record=record_config)
run_exp(task_config, dataset, "Transformer", "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir)
xmodel_config = update_gpu(xmodel_config, xargs.gpu)
task_config = dict(model=xmodel_config, dataset=dataset_config, record=record_config)
run_exp(task_config, dataset, xargs.name, "recorder-{:02d}-{:02d}".format(irun, xargs.times), save_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Vanilla Transformable Transformer")
parser.add_argument("--save_dir", type=str, default="./outputs/tt-ml-runs", help="The checkpoint directory.")
parser.add_argument("--save_dir", type=str, default="./outputs/vtt-runs", help="The checkpoint directory.")
parser.add_argument("--name", type=str, default="Transformer", help="The experiment name.")
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.")
parser.add_argument("--market", type=str, default="csi300", help="The market indicator.")
parser.add_argument("--market", type=str, default="all", help="The market indicator.")
args = parser.parse_args()
main(args)