Update SuperMLP

This commit is contained in:
D-X-Y
2021-03-19 23:57:23 +08:00
parent 31b8122cc1
commit 0c56a729ad
13 changed files with 412 additions and 85 deletions

View File

@@ -147,11 +147,18 @@ class SuperMLP(SuperModule):
root_node.append("fc2", space_fc2)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperMLP, self).apply_candidate(abstract_child)
if "fc1" in abstract_child:
self.fc1.apply_candidate(abstract_child["fc1"])
if "fc2" in abstract_child:
self.fc2.apply_candidate(abstract_child["fc2"])
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self._unified_forward(x)
return self._unified_forward(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return self._unified_forward(x)
return self._unified_forward(input)
def _unified_forward(self, x):
x = self.fc1(x)

View File

@@ -32,7 +32,7 @@ class SuperModule(abc.ABC, nn.Module):
self.apply(_reset_super_run)
def apply_candiate(self, abstract_child):
def apply_candidate(self, abstract_child):
if not isinstance(abstract_child, spaces.VirtualNode):
raise ValueError(
"Invalid abstract child program: {:}".format(abstract_child)