Update LFNA

This commit is contained in:
D-X-Y
2021-05-15 16:01:40 +08:00
parent b81ef2dd74
commit 72f240bf0a
12 changed files with 128 additions and 1050 deletions

View File

@@ -1,7 +1,7 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-prev.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1
# python exps/LFNA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1
# python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
#####################################################
import sys, time, copy, torch, random, argparse
@@ -41,7 +41,7 @@ def main(args):
w_container_per_epoch = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx in range(1, env_info["total"]):
for idx in range(args.prev_time, env_info["total"]):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
@@ -53,8 +53,8 @@ def main(args):
+ need_time
)
# train the same data
historical_x = env_info["{:}-x".format(idx - 1)]
historical_y = env_info["{:}-y".format(idx - 1)]
historical_x = env_info["{:}-x".format(idx - args.prev_time)]
historical_y = env_info["{:}-y".format(idx - args.prev_time)]
# build model
model = get_model(**model_kwargs)
print(model)
@@ -160,6 +160,12 @@ if __name__ == "__main__":
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--prev_time",
type=int,
default=5,
help="The gap between prev_time and current_timestamp",
)
parser.add_argument(
"--batch_size",
type=int,
@@ -184,7 +190,12 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-{:}-d{:}".format(
args.save_dir, args.env_version, args.hidden_dim
args.save_dir = "{:}-d{:}_e{:}_lr{:}-prev{:}-env{:}".format(
args.save_dir,
args.hidden_dim,
args.epochs,
args.init_lr,
args.prev_time,
args.env_version,
)
main(args)