Add more algorithms
This commit is contained in:
76
lib/nas_infer_model/DXYs/CifarNet.py
Normal file
76
lib/nas_infer_model/DXYs/CifarNet.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .construct_utils import drop_path
|
||||
from .head_utils import CifarHEAD, AuxiliaryHeadCIFAR
|
||||
from .base_cells import InferCell
|
||||
|
||||
|
||||
class NetworkCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, N, stem_multiplier, auxiliary, genotype, num_classes):
|
||||
super(NetworkCIFAR, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._stem_multiplier = stem_multiplier
|
||||
|
||||
C_curr = self._stem_multiplier * C
|
||||
self.stem = CifarHEAD(C_curr)
|
||||
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
block_indexs = [0 ] * N + [-1 ] + [1 ] * N + [-1 ] + [2 ] * N
|
||||
block2index = {0:[], 1:[], 2:[]}
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev, spatial, dims = False, 1, []
|
||||
self.auxiliary_index = None
|
||||
self.auxiliary_head = None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, cell._multiplier*C_curr
|
||||
if reduction and C_curr == C*4:
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
|
||||
self.auxiliary_index = index
|
||||
|
||||
if reduction: spatial *= 2
|
||||
dims.append( (C_prev, spatial) )
|
||||
|
||||
self._Layer= len(self.cells)
|
||||
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.drop_path_prob = -1
|
||||
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None: return []
|
||||
else: return list( self.auxiliary_head.parameters() )
|
||||
|
||||
def get_message(self):
|
||||
return self.extra_repr()
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, N={_layerN}, L={_Layer}, stem={_stem_multiplier}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, inputs):
|
||||
stem_feature, logits_aux = self.stem(inputs), None
|
||||
cell_results = [stem_feature, stem_feature]
|
||||
for i, cell in enumerate(self.cells):
|
||||
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
|
||||
cell_results.append( cell_feature )
|
||||
|
||||
if self.auxiliary_index is not None and i == self.auxiliary_index and self.training:
|
||||
logits_aux = self.auxiliary_head( cell_results[-1] )
|
||||
out = self.global_pooling( cell_results[-1] )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
if logits_aux is None: return out, logits
|
||||
else : return out, [logits, logits_aux]
|
77
lib/nas_infer_model/DXYs/ImageNet.py
Normal file
77
lib/nas_infer_model/DXYs/ImageNet.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .construct_utils import drop_path
|
||||
from .base_cells import InferCell
|
||||
from .head_utils import ImageNetHEAD, AuxiliaryHeadImageNet
|
||||
|
||||
|
||||
class NetworkImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, N, auxiliary, genotype, num_classes):
|
||||
super(NetworkImageNet, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = C, C, C, True
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
self.auxiliary_index = None
|
||||
for i, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell._multiplier * C_curr
|
||||
if reduction and C_curr == C*4:
|
||||
C_to_auxiliary = C_prev
|
||||
self.auxiliary_index = i
|
||||
|
||||
self._NNN = len(self.cells)
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
|
||||
else:
|
||||
self.auxiliary_head = None
|
||||
self.global_pooling = nn.AvgPool2d(7)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.drop_path_prob = -1
|
||||
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, N=[{_layerN}, {_NNN}], aux-index={auxiliary_index}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def get_message(self):
|
||||
return self.extra_repr()
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None: return []
|
||||
else: return list( self.auxiliary_head.parameters() )
|
||||
|
||||
def forward(self, inputs):
|
||||
s0 = self.stem0(inputs)
|
||||
s1 = self.stem1(s0)
|
||||
logits_aux = None
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == self.auxiliary_index and self.auxiliary_head and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0), -1))
|
||||
|
||||
if logits_aux is None: return out, logits
|
||||
else : return out, [logits, logits_aux]
|
4
lib/nas_infer_model/DXYs/__init__.py
Normal file
4
lib/nas_infer_model/DXYs/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Performance-Aware Template Network for One-Shot Neural Architecture Search
|
||||
from .CifarNet import NetworkCIFAR as CifarNet
|
||||
from .ImageNet import NetworkImageNet as ImageNet
|
||||
from .genotypes import Networks
|
173
lib/nas_infer_model/DXYs/base_cells.py
Normal file
173
lib/nas_infer_model/DXYs/base_cells.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import math
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .construct_utils import drop_path
|
||||
from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride, PRIMITIVES):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
self.name2idx = {}
|
||||
for idx, primitive in enumerate(PRIMITIVES):
|
||||
op = OPS[primitive](C, C, stride, False)
|
||||
self._ops.append(op)
|
||||
assert primitive not in self.name2idx, '{:} has already in'.format(primitive)
|
||||
self.name2idx[primitive] = idx
|
||||
|
||||
def forward(self, x, weights, op_name):
|
||||
if op_name is None:
|
||||
if weights is None:
|
||||
return [op(x) for op in self._ops]
|
||||
else:
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
else:
|
||||
op_index = self.name2idx[op_name]
|
||||
return self._ops[op_index](x)
|
||||
|
||||
|
||||
|
||||
class SearchCell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual):
|
||||
super(SearchCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.PRIMITIVES = deepcopy(PRIMITIVES)
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self._use_residual = use_residual
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride, self.PRIMITIVES)
|
||||
self._ops.append(op)
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes):
|
||||
if modes[0] is None:
|
||||
if modes[1] == 'normal':
|
||||
output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob)
|
||||
elif modes[1] == 'only_W':
|
||||
output = self.__forwardOnlyW(S0, S1, drop_prob)
|
||||
else:
|
||||
test_genotype = modes[0]
|
||||
if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat
|
||||
else : operations, concats = test_genotype.normal, test_genotype.normal_concat
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations))
|
||||
for i, (opA, opB) in enumerate(operations):
|
||||
A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0])
|
||||
B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0])
|
||||
state = A + B
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
output = torch.cat([states[i] for i in concats], dim=1)
|
||||
if self._use_residual and S1.size() == output.size():
|
||||
return S1 + output
|
||||
else: return output
|
||||
|
||||
def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob):
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
x = self._ops[offset+j](h, weights[offset+j], None)
|
||||
if self.training and drop_prob > 0.:
|
||||
x = drop_path(x, math.pow(drop_prob, 1./len(states)))
|
||||
clist.append( x )
|
||||
connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0)
|
||||
state = sum(w * node for w, node in zip(connection, clist))
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
def __forwardOnlyW(self, S0, S1, drop_prob):
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
xs = self._ops[offset+j](h, None, None)
|
||||
clist += xs
|
||||
if self.training and drop_prob > 0.:
|
||||
xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist]
|
||||
else: xlist = clist
|
||||
state = sum(xlist) * 2 / len(xlist)
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
|
||||
class InferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(InferCell, self).__init__()
|
||||
print(C_prev_prev, C_prev, C)
|
||||
|
||||
if reduction_prev is None:
|
||||
self.preprocess0 = Identity()
|
||||
elif reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat
|
||||
else : step_ops, concat = genotype.normal, genotype.normal_concat
|
||||
self._steps = len(step_ops)
|
||||
self._concat = concat
|
||||
self._multiplier = len(concat)
|
||||
self._ops = nn.ModuleList()
|
||||
self._indices = []
|
||||
for operations in step_ops:
|
||||
for name, index in operations:
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
if reduction_prev is None and index == 0:
|
||||
op = OPS[name](C_prev_prev, C, stride, True)
|
||||
else:
|
||||
op = OPS[name](C , C, stride, True)
|
||||
self._ops.append( op )
|
||||
self._indices.append( index )
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, S0, S1, drop_prob):
|
||||
s0 = self.preprocess0(S0)
|
||||
s1 = self.preprocess1(S1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
|
||||
state = h1 + h2
|
||||
states += [state]
|
||||
output = torch.cat([states[i] for i in self._concat], dim=1)
|
||||
return output
|
60
lib/nas_infer_model/DXYs/construct_utils.py
Normal file
60
lib/nas_infer_model/DXYs/construct_utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1. - drop_prob
|
||||
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
||||
mask = mask.bernoulli_(keep_prob)
|
||||
x = torch.div(x, keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
def return_alphas_str(basemodel):
|
||||
if hasattr(basemodel, 'alphas_normal'):
|
||||
string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) )
|
||||
else: string = ''
|
||||
if hasattr(basemodel, 'alphas_reduce'):
|
||||
string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
|
||||
|
||||
if hasattr(basemodel, 'get_adjacency'):
|
||||
adjacency = basemodel.get_adjacency()
|
||||
for i in range( len(adjacency) ):
|
||||
weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 )
|
||||
adj = torch.mm(weight, adjacency[i]).view(-1)
|
||||
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
|
||||
string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj))
|
||||
for i in range( len(adjacency) ):
|
||||
weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 )
|
||||
adj = torch.mm(weight, adjacency[i]).view(-1)
|
||||
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
|
||||
string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj))
|
||||
|
||||
if hasattr(basemodel, 'alphas_connect'):
|
||||
weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu()
|
||||
ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()]
|
||||
IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()]
|
||||
string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN )
|
||||
else:
|
||||
string = string + '\nconnect = None'
|
||||
|
||||
if hasattr(basemodel, 'get_gcn_out'):
|
||||
outputs = basemodel.get_gcn_out(True)
|
||||
for i, output in enumerate(outputs):
|
||||
string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) )
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def remove_duplicate_archs(all_archs):
|
||||
archs = []
|
||||
str_archs = ['{:}'.format(x) for x in all_archs]
|
||||
for i, arch_x in enumerate(str_archs):
|
||||
choose = True
|
||||
for j in range(i):
|
||||
if arch_x == str_archs[j]:
|
||||
choose = False; break
|
||||
if choose: archs.append(all_archs[i])
|
||||
return archs
|
172
lib/nas_infer_model/DXYs/genotypes.py
Normal file
172
lib/nas_infer_model/DXYs/genotypes.py
Normal file
@@ -0,0 +1,172 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from collections import namedtuple
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat connectN connects')
|
||||
#Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
PRIMITIVES_small = [
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'conv_3x1_1x3',
|
||||
]
|
||||
|
||||
PRIMITIVES_large = [
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5',
|
||||
'conv_3x1_1x3',
|
||||
]
|
||||
|
||||
PRIMITIVES_huge = [
|
||||
'skip_connect',
|
||||
'nor_conv_1x1',
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'nor_conv_3x3',
|
||||
'sep_conv_3x3',
|
||||
'dil_conv_3x3',
|
||||
'conv_3x1_1x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'conv_7x1_1x7',
|
||||
'att_squeeze',
|
||||
]
|
||||
|
||||
PRIMITIVES = {'small': PRIMITIVES_small,
|
||||
'large': PRIMITIVES_large,
|
||||
'huge' : PRIMITIVES_huge}
|
||||
|
||||
NASNet = Genotype(
|
||||
normal = [
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 0)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 0)),
|
||||
(('avg_pool_3x3', 1), ('skip_connect', 0)),
|
||||
(('avg_pool_3x3', 0), ('avg_pool_3x3', 0)),
|
||||
(('sep_conv_3x3', 1), ('skip_connect', 1)),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
(('sep_conv_5x5', 1), ('sep_conv_7x7', 0)),
|
||||
(('max_pool_3x3', 1), ('sep_conv_7x7', 0)),
|
||||
(('avg_pool_3x3', 1), ('sep_conv_5x5', 0)),
|
||||
(('skip_connect', 3), ('avg_pool_3x3', 2)),
|
||||
(('sep_conv_3x3', 2), ('max_pool_3x3', 1)),
|
||||
],
|
||||
reduce_concat = [4, 5, 6],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
PNASNet = Genotype(
|
||||
normal = [
|
||||
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
|
||||
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 1)),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
|
||||
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 1)),
|
||||
],
|
||||
reduce_concat = [2, 3, 4, 5, 6],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
|
||||
DARTS_V1 = Genotype(
|
||||
normal=[
|
||||
(('sep_conv_3x3', 1), ('sep_conv_3x3', 0)), # step 1
|
||||
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 2
|
||||
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 3
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 2)) # step 4
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
|
||||
(('skip_connect', 2), ('max_pool_3x3', 0)), # step 2
|
||||
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
|
||||
(('skip_connect', 2), ('avg_pool_3x3', 0)) # step 4
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
# DARTS: Differentiable Architecture Search, ICLR 2019
|
||||
DARTS_V2 = Genotype(
|
||||
normal=[
|
||||
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 1
|
||||
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 2
|
||||
(('sep_conv_3x3', 1), ('skip_connect', 0)), # step 3
|
||||
(('skip_connect', 0), ('dil_conv_3x3', 2)) # step 4
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
|
||||
(('skip_connect', 2), ('max_pool_3x3', 1)), # step 2
|
||||
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
|
||||
(('skip_connect', 2), ('max_pool_3x3', 1)) # step 4
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
||||
SETN = Genotype(
|
||||
normal=[
|
||||
(('skip_connect', 0), ('sep_conv_5x5', 1)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_5x5', 3)),
|
||||
(('max_pool_3x3', 1), ('conv_3x1_1x3', 4))],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('sep_conv_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('skip_connect', 1))],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None
|
||||
)
|
||||
|
||||
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019
|
||||
GDAS_V1 = Genotype(
|
||||
normal=[
|
||||
(('skip_connect', 0), ('skip_connect', 1)),
|
||||
(('skip_connect', 0), ('sep_conv_5x5', 2)),
|
||||
(('sep_conv_3x3', 3), ('skip_connect', 0)),
|
||||
(('sep_conv_5x5', 4), ('sep_conv_3x3', 3))],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 2), ('sep_conv_5x5', 1)),
|
||||
(('dil_conv_5x5', 2), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_5x5', 1))],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None
|
||||
)
|
||||
|
||||
|
||||
|
||||
Networks = {'DARTS_V1': DARTS_V1,
|
||||
'DARTS_V2': DARTS_V2,
|
||||
'DARTS' : DARTS_V2,
|
||||
'NASNet' : NASNet,
|
||||
'GDAS_V1' : GDAS_V1,
|
||||
'PNASNet' : PNASNet,
|
||||
'SETN' : SETN,
|
||||
}
|
65
lib/nas_infer_model/DXYs/head_utils.py
Normal file
65
lib/nas_infer_model/DXYs/head_utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ImageNetHEAD(nn.Sequential):
|
||||
def __init__(self, C, stride=2):
|
||||
super(ImageNetHEAD, self).__init__()
|
||||
self.add_module('conv1', nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False))
|
||||
self.add_module('bn1' , nn.BatchNorm2d(C // 2))
|
||||
self.add_module('relu1', nn.ReLU(inplace=True))
|
||||
self.add_module('conv2', nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False))
|
||||
self.add_module('bn2' , nn.BatchNorm2d(C))
|
||||
|
||||
|
||||
class CifarHEAD(nn.Sequential):
|
||||
def __init__(self, C):
|
||||
super(CifarHEAD, self).__init__()
|
||||
self.add_module('conv', nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False))
|
||||
self.add_module('bn', nn.BatchNorm2d(C))
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
class AuxiliaryHeadImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
16
lib/nas_infer_model/__init__.py
Normal file
16
lib/nas_infer_model/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
|
||||
def obtain_nas_infer_model(config):
|
||||
if config.arch == 'dxys':
|
||||
from .DXYs import CifarNet, ImageNet, Networks
|
||||
genotype = Networks[config.genotype]
|
||||
if config.dataset == 'cifar':
|
||||
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
|
||||
elif config.dataset == 'imagenet':
|
||||
return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num)
|
||||
else: raise ValueError('invalid dataset : {:}'.format(config.dataset))
|
||||
else:
|
||||
raise ValueError('invalid nas arch type : {:}'.format(config.arch))
|
180
lib/nas_infer_model/operations.py
Normal file
180
lib/nas_infer_model/operations.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
OPS = {
|
||||
'none' : lambda C_in, C_out, stride, affine: Zero(stride),
|
||||
'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'),
|
||||
'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'),
|
||||
'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), affine),
|
||||
'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), affine),
|
||||
'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), affine),
|
||||
'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine),
|
||||
'sep_conv_3x3' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine=affine),
|
||||
'sep_conv_7x7' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 7, stride, 3, affine=affine),
|
||||
'dil_conv_3x3' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine=affine),
|
||||
'conv_7x1_1x7' : lambda C_in, C_out, stride, affine: Conv717(C_in, C_out, stride, affine),
|
||||
'conv_3x1_1x3' : lambda C_in, C_out, stride, affine: Conv313(C_in, C_out, stride, affine)
|
||||
}
|
||||
|
||||
|
||||
class POOLING(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, mode):
|
||||
super(POOLING, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0)
|
||||
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.preprocess is not None:
|
||||
x = self.preprocess(inputs)
|
||||
else: x = inputs
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Conv313(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine):
|
||||
super(Conv313, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in , C_out, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
|
||||
nn.Conv2d(C_out, C_out, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Conv717(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine):
|
||||
super(Conv717, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in , C_out, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
|
||||
nn.Conv2d(C_out, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
|
||||
super(DilConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride= 1, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'stride={stride}'.format(**self.__dict__)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine=True):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
self.stride = stride
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
if stride == 2:
|
||||
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
|
||||
C_outs = [C_out // 2, C_out - C_out // 2]
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(2):
|
||||
self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) )
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
elif stride == 4:
|
||||
assert C_out % 4 == 0, 'C_out : {:}'.format(C_out)
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(4):
|
||||
self.convs.append( nn.Conv2d(C_in, C_out // 4, 1, stride=stride, padding=0, bias=False) )
|
||||
self.pad = nn.ConstantPad2d((0, 3, 0, 3), 0)
|
||||
else:
|
||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
if self.stride == 2:
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
|
||||
else:
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:-2,1:-2]),
|
||||
self.convs[2](y[:,:,2:-1,2:-1]), self.convs[3](y[:,:,3:,3:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
def extra_repr(self):
|
||||
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
Reference in New Issue
Block a user