Complete Super Linear
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
from .super_module import SuperRunMode
|
||||
from .super_module import SuperModule
|
||||
from .super_mlp import SuperLinear
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
@@ -52,14 +53,15 @@ class SuperLinear(SuperModule):
|
||||
def bias(self):
|
||||
return spaces.has_categorical(self._bias, True)
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
if not spaces.is_determined(self._in_features):
|
||||
root_node.append("_in_features", self._in_features)
|
||||
root_node.append("_in_features", self._in_features.abstract())
|
||||
if not spaces.is_determined(self._out_features):
|
||||
root_node.append("_out_features", self._out_features)
|
||||
root_node.append("_out_features", self._out_features.abstract())
|
||||
if not spaces.is_determined(self._bias):
|
||||
root_node.append("_bias", self._bias)
|
||||
root_node.append("_bias", self._bias.abstract())
|
||||
return root_node
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
@@ -69,6 +71,37 @@ class SuperLinear(SuperModule):
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self._super_bias, -bound, bound)
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# check inputs ->
|
||||
if not spaces.is_determined(self._in_features):
|
||||
expected_input_dim = self.abstract_child["_in_features"].value
|
||||
else:
|
||||
expected_input_dim = spaces.get_determined_value(self._in_features)
|
||||
if input.size(-1) != expected_input_dim:
|
||||
raise ValueError(
|
||||
"Expect the input dim of {:} instead of {:}".format(
|
||||
expected_input_dim, input.size(-1)
|
||||
)
|
||||
)
|
||||
# create the weight matrix
|
||||
if not spaces.is_determined(self._out_features):
|
||||
out_dim = self.abstract_child["_out_features"].value
|
||||
else:
|
||||
out_dim = spaces.get_determined_value(self._out_features)
|
||||
candidate_weight = self._super_weight[:out_dim, :expected_input_dim]
|
||||
# create the bias matrix
|
||||
if not spaces.is_determined(self._bias):
|
||||
if self.abstract_child["_bias"].value:
|
||||
candidate_bias = self._super_bias[:out_dim]
|
||||
else:
|
||||
candidate_bias = None
|
||||
else:
|
||||
if spaces.get_determined_value(self._bias):
|
||||
candidate_bias = self._super_bias[:out_dim]
|
||||
else:
|
||||
candidate_bias = None
|
||||
return F.linear(input, candidate_weight, candidate_bias)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, self._super_weight, self._super_bias)
|
||||
|
||||
@@ -78,8 +111,9 @@ class SuperLinear(SuperModule):
|
||||
)
|
||||
|
||||
|
||||
class SuperMLP(nn.Module):
|
||||
# MLP: FC -> Activation -> Drop -> FC -> Drop
|
||||
class SuperMLP(SuperModule):
|
||||
"""An MLP layer: FC -> Activation -> Drop -> FC -> Drop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
@@ -88,13 +122,13 @@ class SuperMLP(nn.Module):
|
||||
act_layer=nn.GELU,
|
||||
drop: Optional[float] = None,
|
||||
):
|
||||
super(MLP, self).__init__()
|
||||
super(SuperMLP, self).__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop or 0)
|
||||
self.drop = nn.Dropout(drop or 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
|
@@ -6,11 +6,14 @@ import abc
|
||||
import torch.nn as nn
|
||||
from enum import Enum
|
||||
|
||||
import spaces
|
||||
|
||||
|
||||
class SuperRunMode(Enum):
|
||||
"""This class defines the enumerations for Super Model Running Mode."""
|
||||
|
||||
FullModel = "fullmodel"
|
||||
Candidate = "candidate"
|
||||
Default = "fullmodel"
|
||||
|
||||
|
||||
@@ -20,8 +23,23 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
def __init__(self):
|
||||
super(SuperModule, self).__init__()
|
||||
self._super_run_type = SuperRunMode.Default
|
||||
self._abstract_child = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_super_run_type(self, super_run_type):
|
||||
def _reset_super_run(m):
|
||||
if isinstance(m, SuperModule):
|
||||
m._super_run_type = super_run_type
|
||||
|
||||
self.apply(_reset_super_run)
|
||||
|
||||
def apply_candiate(self, abstract_child):
|
||||
if not isinstance(abstract_child, spaces.VirtualNode):
|
||||
raise ValueError(
|
||||
"Invalid abstract child program: {:}".format(abstract_child)
|
||||
)
|
||||
self._abstract_child = abstract_child
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -29,13 +47,24 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
def super_run_type(self):
|
||||
return self._super_run_type
|
||||
|
||||
@property
|
||||
def abstract_child(self):
|
||||
return self._abstract_child
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward_raw(self, *inputs):
|
||||
"""Use the largest candidate for forward. Similar to the original PyTorch model."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward_candidate(self, *inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *inputs):
|
||||
if self.super_run_type == SuperRunMode.FullModel:
|
||||
return self.forward_raw(*inputs)
|
||||
elif self.super_run_type == SuperRunMode.Candidate:
|
||||
return self.forward_candidate(*inputs)
|
||||
else:
|
||||
raise ModeError(
|
||||
"Unknown Super Model Run Mode: {:}".format(self.super_run_type)
|
||||
|
Reference in New Issue
Block a user