This commit is contained in:
Jack Turner
2021-02-26 16:12:51 +00:00
parent c895924c99
commit b74255e1f3
74 changed files with 11326 additions and 537 deletions

0
pycls/models/__init__.py Normal file
View File

406
pycls/models/anynet.py Normal file
View File

@@ -0,0 +1,406 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""AnyNet models."""
import pycls.core.net as net
import torch.nn as nn
from pycls.core.config import cfg
def get_stem_fun(stem_type):
"""Retrieves the stem function by name."""
stem_funs = {
"res_stem_cifar": ResStemCifar,
"res_stem_in": ResStemIN,
"simple_stem_in": SimpleStemIN,
}
err_str = "Stem type '{}' not supported"
assert stem_type in stem_funs.keys(), err_str.format(stem_type)
return stem_funs[stem_type]
def get_block_fun(block_type):
"""Retrieves the block function by name."""
block_funs = {
"vanilla_block": VanillaBlock,
"res_basic_block": ResBasicBlock,
"res_bottleneck_block": ResBottleneckBlock,
}
err_str = "Block type '{}' not supported"
assert block_type in block_funs.keys(), err_str.format(block_type)
return block_funs[block_type]
class AnyHead(nn.Module):
"""AnyNet head: AvgPool, 1x1."""
def __init__(self, w_in, nc):
super(AnyHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(w_in, nc, bias=True)
def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
@staticmethod
def complexity(cx, w_in, nc):
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
return cx
class VanillaBlock(nn.Module):
"""Vanilla block: [3x3 conv, BN, Relu] x2."""
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Vanilla block does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
super(VanillaBlock, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Vanilla block does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class BasicTransform(nn.Module):
"""Basic transformation: [3x3 conv, BN, Relu] x2."""
def __init__(self, w_in, w_out, stride):
super(BasicTransform, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBasicBlock(nn.Module):
"""Residual basic block: x + F(x), F = basic transform."""
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Basic transform does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
super(ResBasicBlock, self).__init__()
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = BasicTransform(w_in, w_out, stride)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Basic transform does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = BasicTransform.complexity(cx, w_in, w_out, stride)
return cx
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block: AvgPool, FC, ReLU, FC, Sigmoid."""
def __init__(self, w_in, w_se):
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, 1, bias=True),
nn.ReLU(inplace=cfg.MEM.RELU_INPLACE),
nn.Conv2d(w_se, w_in, 1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
@staticmethod
def complexity(cx, w_in, w_se):
h, w = cx["h"], cx["w"]
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
cx["h"], cx["w"] = h, w
return cx
class BottleneckTransform(nn.Module):
"""Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
def __init__(self, w_in, w_out, stride, bm, gw, se_r):
super(BottleneckTransform, self).__init__()
w_b = int(round(w_out * bm))
g = w_b // gw
self.a = nn.Conv2d(w_in, w_b, 1, stride=1, padding=0, bias=False)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_b, w_b, 3, stride=stride, padding=1, groups=g, bias=False)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
if se_r:
w_se = int(round(w_in * se_r))
self.se = SE(w_b, w_se)
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm, gw, se_r):
w_b = int(round(w_out * bm))
g = w_b // gw
cx = net.complexity_conv2d(cx, w_in, w_b, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_b, 3, stride, 1, g)
cx = net.complexity_batchnorm2d(cx, w_b)
if se_r:
w_se = int(round(w_in * se_r))
cx = SE.complexity(cx, w_b, w_se)
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBottleneckBlock(nn.Module):
"""Residual bottleneck block: x + F(x), F = bottleneck transform."""
def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
super(ResBottleneckBlock, self).__init__()
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, bm, gw, se_r)
return cx
class ResStemCifar(nn.Module):
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(ResStemCifar, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResStemIN(nn.Module):
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
def __init__(self, w_in, w_out):
super(ResStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
return cx
class SimpleStemIN(nn.Module):
"""Simple stem for ImageNet: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(SimpleStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class AnyStage(nn.Module):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
super(AnyStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
name = "b{}".format(i + 1)
self.add_module(name, block_fun(b_w_in, w_out, b_stride, bm, gw, se_r))
def forward(self, x):
for block in self.children():
x = block(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
cx = block_fun.complexity(cx, b_w_in, w_out, b_stride, bm, gw, se_r)
return cx
class AnyNet(nn.Module):
"""AnyNet model."""
@staticmethod
def get_args():
return {
"stem_type": cfg.ANYNET.STEM_TYPE,
"stem_w": cfg.ANYNET.STEM_W,
"block_type": cfg.ANYNET.BLOCK_TYPE,
"ds": cfg.ANYNET.DEPTHS,
"ws": cfg.ANYNET.WIDTHS,
"ss": cfg.ANYNET.STRIDES,
"bms": cfg.ANYNET.BOT_MULS,
"gws": cfg.ANYNET.GROUP_WS,
"se_r": cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None,
"nc": cfg.MODEL.NUM_CLASSES,
}
def __init__(self, **kwargs):
super(AnyNet, self).__init__()
kwargs = self.get_args() if not kwargs else kwargs
#print(kwargs)
self._construct(**kwargs)
self.apply(net.init_weights)
def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
# Generate dummy bot muls and gs for models that do not use them
bms = bms if bms else [None for _d in ds]
gws = gws if gws else [None for _d in ds]
stage_params = list(zip(ds, ws, ss, bms, gws))
stem_fun = get_stem_fun(stem_type)
self.stem = stem_fun(3, stem_w)
block_fun = get_block_fun(block_type)
prev_w = stem_w
for i, (d, w, s, bm, gw) in enumerate(stage_params):
name = "s{}".format(i + 1)
self.add_module(name, AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r))
prev_w = w
self.head = AnyHead(w_in=prev_w, nc=nc)
def forward(self, x, get_ints=False):
for module in self.children():
x = module(x)
return x
@staticmethod
def complexity(cx, **kwargs):
"""Computes model complexity. If you alter the model, make sure to update."""
kwargs = AnyNet.get_args() if not kwargs else kwargs
return AnyNet._complexity(cx, **kwargs)
@staticmethod
def _complexity(cx, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
bms = bms if bms else [None for _d in ds]
gws = gws if gws else [None for _d in ds]
stage_params = list(zip(ds, ws, ss, bms, gws))
stem_fun = get_stem_fun(stem_type)
cx = stem_fun.complexity(cx, 3, stem_w)
block_fun = get_block_fun(block_type)
prev_w = stem_w
for d, w, s, bm, gw in stage_params:
cx = AnyStage.complexity(cx, prev_w, w, s, d, block_fun, bm, gw, se_r)
prev_w = w
cx = AnyHead.complexity(cx, prev_w, nc)
return cx

108
pycls/models/common.py Normal file
View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from pycls.core.config import cfg
def Preprocess(x):
if cfg.TASK == 'jig':
assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw'
assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw'
x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]])
return x
class Classifier(nn.Module):
def __init__(self, channels, num_classes):
super(Classifier, self).__init__()
if cfg.TASK == 'jig':
self.jig_sq = cfg.JIGSAW_GRID ** 2
self.pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(channels * self.jig_sq, num_classes)
elif cfg.TASK == 'col':
self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)
elif cfg.TASK == 'seg':
self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES)
else:
self.pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(channels, num_classes)
def forward(self, x, shape):
if cfg.TASK == 'jig':
x = self.pooling(x)
x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]])
x = self.classifier(x.view(x.size(0), -1))
elif cfg.TASK in ['col', 'seg']:
x = self.classifier(x)
x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x)
else:
x = self.pooling(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, num_classes, rates):
super(ASPP, self).__init__()
assert len(rates) in [1, 3]
self.rates = rates
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.aspp1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp2 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0],
padding=rates[0], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
if len(self.rates) == 3:
self.aspp3 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1],
padding=rates[1], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp4 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2],
padding=rates[2], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp5 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.classifier = nn.Sequential(
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1,
bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, num_classes, 1)
)
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x5 = self.global_pooling(x)
x5 = self.aspp5(x5)
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
align_corners=True)(x5)
if len(self.rates) == 3:
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x = torch.cat((x1, x2, x3, x4, x5), 1)
else:
x = torch.cat((x1, x2, x5), 1)
x = self.classifier(x)
return x

