Update ablation for GeMOSA

This commit is contained in:
D-X-Y
2021-05-27 21:51:42 +08:00
parent ffc0d16d6c
commit 726dbef326
3 changed files with 16 additions and 4 deletions

View File

@@ -20,9 +20,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1):
SuperLinear(100, 1),
).to(device)
model.train()
optimizer = torch.optim.Adam(
model.parameters(), lr=max_lr, amsgrad=True
)
optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, amsgrad=True)
loss_func = torch.nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
@@ -47,7 +45,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1):
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
# print('loss={:}, best-loss={:}'.format(loss.item(), best_loss))
model.load_state_dict(best_param)
return model, loss_func, best_loss