Update LFNA ablation codes
This commit is contained in:
@@ -41,4 +41,5 @@ super_name2activation = {
|
||||
|
||||
|
||||
from .super_trade_stem import SuperAlphaEBDv1
|
||||
from .super_positional_embedding import SuperDynamicPositionE
|
||||
from .super_positional_embedding import SuperPositionalEncoder
|
||||
|
@@ -10,6 +10,41 @@ from .super_module import SuperModule
|
||||
from .super_module import IntSpaceType
|
||||
|
||||
|
||||
class SuperDynamicPositionE(SuperModule):
|
||||
"""Applies a positional encoding to the input positions."""
|
||||
|
||||
def __init__(self, dimension: int, scale: float = 1.0) -> None:
|
||||
super(SuperDynamicPositionE, self).__init__()
|
||||
|
||||
self._scale = scale
|
||||
self._dimension = dimension
|
||||
# weights to be optimized
|
||||
self.register_buffer(
|
||||
"_div_term",
|
||||
torch.exp(
|
||||
torch.arange(0, dimension, 2).float() * (-math.log(10000.0) / dimension)
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
return root_node
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_raw(input)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
print("---")
|
||||
return F.linear(input, self._super_weight, self._super_bias)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "scale={:}, dim={:}".format(self._scale, self._dimension)
|
||||
|
||||
|
||||
class SuperPositionalEncoder(SuperModule):
|
||||
"""Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
|
||||
https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65
|
||||
|
Reference in New Issue
Block a user