Update MLAML

This commit is contained in:
D-X-Y
2021-05-27 17:41:32 +00:00
parent c6db1ef65a
commit 9af34ea94d
3 changed files with 21 additions and 27 deletions

View File

@@ -1,10 +1,10 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5
# python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16
# python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32
# python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32
# python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5
# python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5
# python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5
# python exps/GeMOSA/baselines/maml-ft.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
@@ -155,6 +155,8 @@ def main(args):
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = maml.criterion(future_y_hat, future_y)
meta_losses.append(future_loss)
@@ -195,8 +197,6 @@ def main(args):
train_results = train_metric.get_info()
return train_results, future_container
train_results, future_container = finetune(0)
metric = metric_cls(True)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(test_env):
@@ -212,7 +212,9 @@ def main(args):
)
# build optimizer
future_x.to(args.device), future_y.to(args.device)
train_results, future_container = finetune(idx)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
@@ -237,7 +239,7 @@ if __name__ == "__main__":
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-maml-nft",
default="./outputs/GeMOSA-synthetic/use-maml-ft",
help="The checkpoint directory.",
)
parser.add_argument(

View File

@@ -155,6 +155,8 @@ def main(args):
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = maml.criterion(future_y_hat, future_y)
meta_losses.append(future_loss)
@@ -212,7 +214,7 @@ def main(args):
)
# build optimizer
future_x.to(args.device), future_y.to(args.device)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
@@ -237,7 +239,7 @@ if __name__ == "__main__":
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-maml-nft",
default="./outputs/GeMOSA-synthetic/use-maml-nft",
help="The checkpoint directory.",
)
parser.add_argument(