Fix bugs in LFNA
This commit is contained in:
@@ -101,10 +101,7 @@ def main(args):
|
||||
)
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
milestones=[
|
||||
int(args.epochs * 0.8),
|
||||
int(args.epochs * 0.9),
|
||||
],
|
||||
milestones=[int(args.epochs * 0.8), int(args.epochs * 0.9),],
|
||||
gamma=0.1,
|
||||
)
|
||||
logger.log("The base-model is\n{:}".format(base_model))
|
||||
@@ -166,7 +163,7 @@ def main(args):
|
||||
w_container_per_epoch = dict()
|
||||
for idx in range(args.seq_length, len(eval_env)):
|
||||
# build-timestamp
|
||||
future_time = env_info["{:}-timestamp".format(idx)]
|
||||
future_time = env_info["{:}-timestamp".format(idx)].item()
|
||||
time_seqs = []
|
||||
for iseq in range(args.seq_length):
|
||||
time_seqs.append(future_time - iseq * eval_env.timestamp_interval)
|
||||
@@ -190,7 +187,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# creating the new meta-time-embedding
|
||||
distance = meta_model.get_closest_meta_distance(future_time.item())
|
||||
distance = meta_model.get_closest_meta_distance(future_time)
|
||||
if distance < eval_env.timestamp_interval:
|
||||
continue
|
||||
#
|
||||
@@ -198,7 +195,9 @@ def main(args):
|
||||
optimizer = torch.optim.Adam(
|
||||
[new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True
|
||||
)
|
||||
meta_model.replace_append_learnt(torch.Tensor([future_time]).to(args.device), new_param)
|
||||
meta_model.replace_append_learnt(
|
||||
torch.Tensor([future_time], device=args.device), new_param
|
||||
)
|
||||
meta_model.eval()
|
||||
base_model.train()
|
||||
for iepoch in range(args.epochs):
|
||||
@@ -241,22 +240,13 @@ if __name__ == "__main__":
|
||||
help="The synthetic enviornment version.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden_dim",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The hidden dimension.",
|
||||
"--hidden_dim", type=int, default=16, help="The hidden dimension.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layer_dim",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The layer chunk dimension.",
|
||||
"--layer_dim", type=int, default=16, help="The layer chunk dimension.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time_dim",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The timestamp dimension.",
|
||||
"--time_dim", type=int, default=16, help="The timestamp dimension.",
|
||||
)
|
||||
#####
|
||||
parser.add_argument(
|
||||
@@ -272,10 +262,7 @@ if __name__ == "__main__":
|
||||
help="The weight decay for the optimizer (default is Adam)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--meta_batch",
|
||||
type=int,
|
||||
default=64,
|
||||
help="The batch size for the meta-model",
|
||||
"--meta_batch", type=int, default=64, help="The batch size for the meta-model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampler_enlarge",
|
||||
@@ -297,10 +284,7 @@ if __name__ == "__main__":
|
||||
"--workers", type=int, default=4, help="The number of workers in parallel."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="",
|
||||
"--device", type=str, default="cpu", help="",
|
||||
)
|
||||
# Random Seed
|
||||
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||
|
Reference in New Issue
Block a user