updates for beta

This commit is contained in:
D-X-Y
2019-11-09 16:50:13 +11:00
parent 34ba8053de
commit 975fe4c385
9 changed files with 415 additions and 38 deletions

View File

@@ -83,7 +83,8 @@ class SearchCell(nn.Module):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
#aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) )
inter_nodes.append( aggregation )
nodes.append( sum(inter_nodes) )
return nodes[-1]

View File

@@ -3,7 +3,7 @@
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import torch
import torch, random
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
@@ -87,7 +87,7 @@ class TinyNetworkSETN(nn.Module):
return Structure( genotypes )
def dync_genotype(self):
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
@@ -95,9 +95,12 @@ class TinyNetworkSETN(nn.Module):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = alphas_cpu[ self.edge2index[node_str] ]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[ op_index ]
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[ self.edge2index[node_str] ]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[ op_index ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )