Update base models
This commit is contained in:
@@ -28,13 +28,24 @@ def lfna_setup(args):
|
||||
env_info["dynamic_env"] = dynamic_env
|
||||
torch.save(env_info, cache_path)
|
||||
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
config=dict(model_type="simple_mlp"),
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
hidden_dim=args.hidden_dim,
|
||||
act_cls="leaky_relu",
|
||||
norm_cls="identity",
|
||||
)
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
config=dict(model_type="norm_mlp"),
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
hidden_dims=[args.hidden_dim] * 2,
|
||||
act_cls="gelu",
|
||||
norm_cls="layer_norm_1d",
|
||||
)
|
||||
return logger, env_info, model_kwargs
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user