Support GDAS (FRC), see details in docs/CVPR-2019-GDAS.md
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||
__all__ = ['OPS', 'RAW_OP_CLASSES', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||
|
||||
OPS = {
|
||||
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
|
||||
@@ -175,7 +175,7 @@ class FactorizedReduce(nn.Module):
|
||||
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine))
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
elif stride == 1:
|
||||
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
|
||||
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=not affine)
|
||||
else:
|
||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||
@@ -256,41 +256,44 @@ def drop_path(x, drop_prob):
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours
|
||||
class GDAS_Reduction_Cell(nn.Module):
|
||||
|
||||
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier, affine, track_running_stats):
|
||||
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats):
|
||||
super(GDAS_Reduction_Cell, self).__init__()
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine, track_running_stats)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, affine, track_running_stats)
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.reduction = True
|
||||
self.ops1 = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
|
||||
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True),
|
||||
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=not affine),
|
||||
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=not affine),
|
||||
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True)),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
|
||||
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats)),
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
|
||||
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True),
|
||||
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=not affine),
|
||||
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=not affine),
|
||||
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True))])
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
|
||||
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats))])
|
||||
|
||||
self.ops2 = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(C, affine=True)),
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats)),
|
||||
nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(C, affine=True))])
|
||||
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats))])
|
||||
|
||||
@property
|
||||
def multiplier(self):
|
||||
return 4
|
||||
|
||||
def forward(self, s0, s1, drop_prob = -1):
|
||||
s0 = self.preprocess0(s0)
|
||||
@@ -307,3 +310,10 @@ class GDAS_Reduction_Cell(nn.Module):
|
||||
if self.training and drop_prob > 0.:
|
||||
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
|
||||
return torch.cat([X0, X1, X2, X3], dim=1)
|
||||
|
||||
|
||||
# To manage the useful classes in this file.
|
||||
RAW_OP_CLASSES = {
|
||||
'gdas_reduction': GDAS_Reduction_Cell
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user