Prototype MAML
This commit is contained in:
@@ -21,8 +21,12 @@ def get_model(config: Dict[Text, Any], **kwargs):
|
||||
act_cls = super_name2activation[kwargs["act_cls"]]
|
||||
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
||||
mean, std = kwargs.get("mean", None), kwargs.get("std", None)
|
||||
hidden_dim1 = kwargs.get("hidden_dim1", 200)
|
||||
hidden_dim2 = kwargs.get("hidden_dim2", 100)
|
||||
if "hidden_dim" in kwargs:
|
||||
hidden_dim1 = kwargs.get("hidden_dim")
|
||||
hidden_dim2 = kwargs.get("hidden_dim")
|
||||
else:
|
||||
hidden_dim1 = kwargs.get("hidden_dim1", 200)
|
||||
hidden_dim2 = kwargs.get("hidden_dim2", 100)
|
||||
model = SuperSequential(
|
||||
norm_cls(mean=mean, std=std),
|
||||
SuperLinear(kwargs["input_dim"], hidden_dim1),
|
||||
@@ -34,4 +38,3 @@ def get_model(config: Dict[Text, Any], **kwargs):
|
||||
else:
|
||||
raise TypeError("Unkonwn model type: {:}".format(model_type))
|
||||
return model
|
||||
|
||||
|
Reference in New Issue
Block a user