update GDAS (TO-FINISH)

This commit is contained in:
D-X-Y
2019-10-16 16:29:57 +11:00
parent 6814816d5f
commit d28826793d
6 changed files with 429 additions and 4 deletions

View File

@@ -9,8 +9,7 @@ def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7):
else :
while True: # a trick to avoid the gumbels bug
gumbels = -torch.empty_like(logits).exponential_().log()
new_logits = (logits + gumbels) / tau
#new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
probs = nn.functional.softmax(new_logits, dim=1)
if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break