Update SuperMLP

This commit is contained in:
D-X-Y
2021-03-19 23:57:23 +08:00
parent 31b8122cc1
commit 0c56a729ad
13 changed files with 412 additions and 85 deletions

View File

@@ -147,11 +147,18 @@ class SuperMLP(SuperModule):
root_node.append("fc2", space_fc2)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperMLP, self).apply_candidate(abstract_child)
if "fc1" in abstract_child:
self.fc1.apply_candidate(abstract_child["fc1"])
if "fc2" in abstract_child:
self.fc2.apply_candidate(abstract_child["fc2"])
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self._unified_forward(x)
return self._unified_forward(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return self._unified_forward(x)
return self._unified_forward(input)
def _unified_forward(self, x):
x = self.fc1(x)