update codes
This commit is contained in:
@@ -9,10 +9,11 @@ from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
|
||||
class SearchCell(nn.Module):
|
||||
# This module is used for NAS-Bench-102, represents a small search space with a complete DAG
|
||||
class NAS102SearchCell(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
|
||||
super(SearchCell, self).__init__()
|
||||
super(NAS102SearchCell, self).__init__()
|
||||
|
||||
self.op_names = deepcopy(op_names)
|
||||
self.edges = nn.ModuleDict()
|
||||
@@ -74,7 +75,7 @@ class SearchCell(nn.Module):
|
||||
nodes.append( sum(inter_nodes) )
|
||||
return nodes[-1]
|
||||
|
||||
# uniform random sampling per iteration
|
||||
# uniform random sampling per iteration, SETN
|
||||
def forward_urs(self, inputs):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
@@ -118,3 +119,61 @@ class SearchCell(nn.Module):
|
||||
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
|
||||
nodes.append( sum(inter_nodes) )
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, space, C, stride, affine, track_running_stats):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in space:
|
||||
op = OPS[primitive](C, C, stride, affine, track_running_stats)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights, index):
|
||||
#return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
return self._ops[index](x) * weights[index]
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
class NASNetSearchCell(nn.Module):
|
||||
|
||||
def __init__(self, space, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
|
||||
super(NASNetSearchCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.op_names = deepcopy(space)
|
||||
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
|
||||
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
|
||||
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
self.edges = nn.ModuleDict()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(space, C, stride, affine, track_running_stats)
|
||||
self.edges[ node_str ] = op
|
||||
self.edge_keys = sorted(list(self.edges.keys()))
|
||||
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||
self.num_edges = len(self.edges)
|
||||
|
||||
def forward_gdas(self, s0, s1, weightss, indexs):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
op = self.edges[ node_str ]
|
||||
weights = weightss[ self.edge2index[node_str] ]
|
||||
index = indexs[ self.edge2index[node_str] ].item()
|
||||
clist.append( op(h, weights, index) )
|
||||
states.append( sum(clist) )
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
Reference in New Issue
Block a user