update SETN
This commit is contained in:
@@ -81,10 +81,10 @@ class Structure:
|
||||
if consider_zero:
|
||||
if op == 'none' or nodes[xin] == '#': x = '#' # zero
|
||||
elif op == 'skip_connect': x = nodes[xin]
|
||||
else: x = nodes[xin] + '@{:}'.format(op)
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
else:
|
||||
if op == 'skip_connect': x = nodes[xin]
|
||||
else: x = nodes[xin] + '@{:}'.format(op)
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
cur_node.append(x)
|
||||
nodes[i_node+1] = '+'.join( sorted(cur_node) )
|
||||
return nodes[ len(self.nodes) ]
|
||||
|
@@ -84,7 +84,6 @@ class TinyNetworkSETN(nn.Module):
|
||||
genotypes.append( tuple(xlist) )
|
||||
return Structure( genotypes )
|
||||
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
@@ -103,6 +102,26 @@ class TinyNetworkSETN(nn.Module):
|
||||
genotypes.append( tuple(xlist) )
|
||||
return Structure( genotypes )
|
||||
|
||||
def get_log_prob(self, arch):
|
||||
with torch.no_grad():
|
||||
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
|
||||
select_logits = []
|
||||
for i, node_info in enumerate(arch.nodes):
|
||||
for op, xin in node_info:
|
||||
node_str = '{:}<-{:}'.format(i+1, xin)
|
||||
op_index = self.op_names.index(op)
|
||||
select_logits.append( logits[self.edge2index[node_str], op_index] )
|
||||
return sum(select_logits).item()
|
||||
|
||||
|
||||
def return_topK(self, K):
|
||||
archs = Structure.gen_all(self.op_names, self.max_nodes, False)
|
||||
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
|
||||
if K < 0 or K >= len(archs): K = len(archs)
|
||||
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
|
||||
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
|
||||
return return_pairs
|
||||
|
||||
|
||||
def forward(self, inputs):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
|
Reference in New Issue
Block a user