Move str2bool to config_utils

This commit is contained in:
D-X-Y
2021-03-30 09:17:05 +00:00
parent 9fc2c991f5
commit c2270fd153
16 changed files with 519 additions and 305 deletions

View File

@@ -15,7 +15,7 @@
# python exps/trading/baselines.py --alg TabNet #
# #
# python exps/trading/baselines.py --alg Transformer#
# python exps/trading/baselines.py --alg TSF
# python exps/trading/baselines.py --alg TSF
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0
#####################################################
import sys
@@ -30,6 +30,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import arg_str2bool
from procedures.q_exps import update_gpu
from procedures.q_exps import update_market
from procedures.q_exps import run_exp
@@ -182,6 +183,12 @@ if __name__ == "__main__":
help="The market indicator.",
)
parser.add_argument("--times", type=int, default=5, help="The repeated run times.")
parser.add_argument(
"--shared_dataset",
type=arg_str2bool,
default=False,
help="Whether to share the dataset for all algorithms?",
)
parser.add_argument(
"--gpu", type=int, default=0, help="The GPU ID used for train / test."
)
@@ -189,9 +196,13 @@ if __name__ == "__main__":
"--alg",
type=str,
choices=list(alg2configs.keys()),
nargs="+",
required=True,
help="The algorithm name.",
help="The algorithm name(s).",
)
args = parser.parse_args()
main(args, alg2configs[args.alg])
if len(args.alg) == 1:
main(args, alg2configs[args.alg[0]])
else:
print("-")

View File

@@ -15,6 +15,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import arg_str2bool
import qlib
from qlib.config import REG_CN
from qlib.workflow import R
@@ -184,16 +185,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser("Show Results")
def str2bool(v):
if isinstance(v, bool):
return v
elif v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser.add_argument(
"--save_dir",
type=str,
@@ -203,7 +194,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--verbose",
type=str2bool,
type=arg_str2bool,
default=False,
help="Print detailed log information or not.",
)
@@ -228,7 +219,7 @@ if __name__ == "__main__":
info_dict["heads"],
info_dict["values"],
info_dict["names"],
space=14,
space=18,
verbose=True,
sort_key=True,
)