Update synthetic

This commit is contained in:
D-X-Y
2021-05-09 23:36:55 +08:00
parent 9168c62855
commit 6e7b1c551f
2 changed files with 15 additions and 6 deletions

View File

@@ -1,7 +1,7 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-his.py --srange 1-999
# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
@@ -36,12 +36,14 @@ def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
cache_path = (logger.path(None) / ".." / "env-info.pth").resolve()
cache_path = (
logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version)
).resolve()
if cache_path.exists():
env_info = torch.load(cache_path)
else:
env_info = dict()
dynamic_env = get_synthetic_env()
dynamic_env = get_synthetic_env(version=args.env_version)
env_info["total"] = len(dynamic_env)
for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)):
env_info["{:}-timestamp".format(idx)] = timestamp
@@ -169,6 +171,12 @@ if __name__ == "__main__":
default="./outputs/lfna-synthetic/use-all-past-data",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--init_lr",
type=float,
@@ -202,4 +210,5 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-{:}".format(args.save_dir, args.env_version)
main(args)