update codes

This commit is contained in:
D-X-Y
2020-01-12 01:42:17 +11:00
parent 654015bf9d
commit 33384a78af
15 changed files with 288 additions and 21 deletions

View File

@@ -9,10 +9,11 @@ from copy import deepcopy
from ..cell_operations import OPS
class SearchCell(nn.Module):
# This module is used for NAS-Bench-102, represents a small search space with a complete DAG
class NAS102SearchCell(nn.Module):
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
super(SearchCell, self).__init__()
super(NAS102SearchCell, self).__init__()
self.op_names = deepcopy(op_names)
self.edges = nn.ModuleDict()
@@ -74,7 +75,7 @@ class SearchCell(nn.Module):
nodes.append( sum(inter_nodes) )
return nodes[-1]
# uniform random sampling per iteration
# uniform random sampling per iteration, SETN
def forward_urs(self, inputs):
nodes = [inputs]
for i in range(1, self.max_nodes):
@@ -118,3 +119,61 @@ class SearchCell(nn.Module):
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
class MixedOp(nn.Module):
def __init__(self, space, C, stride, affine, track_running_stats):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in space:
op = OPS[primitive](C, C, stride, affine, track_running_stats)
self._ops.append(op)
def forward(self, x, weights, index):
#return sum(w * op(x) for w, op in zip(weights, self._ops))
return self._ops[index](x) * weights[index]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetSearchCell(nn.Module):
def __init__(self, space, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
super(NASNetSearchCell, self).__init__()
self.reduction = reduction
self.op_names = deepcopy(space)
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
self._steps = steps
self._multiplier = multiplier
self._ops = nn.ModuleList()
self.edges = nn.ModuleDict()
for i in range(self._steps):
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j)
stride = 2 if reduction and j < 2 else 1
op = MixedOp(space, C, stride, affine, track_running_stats)
self.edges[ node_str ] = op
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
def forward_gdas(self, s0, s1, weightss, indexs):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = '{:}<-{:}'.format(i, j)
op = self.edges[ node_str ]
weights = weightss[ self.edge2index[node_str] ]
index = indexs[ self.edge2index[node_str] ].item()
clist.append( op(h, weights, index) )
states.append( sum(clist) )
return torch.cat(states[-self._multiplier:], dim=1)