update GDAS
This commit is contained in:
@@ -81,13 +81,21 @@ class TinyNetworkGDAS(nn.Module):
|
||||
return Structure( genotypes )
|
||||
|
||||
def forward(self, inputs):
|
||||
while True:
|
||||
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
|
||||
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): continue
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_gdas(feature, self.arch_parameters, self.tau)
|
||||
feature = cell.forward_gdas(feature, hardwts, index)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
|
Reference in New Issue
Block a user