Move to xautodl
This commit is contained in:
286
xautodl/models/shape_infers/InferCifarResNet.py
Normal file
286
xautodl/models/shape_infers/InferCifarResNet.py
Normal file
@@ -0,0 +1,286 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[1],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
iCs[1],
|
||||
iCs[2],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferCifarResNet(nn.Module):
|
||||
def __init__(
|
||||
self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual
|
||||
):
|
||||
super(InferCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks)
|
||||
|
||||
self.message = (
|
||||
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
last_channel_idx = 1
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iCs,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(iL + 1, layer_blocks):
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
break
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
263
xautodl/models/shape_infers/InferCifarResNet_depth.py
Normal file
263
xautodl/models/shape_infers/InferCifarResNet_depth.py
Normal file
@@ -0,0 +1,263 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
planes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferDepthCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
|
||||
super(InferDepthCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks)
|
||||
|
||||
self.message = (
|
||||
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
]
|
||||
)
|
||||
self.channels = [16]
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
planes,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
break
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.channels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
277
xautodl/models/shape_infers/InferCifarResNet_width.py
Normal file
277
xautodl/models/shape_infers/InferCifarResNet_width.py
Normal file
@@ -0,0 +1,277 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[1],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
iCs[1],
|
||||
iCs[2],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferWidthCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
|
||||
super(InferWidthCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = (
|
||||
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
last_channel_idx = 1
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iCs,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
324
xautodl/models/shape_infers/InferImagenetResNet.py
Normal file
324
xautodl/models/shape_infers/InferImagenetResNet.py
Normal file
@@ -0,0 +1,324 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[1],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
iCs[1],
|
||||
iCs[2],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferImagenetResNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
block_name,
|
||||
layers,
|
||||
xblocks,
|
||||
xchannels,
|
||||
deep_stem,
|
||||
num_classes,
|
||||
zero_init_residual,
|
||||
):
|
||||
super(InferImagenetResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "BasicBlock":
|
||||
block = ResNetBasicblock
|
||||
elif block_name == "Bottleneck":
|
||||
block = ResNetBottleneck
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
assert len(xblocks) == len(
|
||||
layers
|
||||
), "invalid layers : {:} vs xblocks : {:}".format(layers, xblocks)
|
||||
|
||||
self.message = "InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}".format(
|
||||
sum(layers) * block.num_conv, sum(xblocks) * block.num_conv, xblocks
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
if not deep_stem:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
7,
|
||||
2,
|
||||
3,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
last_channel_idx = 1
|
||||
else:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
3,
|
||||
2,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
),
|
||||
ConvBNReLU(
|
||||
xchannels[1],
|
||||
xchannels[2],
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
last_channel_idx = 2
|
||||
self.layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||
for stage, layer_blocks in enumerate(layers):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iCs,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(iL + 1, layer_blocks):
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
break
|
||||
assert last_channel_idx + 1 == len(self.xchannels), "{:} vs {:}".format(
|
||||
last_channel_idx, len(self.xchannels)
|
||||
)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
174
xautodl/models/shape_infers/InferMobileNetV2.py
Normal file
174
xautodl/models/shape_infers/InferMobileNetV2.py
Normal file
@@ -0,0 +1,174 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
|
||||
from torch import nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import parse_channel_info
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
groups,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(out_planes)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU6(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
if self.bn:
|
||||
out = self.bn(out)
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, channels, stride, expand_ratio, additive):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2], "invalid stride : {:}".format(stride)
|
||||
assert len(channels) in [2, 3], "invalid channels : {:}".format(channels)
|
||||
|
||||
if len(channels) == 2:
|
||||
layers = []
|
||||
else:
|
||||
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
|
||||
layers.extend(
|
||||
[
|
||||
# dw
|
||||
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
|
||||
# pw-linear
|
||||
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
|
||||
]
|
||||
)
|
||||
self.conv = nn.Sequential(*layers)
|
||||
self.additive = additive
|
||||
if self.additive and channels[0] != channels[-1]:
|
||||
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
|
||||
else:
|
||||
self.shortcut = None
|
||||
self.out_dim = channels[-1]
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
# if self.additive: return additive_func(out, x)
|
||||
if self.shortcut:
|
||||
return out + self.shortcut(x)
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
class InferMobileNetV2(nn.Module):
|
||||
def __init__(self, num_classes, xchannels, xblocks, dropout):
|
||||
super(InferMobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
assert len(inverted_residual_setting) == len(
|
||||
xblocks
|
||||
), "invalid number of layers : {:} vs {:}".format(
|
||||
len(inverted_residual_setting), len(xblocks)
|
||||
)
|
||||
for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
|
||||
assert block_num <= ir_setting[2], "{:} vs {:}".format(
|
||||
block_num, ir_setting
|
||||
)
|
||||
xchannels = parse_channel_info(xchannels)
|
||||
# for i, chs in enumerate(xchannels):
|
||||
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
|
||||
self.xchannels = xchannels
|
||||
self.message = "InferMobileNetV2 : xblocks={:}".format(xblocks)
|
||||
# building first layer
|
||||
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
|
||||
last_channel_idx = 1
|
||||
|
||||
# building inverted residual blocks
|
||||
for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
additv = True if i > 0 else False
|
||||
module = block(self.xchannels[last_channel_idx], stride, t, additv)
|
||||
features.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(
|
||||
stage,
|
||||
i,
|
||||
n,
|
||||
len(features),
|
||||
self.xchannels[last_channel_idx],
|
||||
stride,
|
||||
t,
|
||||
c,
|
||||
)
|
||||
last_channel_idx += 1
|
||||
if i + 1 == xblocks[stage]:
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(i + 1, n):
|
||||
last_channel_idx += 1
|
||||
self.xchannels[last_channel_idx][0] = module.out_dim
|
||||
break
|
||||
# building last several layers
|
||||
features.append(
|
||||
ConvBNReLU(
|
||||
self.xchannels[last_channel_idx][0],
|
||||
self.xchannels[last_channel_idx][1],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
)
|
||||
)
|
||||
assert last_channel_idx + 2 == len(self.xchannels), "{:} vs {:}".format(
|
||||
last_channel_idx, len(self.xchannels)
|
||||
)
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
features = self.features(inputs)
|
||||
vectors = features.mean([2, 3])
|
||||
predicts = self.classifier(vectors)
|
||||
return features, predicts
|
64
xautodl/models/shape_infers/InferTinyCellNet.py
Normal file
64
xautodl/models/shape_infers/InferTinyCellNet.py
Normal file
@@ -0,0 +1,64 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from typing import List, Text, Any
|
||||
import torch.nn as nn
|
||||
from models.cell_operations import ResNetBasicblock
|
||||
from models.cell_infers.cells import InferCell
|
||||
|
||||
|
||||
class DynamicShapeTinyNet(nn.Module):
|
||||
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
|
||||
super(DynamicShapeTinyNet, self).__init__()
|
||||
self._channels = channels
|
||||
if len(channels) % 3 != 2:
|
||||
raise ValueError("invalid number of layers : {:}".format(len(channels)))
|
||||
self._num_stage = N = len(channels) // 3
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(channels[0]),
|
||||
)
|
||||
|
||||
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
c_prev = channels[0]
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(c_prev, c_curr, 2, True)
|
||||
else:
|
||||
cell = InferCell(genotype, c_prev, c_curr, 1)
|
||||
self.cells.append(cell)
|
||||
c_prev = cell.out_dim
|
||||
self._num_layer = len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(c_prev, num_classes)
|
||||
|
||||
def get_message(self) -> Text:
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_channels}, N={_num_stage}, L={_num_layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
9
xautodl/models/shape_infers/__init__.py
Normal file
9
xautodl/models/shape_infers/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from .InferCifarResNet_width import InferWidthCifarResNet
|
||||
from .InferImagenetResNet import InferImagenetResNet
|
||||
from .InferCifarResNet_depth import InferDepthCifarResNet
|
||||
from .InferCifarResNet import InferCifarResNet
|
||||
from .InferMobileNetV2 import InferMobileNetV2
|
||||
from .InferTinyCellNet import DynamicShapeTinyNet
|
5
xautodl/models/shape_infers/shared_utils.py
Normal file
5
xautodl/models/shape_infers/shared_utils.py
Normal file
@@ -0,0 +1,5 @@
|
||||
def parse_channel_info(xstring):
|
||||
blocks = xstring.split(" ")
|
||||
blocks = [x.split("-") for x in blocks]
|
||||
blocks = [[int(_) for _ in x] for x in blocks]
|
||||
return blocks
|
Reference in New Issue
Block a user