232
pycls/models/effnet.py Normal file
View File

@@ -0,0 +1,232 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""EfficientNet models."""
import pycls.core.net as net
import torch
import torch.nn as nn
from pycls.core.config import cfg
class EffHead(nn.Module):
"""EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
def __init__(self, w_in, w_out, nc):
super(EffHead, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.conv_swish = Swish()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
if cfg.EN.DROPOUT_RATIO > 0.0:
self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO)
self.fc = nn.Linear(w_out, nc, bias=True)
def forward(self, x):
x = self.conv_swish(self.conv_bn(self.conv(x)))
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x) if hasattr(self, "dropout") else x
x = self.fc(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, nc):
cx = net.complexity_conv2d(cx, w_in, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_out, nc, 1, 1, 0, bias=True)
return cx
class Swish(nn.Module):
"""Swish activation function: x * sigmoid(x)."""
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
def __init__(self, w_in, w_se):
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, 1, bias=True),
Swish(),
nn.Conv2d(w_se, w_in, 1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
@staticmethod
def complexity(cx, w_in, w_se):
h, w = cx["h"], cx["w"]
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
cx["h"], cx["w"] = h, w
return cx
class MBConv(nn.Module):
"""Mobile inverted bottleneck block w/ SE (MBConv)."""
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out):
# expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
super(MBConv, self).__init__()
self.exp = None
w_exp = int(w_in * exp_r)
if w_exp != w_in:
self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.exp_swish = Swish()
dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.dwise_swish = Swish()
self.se = SE(w_exp, int(w_in * se_r))
self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
# Skip connection if in and out shapes are the same (MN-V2 style)
self.has_skip = stride == 1 and w_in == w_out
def forward(self, x):
f_x = x
if self.exp:
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
f_x = self.se(f_x)
f_x = self.lin_proj_bn(self.lin_proj(f_x))
if self.has_skip:
if self.training and cfg.EN.DC_RATIO > 0.0:
f_x = net.drop_connect(f_x, cfg.EN.DC_RATIO)
f_x = x + f_x
return f_x
@staticmethod
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out):
w_exp = int(w_in * exp_r)
if w_exp != w_in:
cx = net.complexity_conv2d(cx, w_in, w_exp, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_exp)
padding = (kernel - 1) // 2
cx = net.complexity_conv2d(cx, w_exp, w_exp, kernel, stride, padding, w_exp)
cx = net.complexity_batchnorm2d(cx, w_exp)
cx = SE.complexity(cx, w_exp, int(w_in * se_r))
cx = net.complexity_conv2d(cx, w_exp, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class EffStage(nn.Module):
"""EfficientNet stage."""
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
super(EffStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
name = "b{}".format(i + 1)
self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out))
def forward(self, x):
for block in self.children():
x = block(x)
return x
@staticmethod
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out, d):
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
cx = MBConv.complexity(cx, b_w_in, exp_r, kernel, b_stride, se_r, w_out)
return cx
class StemIN(nn.Module):
"""EfficientNet stem for ImageNet: 3x3, BN, Swish."""
def __init__(self, w_in, w_out):
super(StemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.swish = Swish()
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class EffNet(nn.Module):
"""EfficientNet model."""
@staticmethod
def get_args():
return {
"stem_w": cfg.EN.STEM_W,
"ds": cfg.EN.DEPTHS,
"ws": cfg.EN.WIDTHS,
"exp_rs": cfg.EN.EXP_RATIOS,
"se_r": cfg.EN.SE_R,
"ss": cfg.EN.STRIDES,
"ks": cfg.EN.KERNELS,
"head_w": cfg.EN.HEAD_W,
"nc": cfg.MODEL.NUM_CLASSES,
}
def __init__(self):
err_str = "Dataset {} is not supported"
assert cfg.TRAIN.DATASET in ["imagenet"], err_str.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in ["imagenet"], err_str.format(cfg.TEST.DATASET)
super(EffNet, self).__init__()
self._construct(**EffNet.get_args())
self.apply(net.init_weights)
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
self.stem = StemIN(3, stem_w)
prev_w = stem_w
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
name = "s{}".format(i + 1)
self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d))
prev_w = w
self.head = EffHead(prev_w, head_w, nc)
def forward(self, x):
for module in self.children():
x = module(x)
return x
@staticmethod
def complexity(cx):
"""Computes model complexity. If you alter the model, make sure to update."""
return EffNet._complexity(cx, **EffNet.get_args())
@staticmethod
def _complexity(cx, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
cx = StemIN.complexity(cx, 3, stem_w)
prev_w = stem_w
for d, w, exp_r, stride, kernel in stage_params:
cx = EffStage.complexity(cx, prev_w, exp_r, kernel, stride, se_r, w, d)
prev_w = w
cx = EffHead.complexity(cx, prev_w, head_w, nc)
return cx

View File

@@ -0,0 +1,634 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS genotypes (adopted from DARTS)."""
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
# NASNet ops
NASNET_OPS = [
'skip_connect',
'conv_3x1_1x3',
'conv_7x1_1x7',
'dil_conv_3x3',
'avg_pool_3x3',
'max_pool_3x3',
'max_pool_5x5',
'max_pool_7x7',
'conv_1x1',
'conv_3x3',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
]
# ENAS ops
ENAS_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'avg_pool_3x3',
'max_pool_3x3',
]
# AmoebaNet ops
AMOEBA_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'avg_pool_3x3',
'max_pool_3x3',
'dil_sep_conv_3x3',
'conv_7x1_1x7',
]
# NAO ops
NAO_OPS = [
'skip_connect',
'conv_1x1',
'conv_3x3',
'conv_3x1_1x3',
'conv_7x1_1x7',
'max_pool_2x2',
'max_pool_3x3',
'max_pool_5x5',
'avg_pool_2x2',
'avg_pool_3x3',
'avg_pool_5x5',
]
# PNAS ops
PNAS_OPS = [
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'conv_7x1_1x7',
'skip_connect',
'avg_pool_3x3',
'max_pool_3x3',
'dil_conv_3x3',
]
# DARTS ops
DARTS_OPS = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
]
NASNet = Genotype(
normal=[
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 0),
('avg_pool_3x3', 0),
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('skip_connect', 1),
],
normal_concat=[2, 3, 4, 5, 6],
reduce=[
('sep_conv_5x5', 1),
('sep_conv_7x7', 0),
('max_pool_3x3', 1),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('skip_connect', 3),
('avg_pool_3x3', 2),
('sep_conv_3x3', 2),
('max_pool_3x3', 1),
],
reduce_concat=[4, 5, 6],
)
PNASNet = Genotype(
normal=[
('sep_conv_5x5', 0),
('max_pool_3x3', 0),
('sep_conv_7x7', 1),
('max_pool_3x3', 1),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 4),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 1),
],
normal_concat=[2, 3, 4, 5, 6],
reduce=[
('sep_conv_5x5', 0),
('max_pool_3x3', 0),
('sep_conv_7x7', 1),
('max_pool_3x3', 1),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 4),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 1),
],
reduce_concat=[2, 3, 4, 5, 6],
)
AmoebaNet = Genotype(
normal=[
('avg_pool_3x3', 0),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('avg_pool_3x3', 3),
('sep_conv_3x3', 1),
('skip_connect', 1),
('skip_connect', 0),
('avg_pool_3x3', 1),
],
normal_concat=[4, 5, 6],
reduce=[
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_7x7', 2),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('conv_7x1_1x7', 0),
('sep_conv_3x3', 5),
],
reduce_concat=[3, 4, 6]
)
DARTS_V1 = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 0),
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 2)
],
normal_concat=[2, 3, 4, 5],
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('skip_connect', 2),
('max_pool_3x3', 0),
('max_pool_3x3', 0),
('skip_connect', 2),
('skip_connect', 2),
('avg_pool_3x3', 0)
],
reduce_concat=[2, 3, 4, 5]
)
DARTS_V2 = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('skip_connect', 0),
('skip_connect', 0),
('dil_conv_3x3', 2)
],
normal_concat=[2, 3, 4, 5],
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('skip_connect', 2),
('max_pool_3x3', 1),
('max_pool_3x3', 0),
('skip_connect', 2),
('skip_connect', 2),
('max_pool_3x3', 1)
],
reduce_concat=[2, 3, 4, 5]
)
PDARTS = Genotype(
normal=[
('skip_connect', 0),
('dil_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('dil_conv_5x5', 4)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('dil_conv_3x3', 1),
('dil_conv_3x3', 1),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
PCDARTS_C10 = Genotype(
normal=[
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 0),
('dil_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 1),
('avg_pool_3x3', 0),
('dil_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 1),
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2)
],
reduce_concat=range(2, 6)
)
PCDARTS_IN1K = Genotype(
normal=[
('skip_connect', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('skip_connect', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('dil_conv_5x5', 4)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 2),
('max_pool_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_CLS = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('skip_connect', 1),
('max_pool_3x3', 0),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_ROT = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('sep_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_COL = Genotype(
normal=[
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_5x5', 3),
('max_pool_3x3', 0),
('sep_conv_3x3', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_JIG = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 1)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_CLS = Genotype(
normal=[
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('dil_conv_5x5', 3),
('dil_conv_5x5', 2),
('dil_conv_5x5', 4),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_ROT = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('dil_conv_5x5', 2),
('sep_conv_5x5', 0),
('dil_conv_5x5', 3),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('sep_conv_3x3', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_COL = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 2),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('sep_conv_3x3', 4),
('sep_conv_5x5', 1)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_JIG = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 4)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 0),
('skip_connect', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 0),
('sep_conv_5x5', 3),
('sep_conv_5x5', 0),
('sep_conv_5x5', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_SEG = Genotype(
normal=[
('skip_connect', 0),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 0),
('sep_conv_3x3', 4),
('sep_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_ROT = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_5x5', 2),
('sep_conv_5x5', 1),
('sep_conv_5x5', 3),
('dil_conv_5x5', 2),
('sep_conv_5x5', 2),
('sep_conv_5x5', 0)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_COL = Genotype(
normal=[
('dil_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 0),
('sep_conv_5x5', 2),
('dil_conv_3x3', 3),
('skip_connect', 0),
('skip_connect', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_JIG = Genotype(
normal=[
('dil_conv_5x5', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 0),
('dil_conv_5x5', 1)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 1),
('dil_conv_5x5', 2),
('dil_conv_5x5', 2),
('dil_conv_5x5', 0),
('dil_conv_5x5', 3),
('dil_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
# Supported genotypes
GENOTYPES = {
'nas': NASNet,
'pnas': PNASNet,
'amoeba': AmoebaNet,
'darts_v1': DARTS_V1,
'darts_v2': DARTS_V2,
'pdarts': PDARTS,
'pcdarts_c10': PCDARTS_C10,
'pcdarts_in1k': PCDARTS_IN1K,
'unnas_imagenet_cls': UNNAS_IMAGENET_CLS,
'unnas_imagenet_rot': UNNAS_IMAGENET_ROT,
'unnas_imagenet_col': UNNAS_IMAGENET_COL,
'unnas_imagenet_jig': UNNAS_IMAGENET_JIG,
'unnas_imagenet22k_cls': UNNAS_IMAGENET22K_CLS,
'unnas_imagenet22k_rot': UNNAS_IMAGENET22K_ROT,
'unnas_imagenet22k_col': UNNAS_IMAGENET22K_COL,
'unnas_imagenet22k_jig': UNNAS_IMAGENET22K_JIG,
'unnas_cityscapes_seg': UNNAS_CITYSCAPES_SEG,
'unnas_cityscapes_rot': UNNAS_CITYSCAPES_ROT,
'unnas_cityscapes_col': UNNAS_CITYSCAPES_COL,
'unnas_cityscapes_jig': UNNAS_CITYSCAPES_JIG,
'custom': None,
}

299
pycls/models/nas/nas.py Normal file
View File

@@ -0,0 +1,299 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS network (adopted from DARTS)."""
from torch.autograd import Variable
import torch
import torch.nn as nn
import pycls.core.logging as logging
from pycls.core.config import cfg
from pycls.models.common import Preprocess
from pycls.models.common import Classifier
from pycls.models.nas.genotypes import GENOTYPES
from pycls.models.nas.genotypes import Genotype
from pycls.models.nas.operations import FactorizedReduce
from pycls.models.nas.operations import OPS
from pycls.models.nas.operations import ReLUConvBN
from pycls.models.nas.operations import Identity
logger = logging.get_logger(__name__)
def drop_path(x, drop_prob):
"""Drop path (ported from DARTS)."""
if drop_prob > 0.:
keep_prob = 1.-drop_prob
mask = Variable(
torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
)
x.div_(keep_prob)
x.mul_(mask)
return x
class Cell(nn.Module):
"""NAS cell (ported from DARTS)."""
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C))
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
if reduction:
op_names, indices = zip(*genotype.reduce)
concat = genotype.reduce_concat
else:
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
for name, index in zip(op_names, indices):
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True)
self._ops += [op]
self._indices = indices
def forward(self, s0, s1, drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2*i]]
h2 = states[self._indices[2*i+1]]
op1 = self._ops[2*i]
op2 = self._ops[2*i+1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
states += [s]
return torch.cat([states[i] for i in self._concat], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class AuxiliaryHeadImageNet(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 14x14"""
super(AuxiliaryHeadImageNet, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
# Commenting it out for consistency with the experiments in the paper.
# nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class NetworkCIFAR(nn.Module):
"""CIFAR network (ported from DARTS)."""
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkCIFAR, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
stem_multiplier = 3
C_curr = stem_multiplier*C
self.stem = nn.Sequential(
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
if i == 2*layers//3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
self.classifier = Classifier(C_prev, num_classes)
def forward(self, input):
input = Preprocess(input)
logits_aux = None
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2*self._layers//3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1, input.shape[2:])
if self._auxiliary and self.training:
return logits, logits_aux
return logits
class NetworkImageNet(nn.Module):
"""ImageNet network (ported from DARTS)."""
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkImageNet, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
self.stem0 = nn.Sequential(
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_prev_prev, C_prev, C_curr = C, C, C
self.cells = nn.ModuleList()
reduction_prev = True
reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3]
for i in range(layers):
if i in reduction_layers:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if i == 2 * layers // 3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
self.classifier = Classifier(C_prev, num_classes)
def forward(self, input):
input = Preprocess(input)
logits_aux = None
s0 = self.stem0(input)
s1 = self.stem1(s0)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2 * self._layers // 3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1, input.shape[2:])
if self._auxiliary and self.training:
return logits, logits_aux
return logits
class NAS(nn.Module):
"""NAS net wrapper (delegates to nets from DARTS)."""
def __init__(self):
assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
'Training on {} is not supported'.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
'Testing on {} is not supported'.format(cfg.TEST.DATASET)
assert cfg.NAS.GENOTYPE in GENOTYPES, \
'Genotype {} not supported'.format(cfg.NAS.GENOTYPE)
super(NAS, self).__init__()
logger.info('Constructing NAS: {}'.format(cfg.NAS))
# Use a custom or predefined genotype
if cfg.NAS.GENOTYPE == 'custom':
genotype = Genotype(
normal=cfg.NAS.CUSTOM_GENOTYPE[0],
normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1],
reduce=cfg.NAS.CUSTOM_GENOTYPE[2],
reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3],
)
else:
genotype = GENOTYPES[cfg.NAS.GENOTYPE]
# Determine the network constructor for dataset
if 'cifar' in cfg.TRAIN.DATASET:
net_ctor = NetworkCIFAR
else:
net_ctor = NetworkImageNet
# Construct the network
self.net_ = net_ctor(
C=cfg.NAS.WIDTH,
num_classes=cfg.MODEL.NUM_CLASSES,
layers=cfg.NAS.DEPTH,
auxiliary=cfg.NAS.AUX,
genotype=genotype
)
# Drop path probability (set / annealed based on epoch)
self.net_.drop_path_prob = 0.0
def set_drop_path_prob(self, drop_path_prob):
self.net_.drop_path_prob = drop_path_prob
def forward(self, x):
return self.net_.forward(x)

View File

@@ -0,0 +1,201 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS ops (adopted from DARTS)."""
import torch
import torch.nn as nn
OPS = {
'none': lambda C, stride, affine:
Zero(stride),
'avg_pool_2x2': lambda C, stride, affine:
nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False),
'avg_pool_3x3': lambda C, stride, affine:
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
'avg_pool_5x5': lambda C, stride, affine:
nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False),
'max_pool_2x2': lambda C, stride, affine:
nn.MaxPool2d(2, stride=stride, padding=0),
'max_pool_3x3': lambda C, stride, affine:
nn.MaxPool2d(3, stride=stride, padding=1),
'max_pool_5x5': lambda C, stride, affine:
nn.MaxPool2d(5, stride=stride, padding=2),
'max_pool_7x7': lambda C, stride, affine:
nn.MaxPool2d(7, stride=stride, padding=3),
'skip_connect': lambda C, stride, affine:
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'conv_1x1': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_3x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'sep_conv_3x3': lambda C, stride, affine:
SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine:
SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine:
SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine:
DilConv(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5': lambda C, stride, affine:
DilConv(C, C, 5, stride, 4, 2, affine=affine),
'dil_sep_conv_3x3': lambda C, stride, affine:
DilSepConv(C, C, 3, stride, 2, 2, affine=affine),
'conv_3x1_1x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
nn.Conv2d(C, C, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_7x1_1x7': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
}
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_out, kernel_size, stride=stride,
padding=padding, bias=False
),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.op(x)
class DilConv(nn.Module):
def __init__(
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
):
super(DilConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class DilSepConv(nn.Module):
def __init__(
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
):
super(DilSepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:,:,::self.stride,::self.stride].mul(0.)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
out = self.bn(out)
return out

89
pycls/models/regnet.py Normal file
View File

@@ -0,0 +1,89 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""RegNet models."""
import numpy as np
from pycls.core.config import cfg
from pycls.models.anynet import AnyNet
def quantize_float(f, q):
"""Converts a float to closest non-zero int divisible by q."""
return int(round(f / q) * q)
def adjust_ws_gs_comp(ws, bms, gs):
"""Adjusts the compatibility of widths and groups."""
ws_bot = [int(w * b) for w, b in zip(ws, bms)]
gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
return ws, gs
def get_stages_from_blocks(ws, rs):
"""Gets ws/ds of network at each stage from per block values."""
ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
return s_ws, s_ds
def generate_regnet(w_a, w_0, w_m, d, q=8):
"""Generates per block ws from RegNet parameters."""
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
ws_cont = np.arange(d) * w_a + w_0
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
ws = w_0 * np.power(w_m, ks)
ws = np.round(np.divide(ws, q)) * q
num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
return ws, num_stages, max_stage, ws_cont
class RegNet(AnyNet):
"""RegNet model."""
@staticmethod
def get_args():
"""Convert RegNet to AnyNet parameter format."""
# Generate RegNet ws per block
w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH
ws, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
# Convert to per stage format
s_ws, s_ds = get_stages_from_blocks(ws, ws)
# Use the same gw, bm and ss for each stage
s_gs = [cfg.REGNET.GROUP_W for _ in range(num_stages)]
s_bs = [cfg.REGNET.BOT_MUL for _ in range(num_stages)]
s_ss = [cfg.REGNET.STRIDE for _ in range(num_stages)]
# Adjust the compatibility of ws and gws
s_ws, s_gs = adjust_ws_gs_comp(s_ws, s_bs, s_gs)
# Get AnyNet arguments defining the RegNet
return {
"stem_type": cfg.REGNET.STEM_TYPE,
"stem_w": cfg.REGNET.STEM_W,
"block_type": cfg.REGNET.BLOCK_TYPE,
"ds": s_ds,
"ws": s_ws,
"ss": s_ss,
"bms": s_bs,
"gws": s_gs,
"se_r": cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None,
"nc": cfg.MODEL.NUM_CLASSES,
}
def __init__(self):
kwargs = RegNet.get_args()
super(RegNet, self).__init__(**kwargs)
@staticmethod
def complexity(cx, **kwargs):
"""Computes model complexity. If you alter the model, make sure to update."""
kwargs = RegNet.get_args() if not kwargs else kwargs
return AnyNet.complexity(cx, **kwargs)

280
pycls/models/resnet.py Normal file
View File

@@ -0,0 +1,280 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""ResNe(X)t models."""
import pycls.core.net as net
import torch.nn as nn
from pycls.core.config import cfg
# Stage depths for ImageNet models
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
def get_trans_fun(name):
"""Retrieves the transformation function by name."""
trans_funs = {
"basic_transform": BasicTransform,
"bottleneck_transform": BottleneckTransform,
}
err_str = "Transformation function '{}' not supported"
assert name in trans_funs.keys(), err_str.format(name)
return trans_funs[name]
class ResHead(nn.Module):
"""ResNet head: AvgPool, 1x1."""
def __init__(self, w_in, nc):
super(ResHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(w_in, nc, bias=True)
def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
@staticmethod
def complexity(cx, w_in, nc):
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
return cx
class BasicTransform(nn.Module):
"""Basic transformation: 3x3, BN, ReLU, 3x3, BN."""
def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1):
err_str = "Basic transform does not support w_b and num_gs options"
assert w_b is None and num_gs == 1, err_str
super(BasicTransform, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, w_b=None, num_gs=1):
err_str = "Basic transform does not support w_b and num_gs options"
assert w_b is None and num_gs == 1, err_str
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class BottleneckTransform(nn.Module):
"""Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN."""
def __init__(self, w_in, w_out, stride, w_b, num_gs):
super(BottleneckTransform, self).__init__()
# MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, w_b, num_gs):
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
cx = net.complexity_conv2d(cx, w_in, w_b, 1, s1, 0)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_b, 3, s3, 1, num_gs)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBlock(nn.Module):
"""Residual block: x + F(x)."""
def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
super(ResBlock, self).__init__()
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = trans_fun(w_in, w_out, stride, w_b, num_gs)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, trans_fun, w_b, num_gs):
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, num_gs)
return cx
class ResStage(nn.Module):
"""Stage of ResNet."""
def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
super(ResStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN)
res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs)
self.add_module("b{}".format(i + 1), res_block)
def forward(self, x):
for block in self.children():
x = block(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, d, w_b=None, num_gs=1):
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
trans_f = get_trans_fun(cfg.RESNET.TRANS_FUN)
cx = ResBlock.complexity(cx, b_w_in, w_out, b_stride, trans_f, w_b, num_gs)
return cx
class ResStemCifar(nn.Module):
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(ResStemCifar, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResStemIN(nn.Module):
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
def __init__(self, w_in, w_out):
super(ResStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
return cx
class ResNet(nn.Module):
"""ResNet model."""
def __init__(self):
datasets = ["cifar10", "imagenet"]
err_str = "Dataset {} is not supported"
assert cfg.TRAIN.DATASET in datasets, err_str.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in datasets, err_str.format(cfg.TEST.DATASET)
super(ResNet, self).__init__()
if "cifar" in cfg.TRAIN.DATASET:
self._construct_cifar()
else:
self._construct_imagenet()
self.apply(net.init_weights)
def _construct_cifar(self):
err_str = "Model depth should be of the format 6n + 2 for cifar"
assert (cfg.MODEL.DEPTH - 2) % 6 == 0, err_str
d = int((cfg.MODEL.DEPTH - 2) / 6)
self.stem = ResStemCifar(3, 16)
self.s1 = ResStage(16, 16, stride=1, d=d)
self.s2 = ResStage(16, 32, stride=2, d=d)
self.s3 = ResStage(32, 64, stride=2, d=d)
self.head = ResHead(64, nc=cfg.MODEL.NUM_CLASSES)
def _construct_imagenet(self):
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
w_b = gw * g
self.stem = ResStemIN(3, 64)
self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g)
self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g)
self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g)
self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g)
self.head = ResHead(2048, nc=cfg.MODEL.NUM_CLASSES)
def forward(self, x):
for module in self.children():
x = module(x)
return x
@staticmethod
def complexity(cx):
"""Computes model complexity. If you alter the model, make sure to update."""
if "cifar" in cfg.TRAIN.DATASET:
d = int((cfg.MODEL.DEPTH - 2) / 6)
cx = ResStemCifar.complexity(cx, 3, 16)
cx = ResStage.complexity(cx, 16, 16, stride=1, d=d)
cx = ResStage.complexity(cx, 16, 32, stride=2, d=d)
cx = ResStage.complexity(cx, 32, 64, stride=2, d=d)
cx = ResHead.complexity(cx, 64, nc=cfg.MODEL.NUM_CLASSES)
else:
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
w_b = gw * g
cx = ResStemIN.complexity(cx, 3, 64)
cx = ResStage.complexity(cx, 64, 256, 1, d=d1, w_b=w_b, num_gs=g)
cx = ResStage.complexity(cx, 256, 512, 2, d=d2, w_b=w_b * 2, num_gs=g)
cx = ResStage.complexity(cx, 512, 1024, 2, d=d3, w_b=w_b * 4, num_gs=g)
cx = ResStage.complexity(cx, 1024, 2048, 2, d=d4, w_b=w_b * 8, num_gs=g)
cx = ResHead.complexity(cx, 2048, nc=cfg.MODEL.NUM_CLASSES)
return cx