Files
xautodl/lib/models/xcore.py

38 lines
1.4 KiB
Python
Raw Normal View History

2021-04-29 02:17:44 -07:00
#######################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
2021-04-29 14:28:37 +08:00
#######################################################
# Use module in xlayers to construct different models #
#######################################################
from typing import List, Text, Dict, Any
2021-04-29 16:30:47 +08:00
import torch
2021-04-29 14:28:37 +08:00
__all__ = ["get_model"]
2021-04-29 16:30:47 +08:00
from xlayers.super_core import SuperSequential
2021-04-29 14:28:37 +08:00
from xlayers.super_core import SuperLinear
2021-05-06 16:38:58 +08:00
from xlayers.super_core import super_name2norm
from xlayers.super_core import super_name2activation
2021-04-29 14:28:37 +08:00
def get_model(config: Dict[Text, Any], **kwargs):
model_type = config.get("model_type", "simple_mlp")
if model_type == "simple_mlp":
2021-05-06 16:38:58 +08:00
act_cls = super_name2activation[kwargs["act_cls"]]
norm_cls = super_name2norm[kwargs["norm_cls"]]
mean, std = kwargs.get("mean", None), kwargs.get("std", None)
hidden_dim1 = kwargs.get("hidden_dim1", 200)
hidden_dim2 = kwargs.get("hidden_dim2", 100)
2021-04-29 14:28:37 +08:00
model = SuperSequential(
2021-05-06 16:38:58 +08:00
norm_cls(mean=mean, std=std),
SuperLinear(kwargs["input_dim"], hidden_dim1),
act_cls(),
SuperLinear(hidden_dim1, hidden_dim2),
act_cls(),
SuperLinear(hidden_dim2, kwargs["output_dim"]),
2021-04-29 14:28:37 +08:00
)
else:
raise TypeError("Unkonwn model type: {:}".format(model_type))
return model
2021-05-07 10:26:35 +08:00