Update SuperMLP
This commit is contained in:
@@ -147,11 +147,18 @@ class SuperMLP(SuperModule):
|
||||
root_node.append("fc2", space_fc2)
|
||||
return root_node
|
||||
|
||||
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
||||
super(SuperMLP, self).apply_candidate(abstract_child)
|
||||
if "fc1" in abstract_child:
|
||||
self.fc1.apply_candidate(abstract_child["fc1"])
|
||||
if "fc2" in abstract_child:
|
||||
self.fc2.apply_candidate(abstract_child["fc2"])
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self._unified_forward(x)
|
||||
return self._unified_forward(input)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self._unified_forward(x)
|
||||
return self._unified_forward(input)
|
||||
|
||||
def _unified_forward(self, x):
|
||||
x = self.fc1(x)
|
||||
|
Reference in New Issue
Block a user