Prototype MAML

This commit is contained in:
D-X-Y
2021-05-10 01:02:38 +08:00
parent 6e7b1c551f
commit cbd2afb4ef
14 changed files with 1497 additions and 702 deletions

View File

@@ -1,7 +1,8 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-same.py --srange 1-999
# python exps/LFNA/basic-same.py --srange 1-999 --env_version v1 --hidden_dim 16
# python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
@@ -22,6 +23,8 @@ from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from datasets.synthetic_core import get_synthetic_env
from models.xcore import get_model
from lfna_utils import lfna_setup
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
@@ -33,22 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000):
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
cache_path = (logger.path(None) / ".." / "env-info.pth").resolve()
if cache_path.exists():
env_info = torch.load(cache_path)
else:
env_info = dict()
dynamic_env = get_synthetic_env()
env_info["total"] = len(dynamic_env)
for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)):
env_info["{:}-timestamp".format(idx)] = timestamp
env_info["{:}-x".format(idx)] = _allx
env_info["{:}-y".format(idx)] = _ally
env_info["dynamic_env"] = dynamic_env
torch.save(env_info, cache_path)
logger, env_info, model_kwargs = lfna_setup(args)
# check indexes to be evaluated
to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None)
@@ -78,16 +66,6 @@ def main(args):
historical_x = env_info["{:}-x".format(idx)]
historical_y = env_info["{:}-y".format(idx)]
# build model
mean, std = historical_x.mean().item(), historical_x.std().item()
model_kwargs = dict(
input_dim=1,
output_dim=1,
act_cls="leaky_relu",
norm_cls="identity",
# norm_cls="simple_norm",
# mean=mean,
# std=std,
)
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
@@ -151,9 +129,9 @@ def main(args):
logger,
)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch},
logger.path(None) / "final-ckp.pth",
@@ -172,6 +150,18 @@ if __name__ == "__main__":
default="./outputs/lfna-synthetic/use-same-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
@@ -205,4 +195,7 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-{:}-d{:}".format(
args.save_dir, args.env_version, args.hidden_dim
)
main(args)