Update to accomendate last updates of qlib

This commit is contained in:
D-X-Y
2021-03-11 03:09:55 +00:00
parent 731bda649f
commit 58907a2387
5 changed files with 109 additions and 9 deletions

View File

@@ -22,13 +22,13 @@ if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from procedures.q_exps import update_gpu
from procedures.q_exps import update_market
from procedures.q_exps import run_exp
import qlib
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.utils import flatten_dict
from qlib.log import set_log_basic_config
def retrieve_configs():
@@ -49,6 +49,7 @@ def retrieve_configs():
alg2names["SFM"] = "workflow_config_sfm_Alpha360.yaml"
# DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis, https://arxiv.org/pdf/2010.01265.pdf
alg2names["DoubleE"] = "workflow_config_doubleensemble_Alpha360.yaml"
alg2names["TabNet"] = "workflow_config_TabNet_Alpha360.yaml"
# find the yaml paths
alg2paths = OrderedDict()
@@ -66,6 +67,7 @@ def main(xargs, exp_yaml):
with open(exp_yaml) as fp:
config = yaml.safe_load(fp)
config = update_market(config, xargs.market)
config = update_gpu(config, xargs.gpu)
qlib.init(**config.get("qlib_init"))
@@ -77,7 +79,7 @@ def main(xargs, exp_yaml):
for irun in range(xargs.times):
run_exp(
config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir
config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), '{:}-{:}'.format(xargs.save_dir, xargs.market)
)
@@ -87,6 +89,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser("Baselines")
parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.")
parser.add_argument("--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator.")
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.")
parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.")