Fix test bugs

This commit is contained in:
D-X-Y
2021-05-06 16:43:31 +08:00
parent 4c14c7b85b
commit f6a024a6ff
4 changed files with 20 additions and 4 deletions

View File

@@ -82,7 +82,14 @@ def main(args):
historical_x, historical_y = subsample(historical_x, historical_y)
# build model
mean, std = historical_x.mean().item(), historical_x.std().item()
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
model_kwargs = dict(
input_dim=1,
output_dim=1,
act_cls="leaky_relu",
norm_cls="simple_norm",
mean=mean,
std=std,
)
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)