add autodl

This commit is contained in:
mhz
2024-08-25 18:02:31 +02:00
parent 192f286cfb
commit a0a25f291c
431 changed files with 50646 additions and 8 deletions

View File

@@ -0,0 +1,7 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
# This file is expected to be self-contained, expect
# for importing from spaces to include search space.
#####################################################
from .super_core import *

View File

@@ -0,0 +1,154 @@
# borrowed from https://github.com/arogozhnikov/einops/blob/master/einops/parsing.py
import warnings
import keyword
from typing import List
class AnonymousAxis:
"""Important thing: all instances of this class are not equal to each other"""
def __init__(self, value: str):
self.value = int(value)
if self.value <= 1:
if self.value == 1:
raise EinopsError(
"No need to create anonymous axis of length 1. Report this as an issue"
)
else:
raise EinopsError(
"Anonymous axis should have positive length, not {}".format(
self.value
)
)
def __repr__(self):
return "{}-axis".format(str(self.value))
class ParsedExpression:
"""
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
and keeps some information important for downstream
"""
def __init__(self, expression):
self.identifiers = set()
# that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition
self.has_non_unitary_anonymous_axes = False
# composition keeps structure of composite axes, see how different corner cases are handled in tests
self.composition = []
if "." in expression:
raise ValueError("Does not support . in the expression.")
bracket_group = None
def add_axis_name(x):
if x is not None:
if x in self.identifiers:
raise ValueError(
'Indexing expression contains duplicate dimension "{}"'.format(
x
)
)
is_number = str.isdecimal(x)
if is_number and int(x) == 1:
# handling the case of anonymous axis of length 1
if bracket_group is None:
self.composition.append([])
else:
pass # no need to think about 1s inside parenthesis
return
is_axis_name, reason = self.check_axis_name(x, return_reason=True)
if not (is_number or is_axis_name):
raise ValueError(
"Invalid axis identifier: {}\n{}".format(x, reason)
)
if is_number:
x = AnonymousAxis(x)
self.identifiers.add(x)
if is_number:
self.has_non_unitary_anonymous_axes = True
if bracket_group is None:
self.composition.append([x])
else:
bracket_group.append(x)
current_identifier = None
for char in expression:
if char in "() ":
add_axis_name(current_identifier)
current_identifier = None
if char == "(":
if bracket_group is not None:
raise ValueError(
"Axis composition is one-level (brackets inside brackets not allowed)"
)
bracket_group = []
elif char == ")":
if bracket_group is None:
raise ValueError("Brackets are not balanced")
self.composition.append(bracket_group)
bracket_group = None
elif str.isalnum(char) or char == "_":
if current_identifier is None:
current_identifier = char
else:
current_identifier += char
else:
raise ValueError("Unknown character '{}'".format(char))
if bracket_group is not None:
raise ValueError(
'Imbalanced parentheses in expression: "{}"'.format(expression)
)
add_axis_name(current_identifier)
def flat_axes_order(self) -> List:
result = []
for composed_axis in self.composition:
assert isinstance(composed_axis, list), "does not work with ellipsis"
for axis in composed_axis:
result.append(axis)
return result
def has_composed_axes(self) -> bool:
# this will ignore 1 inside brackets
for axes in self.composition:
if isinstance(axes, list) and len(axes) > 1:
return True
return False
@staticmethod
def check_axis_name(name: str, return_reason=False):
"""
Valid axes names are python identifiers except keywords,
and additionally should not start or end with underscore
"""
if not str.isidentifier(name):
result = False, "not a valid python identifier"
elif name[0] == "_" or name[-1] == "_":
result = False, "axis name should should not start or end with underscore"
else:
if keyword.iskeyword(name):
warnings.warn(
"It is discouraged to use axes names that are keywords: {}".format(
name
),
RuntimeWarning,
)
if name in ["axis"]:
warnings.warn(
"It is discouraged to use 'axis' as an axis name "
"and will raise an error in future",
FutureWarning,
)
result = True, None
if return_reason:
return result
else:
return result[0]
def __repr__(self) -> str:
return "{name}({composition})".format(
name=self.__class__.__name__, composition=self.composition
)

View File

@@ -0,0 +1,124 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Callable
from xautodl import spaces
from .super_module import SuperModule
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
class SuperReLU(SuperModule):
"""Applies a the rectified linear unit function element-wise."""
def __init__(self, inplace: bool = False) -> None:
super(SuperReLU, self).__init__()
self._inplace = inplace
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return F.relu(input, inplace=self._inplace)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)
def extra_repr(self) -> str:
return "inplace=True" if self._inplace else ""
class SuperGELU(SuperModule):
"""Applies a the Gaussian Error Linear Units function element-wise."""
def __init__(self) -> None:
super(SuperGELU, self).__init__()
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)
class SuperSigmoid(SuperModule):
"""Applies a the Sigmoid function element-wise."""
def __init__(self) -> None:
super(SuperSigmoid, self).__init__()
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return torch.sigmoid(input)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)
class SuperLeakyReLU(SuperModule):
"""https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#LeakyReLU"""
def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
super(SuperLeakyReLU, self).__init__()
self._negative_slope = negative_slope
self._inplace = inplace
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return F.leaky_relu(input, self._negative_slope, self._inplace)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)
def extra_repr(self) -> str:
inplace_str = "inplace=True" if self._inplace else ""
return "negative_slope={}{}".format(self._negative_slope, inplace_str)
class SuperTanh(SuperModule):
"""Applies a the Tanh function element-wise."""
def __init__(self) -> None:
super(SuperTanh, self).__init__()
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return torch.tanh(input)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)

View File

