update GDAS

This commit is contained in:
D-X-Y
2019-11-19 14:36:42 +11:00
parent 09d68c6375
commit b6c0828382
12 changed files with 3 additions and 1039 deletions

View File

@@ -88,7 +88,9 @@ class TinyNetworkGDAS(nn.Module):
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): continue
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
continue
else: break
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):