Update models

This commit is contained in:
D-X-Y
2021-03-23 11:13:51 +00:00
parent 01397660de
commit 379b904203
7 changed files with 175 additions and 38 deletions

View File

@@ -193,19 +193,15 @@ def get_transformer(config):
raise ValueError("Invalid Configuration: {:}".format(config))
name = config.get("name", "basic")
if name == "basic":
model = TransformerModel(
model = SuperTransformer(
d_feat=config.get("d_feat"),
embed_dim=config.get("embed_dim"),
depth=config.get("depth"),
stem_dim=config.get("stem_dim"),
embed_dims=config.get("embed_dims"),
num_heads=config.get("num_heads"),
mlp_ratio=config.get("mlp_ratio"),
mlp_hidden_multipliers=config.get("mlp_hidden_multipliers"),
qkv_bias=config.get("qkv_bias"),
qk_scale=config.get("qkv_scale"),
pos_drop=config.get("pos_drop"),
mlp_drop_rate=config.get("mlp_drop_rate"),
attn_drop_rate=config.get("attn_drop_rate"),
drop_path_rate=config.get("drop_path_rate"),
norm_layer=config.get("norm_layer", None),
other_drop=config.get("other_drop"),
)
else:
raise ValueError("Unknown model name: {:}".format(name))