Fix black errors
This commit is contained in:
@@ -17,7 +17,7 @@ from xlayers import trunc_normal_
|
||||
from xlayers import super_core
|
||||
|
||||
|
||||
__all__ = ["DefaultSearchSpace"]
|
||||
__all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"]
|
||||
|
||||
|
||||
def _get_mul_specs(candidates, num):
|
||||
@@ -41,6 +41,7 @@ def _assert_types(x, expected_types):
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_NET_CONFIG = None
|
||||
_default_max_depth = 5
|
||||
DefaultSearchSpace = dict(
|
||||
d_feat=6,
|
||||
@@ -163,7 +164,9 @@ class SuperTransformer(super_core.SuperModule):
|
||||
else:
|
||||
stem_dim = spaces.get_determined_value(self._stem_dim)
|
||||
cls_tokens = self.cls_token.expand(batch, -1, -1)
|
||||
cls_tokens = F.interpolate(cls_tokens, size=(stem_dim), mode="linear", align_corners=True)
|
||||
cls_tokens = F.interpolate(
|
||||
cls_tokens, size=(stem_dim), mode="linear", align_corners=True
|
||||
)
|
||||
feats_w_ct = torch.cat((cls_tokens, feats), dim=1)
|
||||
feats_w_tp = self.pos_embed(feats_w_ct)
|
||||
xfeats = self.backbone(feats_w_tp)
|
||||
|
Reference in New Issue
Block a user