@@ -0,0 +1,341 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
from typing import Optional, Text
import torch
import torch.nn as nn
import torch.nn.functional as F
from xautodl import spaces
from .super_module import SuperModule
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
from .super_dropout import SuperDropout, SuperDrop
from .super_linear import SuperLinear
class SuperSelfAttention(SuperModule):
"""The super model for attention layer."""
def __init__(
self,
input_dim: IntSpaceType,
proj_dim: Optional[IntSpaceType],
num_heads: IntSpaceType,
qkv_bias: BoolSpaceType = False,
attn_drop: Optional[float] = None,
proj_drop: Optional[float] = None,
use_mask=False,
):
super(SuperSelfAttention, self).__init__()
self._input_dim = input_dim
self._proj_dim = proj_dim
self._num_heads = num_heads
self._qkv_bias = qkv_bias
self._use_mask = use_mask
self._infinity = 1e9
mul_head_dim = (
spaces.get_max(input_dim) // spaces.get_min(num_heads)
) * spaces.get_min(num_heads)
assert mul_head_dim == spaces.get_max(input_dim)
self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.attn_drop = SuperDrop(attn_drop or 0.0, [-1, -1, -1, -1], recover=True)
if proj_dim is not None:
self.proj = SuperLinear(input_dim, proj_dim)
self.proj_drop = SuperDropout(proj_drop or 0.0)
else:
self.proj = None
@property
def num_heads(self):
return spaces.get_max(self._num_heads)
@property
def input_dim(self):
return spaces.get_max(self._input_dim)
@property
def proj_dim(self):
return spaces.get_max(self._proj_dim)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
space_q = self.q_fc.abstract_search_space
space_k = self.k_fc.abstract_search_space
space_v = self.v_fc.abstract_search_space
if not spaces.is_determined(self._num_heads):
root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True))
if not spaces.is_determined(space_q):
root_node.append("q_fc", space_q)
if not spaces.is_determined(space_k):
root_node.append("k_fc", space_k)
if not spaces.is_determined(space_v):
root_node.append("v_fc", space_v)
if self.proj is not None:
space_proj = self.proj.abstract_search_space
if not spaces.is_determined(space_proj):
root_node.append("proj", space_proj)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperSelfAttention, self).apply_candidate(abstract_child)
if "q_fc" in abstract_child:
self.q_fc.apply_candidate(abstract_child["q_fc"])
if "k_fc" in abstract_child:
self.k_fc.apply_candidate(abstract_child["k_fc"])
if "v_fc" in abstract_child:
self.v_fc.apply_candidate(abstract_child["v_fc"])
if "proj" in abstract_child:
self.proj.apply_candidate(abstract_child["proj"])
def forward_qkv(self, input: torch.Tensor, num_head: int) -> torch.Tensor:
B, N, C = input.shape
q = self.q_fc(input)
k = self.k_fc(input)
v = self.v_fc(input)
if num_head > C:
raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C))
head_dim = C // num_head
# process the first [num_head * head_dim] part
q_v1 = (
q[:, :, : num_head * head_dim]
.reshape(B, N, num_head, head_dim)
.permute(0, 2, 1, 3)
)
k_v1 = (
k[:, :, : num_head * head_dim]
.reshape(B, N, num_head, head_dim)
.permute(0, 2, 1, 3)
)
v_v1 = (
v[:, :, : num_head * head_dim]
.reshape(B, N, num_head, head_dim)
.permute(0, 2, 1, 3)
)
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
if self._use_mask:
mask = torch.triu(
torch.ones((N, N), dtype=torch.bool, device=input.device), 1
)
mask = torch.unsqueeze(torch.unsqueeze(mask, dim=0), dim=0)
attn_v1 = attn_v1.masked_fill(mask, -self._infinity)
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * N
attn_v1 = self.attn_drop(attn_v1)
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
if C == head_dim * num_head:
feats = feats_v1
else: # The channels can not be divided by num_head, the remainder forms an additional head
q_v2 = q[:, :, num_head * head_dim :]
k_v2 = k[:, :, num_head * head_dim :]
v_v2 = v[:, :, num_head * head_dim :]
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
attn_v2 = attn_v2.softmax(dim=-1)
attn_v2 = self.attn_drop(attn_v2)
feats_v2 = attn_v2 @ v_v2
feats = torch.cat([feats_v1, feats_v2], dim=-1)
return feats
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check the num_heads:
if not spaces.is_determined(self._num_heads):
num_heads = self.abstract_child["_num_heads"].value
else:
num_heads = spaces.get_determined_value(self._num_heads)
feats = self.forward_qkv(input, num_heads)
if self.proj is None:
return feats
else:
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
feats = self.forward_qkv(input, self.num_heads)
if self.proj is None:
return feats
else:
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def extra_repr(self) -> str:
return (
"input_dim={:}, proj_dim={:}, num_heads={:}, mask={:}, infinity={:}".format(
self._input_dim,
self._proj_dim,
self._num_heads,
self._use_mask,
self._infinity,
)
)
class SuperQKVAttention(SuperModule):
"""The super model for attention layer."""
def __init__(
self,
in_q_dim: IntSpaceType,
in_k_dim: IntSpaceType,
in_v_dim: IntSpaceType,
proj_dim: IntSpaceType,
num_heads: IntSpaceType,
qkv_bias: BoolSpaceType = False,
attn_drop: Optional[float] = None,
proj_drop: Optional[float] = None,
):
super(SuperQKVAttention, self).__init__()
self._in_v_dim = in_v_dim
self._in_q_dim = in_q_dim
self._in_k_dim = in_k_dim
self._proj_dim = proj_dim
self._num_heads = num_heads
self._qkv_bias = qkv_bias
self.q_fc = SuperLinear(in_q_dim, proj_dim, bias=qkv_bias)
self.k_fc = SuperLinear(in_k_dim, proj_dim, bias=qkv_bias)
self.v_fc = SuperLinear(in_v_dim, proj_dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop or 0.0)
self.proj = SuperLinear(proj_dim, proj_dim)
self.proj_drop = nn.Dropout(proj_drop or 0.0)
self._infinity = 1e9
@property
def num_heads(self):
return spaces.get_max(self._num_heads)
@property
def in_v_dim(self):
return spaces.get_max(self._in_v_dim)
@property
def in_q_dim(self):
return spaces.get_max(self._in_q_dim)
@property
def in_k_dim(self):
return spaces.get_max(self._in_k_dim)
@property
def proj_dim(self):
return spaces.get_max(self._proj_dim)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
space_q = self.q_fc.abstract_search_space
space_k = self.k_fc.abstract_search_space
space_v = self.v_fc.abstract_search_space
space_proj = self.proj.abstract_search_space
if not spaces.is_determined(self._num_heads):
root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True))
if not spaces.is_determined(space_q):
root_node.append("q_fc", space_q)
if not spaces.is_determined(space_k):
root_node.append("k_fc", space_k)
if not spaces.is_determined(space_v):
root_node.append("v_fc", space_v)
if not spaces.is_determined(space_proj):
root_node.append("proj", space_proj)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperQKVAttention, self).apply_candidate(abstract_child)
if "q_fc" in abstract_child:
self.q_fc.apply_candidate(abstract_child["q_fc"])
if "k_fc" in abstract_child:
self.k_fc.apply_candidate(abstract_child["k_fc"])
if "v_fc" in abstract_child:
self.v_fc.apply_candidate(abstract_child["v_fc"])
if "proj" in abstract_child:
self.proj.apply_candidate(abstract_child["proj"])
def forward_qkv(
self, q_tensor, k_tensor, v_tensor, num_head: int, mask=None
) -> torch.Tensor:
q = self.q_fc(q_tensor)
B, N, C = q.shape
k = self.k_fc(k_tensor)
B0, S, _ = k.shape
v = self.v_fc(v_tensor)
assert B0 == v.shape[0] and S == v.shape[1]
head_dim = C // num_head
if num_head > C:
raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C))
q_v1 = (
q[:, :, : num_head * head_dim]
.reshape(B, N, num_head, head_dim)
.permute(0, 2, 1, 3)
)
k_v1 = (
k[:, :, : num_head * head_dim]
.reshape(B0, S, num_head, head_dim)
.permute(0, 2, 1, 3)
)
# compute the attention map
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
if mask is not None:
mask = torch.unsqueeze(mask, dim=1)
attn_v1 = attn_v1.masked_fill(mask, -self._infinity)
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * S
attn_v1 = self.attn_drop(attn_v1)
v_v1 = (
v[:, :, : num_head * head_dim]
.reshape(B0, S, num_head, head_dim)
.permute(0, 2, 1, 3)
)
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
# process the first [num_head * head_dim] part
if C == head_dim * num_head:
feats = feats_v1
else: # The channels can not be divided by num_head, the remainder forms an additional head
# [might have bugs, did not check yet]
q_v2 = q[:, :, num_head * head_dim :]
k_v2 = k[:, :, num_head * head_dim :]
v_v2 = v[:, :, num_head * head_dim :]
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
attn_v2 = attn_v2.softmax(dim=-1)
attn_v2 = self.attn_drop(attn_v2)
feats_v2 = attn_v2 @ v_v2
feats = torch.cat([feats_v1, feats_v2], dim=-1)
return feats
def forward_candidate(
self, q_tensor, k_tensor, v_tensor, mask=None
) -> torch.Tensor:
# check the num_heads:
if not spaces.is_determined(self._num_heads):
num_heads = self.abstract_child["_num_heads"].value
else:
num_heads = spaces.get_determined_value(self._num_heads)
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads, mask)
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def forward_raw(self, q_tensor, k_tensor, v_tensor, mask=None) -> torch.Tensor:
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads, mask)
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def extra_repr(self) -> str:
return "input_dim={:}, proj_dim={:}, num_heads={:}, infinity={:}".format(
(self.in_q_dim, self.in_k_dim, self.in_v_dim),
self._proj_dim,
self._num_heads,
self._infinity,
)

View File

