v2
This commit is contained in:
0
pycls/models/__init__.py
Normal file
0
pycls/models/__init__.py
Normal file
406
pycls/models/anynet.py
Normal file
406
pycls/models/anynet.py
Normal file
@@ -0,0 +1,406 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""AnyNet models."""
|
||||
|
||||
import pycls.core.net as net
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def get_stem_fun(stem_type):
|
||||
"""Retrieves the stem function by name."""
|
||||
stem_funs = {
|
||||
"res_stem_cifar": ResStemCifar,
|
||||
"res_stem_in": ResStemIN,
|
||||
"simple_stem_in": SimpleStemIN,
|
||||
}
|
||||
err_str = "Stem type '{}' not supported"
|
||||
assert stem_type in stem_funs.keys(), err_str.format(stem_type)
|
||||
return stem_funs[stem_type]
|
||||
|
||||
|
||||
def get_block_fun(block_type):
|
||||
"""Retrieves the block function by name."""
|
||||
block_funs = {
|
||||
"vanilla_block": VanillaBlock,
|
||||
"res_basic_block": ResBasicBlock,
|
||||
"res_bottleneck_block": ResBottleneckBlock,
|
||||
}
|
||||
err_str = "Block type '{}' not supported"
|
||||
assert block_type in block_funs.keys(), err_str.format(block_type)
|
||||
return block_funs[block_type]
|
||||
|
||||
|
||||
class AnyHead(nn.Module):
|
||||
"""AnyNet head: AvgPool, 1x1."""
|
||||
|
||||
def __init__(self, w_in, nc):
|
||||
super(AnyHead, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(w_in, nc, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, nc):
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
|
||||
return cx
|
||||
|
||||
|
||||
class VanillaBlock(nn.Module):
|
||||
"""Vanilla block: [3x3 conv, BN, Relu] x2."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Vanilla block does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
super(VanillaBlock, self).__init__()
|
||||
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Vanilla block does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class BasicTransform(nn.Module):
|
||||
"""Basic transformation: [3x3 conv, BN, Relu] x2."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride):
|
||||
super(BasicTransform, self).__init__()
|
||||
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResBasicBlock(nn.Module):
|
||||
"""Residual basic block: x + F(x), F = basic transform."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Basic transform does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
super(ResBasicBlock, self).__init__()
|
||||
self.proj_block = (w_in != w_out) or (stride != 1)
|
||||
if self.proj_block:
|
||||
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.f = BasicTransform(w_in, w_out, stride)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
if self.proj_block:
|
||||
x = self.bn(self.proj(x)) + self.f(x)
|
||||
else:
|
||||
x = x + self.f(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Basic transform does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
proj_block = (w_in != w_out) or (stride != 1)
|
||||
if proj_block:
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = h, w # parallel branch
|
||||
cx = BasicTransform.complexity(cx, w_in, w_out, stride)
|
||||
return cx
|
||||
|
||||
|
||||
class SE(nn.Module):
|
||||
"""Squeeze-and-Excitation (SE) block: AvgPool, FC, ReLU, FC, Sigmoid."""
|
||||
|
||||
def __init__(self, w_in, w_se):
|
||||
super(SE, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.f_ex = nn.Sequential(
|
||||
nn.Conv2d(w_in, w_se, 1, bias=True),
|
||||
nn.ReLU(inplace=cfg.MEM.RELU_INPLACE),
|
||||
nn.Conv2d(w_se, w_in, 1, bias=True),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.f_ex(self.avg_pool(x))
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_se):
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
|
||||
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
|
||||
cx["h"], cx["w"] = h, w
|
||||
return cx
|
||||
|
||||
|
||||
class BottleneckTransform(nn.Module):
|
||||
"""Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm, gw, se_r):
|
||||
super(BottleneckTransform, self).__init__()
|
||||
w_b = int(round(w_out * bm))
|
||||
g = w_b // gw
|
||||
self.a = nn.Conv2d(w_in, w_b, 1, stride=1, padding=0, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_b, w_b, 3, stride=stride, padding=1, groups=g, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
if se_r:
|
||||
w_se = int(round(w_in * se_r))
|
||||
self.se = SE(w_b, w_se)
|
||||
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.c_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm, gw, se_r):
|
||||
w_b = int(round(w_out * bm))
|
||||
g = w_b // gw
|
||||
cx = net.complexity_conv2d(cx, w_in, w_b, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_b, 3, stride, 1, g)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
if se_r:
|
||||
w_se = int(round(w_in * se_r))
|
||||
cx = SE.complexity(cx, w_b, w_se)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResBottleneckBlock(nn.Module):
|
||||
"""Residual bottleneck block: x + F(x), F = bottleneck transform."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
|
||||
super(ResBottleneckBlock, self).__init__()
|
||||
# Use skip connection with projection if shape changes
|
||||
self.proj_block = (w_in != w_out) or (stride != 1)
|
||||
if self.proj_block:
|
||||
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
if self.proj_block:
|
||||
x = self.bn(self.proj(x)) + self.f(x)
|
||||
else:
|
||||
x = x + self.f(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
|
||||
proj_block = (w_in != w_out) or (stride != 1)
|
||||
if proj_block:
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = h, w # parallel branch
|
||||
cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, bm, gw, se_r)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemCifar(nn.Module):
|
||||
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemCifar, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemIN(nn.Module):
|
||||
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
|
||||
return cx
|
||||
|
||||
|
||||
class SimpleStemIN(nn.Module):
|
||||
"""Simple stem for ImageNet: 3x3, BN, ReLU."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(SimpleStemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class AnyStage(nn.Module):
|
||||
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
|
||||
super(AnyStage, self).__init__()
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
name = "b{}".format(i + 1)
|
||||
self.add_module(name, block_fun(b_w_in, w_out, b_stride, bm, gw, se_r))
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.children():
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
cx = block_fun.complexity(cx, b_w_in, w_out, b_stride, bm, gw, se_r)
|
||||
return cx
|
||||
|
||||
|
||||
class AnyNet(nn.Module):
|
||||
"""AnyNet model."""
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
return {
|
||||
"stem_type": cfg.ANYNET.STEM_TYPE,
|
||||
"stem_w": cfg.ANYNET.STEM_W,
|
||||
"block_type": cfg.ANYNET.BLOCK_TYPE,
|
||||
"ds": cfg.ANYNET.DEPTHS,
|
||||
"ws": cfg.ANYNET.WIDTHS,
|
||||
"ss": cfg.ANYNET.STRIDES,
|
||||
"bms": cfg.ANYNET.BOT_MULS,
|
||||
"gws": cfg.ANYNET.GROUP_WS,
|
||||
"se_r": cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None,
|
||||
"nc": cfg.MODEL.NUM_CLASSES,
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(AnyNet, self).__init__()
|
||||
kwargs = self.get_args() if not kwargs else kwargs
|
||||
#print(kwargs)
|
||||
self._construct(**kwargs)
|
||||
self.apply(net.init_weights)
|
||||
|
||||
def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
|
||||
# Generate dummy bot muls and gs for models that do not use them
|
||||
bms = bms if bms else [None for _d in ds]
|
||||
gws = gws if gws else [None for _d in ds]
|
||||
stage_params = list(zip(ds, ws, ss, bms, gws))
|
||||
stem_fun = get_stem_fun(stem_type)
|
||||
self.stem = stem_fun(3, stem_w)
|
||||
block_fun = get_block_fun(block_type)
|
||||
prev_w = stem_w
|
||||
for i, (d, w, s, bm, gw) in enumerate(stage_params):
|
||||
name = "s{}".format(i + 1)
|
||||
self.add_module(name, AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r))
|
||||
prev_w = w
|
||||
self.head = AnyHead(w_in=prev_w, nc=nc)
|
||||
|
||||
def forward(self, x, get_ints=False):
|
||||
for module in self.children():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, **kwargs):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
kwargs = AnyNet.get_args() if not kwargs else kwargs
|
||||
return AnyNet._complexity(cx, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _complexity(cx, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
|
||||
bms = bms if bms else [None for _d in ds]
|
||||
gws = gws if gws else [None for _d in ds]
|
||||
stage_params = list(zip(ds, ws, ss, bms, gws))
|
||||
stem_fun = get_stem_fun(stem_type)
|
||||
cx = stem_fun.complexity(cx, 3, stem_w)
|
||||
block_fun = get_block_fun(block_type)
|
||||
prev_w = stem_w
|
||||
for d, w, s, bm, gw in stage_params:
|
||||
cx = AnyStage.complexity(cx, prev_w, w, s, d, block_fun, bm, gw, se_r)
|
||||
prev_w = w
|
||||
cx = AnyHead.complexity(cx, prev_w, nc)
|
||||
return cx
|
108
pycls/models/common.py
Normal file
108
pycls/models/common.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def Preprocess(x):
|
||||
if cfg.TASK == 'jig':
|
||||
assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw'
|
||||
assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw'
|
||||
x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]])
|
||||
return x
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self, channels, num_classes):
|
||||
super(Classifier, self).__init__()
|
||||
if cfg.TASK == 'jig':
|
||||
self.jig_sq = cfg.JIGSAW_GRID ** 2
|
||||
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(channels * self.jig_sq, num_classes)
|
||||
elif cfg.TASK == 'col':
|
||||
self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)
|
||||
elif cfg.TASK == 'seg':
|
||||
self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES)
|
||||
else:
|
||||
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(channels, num_classes)
|
||||
|
||||
def forward(self, x, shape):
|
||||
if cfg.TASK == 'jig':
|
||||
x = self.pooling(x)
|
||||
x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]])
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
elif cfg.TASK in ['col', 'seg']:
|
||||
x = self.classifier(x)
|
||||
x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x)
|
||||
else:
|
||||
x = self.pooling(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
class ASPP(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_classes, rates):
|
||||
super(ASPP, self).__init__()
|
||||
assert len(rates) in [1, 3]
|
||||
self.rates = rates
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.aspp1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp2 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0],
|
||||
padding=rates[0], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
if len(self.rates) == 3:
|
||||
self.aspp3 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1],
|
||||
padding=rates[1], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp4 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2],
|
||||
padding=rates[2], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp5 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, num_classes, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.aspp1(x)
|
||||
x2 = self.aspp2(x)
|
||||
x5 = self.global_pooling(x)
|
||||
x5 = self.aspp5(x5)
|
||||
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
|
||||
align_corners=True)(x5)
|
||||
if len(self.rates) == 3:
|
||||
x3 = self.aspp3(x)
|
||||
x4 = self.aspp4(x)
|
||||
x = torch.cat((x1, x2, x3, x4, x5), 1)
|
||||
else:
|
||||
x = torch.cat((x1, x2, x5), 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
232
pycls/models/effnet.py
Normal file
232
pycls/models/effnet.py
Normal file
@@ -0,0 +1,232 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""EfficientNet models."""
|
||||
|
||||
import pycls.core.net as net
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
class EffHead(nn.Module):
|
||||
"""EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
|
||||
|
||||
def __init__(self, w_in, w_out, nc):
|
||||
super(EffHead, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.conv_swish = Swish()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
if cfg.EN.DROPOUT_RATIO > 0.0:
|
||||
self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO)
|
||||
self.fc = nn.Linear(w_out, nc, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_swish(self.conv_bn(self.conv(x)))
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.dropout(x) if hasattr(self, "dropout") else x
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, nc):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_out, nc, 1, 1, 0, bias=True)
|
||||
return cx
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
"""Swish activation function: x * sigmoid(x)."""
|
||||
|
||||
def __init__(self):
|
||||
super(Swish, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SE(nn.Module):
|
||||
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
|
||||
|
||||
def __init__(self, w_in, w_se):
|
||||
super(SE, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.f_ex = nn.Sequential(
|
||||
nn.Conv2d(w_in, w_se, 1, bias=True),
|
||||
Swish(),
|
||||
nn.Conv2d(w_se, w_in, 1, bias=True),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.f_ex(self.avg_pool(x))
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_se):
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
|
||||
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
|
||||
cx["h"], cx["w"] = h, w
|
||||
return cx
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
"""Mobile inverted bottleneck block w/ SE (MBConv)."""
|
||||
|
||||
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out):
|
||||
# expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
|
||||
super(MBConv, self).__init__()
|
||||
self.exp = None
|
||||
w_exp = int(w_in * exp_r)
|
||||
if w_exp != w_in:
|
||||
self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
|
||||
self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.exp_swish = Swish()
|
||||
dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
|
||||
self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
|
||||
self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.dwise_swish = Swish()
|
||||
self.se = SE(w_exp, int(w_in * se_r))
|
||||
self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
# Skip connection if in and out shapes are the same (MN-V2 style)
|
||||
self.has_skip = stride == 1 and w_in == w_out
|
||||
|
||||
def forward(self, x):
|
||||
f_x = x
|
||||
if self.exp:
|
||||
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
|
||||
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
|
||||
f_x = self.se(f_x)
|
||||
f_x = self.lin_proj_bn(self.lin_proj(f_x))
|
||||
if self.has_skip:
|
||||
if self.training and cfg.EN.DC_RATIO > 0.0:
|
||||
f_x = net.drop_connect(f_x, cfg.EN.DC_RATIO)
|
||||
f_x = x + f_x
|
||||
return f_x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out):
|
||||
w_exp = int(w_in * exp_r)
|
||||
if w_exp != w_in:
|
||||
cx = net.complexity_conv2d(cx, w_in, w_exp, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_exp)
|
||||
padding = (kernel - 1) // 2
|
||||
cx = net.complexity_conv2d(cx, w_exp, w_exp, kernel, stride, padding, w_exp)
|
||||
cx = net.complexity_batchnorm2d(cx, w_exp)
|
||||
cx = SE.complexity(cx, w_exp, int(w_in * se_r))
|
||||
cx = net.complexity_conv2d(cx, w_exp, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class EffStage(nn.Module):
|
||||
"""EfficientNet stage."""
|
||||
|
||||
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
|
||||
super(EffStage, self).__init__()
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
name = "b{}".format(i + 1)
|
||||
self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out))
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.children():
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out, d):
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
cx = MBConv.complexity(cx, b_w_in, exp_r, kernel, b_stride, se_r, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class StemIN(nn.Module):
|
||||
"""EfficientNet stem for ImageNet: 3x3, BN, Swish."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(StemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.swish = Swish()
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class EffNet(nn.Module):
|
||||
"""EfficientNet model."""
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
return {
|
||||
"stem_w": cfg.EN.STEM_W,
|
||||
"ds": cfg.EN.DEPTHS,
|
||||
"ws": cfg.EN.WIDTHS,
|
||||
"exp_rs": cfg.EN.EXP_RATIOS,
|
||||
"se_r": cfg.EN.SE_R,
|
||||
"ss": cfg.EN.STRIDES,
|
||||
"ks": cfg.EN.KERNELS,
|
||||
"head_w": cfg.EN.HEAD_W,
|
||||
"nc": cfg.MODEL.NUM_CLASSES,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
err_str = "Dataset {} is not supported"
|
||||
assert cfg.TRAIN.DATASET in ["imagenet"], err_str.format(cfg.TRAIN.DATASET)
|
||||
assert cfg.TEST.DATASET in ["imagenet"], err_str.format(cfg.TEST.DATASET)
|
||||
super(EffNet, self).__init__()
|
||||
self._construct(**EffNet.get_args())
|
||||
self.apply(net.init_weights)
|
||||
|
||||
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
|
||||
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
|
||||
self.stem = StemIN(3, stem_w)
|
||||
prev_w = stem_w
|
||||
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
|
||||
name = "s{}".format(i + 1)
|
||||
self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d))
|
||||
prev_w = w
|
||||
self.head = EffHead(prev_w, head_w, nc)
|
||||
|
||||
def forward(self, x):
|
||||
for module in self.children():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
return EffNet._complexity(cx, **EffNet.get_args())
|
||||
|
||||
@staticmethod
|
||||
def _complexity(cx, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
|
||||
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
|
||||
cx = StemIN.complexity(cx, 3, stem_w)
|
||||
prev_w = stem_w
|
||||
for d, w, exp_r, stride, kernel in stage_params:
|
||||
cx = EffStage.complexity(cx, prev_w, exp_r, kernel, stride, se_r, w, d)
|
||||
prev_w = w
|
||||
cx = EffHead.complexity(cx, prev_w, head_w, nc)
|
||||
return cx
|
634
pycls/models/nas/genotypes.py
Normal file
634
pycls/models/nas/genotypes.py
Normal file
@@ -0,0 +1,634 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""NAS genotypes (adopted from DARTS)."""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
|
||||
# NASNet ops
|
||||
NASNET_OPS = [
|
||||
'skip_connect',
|
||||
'conv_3x1_1x3',
|
||||
'conv_7x1_1x7',
|
||||
'dil_conv_3x3',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'max_pool_5x5',
|
||||
'max_pool_7x7',
|
||||
'conv_1x1',
|
||||
'conv_3x3',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
]
|
||||
|
||||
# ENAS ops
|
||||
ENAS_OPS = [
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
]
|
||||
|
||||
# AmoebaNet ops
|
||||
AMOEBA_OPS = [
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'dil_sep_conv_3x3',
|
||||
'conv_7x1_1x7',
|
||||
]
|
||||
|
||||
# NAO ops
|
||||
NAO_OPS = [
|
||||
'skip_connect',
|
||||
'conv_1x1',
|
||||
'conv_3x3',
|
||||
'conv_3x1_1x3',
|
||||
'conv_7x1_1x7',
|
||||
'max_pool_2x2',
|
||||
'max_pool_3x3',
|
||||
'max_pool_5x5',
|
||||
'avg_pool_2x2',
|
||||
'avg_pool_3x3',
|
||||
'avg_pool_5x5',
|
||||
]
|
||||
|
||||
# PNAS ops
|
||||
PNAS_OPS = [
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'conv_7x1_1x7',
|
||||
'skip_connect',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'dil_conv_3x3',
|
||||
]
|
||||
|
||||
# DARTS ops
|
||||
DARTS_OPS = [
|
||||
'none',
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5',
|
||||
]
|
||||
|
||||
|
||||
NASNet = Genotype(
|
||||
normal=[
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5, 6],
|
||||
reduce=[
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_7x7', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_7x7', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('skip_connect', 3),
|
||||
('avg_pool_3x3', 2),
|
||||
('sep_conv_3x3', 2),
|
||||
('max_pool_3x3', 1),
|
||||
],
|
||||
reduce_concat=[4, 5, 6],
|
||||
)
|
||||
|
||||
|
||||
PNASNet = Genotype(
|
||||
normal=[
|
||||
('sep_conv_5x5', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 1),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 4),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5, 6],
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 1),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 4),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5, 6],
|
||||
)
|
||||
|
||||
|
||||
AmoebaNet = Genotype(
|
||||
normal=[
|
||||
('avg_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 1),
|
||||
('skip_connect', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
],
|
||||
normal_concat=[4, 5, 6],
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 2),
|
||||
('sep_conv_7x7', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('conv_7x1_1x7', 0),
|
||||
('sep_conv_3x3', 5),
|
||||
],
|
||||
reduce_concat=[3, 4, 6]
|
||||
)
|
||||
|
||||
|
||||
DARTS_V1 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 2)
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 2),
|
||||
('skip_connect', 2),
|
||||
('avg_pool_3x3', 0)
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5]
|
||||
)
|
||||
|
||||
|
||||
DARTS_V2 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('skip_connect', 0),
|
||||
('dil_conv_3x3', 2)
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 2),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 1)
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5]
|
||||
)
|
||||
|
||||
PDARTS = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('dil_conv_3x3', 1),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
PCDARTS_C10 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('dil_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
PCDARTS_IN1K = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('dil_conv_5x5', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_CLS = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_COL = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 3),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_JIG = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_CLS = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 4),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_3x3', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_COL = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 1)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_JIG = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('skip_connect', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_5x5', 3),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_5x5', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_SEG = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_5x5', 0)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_COL = Genotype(
|
||||
normal=[
|
||||
('dil_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_5x5', 2),
|
||||
('dil_conv_3x3', 3),
|
||||
('skip_connect', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('skip_connect', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_JIG = Genotype(
|
||||
normal=[
|
||||
('dil_conv_5x5', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
|
||||
# Supported genotypes
|
||||
GENOTYPES = {
|
||||
'nas': NASNet,
|
||||
'pnas': PNASNet,
|
||||
'amoeba': AmoebaNet,
|
||||
'darts_v1': DARTS_V1,
|
||||
'darts_v2': DARTS_V2,
|
||||
'pdarts': PDARTS,
|
||||
'pcdarts_c10': PCDARTS_C10,
|
||||
'pcdarts_in1k': PCDARTS_IN1K,
|
||||
'unnas_imagenet_cls': UNNAS_IMAGENET_CLS,
|
||||
'unnas_imagenet_rot': UNNAS_IMAGENET_ROT,
|
||||
'unnas_imagenet_col': UNNAS_IMAGENET_COL,
|
||||
'unnas_imagenet_jig': UNNAS_IMAGENET_JIG,
|
||||
'unnas_imagenet22k_cls': UNNAS_IMAGENET22K_CLS,
|
||||
'unnas_imagenet22k_rot': UNNAS_IMAGENET22K_ROT,
|
||||
'unnas_imagenet22k_col': UNNAS_IMAGENET22K_COL,
|
||||
'unnas_imagenet22k_jig': UNNAS_IMAGENET22K_JIG,
|
||||
'unnas_cityscapes_seg': UNNAS_CITYSCAPES_SEG,
|
||||
'unnas_cityscapes_rot': UNNAS_CITYSCAPES_ROT,
|
||||
'unnas_cityscapes_col': UNNAS_CITYSCAPES_COL,
|
||||
'unnas_cityscapes_jig': UNNAS_CITYSCAPES_JIG,
|
||||
'custom': None,
|
||||
}
|
299
pycls/models/nas/nas.py
Normal file
299
pycls/models/nas/nas.py
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""NAS network (adopted from DARTS)."""
|
||||
|
||||
from torch.autograd import Variable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import pycls.core.logging as logging
|
||||
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.common import Preprocess
|
||||
from pycls.models.common import Classifier
|
||||
from pycls.models.nas.genotypes import GENOTYPES
|
||||
from pycls.models.nas.genotypes import Genotype
|
||||
from pycls.models.nas.operations import FactorizedReduce
|
||||
from pycls.models.nas.operations import OPS
|
||||
from pycls.models.nas.operations import ReLUConvBN
|
||||
from pycls.models.nas.operations import Identity
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
"""Drop path (ported from DARTS)."""
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1.-drop_prob
|
||||
mask = Variable(
|
||||
torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
|
||||
)
|
||||
x.div_(keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
"""NAS cell (ported from DARTS)."""
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C))
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
if reduction:
|
||||
op_names, indices = zip(*genotype.reduce)
|
||||
concat = genotype.reduce_concat
|
||||
else:
|
||||
op_names, indices = zip(*genotype.normal)
|
||||
concat = genotype.normal_concat
|
||||
self._compile(C, op_names, indices, concat, reduction)
|
||||
|
||||
def _compile(self, C, op_names, indices, concat, reduction):
|
||||
assert len(op_names) == len(indices)
|
||||
self._steps = len(op_names) // 2
|
||||
self._concat = concat
|
||||
self.multiplier = len(concat)
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for name, index in zip(op_names, indices):
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
op = OPS[name](C, stride, True)
|
||||
self._ops += [op]
|
||||
self._indices = indices
|
||||
|
||||
def forward(self, s0, s1, drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
s = h1 + h2
|
||||
states += [s]
|
||||
return torch.cat([states[i] for i in self._concat], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
class AuxiliaryHeadImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
|
||||
# Commenting it out for consistency with the experiments in the paper.
|
||||
# nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
class NetworkCIFAR(nn.Module):
|
||||
"""CIFAR network (ported from DARTS)."""
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkCIFAR, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
|
||||
stem_multiplier = 3
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = False
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
|
||||
if i == 2*layers//3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
|
||||
self.classifier = Classifier(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
input = Preprocess(input)
|
||||
logits_aux = None
|
||||
s0 = s1 = self.stem(input)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2*self._layers//3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
logits = self.classifier(s1, input.shape[2:])
|
||||
if self._auxiliary and self.training:
|
||||
return logits, logits_aux
|
||||
return logits
|
||||
|
||||
|
||||
class NetworkImageNet(nn.Module):
|
||||
"""ImageNet network (ported from DARTS)."""
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkImageNet, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C, C, C
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = True
|
||||
reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3]
|
||||
for i in range(layers):
|
||||
if i in reduction_layers:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
|
||||
if i == 2 * layers // 3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
|
||||
self.classifier = Classifier(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
input = Preprocess(input)
|
||||
logits_aux = None
|
||||
s0 = self.stem0(input)
|
||||
s1 = self.stem1(s0)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2 * self._layers // 3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
logits = self.classifier(s1, input.shape[2:])
|
||||
if self._auxiliary and self.training:
|
||||
return logits, logits_aux
|
||||
return logits
|
||||
|
||||
|
||||
class NAS(nn.Module):
|
||||
"""NAS net wrapper (delegates to nets from DARTS)."""
|
||||
|
||||
def __init__(self):
|
||||
assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
|
||||
'Training on {} is not supported'.format(cfg.TRAIN.DATASET)
|
||||
assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
|
||||
'Testing on {} is not supported'.format(cfg.TEST.DATASET)
|
||||
assert cfg.NAS.GENOTYPE in GENOTYPES, \
|
||||
'Genotype {} not supported'.format(cfg.NAS.GENOTYPE)
|
||||
super(NAS, self).__init__()
|
||||
logger.info('Constructing NAS: {}'.format(cfg.NAS))
|
||||
# Use a custom or predefined genotype
|
||||
if cfg.NAS.GENOTYPE == 'custom':
|
||||
genotype = Genotype(
|
||||
normal=cfg.NAS.CUSTOM_GENOTYPE[0],
|
||||
normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1],
|
||||
reduce=cfg.NAS.CUSTOM_GENOTYPE[2],
|
||||
reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3],
|
||||
)
|
||||
else:
|
||||
genotype = GENOTYPES[cfg.NAS.GENOTYPE]
|
||||
# Determine the network constructor for dataset
|
||||
if 'cifar' in cfg.TRAIN.DATASET:
|
||||
net_ctor = NetworkCIFAR
|
||||
else:
|
||||
net_ctor = NetworkImageNet
|
||||
# Construct the network
|
||||
self.net_ = net_ctor(
|
||||
C=cfg.NAS.WIDTH,
|
||||
num_classes=cfg.MODEL.NUM_CLASSES,
|
||||
layers=cfg.NAS.DEPTH,
|
||||
auxiliary=cfg.NAS.AUX,
|
||||
genotype=genotype
|
||||
)
|
||||
# Drop path probability (set / annealed based on epoch)
|
||||
self.net_.drop_path_prob = 0.0
|
||||
|
||||
def set_drop_path_prob(self, drop_path_prob):
|
||||
self.net_.drop_path_prob = drop_path_prob
|
||||
|
||||
def forward(self, x):
|
||||
return self.net_.forward(x)
|
201
pycls/models/nas/operations.py
Normal file
201
pycls/models/nas/operations.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
"""NAS ops (adopted from DARTS)."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
OPS = {
|
||||
'none': lambda C, stride, affine:
|
||||
Zero(stride),
|
||||
'avg_pool_2x2': lambda C, stride, affine:
|
||||
nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False),
|
||||
'avg_pool_3x3': lambda C, stride, affine:
|
||||
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
|
||||
'avg_pool_5x5': lambda C, stride, affine:
|
||||
nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False),
|
||||
'max_pool_2x2': lambda C, stride, affine:
|
||||
nn.MaxPool2d(2, stride=stride, padding=0),
|
||||
'max_pool_3x3': lambda C, stride, affine:
|
||||
nn.MaxPool2d(3, stride=stride, padding=1),
|
||||
'max_pool_5x5': lambda C, stride, affine:
|
||||
nn.MaxPool2d(5, stride=stride, padding=2),
|
||||
'max_pool_7x7': lambda C, stride, affine:
|
||||
nn.MaxPool2d(7, stride=stride, padding=3),
|
||||
'skip_connect': lambda C, stride, affine:
|
||||
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
|
||||
'conv_1x1': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'conv_3x3': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'sep_conv_3x3': lambda C, stride, affine:
|
||||
SepConv(C, C, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5': lambda C, stride, affine:
|
||||
SepConv(C, C, 5, stride, 2, affine=affine),
|
||||
'sep_conv_7x7': lambda C, stride, affine:
|
||||
SepConv(C, C, 7, stride, 3, affine=affine),
|
||||
'dil_conv_3x3': lambda C, stride, affine:
|
||||
DilConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5': lambda C, stride, affine:
|
||||
DilConv(C, C, 5, stride, 4, 2, affine=affine),
|
||||
'dil_sep_conv_3x3': lambda C, stride, affine:
|
||||
DilSepConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'conv_3x1_1x3': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
|
||||
nn.Conv2d(C, C, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'conv_7x1_1x7': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
|
||||
nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_out, kernel_size, stride=stride,
|
||||
padding=padding, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilConv(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
|
||||
):
|
||||
super(DilConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=1,
|
||||
padding=padding, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilSepConv(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
|
||||
):
|
||||
super(DilSepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=1,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, affine=True):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
assert C_out % 2 == 0
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
89
pycls/models/regnet.py
Normal file
89
pycls/models/regnet.py
Normal file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""RegNet models."""
|
||||
|
||||
import numpy as np
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.anynet import AnyNet
|
||||
|
||||
|
||||
def quantize_float(f, q):
|
||||
"""Converts a float to closest non-zero int divisible by q."""
|
||||
return int(round(f / q) * q)
|
||||
|
||||
|
||||
def adjust_ws_gs_comp(ws, bms, gs):
|
||||
"""Adjusts the compatibility of widths and groups."""
|
||||
ws_bot = [int(w * b) for w, b in zip(ws, bms)]
|
||||
gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
|
||||
ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
|
||||
ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
|
||||
return ws, gs
|
||||
|
||||
|
||||
def get_stages_from_blocks(ws, rs):
|
||||
"""Gets ws/ds of network at each stage from per block values."""
|
||||
ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
|
||||
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
|
||||
s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
|
||||
s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
|
||||
return s_ws, s_ds
|
||||
|
||||
|
||||
def generate_regnet(w_a, w_0, w_m, d, q=8):
|
||||
"""Generates per block ws from RegNet parameters."""
|
||||
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
|
||||
ws_cont = np.arange(d) * w_a + w_0
|
||||
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
|
||||
ws = w_0 * np.power(w_m, ks)
|
||||
ws = np.round(np.divide(ws, q)) * q
|
||||
num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
|
||||
ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
|
||||
return ws, num_stages, max_stage, ws_cont
|
||||
|
||||
|
||||
class RegNet(AnyNet):
|
||||
"""RegNet model."""
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
"""Convert RegNet to AnyNet parameter format."""
|
||||
# Generate RegNet ws per block
|
||||
w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH
|
||||
ws, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
|
||||
# Convert to per stage format
|
||||
s_ws, s_ds = get_stages_from_blocks(ws, ws)
|
||||
# Use the same gw, bm and ss for each stage
|
||||
s_gs = [cfg.REGNET.GROUP_W for _ in range(num_stages)]
|
||||
s_bs = [cfg.REGNET.BOT_MUL for _ in range(num_stages)]
|
||||
s_ss = [cfg.REGNET.STRIDE for _ in range(num_stages)]
|
||||
# Adjust the compatibility of ws and gws
|
||||
s_ws, s_gs = adjust_ws_gs_comp(s_ws, s_bs, s_gs)
|
||||
# Get AnyNet arguments defining the RegNet
|
||||
return {
|
||||
"stem_type": cfg.REGNET.STEM_TYPE,
|
||||
"stem_w": cfg.REGNET.STEM_W,
|
||||
"block_type": cfg.REGNET.BLOCK_TYPE,
|
||||
"ds": s_ds,
|
||||
"ws": s_ws,
|
||||
"ss": s_ss,
|
||||
"bms": s_bs,
|
||||
"gws": s_gs,
|
||||
"se_r": cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None,
|
||||
"nc": cfg.MODEL.NUM_CLASSES,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
kwargs = RegNet.get_args()
|
||||
super(RegNet, self).__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, **kwargs):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
kwargs = RegNet.get_args() if not kwargs else kwargs
|
||||
return AnyNet.complexity(cx, **kwargs)
|
280
pycls/models/resnet.py
Normal file
280
pycls/models/resnet.py
Normal file
@@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""ResNe(X)t models."""
|
||||
|
||||
import pycls.core.net as net
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
# Stage depths for ImageNet models
|
||||
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
|
||||
|
||||
|
||||
def get_trans_fun(name):
|
||||
"""Retrieves the transformation function by name."""
|
||||
trans_funs = {
|
||||
"basic_transform": BasicTransform,
|
||||
"bottleneck_transform": BottleneckTransform,
|
||||
}
|
||||
err_str = "Transformation function '{}' not supported"
|
||||
assert name in trans_funs.keys(), err_str.format(name)
|
||||
return trans_funs[name]
|
||||
|
||||
|
||||
class ResHead(nn.Module):
|
||||
"""ResNet head: AvgPool, 1x1."""
|
||||
|
||||
def __init__(self, w_in, nc):
|
||||
super(ResHead, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(w_in, nc, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, nc):
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
|
||||
return cx
|
||||
|
||||
|
||||
class BasicTransform(nn.Module):
|
||||
"""Basic transformation: 3x3, BN, ReLU, 3x3, BN."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1):
|
||||
err_str = "Basic transform does not support w_b and num_gs options"
|
||||
assert w_b is None and num_gs == 1, err_str
|
||||
super(BasicTransform, self).__init__()
|
||||
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, w_b=None, num_gs=1):
|
||||
err_str = "Basic transform does not support w_b and num_gs options"
|
||||
assert w_b is None and num_gs == 1, err_str
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class BottleneckTransform(nn.Module):
|
||||
"""Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, w_b, num_gs):
|
||||
super(BottleneckTransform, self).__init__()
|
||||
# MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
|
||||
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
|
||||
self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.c_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, w_b, num_gs):
|
||||
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
|
||||
cx = net.complexity_conv2d(cx, w_in, w_b, 1, s1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_b, 3, s3, 1, num_gs)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block: x + F(x)."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
|
||||
super(ResBlock, self).__init__()
|
||||
# Use skip connection with projection if shape changes
|
||||
self.proj_block = (w_in != w_out) or (stride != 1)
|
||||
if self.proj_block:
|
||||
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.f = trans_fun(w_in, w_out, stride, w_b, num_gs)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
if self.proj_block:
|
||||
x = self.bn(self.proj(x)) + self.f(x)
|
||||
else:
|
||||
x = x + self.f(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, trans_fun, w_b, num_gs):
|
||||
proj_block = (w_in != w_out) or (stride != 1)
|
||||
if proj_block:
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = h, w # parallel branch
|
||||
cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, num_gs)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStage(nn.Module):
|
||||
"""Stage of ResNet."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
|
||||
super(ResStage, self).__init__()
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN)
|
||||
res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs)
|
||||
self.add_module("b{}".format(i + 1), res_block)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.children():
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, d, w_b=None, num_gs=1):
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
trans_f = get_trans_fun(cfg.RESNET.TRANS_FUN)
|
||||
cx = ResBlock.complexity(cx, b_w_in, w_out, b_stride, trans_f, w_b, num_gs)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemCifar(nn.Module):
|
||||
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemCifar, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemIN(nn.Module):
|
||||
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
|
||||
return cx
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
"""ResNet model."""
|
||||
|
||||
def __init__(self):
|
||||
datasets = ["cifar10", "imagenet"]
|
||||
err_str = "Dataset {} is not supported"
|
||||
assert cfg.TRAIN.DATASET in datasets, err_str.format(cfg.TRAIN.DATASET)
|
||||
assert cfg.TEST.DATASET in datasets, err_str.format(cfg.TEST.DATASET)
|
||||
super(ResNet, self).__init__()
|
||||
if "cifar" in cfg.TRAIN.DATASET:
|
||||
self._construct_cifar()
|
||||
else:
|
||||
self._construct_imagenet()
|
||||
self.apply(net.init_weights)
|
||||
|
||||
def _construct_cifar(self):
|
||||
err_str = "Model depth should be of the format 6n + 2 for cifar"
|
||||
assert (cfg.MODEL.DEPTH - 2) % 6 == 0, err_str
|
||||
d = int((cfg.MODEL.DEPTH - 2) / 6)
|
||||
self.stem = ResStemCifar(3, 16)
|
||||
self.s1 = ResStage(16, 16, stride=1, d=d)
|
||||
self.s2 = ResStage(16, 32, stride=2, d=d)
|
||||
self.s3 = ResStage(32, 64, stride=2, d=d)
|
||||
self.head = ResHead(64, nc=cfg.MODEL.NUM_CLASSES)
|
||||
|
||||
def _construct_imagenet(self):
|
||||
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
|
||||
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
|
||||
w_b = gw * g
|
||||
self.stem = ResStemIN(3, 64)
|
||||
self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g)
|
||||
self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g)
|
||||
self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g)
|
||||
self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g)
|
||||
self.head = ResHead(2048, nc=cfg.MODEL.NUM_CLASSES)
|
||||
|
||||
def forward(self, x):
|
||||
for module in self.children():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
if "cifar" in cfg.TRAIN.DATASET:
|
||||
d = int((cfg.MODEL.DEPTH - 2) / 6)
|
||||
cx = ResStemCifar.complexity(cx, 3, 16)
|
||||
cx = ResStage.complexity(cx, 16, 16, stride=1, d=d)
|
||||
cx = ResStage.complexity(cx, 16, 32, stride=2, d=d)
|
||||
cx = ResStage.complexity(cx, 32, 64, stride=2, d=d)
|
||||
cx = ResHead.complexity(cx, 64, nc=cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
|
||||
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
|
||||
w_b = gw * g
|
||||
cx = ResStemIN.complexity(cx, 3, 64)
|
||||
cx = ResStage.complexity(cx, 64, 256, 1, d=d1, w_b=w_b, num_gs=g)
|
||||
cx = ResStage.complexity(cx, 256, 512, 2, d=d2, w_b=w_b * 2, num_gs=g)
|
||||
cx = ResStage.complexity(cx, 512, 1024, 2, d=d3, w_b=w_b * 4, num_gs=g)
|
||||
cx = ResStage.complexity(cx, 1024, 2048, 2, d=d4, w_b=w_b * 8, num_gs=g)
|
||||
cx = ResHead.complexity(cx, 2048, nc=cfg.MODEL.NUM_CLASSES)
|
||||
return cx
|
Reference in New Issue
Block a user