Update MLAML
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
# python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5
|
||||
# python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16
|
||||
# python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32
|
||||
# python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32
|
||||
# python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5
|
||||
# python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5
|
||||
# python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5
|
||||
# python exps/GeMOSA/baselines/maml-ft.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
@@ -155,6 +155,8 @@ def main(args):
|
||||
allys = allys.view(-1)
|
||||
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
|
||||
future_container = maml.adapt(historical_x, historical_y)
|
||||
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = maml.predict(future_x, future_container)
|
||||
future_loss = maml.criterion(future_y_hat, future_y)
|
||||
meta_losses.append(future_loss)
|
||||
@@ -195,8 +197,6 @@ def main(args):
|
||||
train_results = train_metric.get_info()
|
||||
return train_results, future_container
|
||||
|
||||
train_results, future_container = finetune(0)
|
||||
|
||||
metric = metric_cls(True)
|
||||
per_timestamp_time, start_time = AverageMeter(), time.time()
|
||||
for idx, (future_time, (future_x, future_y)) in enumerate(test_env):
|
||||
@@ -212,7 +212,9 @@ def main(args):
|
||||
)
|
||||
|
||||
# build optimizer
|
||||
future_x.to(args.device), future_y.to(args.device)
|
||||
train_results, future_container = finetune(idx)
|
||||
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = maml.predict(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
metric(future_y_hat, future_y)
|
||||
@@ -237,7 +239,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./outputs/lfna-synthetic/use-maml-nft",
|
||||
default="./outputs/GeMOSA-synthetic/use-maml-ft",
|
||||
help="The checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@@ -155,6 +155,8 @@ def main(args):
|
||||
allys = allys.view(-1)
|
||||
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
|
||||
future_container = maml.adapt(historical_x, historical_y)
|
||||
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = maml.predict(future_x, future_container)
|
||||
future_loss = maml.criterion(future_y_hat, future_y)
|
||||
meta_losses.append(future_loss)
|
||||
@@ -212,7 +214,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# build optimizer
|
||||
future_x.to(args.device), future_y.to(args.device)
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = maml.predict(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
metric(future_y_hat, future_y)
|
||||
@@ -237,7 +239,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./outputs/lfna-synthetic/use-maml-nft",
|
||||
default="./outputs/GeMOSA-synthetic/use-maml-nft",
|
||||
help="The checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
Reference in New Issue
Block a user