NAS-sharing-parameters support 3 datasets / update ops / update pypi
This commit is contained in:
@@ -13,16 +13,22 @@ OPS = {
|
||||
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
|
||||
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
|
||||
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
|
||||
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
|
||||
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
|
||||
}
|
||||
|
||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
|
||||
|
||||
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
||||
'aa-nas' : NAS_BENCH_102,
|
||||
'nas-bench-102': NAS_BENCH_102,
|
||||
'full' : sorted(list(OPS.keys()))}
|
||||
'darts' : DARTS_SPACE}
|
||||
#'full' : sorted(list(OPS.keys()))}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
@@ -39,6 +45,34 @@ class ReLUConvBN(nn.Module):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DualSepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(DualSepConv, self).__init__()
|
||||
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
|
||||
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.op_a(x)
|
||||
x = self.op_b(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, stride, affine=True):
|
||||
|
Reference in New Issue
Block a user