Fix test bugs
This commit is contained in:
@@ -78,7 +78,14 @@ def main(args):
|
||||
historical_y = env_info["{:}-y".format(idx)]
|
||||
# build model
|
||||
mean, std = historical_x.mean().item(), historical_x.std().item()
|
||||
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
|
||||
model_kwargs = dict(
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
act_cls="leaky_relu",
|
||||
norm_cls="simple_norm",
|
||||
mean=mean,
|
||||
std=std,
|
||||
)
|
||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||
# build optimizer
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
||||
|
Reference in New Issue
Block a user