update GDAS

This commit is contained in:
D-X-Y
2019-11-19 11:58:04 +11:00
parent c3672648d7
commit 09d68c6375
20 changed files with 1176 additions and 90 deletions

View File

@@ -47,35 +47,17 @@ class SearchCell(nn.Module):
return nodes[-1]
# GDAS
def forward_gdas(self, inputs, alphas, _tau):
avoid_zero = 0
while True:
gumbels = -torch.empty_like(alphas).exponential_().log()
logits = (alphas.log_softmax(dim=1) + gumbels) / _tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
continue # avoid the numerical error
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = hardwts[ self.edge2index[node_str] ]
argmaxs = index[ self.edge2index[node_str] ].item()
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
inter_nodes.append( weigsum )
nodes.append( sum(inter_nodes) )
avoid_zero += 1
if nodes[-1].sum().item() == 0:
if avoid_zero < 10: continue
else:
warnings.warn('get zero outputs with avoid_zero={:}'.format(avoid_zero))
break
else:
break
def forward_gdas(self, inputs, hardwts, index):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = hardwts[ self.edge2index[node_str] ]
argmaxs = index[ self.edge2index[node_str] ].item()
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
inter_nodes.append( weigsum )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# joint