@@ -0,0 +1,113 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
from typing import Optional, Text
import torch
import torch.nn as nn
import torch.nn.functional as F
from xautodl import spaces
from .super_module import SuperModule
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
from .super_linear import SuperLinear
class SuperQKVAttentionV2(SuperModule):
"""The super model for attention layer."""
def __init__(
self,
qk_att_dim: int,
in_v_dim: int,
hidden_dim: int,
num_heads: int,
proj_dim: int,
qkv_bias: bool = False,
attn_drop: Optional[float] = None,
proj_drop: Optional[float] = None,
):
super(SuperQKVAttentionV2, self).__init__()
self._in_v_dim = in_v_dim
self._qk_att_dim = qk_att_dim
self._proj_dim = proj_dim
self._hidden_dim = hidden_dim
self._num_heads = num_heads
self._qkv_bias = qkv_bias
self.qk_fc = SuperLinear(qk_att_dim, num_heads, bias=qkv_bias)
self.v_fc = SuperLinear(in_v_dim, hidden_dim * num_heads, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop or 0.0)
self.proj = SuperLinear(hidden_dim * num_heads, proj_dim)
self.proj_drop = nn.Dropout(proj_drop or 0.0)
self._infinity = 1e9
@property
def num_heads(self):
return spaces.get_max(self._num_heads)
@property
def in_v_dim(self):
return spaces.get_max(self._in_v_dim)
@property
def qk_att_dim(self):
return spaces.get_max(self._qk_att_dim)
@property
def hidden_dim(self):
return spaces.get_max(self._hidden_dim)
@property
def proj_dim(self):
return spaces.get_max(self._proj_dim)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
raise NotImplementedError
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperQKVAttentionV2, self).apply_candidate(abstract_child)
raise NotImplementedError
def forward_qkv(
self, qk_att_tensor, v_tensor, num_head: int, mask=None
) -> torch.Tensor:
qk_att = self.qk_fc(qk_att_tensor)
B, N, S, _ = qk_att.shape
assert _ == num_head
attn_v1 = qk_att.permute(0, 3, 1, 2)
if mask is not None:
mask = torch.unsqueeze(mask, dim=1)
attn_v1 = attn_v1.masked_fill(mask, -self._infinity)
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * S
attn_v1 = self.attn_drop(attn_v1)
v = self.v_fc(v_tensor)
B0, _, _ = v.shape
v_v1 = v.reshape(B0, S, num_head, -1).permute(0, 2, 1, 3)
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
return feats_v1
def forward_candidate(self, qk_att_tensor, v_tensor, mask=None) -> torch.Tensor:
return self.forward_raw(qk_att_tensor, v_tensor, mask)
def forward_raw(self, qk_att_tensor, v_tensor, mask=None) -> torch.Tensor:
feats = self.forward_qkv(qk_att_tensor, v_tensor, self.num_heads, mask)
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def extra_repr(self) -> str:
return "input_dim={:}, hidden_dim={:}, proj_dim={:}, num_heads={:}, infinity={:}".format(
(self.qk_att_dim, self.in_v_dim),
self._hidden_dim,
self._proj_dim,
self._num_heads,
self._infinity,
)

View File

@@ -0,0 +1,120 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import torch
from itertools import islice
import operator
from collections import OrderedDict
from typing import Optional, Union, Callable, TypeVar, Iterator
from xautodl import spaces
from .super_module import SuperModule
T = TypeVar("T", bound=SuperModule)
class SuperSequential(SuperModule):
"""A sequential container wrapped with 'Super' ability.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
To make it easier to understand, here is a small example::
# Example of using Sequential
model = SuperSequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
"""
def __init__(self, *args):
super(SuperSequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
if not isinstance(args, (list, tuple)):
raise ValueError("Invalid input type: {:}".format(type(args)))
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def _get_item_by_idx(self, iterator, idx) -> T:
"""Get the idx-th item of the iterator"""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError("index {} is out of range".format(idx))
idx %= size
return next(islice(iterator, idx, None))
def __getitem__(self, idx) -> Union["SuperSequential", T]:
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)
def __setitem__(self, idx: int, module: SuperModule) -> None:
key: str = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)
def __delitem__(self, idx: Union[slice, int]) -> None:
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
def __len__(self) -> int:
return len(self._modules)
def __dir__(self):
keys = super(SuperSequential, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def __iter__(self) -> Iterator[SuperModule]:
return iter(self._modules.values())
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
for index, module in enumerate(self):
if not isinstance(module, SuperModule):
continue
space = module.abstract_search_space
if not spaces.is_determined(space):
root_node.append(str(index), space)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperSequential, self).apply_candidate(abstract_child)
for index, module in enumerate(self):
if str(index) in abstract_child:
module.apply_candidate(abstract_child[str(index)])
def forward_candidate(self, input):
return self.forward_raw(input)
def forward_raw(self, input):
for module in self:
input = module(input)
return input
def forward_with_container(self, input, container, prefix=[]):
for index, module in enumerate(self):
input = module.forward_with_container(
input, container, prefix + [str(index)]
)
return input

View File

@@ -0,0 +1,51 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
from .super_module import SuperRunMode
from .super_module import IntSpaceType
from .super_module import LayerOrder
from .super_module import SuperModule
from .super_container import SuperSequential
from .super_linear import SuperLinear
from .super_linear import SuperMLPv1, SuperMLPv2
from .super_norm import SuperSimpleNorm
from .super_norm import SuperLayerNorm1D
from .super_norm import SuperSimpleLearnableNorm
from .super_norm import SuperIdentity
from .super_dropout import SuperDropout
from .super_dropout import SuperDrop
super_name2norm = {
"simple_norm": SuperSimpleNorm,
"simple_learn_norm": SuperSimpleLearnableNorm,
"layer_norm_1d": SuperLayerNorm1D,
"identity": SuperIdentity,
}
from .super_attention import SuperSelfAttention
from .super_attention import SuperQKVAttention
from .super_attention_v2 import SuperQKVAttentionV2
from .super_transformer import SuperTransformerEncoderLayer
from .super_activations import SuperReLU
from .super_activations import SuperLeakyReLU
from .super_activations import SuperTanh
from .super_activations import SuperGELU
from .super_activations import SuperSigmoid
super_name2activation = {
"relu": SuperReLU,
"sigmoid": SuperSigmoid,
"gelu": SuperGELU,
"leaky_relu": SuperLeakyReLU,
"tanh": SuperTanh,
}
from .super_trade_stem import SuperAlphaEBDv1
from .super_positional_embedding import SuperDynamicPositionE
from .super_positional_embedding import SuperPositionalEncoder
from .super_rearrange import SuperReArrange

View File

@@ -0,0 +1,83 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Callable, Tuple
from xautodl import spaces
from .super_module import SuperModule
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
class SuperDropout(SuperModule):
"""Applies a the dropout function element-wise."""
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super(SuperDropout, self).__init__()
self._p = p
self._inplace = inplace
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return F.dropout(input, self._p, self.training, self._inplace)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)
def extra_repr(self) -> str:
xstr = "inplace=True" if self._inplace else ""
return "p={:}".format(self._p) + ", " + xstr
class SuperDrop(SuperModule):
"""Applies a the drop-path function element-wise."""
def __init__(self, p: float, dims: Tuple[int], recover: bool = True) -> None:
super(SuperDrop, self).__init__()
self._p = p
self._dims = dims
self._recover = recover
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
if not self.training or self._p <= 0:
return input
keep_prob = 1 - self._p
shape = [input.shape[0]] + [
x if y == -1 else y for x, y in zip(input.shape[1:], self._dims)
]
random_tensor = keep_prob + torch.rand(
shape, dtype=input.dtype, device=input.device
)
random_tensor.floor_() # binarize
if self._recover:
return input.div(keep_prob) * random_tensor
else:
return input * random_tensor # as masks
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)
def extra_repr(self) -> str:
return (
"p={:}".format(self._p)
+ ", dims={:}".format(self._dims)
+ ", recover={:}".format(self._recover)
)

