Update LFNA ablation codes

This commit is contained in:
D-X-Y
2021-05-12 15:45:45 +08:00
parent 4da19d6efe
commit d51e5fdc7f
6 changed files with 443 additions and 280 deletions

View File

@@ -41,4 +41,5 @@ super_name2activation = {
from .super_trade_stem import SuperAlphaEBDv1
from .super_positional_embedding import SuperDynamicPositionE
from .super_positional_embedding import SuperPositionalEncoder

View File

@@ -10,6 +10,41 @@ from .super_module import SuperModule
from .super_module import IntSpaceType
class SuperDynamicPositionE(SuperModule):
"""Applies a positional encoding to the input positions."""
def __init__(self, dimension: int, scale: float = 1.0) -> None:
super(SuperDynamicPositionE, self).__init__()
self._scale = scale
self._dimension = dimension
# weights to be optimized
self.register_buffer(
"_div_term",
torch.exp(
torch.arange(0, dimension, 2).float() * (-math.log(10000.0) / dimension)
),
)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
return root_node
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
import pdb
pdb.set_trace()
print("---")
return F.linear(input, self._super_weight, self._super_bias)
def extra_repr(self) -> str:
return "scale={:}, dim={:}".format(self._scale, self._dimension)
class SuperPositionalEncoder(SuperModule):
"""Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65