Update super cores
This commit is contained in:
5
lib/layers/super_core.py
Normal file
5
lib/layers/super_core.py
Normal file
@@ -0,0 +1,5 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
from .super_module import SuperModule
|
||||
from .super_mlp import SuperLinear
|
@@ -1,38 +1,71 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Optional
|
||||
from torch import Tensor
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import spaces
|
||||
from layers.super_module import SuperModule
|
||||
from layers.super_module import SuperModule
|
||||
from layers.super_module import SuperRunType
|
||||
|
||||
IntSpaceType = Union[int, spaces.Integer, spaces.Categorical]
|
||||
BoolSpaceType = Union[bool, spaces.Categorical]
|
||||
|
||||
|
||||
class SuperLinear(SuperModule):
|
||||
"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`"""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
in_features: IntSpaceType,
|
||||
out_features: IntSpaceType,
|
||||
bias: BoolSpaceType = True,
|
||||
) -> None:
|
||||
super(SuperLinear, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = Parameter(torch.Tensor(out_features, in_features))
|
||||
|
||||
# the raw input args
|
||||
self._in_features = in_features
|
||||
self._out_features = out_features
|
||||
self._bias = bias
|
||||
|
||||
self._super_weight = Parameter(
|
||||
torch.Tensor(self.out_features, self.in_features)
|
||||
)
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(out_features))
|
||||
self._super_bias = Parameter(torch.Tensor(self.out_features))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self.register_parameter("_super_bias", None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
@property
|
||||
def in_features(self):
|
||||
return spaces.get_max(self._in_features)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
@property
|
||||
def out_features(self):
|
||||
return spaces.get_max(self._out_features)
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return spaces.has_categorical(self._bias, True)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5))
|
||||
if self.bias:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._super_weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self._super_bias, -bound, bound)
|
||||
|
||||
def forward_raw(self, input: Tensor) -> Tensor:
|
||||
return F.linear(input, self._super_weight, self._super_bias)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "in_features={:}, out_features={:}, bias={:}".format(
|
||||
self.in_features, self.out_features, self.bias is not None
|
||||
self.in_features, self.out_features, self.bias
|
||||
)
|
||||
|
||||
|
||||
|
@@ -4,6 +4,14 @@
|
||||
|
||||
import abc
|
||||
import torch.nn as nn
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SuperRunMode(Enum):
|
||||
"""This class defines the enumerations for Super Model Running Mode."""
|
||||
|
||||
FullModel = "fullmodel"
|
||||
Default = "fullmodel"
|
||||
|
||||
|
||||
class SuperModule(abc.ABCMeta, nn.Module):
|
||||
@@ -11,7 +19,24 @@ class SuperModule(abc.ABCMeta, nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(SuperModule, self).__init__()
|
||||
self._super_run_type = SuperRunMode.default
|
||||
|
||||
@abc.abstractmethod
|
||||
def abstract_search_space(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def super_run_type(self):
|
||||
return self._super_run_type
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward_raw(self, *inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *inputs):
|
||||
if self.super_run_type == SuperRunMode.FullModel:
|
||||
return self.forward_raw(*inputs)
|
||||
else:
|
||||
raise ModeError(
|
||||
"Unknown Super Model Run Mode: {:}".format(self.super_run_type)
|
||||
)
|
||||
|
@@ -9,3 +9,5 @@ from .basic_space import Continuous
|
||||
from .basic_space import Integer
|
||||
from .basic_op import has_categorical
|
||||
from .basic_op import has_continuous
|
||||
from .basic_op import get_min
|
||||
from .basic_op import get_max
|
||||
|
@@ -1,4 +1,7 @@
|
||||
from spaces.basic_space import Space
|
||||
from spaces.basic_space import Integer
|
||||
from spaces.basic_space import Continuous
|
||||
from spaces.basic_space import Categorical
|
||||
from spaces.basic_space import _EPS
|
||||
|
||||
|
||||
@@ -14,3 +17,33 @@ def has_continuous(space_or_value, x):
|
||||
return space_or_value.has(x)
|
||||
else:
|
||||
return abs(space_or_value - x) <= _EPS
|
||||
|
||||
|
||||
def get_max(space_or_value):
|
||||
if isinstance(space_or_value, Integer):
|
||||
return max(space_or_value.candidates)
|
||||
elif isinstance(space_or_value, Continuous):
|
||||
return space_or_value.upper
|
||||
elif isinstance(space_or_value, Categorical):
|
||||
values = []
|
||||
for index in range(len(space_or_value)):
|
||||
max_value = get_max(space_or_value[index])
|
||||
values.append(max_value)
|
||||
return max(values)
|
||||
else:
|
||||
return space_or_value
|
||||
|
||||
|
||||
def get_min(space_or_value):
|
||||
if isinstance(space_or_value, Integer):
|
||||
return min(space_or_value.candidates)
|
||||
elif isinstance(space_or_value, Continuous):
|
||||
return space_or_value.lower
|
||||
elif isinstance(space_or_value, Categorical):
|
||||
values = []
|
||||
for index in range(len(space_or_value)):
|
||||
min_value = get_min(space_or_value[index])
|
||||
values.append(min_value)
|
||||
return min(values)
|
||||
else:
|
||||
return space_or_value
|
||||
|
@@ -10,6 +10,9 @@ import numpy as np
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
__all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"]
|
||||
|
||||
_EPS = 1e-9
|
||||
|
||||
|
||||
@@ -54,6 +57,10 @@ class Categorical(Space):
|
||||
), "default >= {:}".format(len(self._candidates))
|
||||
assert len(self) > 0, "Please provide at least one candidate"
|
||||
|
||||
@property
|
||||
def candidates(self):
|
||||
return self._candidates
|
||||
|
||||
@property
|
||||
def determined(self):
|
||||
if len(self) == 1:
|
||||
|
Reference in New Issue
Block a user