View File

@@ -0,0 +1,310 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Callable
from xautodl import spaces
from .super_module import SuperModule
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
class SuperLinear(SuperModule):
"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`"""
def __init__(
self,
in_features: IntSpaceType,
out_features: IntSpaceType,
bias: BoolSpaceType = True,
) -> None:
super(SuperLinear, self).__init__()
# the raw input args
self._in_features = in_features
self._out_features = out_features
self._bias = bias
# weights to be optimized
self.register_parameter(
"_super_weight",
torch.nn.Parameter(torch.Tensor(self.out_features, self.in_features)),
)
if self.bias:
self.register_parameter(
"_super_bias", torch.nn.Parameter(torch.Tensor(self.out_features))
)
else:
self.register_parameter("_super_bias", None)
self.reset_parameters()
@property
def in_features(self):
return spaces.get_max(self._in_features)
@property
def out_features(self):
return spaces.get_max(self._out_features)
@property
def bias(self):
return spaces.has_categorical(self._bias, True)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
if not spaces.is_determined(self._in_features):
root_node.append(
"_in_features", self._in_features.abstract(reuse_last=True)
)
if not spaces.is_determined(self._out_features):
root_node.append(
"_out_features", self._out_features.abstract(reuse_last=True)
)
if not spaces.is_determined(self._bias):
root_node.append("_bias", self._bias.abstract(reuse_last=True))
return root_node
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5))
if self.bias:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._super_weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self._super_bias, -bound, bound)
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check inputs ->
if not spaces.is_determined(self._in_features):
expected_input_dim = self.abstract_child["_in_features"].value
else:
expected_input_dim = spaces.get_determined_value(self._in_features)
if input.size(-1) != expected_input_dim:
raise ValueError(
"Expect the input dim of {:} instead of {:}".format(
expected_input_dim, input.size(-1)
)
)
# create the weight matrix
if not spaces.is_determined(self._out_features):
out_dim = self.abstract_child["_out_features"].value
else:
out_dim = spaces.get_determined_value(self._out_features)
candidate_weight = self._super_weight[:out_dim, :expected_input_dim]
# create the bias matrix
if not spaces.is_determined(self._bias):
if self.abstract_child["_bias"].value:
candidate_bias = self._super_bias[:out_dim]
else:
candidate_bias = None
else:
if spaces.get_determined_value(self._bias):
candidate_bias = self._super_bias[:out_dim]
else:
candidate_bias = None
return F.linear(input, candidate_weight, candidate_bias)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self._super_weight, self._super_bias)
def extra_repr(self) -> str:
return "in_features={:}, out_features={:}, bias={:}".format(
self._in_features, self._out_features, self._bias
)
def forward_with_container(self, input, container, prefix=[]):
super_weight_name = ".".join(prefix + ["_super_weight"])
super_weight = container.query(super_weight_name)
super_bias_name = ".".join(prefix + ["_super_bias"])
if container.has(super_bias_name):
super_bias = container.query(super_bias_name)
else:
super_bias = None
return F.linear(input, super_weight, super_bias)
class SuperMLPv1(SuperModule):
"""An MLP layer: FC -> Activation -> Drop -> FC -> Drop."""
def __init__(
self,
in_features: IntSpaceType,
hidden_features: IntSpaceType,
out_features: IntSpaceType,
act_layer: Callable[[], nn.Module] = nn.GELU,
drop: Optional[float] = None,
):
super(SuperMLPv1, self).__init__()
self._in_features = in_features
self._hidden_features = hidden_features
self._out_features = out_features
self._drop_rate = drop
self.fc1 = SuperLinear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = SuperLinear(hidden_features, out_features)
self.drop = nn.Dropout(drop or 0.0)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
space_fc1 = self.fc1.abstract_search_space
space_fc2 = self.fc2.abstract_search_space
if not spaces.is_determined(space_fc1):
root_node.append("fc1", space_fc1)
if not spaces.is_determined(space_fc2):
root_node.append("fc2", space_fc2)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperMLPv1, self).apply_candidate(abstract_child)
if "fc1" in abstract_child:
self.fc1.apply_candidate(abstract_child["fc1"])
if "fc2" in abstract_child:
self.fc2.apply_candidate(abstract_child["fc2"])
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
x = self.fc1(input)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def extra_repr(self) -> str:
return "in_features={:}, hidden_features={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format(
self._in_features,
self._hidden_features,
self._out_features,
self._drop_rate,
)
class SuperMLPv2(SuperModule):
"""An MLP layer: FC -> Activation -> Drop -> FC -> Drop."""
def __init__(
self,
in_features: IntSpaceType,
hidden_multiplier: IntSpaceType,
out_features: IntSpaceType,
act_layer: Callable[[], nn.Module] = nn.GELU,
drop: Optional[float] = None,
):
super(SuperMLPv2, self).__init__()
self._in_features = in_features
self._hidden_multiplier = hidden_multiplier
self._out_features = out_features
self._drop_rate = drop
self._create_linear(
"fc1", self.in_features, int(self.in_features * self.hidden_multiplier)
)
self._create_linear(
"fc2", int(self.in_features * self.hidden_multiplier), self.out_features
)
self.act = act_layer()
self.drop = nn.Dropout(drop or 0.0)
self.reset_parameters()
@property
def in_features(self):
return spaces.get_max(self._in_features)
@property
def hidden_multiplier(self):
return spaces.get_max(self._hidden_multiplier)
@property
def out_features(self):
return spaces.get_max(self._out_features)
def _create_linear(self, name, inC, outC):
self.register_parameter(
"{:}_super_weight".format(name), torch.nn.Parameter(torch.Tensor(outC, inC))
)
self.register_parameter(
"{:}_super_bias".format(name), torch.nn.Parameter(torch.Tensor(outC))
)
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.fc1_super_weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.fc2_super_weight, a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.fc1_super_weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.fc1_super_bias, -bound, bound)
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.fc2_super_weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.fc2_super_bias, -bound, bound)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
if not spaces.is_determined(self._in_features):
root_node.append(
"_in_features", self._in_features.abstract(reuse_last=True)
)
if not spaces.is_determined(self._hidden_multiplier):
root_node.append(
"_hidden_multiplier", self._hidden_multiplier.abstract(reuse_last=True)
)
if not spaces.is_determined(self._out_features):
root_node.append(
"_out_features", self._out_features.abstract(reuse_last=True)
)
return root_node
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check inputs ->
if not spaces.is_determined(self._in_features):
expected_input_dim = self.abstract_child["_in_features"].value
else:
expected_input_dim = spaces.get_determined_value(self._in_features)
if input.size(-1) != expected_input_dim:
raise ValueError(
"Expect the input dim of {:} instead of {:}".format(
expected_input_dim, input.size(-1)
)
)
# create the weight and bias matrix for fc1
if not spaces.is_determined(self._hidden_multiplier):
hmul = self.abstract_child["_hidden_multiplier"].value * expected_input_dim
else:
hmul = spaces.get_determined_value(self._hidden_multiplier)
hidden_dim = int(expected_input_dim * hmul)
_fc1_weight = self.fc1_super_weight[:hidden_dim, :expected_input_dim]
_fc1_bias = self.fc1_super_bias[:hidden_dim]
x = F.linear(input, _fc1_weight, _fc1_bias)
x = self.act(x)
x = self.drop(x)
# create the weight and bias matrix for fc2
if not spaces.is_determined(self._out_features):
out_dim = self.abstract_child["_out_features"].value
else:
out_dim = spaces.get_determined_value(self._out_features)
_fc2_weight = self.fc2_super_weight[:out_dim, :hidden_dim]
_fc2_bias = self.fc2_super_bias[:out_dim]
x = F.linear(x, _fc2_weight, _fc2_bias)
x = self.drop(x)
return x
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
x = F.linear(input, self.fc1_super_weight, self.fc1_super_bias)
x = self.act(x)
x = self.drop(x)
x = F.linear(x, self.fc2_super_weight, self.fc2_super_bias)
x = self.drop(x)
return x
def extra_repr(self) -> str:
return "in_features={:}, hidden_multiplier={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format(
self._in_features,
self._hidden_multiplier,
self._out_features,
self._drop_rate,
)

