support first-order DARTS on the NASNet search space
This commit is contained in:
@@ -131,10 +131,12 @@ class MixedOp(nn.Module):
|
||||
op = OPS[primitive](C, C, stride, affine, track_running_stats)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights, index):
|
||||
#return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
def forward_gdas(self, x, weights, index):
|
||||
return self._ops[index](x) * weights[index]
|
||||
|
||||
def forward_darts(self, x, weights):
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
class NASNetSearchCell(nn.Module):
|
||||
@@ -173,7 +175,23 @@ class NASNetSearchCell(nn.Module):
|
||||
op = self.edges[ node_str ]
|
||||
weights = weightss[ self.edge2index[node_str] ]
|
||||
index = indexs[ self.edge2index[node_str] ].item()
|
||||
clist.append( op(h, weights, index) )
|
||||
clist.append( op.forward_gdas(h, weights, index) )
|
||||
states.append( sum(clist) )
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
def forward_darts(self, s0, s1, weightss):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
op = self.edges[ node_str ]
|
||||
weights = weightss[ self.edge2index[node_str] ]
|
||||
clist.append( op.forward_darts(h, weights) )
|
||||
states.append( sum(clist) )
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
Reference in New Issue
Block a user