Re-org debug codes
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
# python exps/LFNA/basic-same.py --env_version v1 --hidden_dim 16
|
||||
# python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim
|
||||
# python exps/LFNA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1
|
||||
# python exps/LFNA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
@@ -58,7 +58,6 @@ def main(args):
|
||||
# build model
|
||||
model = get_model(**model_kwargs)
|
||||
print(model)
|
||||
model.analyze_weights()
|
||||
# build optimizer
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
||||
criterion = torch.nn.MSELoss()
|
||||
@@ -85,6 +84,7 @@ def main(args):
|
||||
best_loss = loss.item()
|
||||
best_param = copy.deepcopy(model.state_dict())
|
||||
model.load_state_dict(best_param)
|
||||
model.analyze_weights()
|
||||
with torch.no_grad():
|
||||
train_metric(preds, historical_y)
|
||||
train_results = train_metric.get_info()
|
||||
|
Reference in New Issue
Block a user