View File

@@ -0,0 +1,227 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import os
from pathlib import Path
import abc
import tempfile
import warnings
from typing import Optional, Union, Callable
import torch
import torch.nn as nn
from enum import Enum
from xautodl import spaces
from .super_utils import IntSpaceType, BoolSpaceType
from .super_utils import LayerOrder, SuperRunMode
from .super_utils import TensorContainer
from .super_utils import ShapeContainer
BEST_DIR_KEY = "best_model_dir"
BEST_NAME_KEY = "best_model_name"
BEST_SCORE_KEY = "best_model_score"
ENABLE_CANDIDATE = 0
DISABLE_CANDIDATE = 1
class SuperModule(abc.ABC, nn.Module):
"""This class equips the nn.Module class with the ability to apply AutoDL."""
def __init__(self):
super(SuperModule, self).__init__()
self._super_run_type = SuperRunMode.Default
self._abstract_child = None
self._verbose = False
self._meta_info = {}
self._candidate_mode = DISABLE_CANDIDATE
def set_super_run_type(self, super_run_type):
def _reset_super_run(m):
if isinstance(m, SuperModule):
m._super_run_type = super_run_type
self.apply(_reset_super_run)
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
if not isinstance(module, SuperModule):
warnings.warn(
"Add {:}:{:} module, which is not SuperModule, into {:}".format(
name, module.__class__.__name__, self.__class__.__name__
)
+ "\n"
+ "It may cause some functions invalid."
)
super(SuperModule, self).add_module(name, module)
def apply_verbose(self, verbose):
def _reset_verbose(m):
if isinstance(m, SuperModule):
m._verbose = verbose
self.apply(_reset_verbose)
def apply_candidate(self, abstract_child):
if not isinstance(abstract_child, spaces.VirtualNode):
raise ValueError(
"Invalid abstract child program: {:}".format(abstract_child)
)
self._abstract_child = abstract_child
def enable_candidate(self):
def _enable_candidate(m):
if isinstance(m, SuperModule):
m._candidate_mode = ENABLE_CANDIDATE
self.apply(_enable_candidate)
def disable_candidate(self):
def _disable_candidate(m):
if isinstance(m, SuperModule):
m._candidate_mode = DISABLE_CANDIDATE
self.apply(_disable_candidate)
def get_w_container(self):
container = TensorContainer()
for name, param in self.named_parameters():
container.append(name, param, True)
for name, buf in self.named_buffers():
container.append(name, buf, False)
return container
def analyze_weights(self):
with torch.no_grad():
for name, param in self.named_parameters():
shapestr = "[{:10s}] shape={:}".format(name, list(param.shape))
finalstr = shapestr + "{:.2f} +- {:.2f}".format(
param.mean(), param.std()
)
print(finalstr)
def numel(self, buffer=True):
total = 0
for name, param in self.named_parameters():
total += param.numel()
if buffer:
for name, buf in self.named_buffers():
total += buf.numel()
return total
def set_best_dir(self, xdir):
self._meta_info[BEST_DIR_KEY] = str(xdir)
Path(xdir).mkdir(parents=True, exist_ok=True)
def set_best_name(self, xname):
self._meta_info[BEST_NAME_KEY] = str(xname)
def save_best(self, score):
if BEST_DIR_KEY not in self._meta_info:
tempdir = tempfile.mkdtemp("-xlayers")
self._meta_info[BEST_DIR_KEY] = tempdir
if BEST_SCORE_KEY not in self._meta_info:
self._meta_info[BEST_SCORE_KEY] = None
best_score = self._meta_info[BEST_SCORE_KEY]
if best_score is None or best_score <= score:
best_save_name = self._meta_info.get(
BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
)
best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
self._meta_info[BEST_SCORE_KEY] = score
torch.save(self.state_dict(), best_save_path)
return True, self._meta_info[BEST_SCORE_KEY]
else:
return False, self._meta_info[BEST_SCORE_KEY]
def load_best(self, best_save_name=None):
if BEST_DIR_KEY not in self._meta_info:
raise ValueError("Please set BEST_DIR_KEY at first")
if best_save_name is None:
best_save_name = self._meta_info.get(
BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
)
best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
state_dict = torch.load(best_save_path)
self.load_state_dict(state_dict)
def has_best(self, best_name=None):
if BEST_DIR_KEY not in self._meta_info:
raise ValueError("Please set BEST_DIR_KEY at first")
if best_name is None:
best_save_name = self._meta_info.get(
BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
)
else:
best_save_name = best_name
best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
return os.path.exists(best_save_path)
@property
def abstract_search_space(self):
raise NotImplementedError
@property
def super_run_type(self):
return self._super_run_type
@property
def abstract_child(self):
return self._abstract_child
@property
def verbose(self):
return self._verbose
@abc.abstractmethod
def forward_raw(self, *inputs):
"""Use the largest candidate for forward. Similar to the original PyTorch model."""
raise NotImplementedError
@abc.abstractmethod
def forward_candidate(self, *inputs):
raise NotImplementedError
@property
def name_with_id(self):
return "name={:}, id={:}".format(self.__class__.__name__, id(self))
def get_shape_str(self, tensors):
if isinstance(tensors, (list, tuple)):
shapes = [self.get_shape_str(tensor) for tensor in tensors]
if len(shapes) == 1:
return shapes[0]
else:
return ", ".join(shapes)
elif isinstance(tensors, (torch.Tensor, nn.Parameter)):
return str(tuple(tensors.shape))
else:
raise TypeError("Invalid input type: {:}.".format(type(tensors)))
def forward(self, *inputs):
if self.verbose:
print(
"[{:}] inputs shape: {:}".format(
self.name_with_id, self.get_shape_str(inputs)
)
)
if self.super_run_type == SuperRunMode.FullModel:
outputs = self.forward_raw(*inputs)
elif self.super_run_type == SuperRunMode.Candidate:
if self._candidate_mode == DISABLE_CANDIDATE:
raise ValueError("candidate mode is disabled")
outputs = self.forward_candidate(*inputs)
else:
raise ValueError(
"Unknown Super Model Run Mode: {:}".format(self.super_run_type)
)
if self.verbose:
print(
"[{:}] outputs shape: {:}".format(
self.name_with_id, self.get_shape_str(outputs)
)
)
return outputs
def forward_with_container(self, inputs, container, prefix=[]):
raise NotImplementedError

View File

