Prototype MAML

This commit is contained in:
D-X-Y
2021-05-10 01:02:38 +08:00
parent 6e7b1c551f
commit cbd2afb4ef
14 changed files with 1497 additions and 702 deletions

View File

@@ -21,8 +21,12 @@ def get_model(config: Dict[Text, Any], **kwargs):
act_cls = super_name2activation[kwargs["act_cls"]]
norm_cls = super_name2norm[kwargs["norm_cls"]]
mean, std = kwargs.get("mean", None), kwargs.get("std", None)
hidden_dim1 = kwargs.get("hidden_dim1", 200)
hidden_dim2 = kwargs.get("hidden_dim2", 100)
if "hidden_dim" in kwargs:
hidden_dim1 = kwargs.get("hidden_dim")
hidden_dim2 = kwargs.get("hidden_dim")
else:
hidden_dim1 = kwargs.get("hidden_dim1", 200)
hidden_dim2 = kwargs.get("hidden_dim2", 100)
model = SuperSequential(
norm_cls(mean=mean, std=std),
SuperLinear(kwargs["input_dim"], hidden_dim1),
@@ -34,4 +38,3 @@ def get_model(config: Dict[Text, Any], **kwargs):
else:
raise TypeError("Unkonwn model type: {:}".format(model_type))
return model