Update ablation for GeMOSA
This commit is contained in:
@@ -20,9 +20,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1):
|
||||
SuperLinear(100, 1),
|
||||
).to(device)
|
||||
model.train()
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(), lr=max_lr, amsgrad=True
|
||||
)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, amsgrad=True)
|
||||
loss_func = torch.nn.MSELoss()
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
@@ -47,7 +45,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1):
|
||||
if best_loss is None or best_loss > loss.item():
|
||||
best_loss = loss.item()
|
||||
best_param = copy.deepcopy(model.state_dict())
|
||||
|
||||
|
||||
# print('loss={:}, best-loss={:}'.format(loss.item(), best_loss))
|
||||
model.load_state_dict(best_param)
|
||||
return model, loss_func, best_loss
|
||||
|
Reference in New Issue
Block a user