init
This commit is contained in:
89
lib/nas/CifarNet.py
Normal file
89
lib/nas/CifarNet.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .construct_utils import Cell, Transition
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
class NetworkCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkCIFAR, self).__init__()
|
||||
self._layers = layers
|
||||
|
||||
stem_multiplier = 3
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = False
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
if reduction and genotype.reduce is None:
|
||||
cell = Transition(C_prev_prev, C_prev, C_curr, reduction_prev)
|
||||
else:
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
|
||||
if i == 2*layers//3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
|
||||
else:
|
||||
self.auxiliary_head = None
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.drop_path_prob = -1
|
||||
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None: return []
|
||||
else: return list( self.auxiliary_head.parameters() )
|
||||
|
||||
def forward(self, inputs):
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2*self._layers//3:
|
||||
if self.auxiliary_head and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
out = self.global_pooling(s1)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
if self.auxiliary_head and self.training:
|
||||
return logits, logits_aux
|
||||
else:
|
||||
return logits
|
101
lib/nas/ImageNet.py
Normal file
101
lib/nas/ImageNet.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .construct_utils import Cell, Transition
|
||||
|
||||
class AuxiliaryHeadImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
|
||||
# Commenting it out for consistency with the experiments in the paper.
|
||||
# nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class NetworkImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkImageNet, self).__init__()
|
||||
self._layers = layers
|
||||
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C, C, C
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = True
|
||||
for i in range(layers):
|
||||
if i in [layers // 3, 2 * layers // 3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
if reduction and genotype.reduce is None:
|
||||
cell = Transition(C_prev_prev, C_prev, C_curr, reduction_prev)
|
||||
else:
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
|
||||
if i == 2 * layers // 3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
|
||||
else:
|
||||
self.auxiliary_head = None
|
||||
self.global_pooling = nn.AvgPool2d(7)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.drop_path_prob = -1
|
||||
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None: return []
|
||||
else: return list( self.auxiliary_head.parameters() )
|
||||
|
||||
def forward(self, input):
|
||||
s0 = self.stem0(input)
|
||||
s1 = self.stem1(s0)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
#print ('{:} : {:} - {:}'.format(i, s0.size(), s1.size()))
|
||||
if i == 2 * self._layers // 3:
|
||||
if self.auxiliary_head and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0), -1))
|
||||
if self.auxiliary_head and self.training:
|
||||
return logits, logits_aux
|
||||
else:
|
||||
return logits
|
27
lib/nas/SE_Module.py
Normal file
27
lib/nas/SE_Module.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# Squeeze and Excitation module
|
||||
|
||||
class SqEx(nn.Module):
|
||||
|
||||
def __init__(self, n_features, reduction=16):
|
||||
super(SqEx, self).__init__()
|
||||
|
||||
if n_features % reduction != 0:
|
||||
raise ValueError('n_features must be divisible by reduction (default = 16)')
|
||||
|
||||
self.linear1 = nn.Linear(n_features, n_features // reduction, bias=True)
|
||||
self.nonlin1 = nn.ReLU(inplace=True)
|
||||
self.linear2 = nn.Linear(n_features // reduction, n_features, bias=True)
|
||||
self.nonlin2 = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
y = F.avg_pool2d(x, kernel_size=x.size()[2:4])
|
||||
y = y.permute(0, 2, 3, 1)
|
||||
y = self.nonlin1(self.linear1(y))
|
||||
y = self.nonlin2(self.linear2(y))
|
||||
y = y.permute(0, 3, 1, 2)
|
||||
y = x * y
|
||||
return y
|
||||
|
18
lib/nas/__init__.py
Normal file
18
lib/nas/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from .model_search import Network
|
||||
from .model_search_v1 import NetworkV1
|
||||
from .model_search_f1 import NetworkF1
|
||||
# acceleration model
|
||||
from .model_search_f1_acc2 import NetworkFACC1
|
||||
from .model_search_acc2 import NetworkACC2
|
||||
from .model_search_v3 import NetworkV3
|
||||
from .model_search_v4 import NetworkV4
|
||||
from .model_search_v5 import NetworkV5
|
||||
from .CifarNet import NetworkCIFAR
|
||||
from .ImageNet import NetworkImageNet
|
||||
|
||||
# genotypes
|
||||
from .genotypes import DARTS_V1, DARTS_V2
|
||||
from .genotypes import NASNet, PNASNet, AmoebaNet, ENASNet
|
||||
from .genotypes import DMS_V1, DMS_F1, GDAS_CC
|
||||
|
||||
from .construct_utils import return_alphas_str
|
151
lib/nas/construct_utils.py
Normal file
151
lib/nas/construct_utils.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .operations import OPS, FactorizedReduce, ReLUConvBN, Identity
|
||||
|
||||
|
||||
def random_select(length, ratio):
|
||||
clist = []
|
||||
index = random.randint(0, length-1)
|
||||
for i in range(length):
|
||||
if i == index or random.random() < ratio:
|
||||
clist.append( 1 )
|
||||
else:
|
||||
clist.append( 0 )
|
||||
return clist
|
||||
|
||||
|
||||
def all_select(length):
|
||||
return [1 for i in range(length)]
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1. - drop_prob
|
||||
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
||||
mask = mask.bernoulli_(keep_prob)
|
||||
x.div_(keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
def return_alphas_str(basemodel):
|
||||
string = 'normal : {:}'.format( F.softmax(basemodel.alphas_normal, dim=-1) )
|
||||
if hasattr(basemodel, 'alphas_reduce'):
|
||||
string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
|
||||
return string
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
print(C_prev_prev, C_prev, C)
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
if reduction:
|
||||
op_names, indices, values = zip(*genotype.reduce)
|
||||
concat = genotype.reduce_concat
|
||||
else:
|
||||
op_names, indices, values = zip(*genotype.normal)
|
||||
concat = genotype.normal_concat
|
||||
self._compile(C, op_names, indices, values, concat, reduction)
|
||||
|
||||
def _compile(self, C, op_names, indices, values, concat, reduction):
|
||||
assert len(op_names) == len(indices)
|
||||
self._steps = len(op_names) // 2
|
||||
self._concat = concat
|
||||
self.multiplier = len(concat)
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for name, index in zip(op_names, indices):
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
op = OPS[name](C, stride, True)
|
||||
self._ops.append( op )
|
||||
self._indices = indices
|
||||
self._values = values
|
||||
|
||||
def forward(self, s0, s1, drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
|
||||
s = h1 + h2
|
||||
|
||||
states += [s]
|
||||
return torch.cat([states[i] for i in self._concat], dim=1)
|
||||
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
|
||||
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier=4):
|
||||
super(Transition, self).__init__()
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.reduction = True
|
||||
self.ops1 = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
|
||||
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True)),
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
|
||||
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True))])
|
||||
|
||||
self.ops2 = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(C, affine=True)),
|
||||
nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(C, affine=True))])
|
||||
|
||||
|
||||
def forward(self, s0, s1, drop_prob = -1):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
X0 = self.ops1[0] (s0)
|
||||
X1 = self.ops1[1] (s1)
|
||||
if self.training and drop_prob > 0.:
|
||||
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
|
||||
|
||||
X2 = self.ops2[0] (X0+X1)
|
||||
X3 = self.ops2[1] (s1)
|
||||
if self.training and drop_prob > 0.:
|
||||
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
|
||||
return torch.cat([X0, X1, X2, X3], dim=1)
|
203
lib/nas/genotypes.py
Normal file
203
lib/nas/genotypes.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from collections import namedtuple
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
PRIMITIVES = [
|
||||
'none',
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5'
|
||||
]
|
||||
|
||||
NASNet = Genotype(
|
||||
normal = [
|
||||
('sep_conv_5x5', 1, 1.0),
|
||||
('sep_conv_3x3', 0, 1.0),
|
||||
('sep_conv_5x5', 0, 1.0),
|
||||
('sep_conv_3x3', 0, 1.0),
|
||||
('avg_pool_3x3', 1, 1.0),
|
||||
('skip_connect', 0, 1.0),
|
||||
('avg_pool_3x3', 0, 1.0),
|
||||
('avg_pool_3x3', 0, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('skip_connect', 1, 1.0),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
('sep_conv_5x5', 1, 1.0),
|
||||
('sep_conv_7x7', 0, 1.0),
|
||||
('max_pool_3x3', 1, 1.0),
|
||||
('sep_conv_7x7', 0, 1.0),
|
||||
('avg_pool_3x3', 1, 1.0),
|
||||
('sep_conv_5x5', 0, 1.0),
|
||||
('skip_connect', 3, 1.0),
|
||||
('avg_pool_3x3', 2, 1.0),
|
||||
('sep_conv_3x3', 2, 1.0),
|
||||
('max_pool_3x3', 1, 1.0),
|
||||
],
|
||||
reduce_concat = [4, 5, 6],
|
||||
)
|
||||
|
||||
AmoebaNet = Genotype(
|
||||
normal = [
|
||||
('avg_pool_3x3', 0, 1.0),
|
||||
('max_pool_3x3', 1, 1.0),
|
||||
('sep_conv_3x3', 0, 1.0),
|
||||
('sep_conv_5x5', 2, 1.0),
|
||||
('sep_conv_3x3', 0, 1.0),
|
||||
('avg_pool_3x3', 3, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('skip_connect', 1, 1.0),
|
||||
('skip_connect', 0, 1.0),
|
||||
('avg_pool_3x3', 1, 1.0),
|
||||
],
|
||||
normal_concat = [4, 5, 6],
|
||||
reduce = [
|
||||
('avg_pool_3x3', 0, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('max_pool_3x3', 0, 1.0),
|
||||
('sep_conv_7x7', 2, 1.0),
|
||||
('sep_conv_7x7', 0, 1.0),
|
||||
('avg_pool_3x3', 1, 1.0),
|
||||
('max_pool_3x3', 0, 1.0),
|
||||
('max_pool_3x3', 1, 1.0),
|
||||
('conv_7x1_1x7', 0, 1.0),
|
||||
('sep_conv_3x3', 5, 1.0),
|
||||
],
|
||||
reduce_concat = [3, 4, 6]
|
||||
)
|
||||
|
||||
DARTS_V1 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('sep_conv_3x3', 0, 1.0),
|
||||
('skip_connect', 0, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('skip_connect', 0, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('sep_conv_3x3', 0, 1.0),
|
||||
('skip_connect', 2, 1.0)],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
('max_pool_3x3', 0, 1.0),
|
||||
('max_pool_3x3', 1, 1.0),
|
||||
('skip_connect', 2, 1.0),
|
||||
('max_pool_3x3', 0, 1.0),
|
||||
('max_pool_3x3', 0, 1.0),
|
||||
('skip_connect', 2, 1.0),
|
||||
('skip_connect', 2, 1.0),
|
||||
('avg_pool_3x3', 0, 1.0)],
|
||||
reduce_concat=[2, 3, 4, 5]
|
||||
)
|
||||
|
||||
DARTS_V2 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('sep_conv_3x3', 0, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('skip_connect', 0, 1.0),
|
||||
('skip_connect', 0, 1.0),
|
||||
('dil_conv_3x3', 2, 1.0)],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
('max_pool_3x3', 0, 1.0),
|
||||
('max_pool_3x3', 1, 1.0),
|
||||
('skip_connect', 2, 1.0),
|
||||
('max_pool_3x3', 1, 1.0),
|
||||
('max_pool_3x3', 0, 1.0),
|
||||
('skip_connect', 2, 1.0),
|
||||
('skip_connect', 2, 1.0),
|
||||
('max_pool_3x3', 1, 1.0)],
|
||||
reduce_concat=[2, 3, 4, 5]
|
||||
)
|
||||
|
||||
PNASNet = Genotype(
|
||||
normal = [
|
||||
('sep_conv_5x5', 0, 1.0),
|
||||
('max_pool_3x3', 0, 1.0),
|
||||
('sep_conv_7x7', 1, 1.0),
|
||||
('max_pool_3x3', 1, 1.0),
|
||||
('sep_conv_5x5', 1, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('sep_conv_3x3', 4, 1.0),
|
||||
('max_pool_3x3', 1, 1.0),
|
||||
('sep_conv_3x3', 0, 1.0),
|
||||
('skip_connect', 1, 1.0),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
('sep_conv_5x5', 0, 1.0),
|
||||
('max_pool_3x3', 0, 1.0),
|
||||
('sep_conv_7x7', 1, 1.0),
|
||||
('max_pool_3x3', 1, 1.0),
|
||||
('sep_conv_5x5', 1, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('sep_conv_3x3', 4, 1.0),
|
||||
('max_pool_3x3', 1, 1.0),
|
||||
('sep_conv_3x3', 0, 1.0),
|
||||
('skip_connect', 1, 1.0),
|
||||
],
|
||||
reduce_concat = [2, 3, 4, 5, 6],
|
||||
)
|
||||
|
||||
# https://arxiv.org/pdf/1802.03268.pdf
|
||||
ENASNet = Genotype(
|
||||
normal = [
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('skip_connect', 1, 1.0),
|
||||
('sep_conv_5x5', 1, 1.0),
|
||||
('skip_connect', 0, 1.0),
|
||||
('avg_pool_3x3', 0, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('sep_conv_3x3', 0, 1.0),
|
||||
('avg_pool_3x3', 1, 1.0),
|
||||
('sep_conv_5x5', 1, 1.0),
|
||||
('avg_pool_3x3', 0, 1.0),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
('sep_conv_5x5', 0, 1.0),
|
||||
('sep_conv_3x3', 1, 1.0), # 2
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('avg_pool_3x3', 1, 1.0), # 3
|
||||
('sep_conv_3x3', 1, 1.0),
|
||||
('avg_pool_3x3', 1, 1.0), # 4
|
||||
('avg_pool_3x3', 1, 1.0),
|
||||
('sep_conv_5x5', 4, 1.0), # 5
|
||||
('sep_conv_3x3', 5, 1.0),
|
||||
('sep_conv_5x5', 0, 1.0),
|
||||
],
|
||||
reduce_concat = [2, 3, 4, 5, 6],
|
||||
)
|
||||
|
||||
DARTS = DARTS_V2
|
||||
|
||||
# Search by normal and reduce
|
||||
DMS_V1 = Genotype(
|
||||
normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[('sep_conv_5x5', 0, 0.12862831354141235), ('sep_conv_3x3', 1, 0.12783904373645782), ('sep_conv_5x5', 2, 0.12725995481014252), ('sep_conv_5x5', 1, 0.12705285847187042), ('dil_conv_5x5', 2, 0.12797553837299347), ('sep_conv_3x3', 1, 0.12737272679805756), ('sep_conv_5x5', 0, 0.12833961844444275), ('sep_conv_5x5', 1, 0.12758426368236542)],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
# Search by normal and fixing reduction
|
||||
DMS_F1 = Genotype(
|
||||
normal=[('skip_connect', 0, 0.16), ('skip_connect', 1, 0.13), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.16), ('sep_conv_3x3', 2, 0.15)],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=None,
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
)
|
||||
|
||||
# Combine DMS_V1 and DMS_F1
|
||||
GDAS_CC = Genotype(
|
||||
normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=None,
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
19
lib/nas/head_utils.py
Normal file
19
lib/nas/head_utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ImageNetHEAD(nn.Sequential):
|
||||
def __init__(self, C, stride=2):
|
||||
super(ImageNetHEAD, self).__init__()
|
||||
self.add_module('conv1', nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False))
|
||||
self.add_module('bn1' , nn.BatchNorm2d(C // 2))
|
||||
self.add_module('relu1', nn.ReLU(inplace=True))
|
||||
self.add_module('conv2', nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False))
|
||||
self.add_module('bn2' , nn.BatchNorm2d(C))
|
||||
|
||||
|
||||
class CifarHEAD(nn.Sequential):
|
||||
def __init__(self, C):
|
||||
super(CifarHEAD, self).__init__()
|
||||
self.add_module('conv', nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False))
|
||||
self.add_module('bn', nn.BatchNorm2d(C))
|
166
lib/nas/model_search.py
Normal file
166
lib/nas/model_search.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from .head_utils import CifarHEAD, ImageNetHEAD
|
||||
from .operations import OPS, FactorizedReduce, ReLUConvBN
|
||||
from .genotypes import PRIMITIVES, Genotype
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in PRIMITIVES:
|
||||
op = OPS[primitive](C, stride, False)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights):
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, s0, s1, weights):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
offset = 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
x = self._ops[offset+j](h, weights[offset+j])
|
||||
clist.append( x )
|
||||
s = sum(clist)
|
||||
offset += len(states)
|
||||
states.append(s)
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3, head='cifar'):
|
||||
super(Network, self).__init__()
|
||||
self._C = C
|
||||
self._num_classes = num_classes
|
||||
self._layers = layers
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
C_curr = stem_multiplier*C
|
||||
if head == 'cifar':
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
elif head == 'imagenet':
|
||||
self.stem = ImageNetHEAD(C_curr, stride=1)
|
||||
else:
|
||||
raise ValueError('Invalid head : {:}'.format(head))
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev, cells = False, []
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, multiplier*C_curr
|
||||
self.cells = nn.ModuleList(cells)
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
# initialize architecture parameters
|
||||
k = sum(1 for i in range(self._steps) for n in range(2+i))
|
||||
num_ops = len(PRIMITIVES)
|
||||
|
||||
self.alphas_normal = Parameter(torch.Tensor(k, num_ops))
|
||||
self.alphas_reduce = Parameter(torch.Tensor(k, num_ops))
|
||||
nn.init.normal_(self.alphas_normal, 0, 0.001)
|
||||
nn.init.normal_(self.alphas_reduce, 0, 0.001)
|
||||
|
||||
def set_tau(self, tau):
|
||||
return -1
|
||||
|
||||
def get_tau(self):
|
||||
return -1
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.alphas_normal, self.alphas_reduce]
|
||||
|
||||
def base_parameters(self):
|
||||
lists = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
lists += list(self.global_pooling.parameters())
|
||||
lists += list(self.classifier.parameters())
|
||||
return lists
|
||||
|
||||
def forward(self, inputs):
|
||||
batch, C, H, W = inputs.size()
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
weights = F.softmax(self.alphas_reduce, dim=-1)
|
||||
else:
|
||||
weights = F.softmax(self.alphas_normal, dim=-1)
|
||||
s0, s1 = s1, cell(s0, s1, weights)
|
||||
out = self.global_pooling(s1)
|
||||
out = out.view(batch, -1)
|
||||
logits = self.classifier(out)
|
||||
return logits
|
||||
|
||||
def genotype(self):
|
||||
|
||||
def _parse(weights):
|
||||
gene, n, start = [], 2, 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if k != PRIMITIVES.index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[k_best], j, float(W[j][k_best])))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
|
||||
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).cpu().numpy())
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
180
lib/nas/model_search_acc2.py
Normal file
180
lib/nas/model_search_acc2.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# gumbel softmax
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .operations import OPS, FactorizedReduce, ReLUConvBN
|
||||
from .genotypes import PRIMITIVES, Genotype
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in PRIMITIVES:
|
||||
op = OPS[primitive](C, stride, False)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights, cpu_weights):
|
||||
use_sum = sum([abs(_) > 1e-10 for _ in cpu_weights])
|
||||
if use_sum > 3:
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
else:
|
||||
clist = []
|
||||
for j, cpu_weight in enumerate(cpu_weights):
|
||||
if abs(cpu_weight) > 1e-10:
|
||||
clist.append( weights[j] * self._ops[j](x) )
|
||||
assert len(clist) > 0, 'invalid length : {:}'.format(cpu_weights)
|
||||
return sum(clist)
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, s0, s1, weights):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
cpu_weights = weights.tolist()
|
||||
states = [s0, s1]
|
||||
offset = 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
x = self._ops[offset+j](h, weights[offset+j], cpu_weights[offset+j])
|
||||
clist.append( x )
|
||||
s = sum(clist)
|
||||
offset += len(states)
|
||||
states.append(s)
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
class NetworkACC2(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3):
|
||||
super(NetworkACC2, self).__init__()
|
||||
self._C = C
|
||||
self._num_classes = num_classes
|
||||
self._layers = layers
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev, cells = False, []
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, multiplier*C_curr
|
||||
self.cells = nn.ModuleList(cells)
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.tau = 5
|
||||
self.use_gumbel = True
|
||||
|
||||
# initialize architecture parameters
|
||||
k = sum(1 for i in range(self._steps) for n in range(2+i))
|
||||
num_ops = len(PRIMITIVES)
|
||||
|
||||
self.alphas_normal = Parameter(torch.Tensor(k, num_ops))
|
||||
self.alphas_reduce = Parameter(torch.Tensor(k, num_ops))
|
||||
nn.init.normal_(self.alphas_normal, 0, 0.001)
|
||||
nn.init.normal_(self.alphas_reduce, 0, 0.001)
|
||||
|
||||
def set_gumbel(self, use_gumbel):
|
||||
self.use_gumbel = use_gumbel
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.alphas_normal, self.alphas_reduce]
|
||||
|
||||
def base_parameters(self):
|
||||
lists = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
lists += list(self.global_pooling.parameters())
|
||||
lists += list(self.classifier.parameters())
|
||||
return lists
|
||||
|
||||
def forward(self, inputs):
|
||||
batch, C, H, W = inputs.size()
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
if self.use_gumbel : weights = F.gumbel_softmax(self.alphas_reduce, self.tau, True)
|
||||
else : weights = F.softmax(self.alphas_reduce, dim=-1)
|
||||
else:
|
||||
if self.use_gumbel : weights = F.gumbel_softmax(self.alphas_normal, self.tau, True)
|
||||
else : weights = F.softmax(self.alphas_normal, dim=-1)
|
||||
|
||||
s0, s1 = s1, cell(s0, s1, weights)
|
||||
out = self.global_pooling(s1)
|
||||
out = out.view(batch, -1)
|
||||
logits = self.classifier(out)
|
||||
return logits
|
||||
|
||||
def genotype(self):
|
||||
|
||||
def _parse(weights):
|
||||
gene, n, start = [], 2, 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if k != PRIMITIVES.index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[k_best], j, float(W[j][k_best])))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
|
||||
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).cpu().numpy())
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
167
lib/nas/model_search_f1.py
Normal file
167
lib/nas/model_search_f1.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# share parameters
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .operations import OPS, FactorizedReduce, ReLUConvBN
|
||||
from .construct_utils import Transition
|
||||
from .genotypes import PRIMITIVES, Genotype
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in PRIMITIVES:
|
||||
op = OPS[primitive](C, stride, False)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights):
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, s0, s1, weights):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
offset = 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
x = self._ops[offset+j](h, weights[offset+j])
|
||||
clist.append( x )
|
||||
s = sum(clist)
|
||||
offset += len(states)
|
||||
states.append(s)
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
class NetworkF1(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3):
|
||||
super(NetworkF1, self).__init__()
|
||||
self._C = C
|
||||
self._num_classes = num_classes
|
||||
self._layers = layers
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev, cells = False, []
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
if reduction:
|
||||
cell = Transition(C_prev_prev, C_prev, C_curr, reduction_prev, multiplier)
|
||||
else:
|
||||
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, multiplier*C_curr
|
||||
self.cells = nn.ModuleList(cells)
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
# initialize architecture parameters
|
||||
k = sum(1 for i in range(self._steps) for n in range(2+i))
|
||||
num_ops = len(PRIMITIVES)
|
||||
|
||||
self.alphas_normal = Parameter(torch.Tensor(k, num_ops))
|
||||
#self.alphas_reduce = Parameter(torch.Tensor(k, num_ops))
|
||||
nn.init.normal_(self.alphas_normal, 0, 0.001)
|
||||
#nn.init.normal_(self.alphas_reduce, 0, 0.001)
|
||||
|
||||
def set_tau(self, tau):
|
||||
return -1
|
||||
|
||||
def get_tau(self):
|
||||
return -1
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.alphas_normal]
|
||||
|
||||
def base_parameters(self):
|
||||
lists = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
lists += list(self.global_pooling.parameters())
|
||||
lists += list(self.classifier.parameters())
|
||||
return lists
|
||||
|
||||
def forward(self, inputs):
|
||||
batch, C, H, W = inputs.size()
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
s0, s1 = s1, cell(s0, s1)
|
||||
else:
|
||||
weights = F.softmax(self.alphas_normal, dim=-1)
|
||||
s0, s1 = s1, cell(s0, s1, weights)
|
||||
#print('{:} : s0 : {:}, s1 : {:}'.format(i, s0.size(), s1.size()))
|
||||
out = self.global_pooling(s1)
|
||||
out = out.view(batch, -1)
|
||||
logits = self.classifier(out)
|
||||
return logits
|
||||
|
||||
def genotype(self):
|
||||
|
||||
def _parse(weights):
|
||||
gene, n, start = [], 2, 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if k != PRIMITIVES.index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[k_best], j, float(W[j][k_best])))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
|
||||
#gene_reduce = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=None , reduce_concat=concat
|
||||
)
|
||||
return genotype
|
183
lib/nas/model_search_f1_acc2.py
Normal file
183
lib/nas/model_search_f1_acc2.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# share parameters
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .operations import OPS, FactorizedReduce, ReLUConvBN
|
||||
from .construct_utils import Transition
|
||||
from .genotypes import PRIMITIVES, Genotype
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in PRIMITIVES:
|
||||
op = OPS[primitive](C, stride, False)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights, cpu_weights):
|
||||
use_sum = sum([abs(_) > 1e-10 for _ in cpu_weights])
|
||||
if use_sum > 3:
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
else:
|
||||
clist = []
|
||||
for j, cpu_weight in enumerate(cpu_weights):
|
||||
if abs(cpu_weight) > 1e-10:
|
||||
clist.append( weights[j] * self._ops[j](x) )
|
||||
assert len(clist) > 0, 'invalid length : {:}'.format(cpu_weights)
|
||||
return sum(clist)
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, s0, s1, weights):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
cpu_weights = weights.tolist()
|
||||
states = [s0, s1]
|
||||
offset = 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
x = self._ops[offset+j](h, weights[offset+j], cpu_weights[offset+j])
|
||||
clist.append( x )
|
||||
s = sum(clist)
|
||||
offset += len(states)
|
||||
states.append(s)
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
class NetworkFACC1(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3):
|
||||
super(NetworkFACC1, self).__init__()
|
||||
self._C = C
|
||||
self._num_classes = num_classes
|
||||
self._layers = layers
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.tau = 5
|
||||
self.use_gumbel = True
|
||||
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev, cells = False, []
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
if reduction:
|
||||
cell = Transition(C_prev_prev, C_prev, C_curr, reduction_prev, multiplier)
|
||||
else:
|
||||
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, multiplier*C_curr
|
||||
self.cells = nn.ModuleList(cells)
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
# initialize architecture parameters
|
||||
k = sum(1 for i in range(self._steps) for n in range(2+i))
|
||||
num_ops = len(PRIMITIVES)
|
||||
|
||||
self.alphas_normal = Parameter(torch.Tensor(k, num_ops))
|
||||
#self.alphas_reduce = Parameter(torch.Tensor(k, num_ops))
|
||||
nn.init.normal_(self.alphas_normal, 0, 0.001)
|
||||
#nn.init.normal_(self.alphas_reduce, 0, 0.001)
|
||||
|
||||
def set_gumbel(self, use_gumbel):
|
||||
self.use_gumbel = use_gumbel
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.alphas_normal]
|
||||
|
||||
def base_parameters(self):
|
||||
lists = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
lists += list(self.global_pooling.parameters())
|
||||
lists += list(self.classifier.parameters())
|
||||
return lists
|
||||
|
||||
def forward(self, inputs):
|
||||
batch, C, H, W = inputs.size()
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
s0, s1 = s1, cell(s0, s1)
|
||||
else:
|
||||
if self.use_gumbel : weights = F.gumbel_softmax(self.alphas_normal, self.tau, True)
|
||||
else : weights = F.softmax(self.alphas_normal, dim=-1)
|
||||
s0, s1 = s1, cell(s0, s1, weights)
|
||||
#print('{:} : s0 : {:}, s1 : {:}'.format(i, s0.size(), s1.size()))
|
||||
out = self.global_pooling(s1)
|
||||
out = out.view(batch, -1)
|
||||
logits = self.classifier(out)
|
||||
return logits
|
||||
|
||||
def genotype(self):
|
||||
|
||||
def _parse(weights):
|
||||
gene, n, start = [], 2, 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if k != PRIMITIVES.index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[k_best], j, float(W[j][k_best])))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
|
||||
#gene_reduce = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=None , reduce_concat=concat
|
||||
)
|
||||
return genotype
|
161
lib/nas/model_search_v1.py
Normal file
161
lib/nas/model_search_v1.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# share parameters
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from .operations import OPS, FactorizedReduce, ReLUConvBN
|
||||
from .genotypes import PRIMITIVES, Genotype
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in PRIMITIVES:
|
||||
op = OPS[primitive](C, stride, False)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights):
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, s0, s1, weights):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
offset = 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
x = self._ops[offset+j](h, weights[offset+j])
|
||||
clist.append( x )
|
||||
s = sum(clist)
|
||||
offset += len(states)
|
||||
states.append(s)
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
class NetworkV1(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3):
|
||||
super(NetworkV1, self).__init__()
|
||||
self._C = C
|
||||
self._num_classes = num_classes
|
||||
self._layers = layers
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev, cells = False, []
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, multiplier*C_curr
|
||||
self.cells = nn.ModuleList(cells)
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
# initialize architecture parameters
|
||||
k = sum(1 for i in range(self._steps) for n in range(2+i))
|
||||
num_ops = len(PRIMITIVES)
|
||||
|
||||
self.alphas_normal = Parameter(torch.Tensor(k, num_ops))
|
||||
#self.alphas_reduce = Parameter(torch.Tensor(k, num_ops))
|
||||
nn.init.normal_(self.alphas_normal, 0, 0.001)
|
||||
#nn.init.normal_(self.alphas_reduce, 0, 0.001)
|
||||
|
||||
def set_tau(self, tau):
|
||||
return -1
|
||||
|
||||
def get_tau(self):
|
||||
return -1
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.alphas_normal]
|
||||
|
||||
def base_parameters(self):
|
||||
lists = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
lists += list(self.global_pooling.parameters())
|
||||
lists += list(self.classifier.parameters())
|
||||
return lists
|
||||
|
||||
def forward(self, inputs):
|
||||
batch, C, H, W = inputs.size()
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
weights = F.softmax(self.alphas_normal, dim=-1)
|
||||
else:
|
||||
weights = F.softmax(self.alphas_normal, dim=-1)
|
||||
s0, s1 = s1, cell(s0, s1, weights)
|
||||
out = self.global_pooling(s1)
|
||||
out = out.view(batch, -1)
|
||||
logits = self.classifier(out)
|
||||
return logits
|
||||
|
||||
def genotype(self):
|
||||
|
||||
def _parse(weights):
|
||||
gene, n, start = [], 2, 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if k != PRIMITIVES.index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[k_best], j, float(W[j][k_best])))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
|
||||
gene_reduce = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
171
lib/nas/model_search_v3.py
Normal file
171
lib/nas/model_search_v3.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# random selection
|
||||
import torch
|
||||
import random
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from .operations import OPS, FactorizedReduce, ReLUConvBN
|
||||
from .genotypes import PRIMITIVES, Genotype
|
||||
from .construct_utils import random_select, all_select
|
||||
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in PRIMITIVES:
|
||||
op = OPS[primitive](C, stride, False)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights, cpu_weights):
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, s0, s1, weights):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
cpu_weights = weights.tolist()
|
||||
states = [s0, s1]
|
||||
offset = 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
if i == 0:
|
||||
indicator = all_select( len(states) )
|
||||
else:
|
||||
indicator = random_select( len(states), 0.5 )
|
||||
for j, h in enumerate(states):
|
||||
if indicator[j] == 0: continue
|
||||
x = self._ops[offset+j](h, weights[offset+j], cpu_weights[offset+j])
|
||||
clist.append( x )
|
||||
s = sum(clist) / sum(indicator)
|
||||
offset += len(states)
|
||||
states.append(s)
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
class NetworkV3(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3):
|
||||
super(NetworkV3, self).__init__()
|
||||
self._C = C
|
||||
self._num_classes = num_classes
|
||||
self._layers = layers
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev, cells = False, []
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, multiplier*C_curr
|
||||
self.cells = nn.ModuleList(cells)
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.tau = 5
|
||||
|
||||
# initialize architecture parameters
|
||||
k = sum(1 for i in range(self._steps) for n in range(2+i))
|
||||
num_ops = len(PRIMITIVES)
|
||||
|
||||
self.alphas_normal = Parameter(torch.Tensor(k, num_ops))
|
||||
self.alphas_reduce = Parameter(torch.Tensor(k, num_ops))
|
||||
nn.init.normal_(self.alphas_normal, 0, 0.001)
|
||||
nn.init.normal_(self.alphas_reduce, 0, 0.001)
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.alphas_normal, self.alphas_reduce]
|
||||
|
||||
def base_parameters(self):
|
||||
lists = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
lists += list(self.global_pooling.parameters())
|
||||
lists += list(self.classifier.parameters())
|
||||
return lists
|
||||
|
||||
def forward(self, inputs):
|
||||
batch, C, H, W = inputs.size()
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
weights = F.softmax(self.alphas_reduce, dim=-1)
|
||||
else:
|
||||
weights = F.softmax(self.alphas_reduce, dim=-1)
|
||||
s0, s1 = s1, cell(s0, s1, weights)
|
||||
out = self.global_pooling(s1)
|
||||
out = out.view(batch, -1)
|
||||
logits = self.classifier(out)
|
||||
return logits
|
||||
|
||||
def genotype(self):
|
||||
|
||||
def _parse(weights):
|
||||
gene, n, start = [], 2, 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if k != PRIMITIVES.index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[k_best], j, float(W[j][k_best])))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
|
||||
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).cpu().numpy())
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
176
lib/nas/model_search_v4.py
Normal file
176
lib/nas/model_search_v4.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# random selection
|
||||
import torch
|
||||
import random
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from .operations import OPS, FactorizedReduce, ReLUConvBN
|
||||
from .genotypes import PRIMITIVES, Genotype
|
||||
from .construct_utils import random_select, all_select
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in PRIMITIVES:
|
||||
op = OPS[primitive](C, stride, False)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights, cpu_weights):
|
||||
indicators = random_select( len(cpu_weights), 0.5 )
|
||||
clist, ws = [], []
|
||||
for w, indicator, op in zip(weights, indicators, self._ops):
|
||||
if indicator:
|
||||
clist.append( w * op(x) )
|
||||
ws.append( w )
|
||||
return sum(clist) / sum(ws)
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, s0, s1, weights):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
cpu_weights = weights.tolist()
|
||||
states = [s0, s1]
|
||||
offset = 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
if i == 0:
|
||||
indicator = all_select( len(states) )
|
||||
else:
|
||||
indicator = random_select( len(states), 0.5 )
|
||||
for j, h in enumerate(states):
|
||||
if indicator[j] == 0: continue
|
||||
x = self._ops[offset+j](h, weights[offset+j], cpu_weights[offset+j])
|
||||
clist.append( x )
|
||||
s = sum(clist) / sum(indicator)
|
||||
offset += len(states)
|
||||
states.append(s)
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
class NetworkV4(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3):
|
||||
super(NetworkV4, self).__init__()
|
||||
self._C = C
|
||||
self._num_classes = num_classes
|
||||
self._layers = layers
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev, cells = False, []
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, multiplier*C_curr
|
||||
self.cells = nn.ModuleList(cells)
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.tau = 5
|
||||
|
||||
# initialize architecture parameters
|
||||
k = sum(1 for i in range(self._steps) for n in range(2+i))
|
||||
num_ops = len(PRIMITIVES)
|
||||
|
||||
self.alphas_normal = Parameter(torch.Tensor(k, num_ops))
|
||||
self.alphas_reduce = Parameter(torch.Tensor(k, num_ops))
|
||||
nn.init.normal_(self.alphas_normal, 0, 0.001)
|
||||
nn.init.normal_(self.alphas_reduce, 0, 0.001)
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.alphas_normal, self.alphas_reduce]
|
||||
|
||||
def base_parameters(self):
|
||||
lists = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
lists += list(self.global_pooling.parameters())
|
||||
lists += list(self.classifier.parameters())
|
||||
return lists
|
||||
|
||||
def forward(self, inputs):
|
||||
batch, C, H, W = inputs.size()
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
weights = F.softmax(self.alphas_reduce, dim=-1)
|
||||
else:
|
||||
weights = F.softmax(self.alphas_reduce, dim=-1)
|
||||
s0, s1 = s1, cell(s0, s1, weights)
|
||||
out = self.global_pooling(s1)
|
||||
out = out.view(batch, -1)
|
||||
logits = self.classifier(out)
|
||||
return logits
|
||||
|
||||
def genotype(self):
|
||||
|
||||
def _parse(weights):
|
||||
gene, n, start = [], 2, 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if k != PRIMITIVES.index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[k_best], j, float(W[j][k_best])))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
|
||||
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).cpu().numpy())
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
174
lib/nas/model_search_v5.py
Normal file
174
lib/nas/model_search_v5.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# gumbel softmax
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from .operations import OPS, FactorizedReduce, ReLUConvBN
|
||||
from .genotypes import PRIMITIVES, Genotype
|
||||
from .construct_utils import random_select, all_select
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in PRIMITIVES:
|
||||
op = OPS[primitive](C, stride, False)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights, cpu_weights):
|
||||
clist = []
|
||||
for j, cpu_weight in enumerate(cpu_weights):
|
||||
if abs(cpu_weight) > 1e-10:
|
||||
clist.append( weights[j] * self._ops[j](x) )
|
||||
assert len(clist) > 0, 'invalid length : {:}'.format(cpu_weights)
|
||||
if len(clist) == 1: return clist[0]
|
||||
else : return sum(clist)
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, s0, s1, weights):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
cpu_weights = weights.tolist()
|
||||
states = [s0, s1]
|
||||
offset = 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
if i == 0: indicator = all_select( len(states) )
|
||||
else : indicator = random_select( len(states), 0.6 )
|
||||
|
||||
for j, h in enumerate(states):
|
||||
if indicator[j] == 0: continue
|
||||
x = self._ops[offset+j](h, weights[offset+j], cpu_weights[offset+j])
|
||||
clist.append( x )
|
||||
s = sum(clist)
|
||||
offset += len(states)
|
||||
states.append(s)
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
class NetworkV5(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3):
|
||||
super(NetworkV5, self).__init__()
|
||||
self._C = C
|
||||
self._num_classes = num_classes
|
||||
self._layers = layers
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev, cells = False, []
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, multiplier*C_curr
|
||||
self.cells = nn.ModuleList(cells)
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.tau = 5
|
||||
|
||||
# initialize architecture parameters
|
||||
k = sum(1 for i in range(self._steps) for n in range(2+i))
|
||||
num_ops = len(PRIMITIVES)
|
||||
|
||||
self.alphas_normal = Parameter(torch.Tensor(k, num_ops))
|
||||
self.alphas_reduce = Parameter(torch.Tensor(k, num_ops))
|
||||
nn.init.normal_(self.alphas_normal, 0, 0.001)
|
||||
nn.init.normal_(self.alphas_reduce, 0, 0.001)
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.alphas_normal, self.alphas_reduce]
|
||||
|
||||
def base_parameters(self):
|
||||
lists = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
lists += list(self.global_pooling.parameters())
|
||||
lists += list(self.classifier.parameters())
|
||||
return lists
|
||||
|
||||
def forward(self, inputs):
|
||||
batch, C, H, W = inputs.size()
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
weights = F.gumbel_softmax(self.alphas_reduce, self.tau, True)
|
||||
else:
|
||||
weights = F.gumbel_softmax(self.alphas_normal, self.tau, True)
|
||||
s0, s1 = s1, cell(s0, s1, weights)
|
||||
out = self.global_pooling(s1)
|
||||
out = out.view(batch, -1)
|
||||
logits = self.classifier(out)
|
||||
return logits
|
||||
|
||||
def genotype(self):
|
||||
|
||||
def _parse(weights):
|
||||
gene, n, start = [], 2, 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if k != PRIMITIVES.index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[k_best], j, float(W[j][k_best])))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy())
|
||||
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).cpu().numpy())
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
122
lib/nas/operations.py
Normal file
122
lib/nas/operations.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
OPS = {
|
||||
'none' : lambda C, stride, affine: Zero(stride),
|
||||
'avg_pool_3x3' : lambda C, stride, affine: nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
|
||||
nn.BatchNorm2d(C, affine=False) ),
|
||||
'max_pool_3x3' : lambda C, stride, affine: nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=stride, padding=1),
|
||||
nn.BatchNorm2d(C, affine=False) ),
|
||||
'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
|
||||
'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
|
||||
'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
|
||||
'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
|
||||
'conv_7x1_1x7' : lambda C, stride, affine: Conv717(C, C, stride, affine),
|
||||
}
|
||||
|
||||
class Conv717(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine):
|
||||
super(Conv717, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in , C_out, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
|
||||
nn.Conv2d(C_out, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
|
||||
super(DilConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, affine=True):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
assert C_out % 2 == 0
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
Reference in New Issue
Block a user