Re-org debug codes

This commit is contained in:
D-X-Y
2021-05-13 08:39:19 +00:00
parent 0138e71cf2
commit 17955123a0
4 changed files with 96 additions and 67 deletions

View File

@@ -1,8 +1,8 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-same.py --env_version v1 --hidden_dim 16
# python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim
# python exps/LFNA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1
# python exps/LFNA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
@@ -58,7 +58,6 @@ def main(args):
# build model
model = get_model(**model_kwargs)
print(model)
model.analyze_weights()
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
@@ -85,6 +84,7 @@ def main(args):
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()