Reformulate Q-Transformer
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
# Refer to:
|
||||
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb
|
||||
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py
|
||||
# python exps/trading/workflow_tt.py --market all --gpu 1
|
||||
# python exps/trading/workflow_tt.py --gpu 1 --market csi300
|
||||
#####################################################
|
||||
import sys, argparse
|
||||
from pathlib import Path
|
||||
@@ -63,7 +63,8 @@ def main(xargs):
|
||||
"class": "QuantTransformer",
|
||||
"module_path": "trade_models",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"net_config": None,
|
||||
"opt_config": None,
|
||||
"GPU": "0",
|
||||
"metric": "loss",
|
||||
},
|
||||
@@ -107,20 +108,23 @@ def main(xargs):
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market)
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
for irun in range(xargs.times):
|
||||
xmodel_config = model_config.copy()
|
||||
xmodel_config = update_gpu(xmodel_config, xags.gpu)
|
||||
task = dict(model=xmodel_config, dataset=dataset_config, record=record_config)
|
||||
run_exp(task_config, dataset, "Transformer", "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir)
|
||||
xmodel_config = update_gpu(xmodel_config, xargs.gpu)
|
||||
task_config = dict(model=xmodel_config, dataset=dataset_config, record=record_config)
|
||||
|
||||
run_exp(task_config, dataset, xargs.name, "recorder-{:02d}-{:02d}".format(irun, xargs.times), save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Vanilla Transformable Transformer")
|
||||
parser.add_argument("--save_dir", type=str, default="./outputs/tt-ml-runs", help="The checkpoint directory.")
|
||||
parser.add_argument("--save_dir", type=str, default="./outputs/vtt-runs", help="The checkpoint directory.")
|
||||
parser.add_argument("--name", type=str, default="Transformer", help="The experiment name.")
|
||||
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
|
||||
parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.")
|
||||
parser.add_argument("--market", type=str, default="csi300", help="The market indicator.")
|
||||
parser.add_argument("--market", type=str, default="all", help="The market indicator.")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user