Move str2bool to config_utils

This commit is contained in:
D-X-Y
2021-03-30 09:17:05 +00:00
parent 9fc2c991f5
commit c2270fd153
16 changed files with 519 additions and 305 deletions

View File

@@ -15,7 +15,7 @@
# python exps/trading/baselines.py --alg TabNet #
# #
# python exps/trading/baselines.py --alg Transformer#
# python exps/trading/baselines.py --alg TSF
# python exps/trading/baselines.py --alg TSF
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0
#####################################################
import sys
@@ -30,6 +30,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import arg_str2bool
from procedures.q_exps import update_gpu
from procedures.q_exps import update_market
from procedures.q_exps import run_exp
@@ -182,6 +183,12 @@ if __name__ == "__main__":
help="The market indicator.",
)
parser.add_argument("--times", type=int, default=5, help="The repeated run times.")
parser.add_argument(
"--shared_dataset",
type=arg_str2bool,
default=False,
help="Whether to share the dataset for all algorithms?",
)
parser.add_argument(
"--gpu", type=int, default=0, help="The GPU ID used for train / test."
)
@@ -189,9 +196,13 @@ if __name__ == "__main__":
"--alg",
type=str,
choices=list(alg2configs.keys()),
nargs="+",
required=True,
help="The algorithm name.",
help="The algorithm name(s).",
)
args = parser.parse_args()
main(args, alg2configs[args.alg])
if len(args.alg) == 1:
main(args, alg2configs[args.alg[0]])
else:
print("-")