Move str2bool to config_utils
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
# python exps/trading/baselines.py --alg TabNet #
|
||||
# #
|
||||
# python exps/trading/baselines.py --alg Transformer#
|
||||
# python exps/trading/baselines.py --alg TSF
|
||||
# python exps/trading/baselines.py --alg TSF
|
||||
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0
|
||||
#####################################################
|
||||
import sys
|
||||
@@ -30,6 +30,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from config_utils import arg_str2bool
|
||||
from procedures.q_exps import update_gpu
|
||||
from procedures.q_exps import update_market
|
||||
from procedures.q_exps import run_exp
|
||||
@@ -182,6 +183,12 @@ if __name__ == "__main__":
|
||||
help="The market indicator.",
|
||||
)
|
||||
parser.add_argument("--times", type=int, default=5, help="The repeated run times.")
|
||||
parser.add_argument(
|
||||
"--shared_dataset",
|
||||
type=arg_str2bool,
|
||||
default=False,
|
||||
help="Whether to share the dataset for all algorithms?",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu", type=int, default=0, help="The GPU ID used for train / test."
|
||||
)
|
||||
@@ -189,9 +196,13 @@ if __name__ == "__main__":
|
||||
"--alg",
|
||||
type=str,
|
||||
choices=list(alg2configs.keys()),
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="The algorithm name.",
|
||||
help="The algorithm name(s).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args, alg2configs[args.alg])
|
||||
if len(args.alg) == 1:
|
||||
main(args, alg2configs[args.alg[0]])
|
||||
else:
|
||||
print("-")
|
||||
|
@@ -15,6 +15,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from config_utils import arg_str2bool
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
@@ -184,16 +185,6 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser("Show Results")
|
||||
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
elif v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
@@ -203,7 +194,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=str2bool,
|
||||
type=arg_str2bool,
|
||||
default=False,
|
||||
help="Print detailed log information or not.",
|
||||
)
|
||||
@@ -228,7 +219,7 @@ if __name__ == "__main__":
|
||||
info_dict["heads"],
|
||||
info_dict["values"],
|
||||
info_dict["names"],
|
||||
space=14,
|
||||
space=18,
|
||||
verbose=True,
|
||||
sort_key=True,
|
||||
)
|
||||
|
Reference in New Issue
Block a user