Update LFNA

This commit is contained in:
D-X-Y
2021-05-12 20:32:50 +08:00
parent 06f4a1f1cf
commit 0b1ca45c44
8 changed files with 121 additions and 15 deletions

View File

@@ -1,7 +1,7 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 --epochs 100000 --meta_batch 16
# python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 --epochs 100000 --meta_batch 64
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
@@ -33,7 +33,7 @@ from lfna_models import HyperNet
def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
dynamic_env = env_info["dynamic_env"]
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
model = get_model(**model_kwargs)
criterion = torch.nn.MSELoss()
logger.log("There are {:} weights.".format(model.get_w_container().numel()))
@@ -72,7 +72,7 @@ def main(args):
)
limit_bar = float(iepoch + 1) / args.epochs * total_bar
limit_bar = min(max(0, int(limit_bar)), total_bar)
limit_bar = min(max(32, int(limit_bar)), total_bar)
losses = []
for ibatch in range(args.meta_batch):
cur_time = random.randint(0, limit_bar)