support first-order DARTS on the NASNet search space

This commit is contained in:
D-X-Y
2020-01-17 22:14:47 +11:00
parent 56f2161a3f
commit db2760c260
10 changed files with 213 additions and 13 deletions

View File

@@ -131,10 +131,12 @@ class MixedOp(nn.Module):
op = OPS[primitive](C, C, stride, affine, track_running_stats)
self._ops.append(op)
def forward(self, x, weights, index):
#return sum(w * op(x) for w, op in zip(weights, self._ops))
def forward_gdas(self, x, weights, index):
return self._ops[index](x) * weights[index]
def forward_darts(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops))
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetSearchCell(nn.Module):
@@ -173,7 +175,23 @@ class NASNetSearchCell(nn.Module):
op = self.edges[ node_str ]
weights = weightss[ self.edge2index[node_str] ]
index = indexs[ self.edge2index[node_str] ].item()
clist.append( op(h, weights, index) )
clist.append( op.forward_gdas(h, weights, index) )
states.append( sum(clist) )
return torch.cat(states[-self._multiplier:], dim=1)
def forward_darts(self, s0, s1, weightss):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = '{:}<-{:}'.format(i, j)
op = self.edges[ node_str ]
weights = weightss[ self.edge2index[node_str] ]
clist.append( op.forward_darts(h, weights) )
states.append( sum(clist) )
return torch.cat(states[-self._multiplier:], dim=1)