Update models
This commit is contained in:
@@ -193,19 +193,15 @@ def get_transformer(config):
|
||||
raise ValueError("Invalid Configuration: {:}".format(config))
|
||||
name = config.get("name", "basic")
|
||||
if name == "basic":
|
||||
model = TransformerModel(
|
||||
model = SuperTransformer(
|
||||
d_feat=config.get("d_feat"),
|
||||
embed_dim=config.get("embed_dim"),
|
||||
depth=config.get("depth"),
|
||||
stem_dim=config.get("stem_dim"),
|
||||
embed_dims=config.get("embed_dims"),
|
||||
num_heads=config.get("num_heads"),
|
||||
mlp_ratio=config.get("mlp_ratio"),
|
||||
mlp_hidden_multipliers=config.get("mlp_hidden_multipliers"),
|
||||
qkv_bias=config.get("qkv_bias"),
|
||||
qk_scale=config.get("qkv_scale"),
|
||||
pos_drop=config.get("pos_drop"),
|
||||
mlp_drop_rate=config.get("mlp_drop_rate"),
|
||||
attn_drop_rate=config.get("attn_drop_rate"),
|
||||
drop_path_rate=config.get("drop_path_rate"),
|
||||
norm_layer=config.get("norm_layer", None),
|
||||
other_drop=config.get("other_drop"),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown model name: {:}".format(name))
|
||||
|
Reference in New Issue
Block a user