Update super cores

This commit is contained in:
D-X-Y
2021-03-18 18:32:26 +08:00
parent 63c8bb9bc8
commit eabdd21d97
9 changed files with 209 additions and 18 deletions

5
lib/layers/super_core.py Normal file
View File

@@ -0,0 +1,5 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
from .super_module import SuperModule
from .super_mlp import SuperLinear

View File

@@ -1,38 +1,71 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import torch.nn as nn
from torch.nn.parameter import Parameter
from typing import Optional
from torch import Tensor
import math
from typing import Optional, Union
import spaces
from layers.super_module import SuperModule
from layers.super_module import SuperModule
from layers.super_module import SuperRunType
IntSpaceType = Union[int, spaces.Integer, spaces.Categorical]
BoolSpaceType = Union[bool, spaces.Categorical]
class SuperLinear(SuperModule):
"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`"""
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
def __init__(
self,
in_features: IntSpaceType,
out_features: IntSpaceType,
bias: BoolSpaceType = True,
) -> None:
super(SuperLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
# the raw input args
self._in_features = in_features
self._out_features = out_features
self._bias = bias
self._super_weight = Parameter(
torch.Tensor(self.out_features, self.in_features)
)
if bias:
self.bias = Parameter(torch.Tensor(out_features))
self._super_bias = Parameter(torch.Tensor(self.out_features))
else:
self.register_parameter("bias", None)
self.register_parameter("_super_bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
@property
def in_features(self):
return spaces.get_max(self._in_features)
def forward(self, input: Tensor) -> Tensor:
return F.linear(input, self.weight, self.bias)
@property
def out_features(self):
return spaces.get_max(self._out_features)
@property
def bias(self):
return spaces.has_categorical(self._bias, True)
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5))
if self.bias:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._super_weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self._super_bias, -bound, bound)
def forward_raw(self, input: Tensor) -> Tensor:
return F.linear(input, self._super_weight, self._super_bias)
def extra_repr(self) -> str:
return "in_features={:}, out_features={:}, bias={:}".format(
self.in_features, self.out_features, self.bias is not None
self.in_features, self.out_features, self.bias
)

View File

@@ -4,6 +4,14 @@
import abc
import torch.nn as nn
from enum import Enum
class SuperRunMode(Enum):
"""This class defines the enumerations for Super Model Running Mode."""
FullModel = "fullmodel"
Default = "fullmodel"
class SuperModule(abc.ABCMeta, nn.Module):
@@ -11,7 +19,24 @@ class SuperModule(abc.ABCMeta, nn.Module):
def __init__(self):
super(SuperModule, self).__init__()
self._super_run_type = SuperRunMode.default
@abc.abstractmethod
def abstract_search_space(self):
raise NotImplementedError
@property
def super_run_type(self):
return self._super_run_type
@abc.abstractmethod
def forward_raw(self, *inputs):
raise NotImplementedError
def forward(self, *inputs):
if self.super_run_type == SuperRunMode.FullModel:
return self.forward_raw(*inputs)
else:
raise ModeError(
"Unknown Super Model Run Mode: {:}".format(self.super_run_type)
)

View File

@@ -9,3 +9,5 @@ from .basic_space import Continuous
from .basic_space import Integer
from .basic_op import has_categorical
from .basic_op import has_continuous
from .basic_op import get_min
from .basic_op import get_max

View File

@@ -1,4 +1,7 @@
from spaces.basic_space import Space
from spaces.basic_space import Integer
from spaces.basic_space import Continuous
from spaces.basic_space import Categorical
from spaces.basic_space import _EPS
@@ -14,3 +17,33 @@ def has_continuous(space_or_value, x):
return space_or_value.has(x)
else:
return abs(space_or_value - x) <= _EPS
def get_max(space_or_value):
if isinstance(space_or_value, Integer):
return max(space_or_value.candidates)
elif isinstance(space_or_value, Continuous):
return space_or_value.upper
elif isinstance(space_or_value, Categorical):
values = []
for index in range(len(space_or_value)):
max_value = get_max(space_or_value[index])
values.append(max_value)
return max(values)
else:
return space_or_value
def get_min(space_or_value):
if isinstance(space_or_value, Integer):
return min(space_or_value.candidates)
elif isinstance(space_or_value, Continuous):
return space_or_value.lower
elif isinstance(space_or_value, Categorical):
values = []
for index in range(len(space_or_value)):
min_value = get_min(space_or_value[index])
values.append(min_value)
return min(values)
else:
return space_or_value

View File

@@ -10,6 +10,9 @@ import numpy as np
from typing import Optional
__all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"]
_EPS = 1e-9
@@ -54,6 +57,10 @@ class Categorical(Space):
), "default >= {:}".format(len(self._candidates))
assert len(self) > 0, "Please provide at least one candidate"
@property
def candidates(self):
return self._candidates
@property
def determined(self):
if len(self) == 1: