NAS-sharing-parameters support 3 datasets / update ops / update pypi

This commit is contained in:
D-X-Y
2020-01-11 00:19:58 +11:00
parent 96152a9904
commit c66afa4df8
17 changed files with 192 additions and 153 deletions

View File

@@ -13,16 +13,22 @@ OPS = {
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
}
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'aa-nas' : NAS_BENCH_102,
'nas-bench-102': NAS_BENCH_102,
'full' : sorted(list(OPS.keys()))}
'darts' : DARTS_SPACE}
#'full' : sorted(list(OPS.keys()))}
class ReLUConvBN(nn.Module):
@@ -39,6 +45,34 @@ class ReLUConvBN(nn.Module):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
)
def forward(self, x):
return self.op(x)
class DualSepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(DualSepConv, self).__init__()
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
def forward(self, x):
x = self.op_a(x)
x = self.op_b(x)
return x
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, affine=True):