update GDAS
This commit is contained in:
@@ -47,35 +47,17 @@ class SearchCell(nn.Module):
|
||||
return nodes[-1]
|
||||
|
||||
# GDAS
|
||||
def forward_gdas(self, inputs, alphas, _tau):
|
||||
avoid_zero = 0
|
||||
while True:
|
||||
gumbels = -torch.empty_like(alphas).exponential_().log()
|
||||
logits = (alphas.log_softmax(dim=1) + gumbels) / _tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
|
||||
continue # avoid the numerical error
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
weights = hardwts[ self.edge2index[node_str] ]
|
||||
argmaxs = index[ self.edge2index[node_str] ].item()
|
||||
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
|
||||
inter_nodes.append( weigsum )
|
||||
nodes.append( sum(inter_nodes) )
|
||||
avoid_zero += 1
|
||||
if nodes[-1].sum().item() == 0:
|
||||
if avoid_zero < 10: continue
|
||||
else:
|
||||
warnings.warn('get zero outputs with avoid_zero={:}'.format(avoid_zero))
|
||||
break
|
||||
else:
|
||||
break
|
||||
def forward_gdas(self, inputs, hardwts, index):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
weights = hardwts[ self.edge2index[node_str] ]
|
||||
argmaxs = index[ self.edge2index[node_str] ].item()
|
||||
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
|
||||
inter_nodes.append( weigsum )
|
||||
nodes.append( sum(inter_nodes) )
|
||||
return nodes[-1]
|
||||
|
||||
# joint
|
||||
|
Reference in New Issue
Block a user