Update LFNA
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
# python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 --epochs 100000 --meta_batch 16
|
||||
# python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 --epochs 100000 --meta_batch 64
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
@@ -33,7 +33,7 @@ from lfna_models import HyperNet
|
||||
def main(args):
|
||||
logger, env_info, model_kwargs = lfna_setup(args)
|
||||
dynamic_env = env_info["dynamic_env"]
|
||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||
model = get_model(**model_kwargs)
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
logger.log("There are {:} weights.".format(model.get_w_container().numel()))
|
||||
@@ -72,7 +72,7 @@ def main(args):
|
||||
)
|
||||
|
||||
limit_bar = float(iepoch + 1) / args.epochs * total_bar
|
||||
limit_bar = min(max(0, int(limit_bar)), total_bar)
|
||||
limit_bar = min(max(32, int(limit_bar)), total_bar)
|
||||
losses = []
|
||||
for ibatch in range(args.meta_batch):
|
||||
cur_time = random.randint(0, limit_bar)
|
||||
|
Reference in New Issue
Block a user