Update LFNA
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
# python exps/LFNA/basic-prev.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1
|
||||
# python exps/LFNA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1
|
||||
# python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
@@ -41,7 +41,7 @@ def main(args):
|
||||
w_container_per_epoch = dict()
|
||||
|
||||
per_timestamp_time, start_time = AverageMeter(), time.time()
|
||||
for idx in range(1, env_info["total"]):
|
||||
for idx in range(args.prev_time, env_info["total"]):
|
||||
|
||||
need_time = "Time Left: {:}".format(
|
||||
convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
|
||||
@@ -53,8 +53,8 @@ def main(args):
|
||||
+ need_time
|
||||
)
|
||||
# train the same data
|
||||
historical_x = env_info["{:}-x".format(idx - 1)]
|
||||
historical_y = env_info["{:}-y".format(idx - 1)]
|
||||
historical_x = env_info["{:}-x".format(idx - args.prev_time)]
|
||||
historical_y = env_info["{:}-y".format(idx - args.prev_time)]
|
||||
# build model
|
||||
model = get_model(**model_kwargs)
|
||||
print(model)
|
||||
@@ -160,6 +160,12 @@ if __name__ == "__main__":
|
||||
default=0.1,
|
||||
help="The initial learning rate for the optimizer (default is Adam)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prev_time",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The gap between prev_time and current_timestamp",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
@@ -184,7 +190,12 @@ if __name__ == "__main__":
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||
args.save_dir = "{:}-{:}-d{:}".format(
|
||||
args.save_dir, args.env_version, args.hidden_dim
|
||||
args.save_dir = "{:}-d{:}_e{:}_lr{:}-prev{:}-env{:}".format(
|
||||
args.save_dir,
|
||||
args.hidden_dim,
|
||||
args.epochs,
|
||||
args.init_lr,
|
||||
args.prev_time,
|
||||
args.env_version,
|
||||
)
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user