Update misc
This commit is contained in:
@@ -11,6 +11,7 @@ import math
|
||||
from typing import Optional, Callable
|
||||
|
||||
from xautodl import spaces
|
||||
from .misc_utils import ParsedExpression
|
||||
from .super_module import SuperModule
|
||||
from .super_module import IntSpaceType
|
||||
from .super_module import BoolSpaceType
|
||||
@@ -24,6 +25,17 @@ class SuperReArrange(SuperModule):
|
||||
|
||||
self._pattern = pattern
|
||||
self._axes_lengths = axes_lengths
|
||||
axes_lengths = tuple(sorted(self._axes_lengths.items()))
|
||||
# Perform initial parsing of pattern and provided supplementary info
|
||||
# axes_lengths is a tuple of tuples (axis_name, axis_length)
|
||||
left, right = pattern.split("->")
|
||||
left = ParsedExpression(left)
|
||||
right = ParsedExpression(right)
|
||||
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
print("-")
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
@@ -31,13 +43,16 @@ class SuperReArrange(SuperModule):
|
||||
return root_node
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
self.forward_raw(input)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
raise NotImplementedError
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
params = repr(self._pattern)
|
||||
for axis, length in self._axes_lengths.items():
|
||||
params += ", {}={}".format(axis, length)
|
||||
return "{}({})".format(self.__class__.__name__, params)
|
||||
return "{:}".format(params)
|
||||
|
Reference in New Issue
Block a user