update GDAS
This commit is contained in:
@@ -88,7 +88,9 @@ class TinyNetworkGDAS(nn.Module):
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): continue
|
||||
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
|
||||
continue
|
||||
else: break
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
|
Reference in New Issue
Block a user