@@ -0,0 +1,224 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Callable
from xautodl import spaces
from .super_module import SuperModule
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
class SuperLayerNorm1D(SuperModule):
"""Super Layer Norm."""
def __init__(
self, dim: IntSpaceType, eps: float = 1e-6, elementwise_affine: bool = True
) -> None:
super(SuperLayerNorm1D, self).__init__()
self._in_dim = dim
self._eps = eps
self._elementwise_affine = elementwise_affine
if self._elementwise_affine:
self.register_parameter("weight", nn.Parameter(torch.Tensor(self.in_dim)))
self.register_parameter("bias", nn.Parameter(torch.Tensor(self.in_dim)))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
@property
def in_dim(self):
return spaces.get_max(self._in_dim)
@property
def eps(self):
return self._eps
def reset_parameters(self) -> None:
if self._elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
if not spaces.is_determined(self._in_dim):
root_node.append("_in_dim", self._in_dim.abstract(reuse_last=True))
return root_node
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check inputs ->
if not spaces.is_determined(self._in_dim):
expected_input_dim = self.abstract_child["_in_dim"].value
else:
expected_input_dim = spaces.get_determined_value(self._in_dim)
if input.size(-1) != expected_input_dim:
raise ValueError(
"Expect the input dim of {:} instead of {:}".format(
expected_input_dim, input.size(-1)
)
)
if self._elementwise_affine:
weight = self.weight[:expected_input_dim]
bias = self.bias[:expected_input_dim]
else:
weight, bias = None, None
return F.layer_norm(input, (expected_input_dim,), weight, bias, self.eps)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return F.layer_norm(input, (self.in_dim,), self.weight, self.bias, self.eps)
def forward_with_container(self, input, container, prefix=[]):
super_weight_name = ".".join(prefix + ["weight"])
if container.has(super_weight_name):
weight = container.query(super_weight_name)
else:
weight = None
super_bias_name = ".".join(prefix + ["bias"])
if container.has(super_bias_name):
bias = container.query(super_bias_name)
else:
bias = None
return F.layer_norm(input, (self.in_dim,), weight, bias, self.eps)
def extra_repr(self) -> str:
return (
"shape={in_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format(
in_dim=self._in_dim,
eps=self._eps,
elementwise_affine=self._elementwise_affine,
)
)
class SuperSimpleNorm(SuperModule):
"""Super simple normalization."""
def __init__(self, mean, std, inplace=False) -> None:
super(SuperSimpleNorm, self).__init__()
self.register_buffer("_mean", torch.tensor(mean, dtype=torch.float))
self.register_buffer("_std", torch.tensor(std, dtype=torch.float))
self._inplace = inplace
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check inputs ->
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
if not self._inplace:
tensor = input.clone()
else:
tensor = input
mean = torch.as_tensor(self._mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(self._std, dtype=tensor.dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(
"std evaluated to zero after conversion to {}, leading to division by zero.".format(
tensor.dtype
)
)
while mean.ndim < tensor.ndim:
mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
return tensor.sub_(mean).div_(std)
def extra_repr(self) -> str:
return "mean={mean}, std={std}, inplace={inplace}".format(
mean=self._mean.item(), std=self._std.item(), inplace=self._inplace
)
class SuperSimpleLearnableNorm(SuperModule):
"""Super simple normalization."""
def __init__(self, mean=0, std=1, eps=1e-6, inplace=False) -> None:
super(SuperSimpleLearnableNorm, self).__init__()
self.register_parameter(
"_mean", nn.Parameter(torch.tensor(mean, dtype=torch.float))
)
self.register_parameter(
"_std", nn.Parameter(torch.tensor(std, dtype=torch.float))
)
self._eps = eps
self._inplace = inplace
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check inputs ->
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
if not self._inplace:
tensor = input.clone()
else:
tensor = input
mean, std = (
self._mean.to(tensor.device),
torch.abs(self._std.to(tensor.device)) + self._eps,
)
if (std == 0).any():
raise ValueError("std leads to division by zero.")
while mean.ndim < tensor.ndim:
mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
return tensor.sub_(mean).div_(std)
def forward_with_container(self, input, container, prefix=[]):
if not self._inplace:
tensor = input.clone()
else:
tensor = input
mean_name = ".".join(prefix + ["_mean"])
std_name = ".".join(prefix + ["_std"])
mean, std = (
container.query(mean_name).to(tensor.device),
torch.abs(container.query(std_name).to(tensor.device)) + self._eps,
)
while mean.ndim < tensor.ndim:
mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
return tensor.sub_(mean).div_(std)
def extra_repr(self) -> str:
return "mean={mean}, std={std}, inplace={inplace}".format(
mean=self._mean.item(), std=self._std.item(), inplace=self._inplace
)
class SuperIdentity(SuperModule):
"""Super identity mapping layer."""
def __init__(self, inplace=False, **kwargs) -> None:
super(SuperIdentity, self).__init__()
self._inplace = inplace
@property
def abstract_search_space(self):
return spaces.VirtualNode(id(self))
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check inputs ->
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
if not self._inplace:
tensor = input.clone()
else:
tensor = input
return tensor
def extra_repr(self) -> str:
return "inplace={inplace}".format(inplace=self._inplace)
def forward_with_container(self, input, container, prefix=[]):
return self.forward_raw(input)

View File

@@ -0,0 +1,105 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
#####################################################
import torch
import torch.nn as nn
import math
from xautodl import spaces
from .super_module import SuperModule
from .super_module import IntSpaceType
class SuperDynamicPositionE(SuperModule):
"""Applies a positional encoding to the input positions."""
def __init__(self, dimension: int, scale: float = 1.0) -> None:
super(SuperDynamicPositionE, self).__init__()
self._scale = scale
self._dimension = dimension
# weights to be optimized
self.register_buffer(
"_div_term",
torch.exp(
torch.arange(0, dimension, 2).float() * (-math.log(10000.0) / dimension)
),
)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
return root_node
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
positions = torch.unsqueeze(input * self._scale, dim=-1)
divisions = torch.reshape(
self._div_term, [1] * input.ndim + [self._div_term.numel()]
)
values = positions / divisions
embeds = torch.cat((torch.sin(values), torch.cos(values)), dim=-1)
return embeds
def extra_repr(self) -> str:
return "scale={:}, dim={:}".format(self._scale, self._dimension)
class SuperPositionalEncoder(SuperModule):
"""Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65
"""
def __init__(self, d_model: IntSpaceType, max_seq_len: int, dropout: float = 0.1):
super(SuperPositionalEncoder, self).__init__()
self._d_model = d_model
# create constant 'pe' matrix with values dependant on
# pos and i
self.dropout = nn.Dropout(p=dropout)
self.register_buffer("pe", self.create_pos_embed(max_seq_len, self.d_model))
@property
def d_model(self):
return spaces.get_max(self._d_model)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
if not spaces.is_determined(self._d_model):
root_node.append("_d_model", self._d_model.abstract(reuse_last=True))
return root_node
def create_pos_embed(self, max_seq_len, d_model):
pe = torch.zeros(max_seq_len, d_model)
for pos in range(max_seq_len):
for i in range(0, d_model):
div = 10000 ** ((i // 2) * 2 / d_model)
value = pos / div
if i % 2 == 0:
pe[pos, i] = math.sin(value)
else:
pe[pos, i] = math.cos(value)
return pe.unsqueeze(0)
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
batch, seq, fdim = input.shape[:3]
embeddings = self.pe[:, :seq]
if not spaces.is_determined(self._d_model):
expected_d_model = self.abstract_child["_d_model"].value
else:
expected_d_model = spaces.get_determined_value(self._d_model)
assert fdim == expected_d_model, "{:} vs {:}".format(fdim, expected_d_model)
embeddings = torch.nn.functional.interpolate(
embeddings, size=(expected_d_model), mode="linear", align_corners=True
)
outs = self.dropout(input + embeddings)
return outs
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
batch, seq, fdim = input.shape[:3]
embeddings = self.pe[:, :seq]
outs = self.dropout(input + embeddings)
return outs

View File

@@ -0,0 +1,187 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#############################################################
# Borrow the idea of https://github.com/arogozhnikov/einops #
#############################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import itertools
import functools
from collections import OrderedDict
from typing import Optional, Callable
from xautodl import spaces
from .misc_utils import ParsedExpression, AnonymousAxis
from .super_module import SuperModule
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
class SuperReArrange(SuperModule):
"""Applies the rearrange operation."""
def __init__(self, pattern, **axes_lengths):
super(SuperReArrange, self).__init__()
self._pattern = pattern
self._axes_lengths = axes_lengths
axes_lengths = tuple(sorted(self._axes_lengths.items()))
# Perform initial parsing of pattern and provided supplementary info
# axes_lengths is a tuple of tuples (axis_name, axis_length)
left, right = pattern.split("->")
left = ParsedExpression(left)
right = ParsedExpression(right)
difference = set.symmetric_difference(left.identifiers, right.identifiers)
if difference:
raise ValueError(
"Identifiers only on one side of expression (should be on both): {}".format(
difference
)
)
# parsing all dimensions to find out lengths
axis_name2known_length = OrderedDict()
for composite_axis in left.composition:
for axis_name in composite_axis:
if isinstance(axis_name, AnonymousAxis):
axis_name2known_length[axis_name] = axis_name.value
else:
axis_name2known_length[axis_name] = None
for axis_name in right.identifiers:
if axis_name not in axis_name2known_length:
if isinstance(axis_name, AnonymousAxis):
axis_name2known_length[axis_name] = axis_name.value
else:
axis_name2known_length[axis_name] = None
axis_name2position = {
name: position for position, name in enumerate(axis_name2known_length)
}
for elementary_axis, axis_length in axes_lengths:
if not ParsedExpression.check_axis_name(elementary_axis):
raise ValueError("Invalid name for an axis", elementary_axis)
if elementary_axis not in axis_name2known_length:
raise ValueError(
"Axis {} is not used in transform".format(elementary_axis)
)
axis_name2known_length[elementary_axis] = axis_length
input_composite_axes = []
# some of shapes will be inferred later - all information is prepared for faster inference
for composite_axis in left.composition:
known = {
axis
for axis in composite_axis
if axis_name2known_length[axis] is not None
}
unknown = {
axis for axis in composite_axis if axis_name2known_length[axis] is None
}
if len(unknown) > 1:
raise ValueError("Could not infer sizes for {}".format(unknown))
assert len(unknown) + len(known) == len(composite_axis)
input_composite_axes.append(
(
[axis_name2position[axis] for axis in known],
[axis_name2position[axis] for axis in unknown],
)
)
axis_position_after_reduction = {}
for axis_name in itertools.chain(*left.composition):
if axis_name in right.identifiers:
axis_position_after_reduction[axis_name] = len(
axis_position_after_reduction
)
result_axes_grouping = []
for composite_axis in right.composition:
result_axes_grouping.append(
[axis_name2position[axis] for axis in composite_axis]
)
ordered_axis_right = list(itertools.chain(*right.composition))
axes_permutation = tuple(
axis_position_after_reduction[axis]
for axis in ordered_axis_right
if axis in left.identifiers
)
#
self.input_composite_axes = input_composite_axes
self.output_composite_axes = result_axes_grouping
self.elementary_axes_lengths = list(axis_name2known_length.values())
self.axes_permutation = axes_permutation
@functools.lru_cache(maxsize=1024)
def reconstruct_from_shape(self, shape):
if len(shape) != len(self.input_composite_axes):
raise ValueError(
"Expected {} dimensions, got {}".format(
len(self.input_composite_axes), len(shape)
)
)
axes_lengths = list(self.elementary_axes_lengths)
for input_axis, (known_axes, unknown_axes) in enumerate(
self.input_composite_axes
):
length = shape[input_axis]
known_product = 1
for axis in known_axes:
known_product *= axes_lengths[axis]
if len(unknown_axes) == 0:
if (
isinstance(length, int)
and isinstance(known_product, int)
and length != known_product
):
raise ValueError(
"Shape mismatch, {} != {}".format(length, known_product)
)
else:
if (
isinstance(length, int)
and isinstance(known_product, int)
and length % known_product != 0
):
raise ValueError(
"Shape mismatch, can't divide axis of length {} in chunks of {}".format(
length, known_product
)
)
(unknown_axis,) = unknown_axes
axes_lengths[unknown_axis] = length // known_product
# at this point all axes_lengths are computed (either have values or variables, but not Nones)
final_shape = []
for output_axis, grouping in enumerate(self.output_composite_axes):
lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping]
final_shape.append(int(np.prod(lengths)))
axes_reordering = self.axes_permutation
return axes_lengths, axes_reordering, final_shape
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
return root_node
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
init_shape, axes_reordering, final_shape = self.reconstruct_from_shape(
tuple(input.shape)
)
tensor = torch.reshape(input, init_shape)
tensor = tensor.permute(axes_reordering)
tensor = torch.reshape(tensor, final_shape)
return tensor
def extra_repr(self) -> str:
params = repr(self._pattern)
for axis, length in self._axes_lengths.items():
params += ", {}={}".format(axis, length)
return "{:}".format(params)

View File

@@ -0,0 +1,58 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from xautodl import spaces
from .super_linear import SuperLinear
from .super_module import SuperModule
from .super_module import IntSpaceType
class SuperAlphaEBDv1(SuperModule):
"""A simple layer to convert the raw trading data from 1-D to 2-D data and apply an FC layer."""
def __init__(self, d_feat: int, embed_dim: IntSpaceType):
super(SuperAlphaEBDv1, self).__init__()
self._d_feat = d_feat
self._embed_dim = embed_dim
self.proj = SuperLinear(d_feat, embed_dim)
@property
def embed_dim(self):
return spaces.get_max(self._embed_dim)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
space = self.proj.abstract_search_space
if not spaces.is_determined(space):
root_node.append("proj", space)
if not spaces.is_determined(self._embed_dim):
root_node.append("_embed_dim", self._embed_dim.abstract(reuse_last=True))
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperAlphaEBDv1, self).apply_candidate(abstract_child)
if "proj" in abstract_child:
self.proj.apply_candidate(abstract_child["proj"])
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
x = input.reshape(len(input), self._d_feat, -1) # [N, F*T] -> [N, F, T]
x = x.permute(0, 2, 1) # [N, F, T] -> [N, T, F]
if not spaces.is_determined(self._embed_dim):
embed_dim = self.abstract_child["_embed_dim"].value
else:
embed_dim = spaces.get_determined_value(self._embed_dim)
out = self.proj(x) * math.sqrt(embed_dim)
return out
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
x = input.reshape(len(input), self._d_feat, -1) # [N, F*T] -> [N, F, T]
x = x.permute(0, 2, 1) # [N, F, T] -> [N, T, F]
out = self.proj(x) * math.sqrt(self.embed_dim)
return out

View File

@@ -0,0 +1,127 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
from typing import Optional, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from xautodl import spaces
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
from .super_module import LayerOrder
from .super_module import SuperModule
from .super_linear import SuperMLPv2
from .super_norm import SuperLayerNorm1D
from .super_attention import SuperSelfAttention
class SuperTransformerEncoderLayer(SuperModule):
"""TransformerEncoderLayer is made up of self-attn and feedforward network.
This is a super model for TransformerEncoderLayer that can support search for the transformer encoder layer.
Reference:
- Paper: Attention Is All You Need, NeurIPS 2017
- PyTorch Implementation: https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
Details:
the original post-norm version: MHA -> residual -> norm -> MLP -> residual -> norm
the pre-norm version: norm -> MHA -> residual -> norm -> MLP -> residual
"""
def __init__(
self,
d_model: IntSpaceType,
num_heads: IntSpaceType,
qkv_bias: BoolSpaceType = False,
mlp_hidden_multiplier: IntSpaceType = 4,
dropout: Optional[float] = None,
att_dropout: Optional[float] = None,
norm_affine: bool = True,
act_layer: Callable[[], nn.Module] = nn.GELU,
order: LayerOrder = LayerOrder.PreNorm,
use_mask: bool = False,
):
super(SuperTransformerEncoderLayer, self).__init__()
mha = SuperSelfAttention(
d_model,
d_model,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=att_dropout,
proj_drop=None,
use_mask=use_mask,
)
mlp = SuperMLPv2(
d_model,
hidden_multiplier=mlp_hidden_multiplier,
out_features=d_model,
act_layer=act_layer,
drop=dropout,
)
if order is LayerOrder.PreNorm:
self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mha = mha
self.drop = nn.Dropout(dropout or 0.0)
self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mlp = mlp
elif order is LayerOrder.PostNorm:
self.mha = mha
self.drop1 = nn.Dropout(dropout or 0.0)
self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mlp = mlp
self.drop2 = nn.Dropout(dropout or 0.0)
self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
else:
raise ValueError("Unknown order: {:}".format(order))
self._order = order
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
xdict = dict(
mha=self.mha.abstract_search_space,
norm1=self.norm1.abstract_search_space,
mlp=self.mlp.abstract_search_space,
norm2=self.norm2.abstract_search_space,
)
for key, space in xdict.items():
if not spaces.is_determined(space):
root_node.append(key, space)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperTransformerEncoderLayer, self).apply_candidate(abstract_child)
valid_keys = ["mha", "norm1", "mlp", "norm2"]
for key in valid_keys:
if key in abstract_child:
getattr(self, key).apply_candidate(abstract_child[key])
def forward_candidate(self, inputs: torch.Tensor) -> torch.Tensor:
return self.forward_raw(inputs)
def forward_raw(self, inputs: torch.Tensor) -> torch.Tensor:
if self._order is LayerOrder.PreNorm:
# https://github.com/google-research/vision_transformer/blob/master/vit_jax/models.py#L135
x = self.norm1(inputs)
x = self.mha(x)
x = self.drop(x)
x = x + inputs
# feed-forward layer -- MLP
y = self.norm2(x)
outs = x + self.mlp(y)
elif self._order is LayerOrder.PostNorm:
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoder
# multi-head attention
x = self.mha(inputs)
x = inputs + self.drop1(x)
x = self.norm1(x)
# feed-forward layer -- MLP
y = self.mlp(x)
y = x + self.drop2(y)
outs = self.norm2(y)
else:
raise ValueError("Unknown order: {:}".format(self._order))
return outs

View File

@@ -0,0 +1,222 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import abc
import warnings
from typing import Optional, Union, Callable
import torch
import torch.nn as nn
from enum import Enum
from xautodl import spaces
IntSpaceType = Union[int, spaces.Integer, spaces.Categorical]
BoolSpaceType = Union[bool, spaces.Categorical]
class LayerOrder(Enum):
"""This class defines the enumerations for order of operation in a residual or normalization-based layer."""
PreNorm = "pre-norm"
PostNorm = "post-norm"
class SuperRunMode(Enum):
"""This class defines the enumerations for Super Model Running Mode."""
FullModel = "fullmodel"
Candidate = "candidate"
Default = "fullmodel"
class ShapeContainer:
"""A class to maintain the shape of each weight tensor for a model."""
def __init__(self):
self._names = []
self._shapes = []
self._name2index = dict()
self._param_or_buffers = []
@property
def shapes(self):
return self._shapes
def __getitem__(self, index):
return self._shapes[index]
def translate(self, tensors, all_none_match=True):
result = TensorContainer()
for index, name in enumerate(self._names):
cur_num = tensors[index].numel()
expected_num = self._shapes[index].numel()
if cur_num < expected_num or (
cur_num > expected_num and not all_none_match
):
raise ValueError("Invalid {:} vs {:}".format(cur_num, expected_num))
cur_tensor = tensors[index].view(-1)[:expected_num]
new_tensor = torch.reshape(cur_tensor, self._shapes[index])
result.append(name, new_tensor, self._param_or_buffers[index])
return result
def append(self, name, shape, param_or_buffer):
if not isinstance(shape, torch.Size):
raise TypeError(
"The input tensor must be torch.Size instead of {:}".format(type(shape))
)
self._names.append(name)
self._shapes.append(shape)
self._param_or_buffers.append(param_or_buffer)
assert name not in self._name2index, "The [{:}] has already been added.".format(
name
)
self._name2index[name] = len(self._names) - 1
def query(self, name):
if not self.has(name):
raise ValueError(
"The {:} is not in {:}".format(name, list(self._name2index.keys()))
)
index = self._name2index[name]
return self._shapes[index]
def has(self, name):
return name in self._name2index
def has_prefix(self, prefix):
for name, idx in self._name2index.items():
if name.startswith(prefix):
return name
return False
def numel(self, index=None):
if index is None:
shapes = self._shapes
else:
shapes = [self._shapes[index]]
total = 0
for shape in shapes:
total += shape.numel()
return total
def __len__(self):
return len(self._names)
def __repr__(self):
return "{name}({num} tensors)".format(
name=self.__class__.__name__, num=len(self)
)
class TensorContainer:
"""A class to maintain both parameters and buffers for a model."""
def __init__(self):
self._names = []
self._tensors = []
self._param_or_buffers = []
self._name2index = dict()
def additive(self, tensors):
result = TensorContainer()
for index, name in enumerate(self._names):
new_tensor = self._tensors[index] + tensors[index]
result.append(name, new_tensor, self._param_or_buffers[index])
return result
def create_container(self, tensors):
result = TensorContainer()
for index, name in enumerate(self._names):
new_tensor = tensors[index]
result.append(name, new_tensor, self._param_or_buffers[index])
return result
def no_grad_clone(self):
result = TensorContainer()
with torch.no_grad():
for index, name in enumerate(self._names):
result.append(
name, self._tensors[index].clone(), self._param_or_buffers[index]
)
return result
def to_shape_container(self):
result = ShapeContainer()
for index, name in enumerate(self._names):
result.append(
name, self._tensors[index].shape, self._param_or_buffers[index]
)
return result
def requires_grad_(self, requires_grad=True):
for tensor in self._tensors:
tensor.requires_grad_(requires_grad)
def parameters(self):
return self._tensors
@property
def tensors(self):
return self._tensors
def flatten(self, tensors=None):
if tensors is None:
tensors = self._tensors
tensors = [tensor.view(-1) for tensor in tensors]
return torch.cat(tensors)
def unflatten(self, tensor):
tensors, s = [], 0
for raw_tensor in self._tensors:
length = raw_tensor.numel()
x = torch.reshape(tensor[s : s + length], shape=raw_tensor.shape)
tensors.append(x)
s += length
return tensors
def append(self, name, tensor, param_or_buffer):
if not isinstance(tensor, torch.Tensor):
raise TypeError(
"The input tensor must be torch.Tensor instead of {:}".format(
type(tensor)
)
)
self._names.append(name)
self._tensors.append(tensor)
self._param_or_buffers.append(param_or_buffer)
assert name not in self._name2index, "The [{:}] has already been added.".format(
name
)
self._name2index[name] = len(self._names) - 1
def query(self, name):
if not self.has(name):
raise ValueError(
"The {:} is not in {:}".format(name, list(self._name2index.keys()))
)
index = self._name2index[name]
return self._tensors[index]
def has(self, name):
return name in self._name2index
def has_prefix(self, prefix):
for name, idx in self._name2index.items():
if name.startswith(prefix):
return name
return False
def numel(self):
total = 0
for tensor in self._tensors:
total += tensor.numel()
return total
def __len__(self):
return len(self._names)
def __repr__(self):
return "{name}({num} tensors)".format(
name=self.__class__.__name__, num=len(self)
)

View File

@@ -0,0 +1,84 @@
# Borrowed from https://github.com/rwightman/pytorch-image-models
import torch
import torch.nn as nn
import math
import warnings
# setup for xlayers
from . import super_core
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
if isinstance(tensor, list):
return [_no_grad_trunc_normal_(x, mean, std, a, b) for x in tensor]
else:
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def init_transformer(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, super_core.SuperLinear):
trunc_normal_(m._super_weight, std=0.02)
if m._super_bias is not None:
nn.init.constant_(m._super_bias, 0)
elif isinstance(m, super_core.SuperLayerNorm1D):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)