Update to accomendate last updates of qlib
This commit is contained in:
@@ -22,13 +22,13 @@ if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from procedures.q_exps import update_gpu
|
||||
from procedures.q_exps import update_market
|
||||
from procedures.q_exps import run_exp
|
||||
|
||||
import qlib
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.log import set_log_basic_config
|
||||
|
||||
|
||||
def retrieve_configs():
|
||||
@@ -49,6 +49,7 @@ def retrieve_configs():
|
||||
alg2names["SFM"] = "workflow_config_sfm_Alpha360.yaml"
|
||||
# DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis, https://arxiv.org/pdf/2010.01265.pdf
|
||||
alg2names["DoubleE"] = "workflow_config_doubleensemble_Alpha360.yaml"
|
||||
alg2names["TabNet"] = "workflow_config_TabNet_Alpha360.yaml"
|
||||
|
||||
# find the yaml paths
|
||||
alg2paths = OrderedDict()
|
||||
@@ -66,6 +67,7 @@ def main(xargs, exp_yaml):
|
||||
|
||||
with open(exp_yaml) as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
config = update_market(config, xargs.market)
|
||||
config = update_gpu(config, xargs.gpu)
|
||||
|
||||
qlib.init(**config.get("qlib_init"))
|
||||
@@ -77,7 +79,7 @@ def main(xargs, exp_yaml):
|
||||
|
||||
for irun in range(xargs.times):
|
||||
run_exp(
|
||||
config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir
|
||||
config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), '{:}-{:}'.format(xargs.save_dir, xargs.market)
|
||||
)
|
||||
|
||||
|
||||
@@ -87,6 +89,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser("Baselines")
|
||||
parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.")
|
||||
parser.add_argument("--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator.")
|
||||
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
|
||||
parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.")
|
||||
parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.")
|
||||
|
Reference in New Issue
Block a user