update affines for NAS

This commit is contained in:
D-X-Y
2019-12-02 18:03:40 +11:00
parent 487fec21bf
commit d175a361bd
9 changed files with 78 additions and 41 deletions

View File

@@ -19,9 +19,9 @@ class InferCell(nn.Module):
cur_innod = []
for (op_name, op_in) in node_info:
if op_in == 0:
layer = OPS[op_name](C_in , C_out, stride)
layer = OPS[op_name](C_in , C_out, stride, True)
else:
layer = OPS[op_name](C_out, C_out, 1)
layer = OPS[op_name](C_out, C_out, 1, True)
cur_index.append( len(self.layers) )
cur_innod.append( op_in )
self.layers.append( layer )

View File

@@ -22,7 +22,7 @@ class TinyNetwork(nn.Module):
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
else:
cell = InferCell(genotype, C_prev, C_curr, 1)
self.cells.append( cell )

View File

@@ -4,16 +4,16 @@
import torch
import torch.nn as nn
__all__ = ['OPS', 'ReLUConvBN', 'ResNetBasicblock', 'SearchSpaceNames']
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = {
'none' : lambda C_in, C_out, stride: Zero(C_in, C_out, stride),
'avg_pool_3x3' : lambda C_in, C_out, stride: POOLING(C_in, C_out, stride, 'avg'),
'max_pool_3x3' : lambda C_in, C_out, stride: POOLING(C_in, C_out, stride, 'max'),
'nor_conv_7x7' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1)),
'nor_conv_3x3' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1)),
'nor_conv_1x1' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1)),
'skip_connect' : lambda C_in, C_out, stride: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride),
'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride),
'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'),
'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'),
'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine),
'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine),
'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine),
'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine),
}
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
@@ -26,12 +26,12 @@ SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out)
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
@@ -40,17 +40,17 @@ class ReLUConvBN(nn.Module):
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride):
def __init__(self, inplanes, planes, stride, affine=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1)
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
else:
self.downsample = None
self.in_dim = inplanes
@@ -76,12 +76,12 @@ class ResNetBasicblock(nn.Module):
class POOLING(nn.Module):
def __init__(self, C_in, C_out, stride, mode):
def __init__(self, C_in, C_out, stride, mode, affine=True):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0)
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine)
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
@@ -126,7 +126,7 @@ class Zero(nn.Module):
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride):
def __init__(self, C_in, C_out, stride, affine):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
@@ -141,8 +141,7 @@ class FactorizedReduce(nn.Module):
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
x = self.relu(x)

View File

@@ -23,9 +23,9 @@ class SearchCell(nn.Module):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if j == 0:
xlists = [OPS[op_name](C_in , C_out, stride) for op_name in op_names]
xlists = [OPS[op_name](C_in , C_out, stride, False) for op_name in op_names]
else:
xlists = [OPS[op_name](C_in , C_out, 1) for op_name in op_names]
xlists = [OPS[op_name](C_in , C_out, 1, False) for op_name in op_names]
self.edges[ node_str ] = nn.ModuleList( xlists )
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}