update scripts

This commit is contained in:
D-X-Y
2019-10-16 00:09:10 +11:00
parent 7b977c08ec
commit 6814816d5f
5 changed files with 13 additions and 24 deletions

View File

@@ -10,6 +10,7 @@ def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7):
while True: # a trick to avoid the gumbels bug
gumbels = -torch.empty_like(logits).exponential_().log()
new_logits = (logits + gumbels) / tau
#new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
probs = nn.functional.softmax(new_logits, dim=1)
if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break