Allow different timeframes
This commit is contained in:
@@ -13,13 +13,15 @@
|
||||
# python exps/trading/baselines.py --alg LightGBM #
|
||||
# python exps/trading/baselines.py --alg DoubleE #
|
||||
# python exps/trading/baselines.py --alg TabNet #
|
||||
# #
|
||||
# python exps/trading/baselines.py --alg Transformer#
|
||||
# #############################
|
||||
# python exps/trading/baselines.py --alg Transformer
|
||||
# python exps/trading/baselines.py --alg TSF
|
||||
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0
|
||||
#####################################################
|
||||
# python exps/trading/baselines.py --alg TSF-2x24-drop0_0 --market csi300
|
||||
# python exps/trading/baselines.py --alg TSF-6x32-drop0_0 --market csi300
|
||||
#################################################################################
|
||||
import sys
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
@@ -60,7 +62,7 @@ def to_layer(config, embed_dim, depth):
|
||||
|
||||
def extend_transformer_settings(alg2configs, name):
|
||||
config = copy.deepcopy(alg2configs[name])
|
||||
for i in range(1, 8):
|
||||
for i in range(1, 9):
|
||||
for j in (6, 12, 24, 32, 48, 64):
|
||||
for k1 in (0, 0.1, 0.2):
|
||||
for k2 in (0, 0.1):
|
||||
@@ -70,6 +72,31 @@ def extend_transformer_settings(alg2configs, name):
|
||||
return alg2configs
|
||||
|
||||
|
||||
def replace_start_time(config, start_time):
|
||||
config = copy.deepcopy(config)
|
||||
xtime = datetime.strptime(start_time, "%Y-%m-%d")
|
||||
config["data_handler_config"]["start_time"] = xtime.date()
|
||||
config["data_handler_config"]["fit_start_time"] = xtime.date()
|
||||
config["task"]["dataset"]["kwargs"]["segments"]["train"][0] = xtime.date()
|
||||
return config
|
||||
|
||||
|
||||
def extend_train_data(alg2configs, name):
|
||||
config = copy.deepcopy(alg2configs[name])
|
||||
start_times = (
|
||||
"2008-01-01",
|
||||
"2009-01-01",
|
||||
"2010-01-01",
|
||||
"2011-01-01",
|
||||
"2012-01-01",
|
||||
"2013-01-01",
|
||||
)
|
||||
for start_time in start_times:
|
||||
config = replace_start_time(config, start_time)
|
||||
alg2configs[name + "s{:}".format(start_time)] = config
|
||||
return alg2configs
|
||||
|
||||
|
||||
def refresh_record(alg2configs):
|
||||
alg2configs = copy.deepcopy(alg2configs)
|
||||
for key, config in alg2configs.items():
|
||||
@@ -133,6 +160,9 @@ def retrieve_configs():
|
||||
)
|
||||
alg2configs = extend_transformer_settings(alg2configs, "TSF")
|
||||
alg2configs = refresh_record(alg2configs)
|
||||
# extend the algorithms by different train-data
|
||||
for name in ("TSF-2x24-drop0_0", "TSF-6x32-drop0_0"):
|
||||
alg2configs = extend_train_data(alg2configs, name)
|
||||
print(
|
||||
"There are {:} algorithms : {:}".format(
|
||||
len(alg2configs), list(alg2configs.keys())
|
||||
|
Reference in New Issue
Block a user