Rerange experimental
This commit is contained in:
44
xautodl/xlayers/super_rearrange.py
Normal file
44
xautodl/xlayers/super_rearrange.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#############################################################
|
||||
# Borrow the idea of https://github.com/arogozhnikov/einops #
|
||||
#############################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
from typing import Optional, Callable
|
||||
|
||||
from xautodl import spaces
|
||||
from .super_module import SuperModule
|
||||
from .super_module import IntSpaceType
|
||||
from .super_module import BoolSpaceType
|
||||
|
||||
|
||||
class SuperRearrange(SuperModule):
|
||||
"""Applies the rearrange operation."""
|
||||
|
||||
def __init__(self, pattern, **axes_lengths):
|
||||
super(SuperRearrange, self).__init__()
|
||||
|
||||
self._pattern = pattern
|
||||
self._axes_lengths = axes_lengths
|
||||
self.reset_parameters()
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
return root_node
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
params = repr(self._pattern)
|
||||
for axis, length in self._axes_lengths.items():
|
||||
params += ", {}={}".format(axis, length)
|
||||
return "{}({})".format(self.__class__.__name__, params)
|
Reference in New Issue
Block a user