Update xmisc

This commit is contained in:
D-X-Y
2021-06-10 23:42:00 -07:00
parent 98f981dd45
commit 248686820c
8 changed files with 72 additions and 487 deletions

View File

@@ -201,7 +201,6 @@ class SuperMLPv2(SuperModule):
self._hidden_multiplier = hidden_multiplier
self._out_features = out_features
self._drop_rate = drop
self._params = nn.ParameterDict({})
self._create_linear(
"fc1", self.in_features, int(self.in_features * self.hidden_multiplier)
@@ -226,26 +225,22 @@ class SuperMLPv2(SuperModule):
return spaces.get_max(self._out_features)
def _create_linear(self, name, inC, outC):
self._params["{:}_super_weight".format(name)] = torch.nn.Parameter(
torch.Tensor(outC, inC)
self.register_parameter(
"{:}_super_weight".format(name), torch.nn.Parameter(torch.Tensor(outC, inC))
)
self._params["{:}_super_bias".format(name)] = torch.nn.Parameter(
torch.Tensor(outC)
self.register_parameter(
"{:}_super_bias".format(name), torch.nn.Parameter(torch.Tensor(outC))
)
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self._params["fc1_super_weight"], a=math.sqrt(5))
nn.init.kaiming_uniform_(self._params["fc2_super_weight"], a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
self._params["fc1_super_weight"]
)
nn.init.kaiming_uniform_(self.fc1_super_weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.fc2_super_weight, a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.fc1_super_weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self._params["fc1_super_bias"], -bound, bound)
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
self._params["fc2_super_weight"]
)
nn.init.uniform_(self.fc1_super_bias, -bound, bound)
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.fc2_super_weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self._params["fc2_super_bias"], -bound, bound)
nn.init.uniform_(self.fc2_super_bias, -bound, bound)
@property
def abstract_search_space(self):
@@ -282,8 +277,8 @@ class SuperMLPv2(SuperModule):
else:
hmul = spaces.get_determined_value(self._hidden_multiplier)
hidden_dim = int(expected_input_dim * hmul)
_fc1_weight = self._params["fc1_super_weight"][:hidden_dim, :expected_input_dim]
_fc1_bias = self._params["fc1_super_bias"][:hidden_dim]
_fc1_weight = self.fc1_super_weight[:hidden_dim, :expected_input_dim]
_fc1_bias = self.fc1_super_bias[:hidden_dim]
x = F.linear(input, _fc1_weight, _fc1_bias)
x = self.act(x)
x = self.drop(x)
@@ -292,21 +287,17 @@ class SuperMLPv2(SuperModule):
out_dim = self.abstract_child["_out_features"].value
else:
out_dim = spaces.get_determined_value(self._out_features)
_fc2_weight = self._params["fc2_super_weight"][:out_dim, :hidden_dim]
_fc2_bias = self._params["fc2_super_bias"][:out_dim]
_fc2_weight = self.fc2_super_weight[:out_dim, :hidden_dim]
_fc2_bias = self.fc2_super_bias[:out_dim]
x = F.linear(x, _fc2_weight, _fc2_bias)
x = self.drop(x)
return x
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
x = F.linear(
input, self._params["fc1_super_weight"], self._params["fc1_super_bias"]
)
x = F.linear(input, self.fc1_super_weight, self.fc1_super_bias)
x = self.act(x)
x = self.drop(x)
x = F.linear(
x, self._params["fc2_super_weight"], self._params["fc2_super_bias"]
)
x = F.linear(x, self.fc2_super_weight, self.fc2_super_bias)
x = self.drop(x)
return x