update scripts
This commit is contained in:
@@ -10,6 +10,7 @@ def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7):
|
||||
while True: # a trick to avoid the gumbels bug
|
||||
gumbels = -torch.empty_like(logits).exponential_().log()
|
||||
new_logits = (logits + gumbels) / tau
|
||||
#new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
|
||||
probs = nn.functional.softmax(new_logits, dim=1)
|
||||
if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break
|
||||
|
||||
|
Reference in New Issue
Block a user