Update base models

This commit is contained in:
D-X-Y
2021-05-12 10:58:54 +00:00
parent 4c51f62906
commit 80ccc49d92
3 changed files with 18 additions and 18 deletions

View File

@@ -28,13 +28,24 @@ def lfna_setup(args):
env_info["dynamic_env"] = dynamic_env
torch.save(env_info, cache_path)
"""
model_kwargs = dict(
config=dict(model_type="simple_mlp"),
input_dim=1,
output_dim=1,
hidden_dim=args.hidden_dim,
act_cls="leaky_relu",
norm_cls="identity",
)
"""
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=1,
output_dim=1,
hidden_dims=[args.hidden_dim] * 2,
act_cls="gelu",
norm_cls="layer_norm_1d",
)
return logger, env_info, model_kwargs