Reformulate via black

This commit is contained in:
D-X-Y
2021-03-17 09:25:58 +00:00
parent a9093e41e1
commit f98edea22a
59 changed files with 12289 additions and 8918 deletions

View File

@@ -72,7 +72,7 @@ def retrieve_configs():
def main(xargs, exp_yaml):
assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml)
pprint('Run {:}'.format(xargs.alg))
pprint("Run {:}".format(xargs.alg))
with open(exp_yaml) as fp:
config = yaml.safe_load(fp)
config = update_market(config, xargs.market)
@@ -87,7 +87,11 @@ def main(xargs, exp_yaml):
for irun in range(xargs.times):
run_exp(
config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), '{:}-{:}'.format(xargs.save_dir, xargs.market)
config.get("task"),
dataset,
xargs.alg,
"recorder-{:02d}-{:02d}".format(irun, xargs.times),
"{:}-{:}".format(xargs.save_dir, xargs.market),
)
@@ -97,7 +101,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser("Baselines")
parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.")
parser.add_argument("--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator.")
parser.add_argument(
"--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator."
)
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.")
parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.")

View File

@@ -179,4 +179,3 @@ if __name__ == "__main__":
all_info_dict.append(info_dict)
info_dict = QResult.merge_dict(all_info_dict)
compare_results(info_dict["heads"], info_dict["values"], info_dict["names"], space=15, verbose=True, sort_key=True)