Add more algorithms
This commit is contained in:
3
others/GDAS/paddlepaddle/lib/models/__init__.py
Normal file
3
others/GDAS/paddlepaddle/lib/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .genotypes import Networks
|
||||
from .nas_net import NASCifarNet
|
||||
from .resnet import resnet_cifar
|
175
others/GDAS/paddlepaddle/lib/models/genotypes.py
Normal file
175
others/GDAS/paddlepaddle/lib/models/genotypes.py
Normal file
@@ -0,0 +1,175 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from collections import namedtuple
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
NASNet = Genotype(
|
||||
normal = [
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 0)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 0)),
|
||||
(('avg_pool_3x3', 1), ('skip_connect', 0)),
|
||||
(('avg_pool_3x3', 0), ('avg_pool_3x3', 0)),
|
||||
(('sep_conv_3x3', 1), ('skip_connect', 1)),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
(('sep_conv_5x5', 1), ('sep_conv_7x7', 0)),
|
||||
(('max_pool_3x3', 1), ('sep_conv_7x7', 0)),
|
||||
(('avg_pool_3x3', 1), ('sep_conv_5x5', 0)),
|
||||
(('skip_connect', 3), ('avg_pool_3x3', 2)),
|
||||
(('sep_conv_3x3', 2), ('max_pool_3x3', 1)),
|
||||
],
|
||||
reduce_concat = [4, 5, 6],
|
||||
)
|
||||
|
||||
|
||||
# Progressive Neural Architecture Search, ECCV 2018
|
||||
PNASNet = Genotype(
|
||||
normal = [
|
||||
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
|
||||
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 1)),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
|
||||
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 1)),
|
||||
],
|
||||
reduce_concat = [2, 3, 4, 5, 6],
|
||||
)
|
||||
|
||||
|
||||
# Regularized Evolution for Image Classifier Architecture Search, AAAI 2019
|
||||
AmoebaNet = Genotype(
|
||||
normal = [
|
||||
(('avg_pool_3x3', 0), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_3x3', 0), ('sep_conv_5x5', 2)),
|
||||
(('sep_conv_3x3', 0), ('avg_pool_3x3', 3)),
|
||||
(('sep_conv_3x3', 1), ('skip_connect', 1)),
|
||||
(('skip_connect', 0), ('avg_pool_3x3', 1)),
|
||||
],
|
||||
normal_concat = [4, 5, 6],
|
||||
reduce = [
|
||||
(('avg_pool_3x3', 0), ('sep_conv_3x3', 1)),
|
||||
(('max_pool_3x3', 0), ('sep_conv_7x7', 2)),
|
||||
(('sep_conv_7x7', 0), ('avg_pool_3x3', 1)),
|
||||
(('max_pool_3x3', 0), ('max_pool_3x3', 1)),
|
||||
(('conv_7x1_1x7', 0), ('sep_conv_3x3', 5)),
|
||||
],
|
||||
reduce_concat = [3, 4, 6]
|
||||
)
|
||||
|
||||
|
||||
# Efficient Neural Architecture Search via Parameter Sharing, ICML 2018
|
||||
ENASNet = Genotype(
|
||||
normal = [
|
||||
(('sep_conv_3x3', 1), ('skip_connect', 1)),
|
||||
(('sep_conv_5x5', 1), ('skip_connect', 0)),
|
||||
(('avg_pool_3x3', 0), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_3x3', 0), ('avg_pool_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('avg_pool_3x3', 0)),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)), # 2
|
||||
(('sep_conv_3x3', 1), ('avg_pool_3x3', 1)), # 3
|
||||
(('sep_conv_3x3', 1), ('avg_pool_3x3', 1)), # 4
|
||||
(('avg_pool_3x3', 1), ('sep_conv_5x5', 4)), # 5
|
||||
(('sep_conv_3x3', 5), ('sep_conv_5x5', 0)),
|
||||
],
|
||||
reduce_concat = [2, 3, 4, 5, 6],
|
||||
)
|
||||
|
||||
|
||||
# DARTS: Differentiable Architecture Search, ICLR 2019
|
||||
DARTS_V1 = Genotype(
|
||||
normal=[
|
||||
(('sep_conv_3x3', 1), ('sep_conv_3x3', 0)), # step 1
|
||||
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 2
|
||||
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 3
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 2)) # step 4
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
|
||||
(('skip_connect', 2), ('max_pool_3x3', 0)), # step 2
|
||||
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
|
||||
(('skip_connect', 2), ('avg_pool_3x3', 0)) # step 4
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
)
|
||||
|
||||
|
||||
# DARTS: Differentiable Architecture Search, ICLR 2019
|
||||
DARTS_V2 = Genotype(
|
||||
normal=[
|
||||
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 1
|
||||
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 2
|
||||
(('sep_conv_3x3', 1), ('skip_connect', 0)), # step 3
|
||||
(('skip_connect', 0), ('dil_conv_3x3', 2)) # step 4
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
|
||||
(('skip_connect', 2), ('max_pool_3x3', 1)), # step 2
|
||||
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
|
||||
(('skip_connect', 2), ('max_pool_3x3', 1)) # step 4
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
)
|
||||
|
||||
|
||||
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
||||
SETN = Genotype(
|
||||
normal=[
|
||||
(('skip_connect', 0), ('sep_conv_5x5', 1)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_5x5', 3)),
|
||||
(('max_pool_3x3', 1), ('conv_3x1_1x3', 4))],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('sep_conv_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('skip_connect', 1))],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
)
|
||||
|
||||
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019
|
||||
GDAS_V1 = Genotype(
|
||||
normal=[
|
||||
(('skip_connect', 0), ('skip_connect', 1)),
|
||||
(('skip_connect', 0), ('sep_conv_5x5', 2)),
|
||||
(('sep_conv_3x3', 3), ('skip_connect', 0)),
|
||||
(('sep_conv_5x5', 4), ('sep_conv_3x3', 3))],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 2), ('sep_conv_5x5', 1)),
|
||||
(('dil_conv_5x5', 2), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_5x5', 1))],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
)
|
||||
|
||||
|
||||
Networks = {'DARTS_V1' : DARTS_V1,
|
||||
'DARTS_V2' : DARTS_V2,
|
||||
'DARTS' : DARTS_V2,
|
||||
'NASNet' : NASNet,
|
||||
'ENASNet' : ENASNet,
|
||||
'AmoebaNet': AmoebaNet,
|
||||
'GDAS_V1' : GDAS_V1,
|
||||
'PNASNet' : PNASNet,
|
||||
'SETN' : SETN,
|
||||
}
|
79
others/GDAS/paddlepaddle/lib/models/nas_net.py
Normal file
79
others/GDAS/paddlepaddle/lib/models/nas_net.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from .operations import OPS
|
||||
|
||||
|
||||
def AuxiliaryHeadCIFAR(inputs, C, class_num):
|
||||
print ('AuxiliaryHeadCIFAR : inputs-shape : {:}'.format(inputs.shape))
|
||||
temp = fluid.layers.relu(inputs)
|
||||
temp = fluid.layers.pool2d(temp, pool_size=5, pool_stride=3, pool_padding=0, pool_type='avg')
|
||||
temp = fluid.layers.conv2d(temp, filter_size=1, num_filters=128, stride=1, padding=0, act=None, bias_attr=False)
|
||||
temp = fluid.layers.batch_norm(input=temp, act='relu', bias_attr=None)
|
||||
temp = fluid.layers.conv2d(temp, filter_size=1, num_filters=768, stride=2, padding=0, act=None, bias_attr=False)
|
||||
temp = fluid.layers.batch_norm(input=temp, act='relu', bias_attr=None)
|
||||
print ('AuxiliaryHeadCIFAR : last---shape : {:}'.format(temp.shape))
|
||||
predict = fluid.layers.fc(input=temp, size=class_num, act='softmax')
|
||||
return predict
|
||||
|
||||
|
||||
def InferCell(name, inputs_prev_prev, inputs_prev, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
print ('[{:}] C_prev_prev={:} C_prev={:}, C={:}, reduction_prev={:}, reduction={:}'.format(name, C_prev_prev, C_prev, C, reduction_prev, reduction))
|
||||
print ('inputs_prev_prev : {:}'.format(inputs_prev_prev.shape))
|
||||
print ('inputs_prev : {:}'.format(inputs_prev.shape))
|
||||
inputs_prev_prev = OPS['skip_connect'](inputs_prev_prev, C_prev_prev, C, 2 if reduction_prev else 1)
|
||||
inputs_prev = OPS['skip_connect'](inputs_prev, C_prev, C, 1)
|
||||
print ('inputs_prev_prev : {:}'.format(inputs_prev_prev.shape))
|
||||
print ('inputs_prev : {:}'.format(inputs_prev.shape))
|
||||
if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat
|
||||
else : step_ops, concat = genotype.normal, genotype.normal_concat
|
||||
states = [inputs_prev_prev, inputs_prev]
|
||||
for istep, operations in enumerate(step_ops):
|
||||
op_a, op_b = operations
|
||||
# the first operation
|
||||
#print ('-->>[{:}/{:}] [{:}] + [{:}]'.format(istep, len(step_ops), op_a, op_b))
|
||||
stride = 2 if reduction and op_a[1] < 2 else 1
|
||||
tensor1 = OPS[ op_a[0] ](states[op_a[1]], C, C, stride)
|
||||
stride = 2 if reduction and op_b[1] < 2 else 1
|
||||
tensor2 = OPS[ op_b[0] ](states[op_b[1]], C, C, stride)
|
||||
state = fluid.layers.elementwise_add(x=tensor1, y=tensor2, act=None)
|
||||
assert tensor1.shape == tensor2.shape, 'invalid shape {:} vs. {:}'.format(tensor1.shape, tensor2.shape)
|
||||
print ('-->>[{:}/{:}] tensor={:} from {:} + {:}'.format(istep, len(step_ops), state.shape, tensor1.shape, tensor2.shape))
|
||||
states.append( state )
|
||||
states_to_cat = [states[x] for x in concat]
|
||||
outputs = fluid.layers.concat(states_to_cat, axis=1)
|
||||
print ('-->> output-shape : {:} from concat={:}'.format(outputs.shape, concat))
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
# NASCifarNet(inputs, 36, 6, 3, 10, 'xxx', True)
|
||||
def NASCifarNet(ipt, C, N, stem_multiplier, class_num, genotype, auxiliary):
|
||||
# cifar head module
|
||||
C_curr = stem_multiplier * C
|
||||
stem = fluid.layers.conv2d(ipt, filter_size=3, num_filters=C_curr, stride=1, padding=1, act=None, bias_attr=False)
|
||||
stem = fluid.layers.batch_norm(input=stem, act=None, bias_attr=None)
|
||||
print ('stem-shape : {:}'.format(stem.shape))
|
||||
# N + 1 + N + 1 + N cells
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev = False
|
||||
auxiliary_pred = None
|
||||
|
||||
cell_results = [stem, stem]
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
xstr = '{:02d}/{:02d}'.format(index, len(layer_channels))
|
||||
cell_result = InferCell(xstr, cell_results[-2], cell_results[-1], genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
C_prev_prev, C_prev = C_prev, cell_result.shape[1]
|
||||
cell_results.append( cell_result )
|
||||
if auxiliary and reduction and C_curr == C*4:
|
||||
auxiliary_pred = AuxiliaryHeadCIFAR(cell_result, C_prev, class_num)
|
||||
|
||||
global_P = fluid.layers.pool2d(input=cell_results[-1], pool_size=8, pool_type='avg', pool_stride=1)
|
||||
predicts = fluid.layers.fc(input=global_P, size=class_num, act='softmax')
|
||||
print ('predict-shape : {:}'.format(predicts.shape))
|
||||
if auxiliary_pred is None:
|
||||
return predicts
|
||||
else:
|
||||
return [predicts, auxiliary_pred]
|
91
others/GDAS/paddlepaddle/lib/models/operations.py
Normal file
91
others/GDAS/paddlepaddle/lib/models/operations.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
OPS = {
|
||||
'none' : lambda inputs, C_in, C_out, stride: ZERO(inputs, stride),
|
||||
'avg_pool_3x3' : lambda inputs, C_in, C_out, stride: POOL_3x3(inputs, C_in, C_out, stride, 'avg'),
|
||||
'max_pool_3x3' : lambda inputs, C_in, C_out, stride: POOL_3x3(inputs, C_in, C_out, stride, 'max'),
|
||||
'skip_connect' : lambda inputs, C_in, C_out, stride: Identity(inputs, C_in, C_out, stride),
|
||||
'sep_conv_3x3' : lambda inputs, C_in, C_out, stride: SepConv(inputs, C_in, C_out, 3, stride, 1),
|
||||
'sep_conv_5x5' : lambda inputs, C_in, C_out, stride: SepConv(inputs, C_in, C_out, 5, stride, 2),
|
||||
'sep_conv_7x7' : lambda inputs, C_in, C_out, stride: SepConv(inputs, C_in, C_out, 7, stride, 3),
|
||||
'dil_conv_3x3' : lambda inputs, C_in, C_out, stride: DilConv(inputs, C_in, C_out, 3, stride, 2, 2),
|
||||
'dil_conv_5x5' : lambda inputs, C_in, C_out, stride: DilConv(inputs, C_in, C_out, 5, stride, 4, 2),
|
||||
'conv_3x1_1x3' : lambda inputs, C_in, C_out, stride: Conv313(inputs, C_in, C_out, stride),
|
||||
'conv_7x1_1x7' : lambda inputs, C_in, C_out, stride: Conv717(inputs, C_in, C_out, stride),
|
||||
}
|
||||
|
||||
|
||||
def ReLUConvBN(inputs, C_in, C_out, kernel, stride, padding):
|
||||
temp = fluid.layers.relu(inputs)
|
||||
temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_out, stride=stride, padding=padding, act=None, bias_attr=False)
|
||||
temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None)
|
||||
return temp
|
||||
|
||||
|
||||
def ZERO(inputs, stride):
|
||||
if stride == 1:
|
||||
return inputs * 0
|
||||
elif stride == 2:
|
||||
return fluid.layers.pool2d(inputs, filter_size=2, pool_stride=2, pool_padding=0, pool_type='avg') * 0
|
||||
else:
|
||||
raise ValueError('invalid stride of {:} not [1, 2]'.format(stride))
|
||||
|
||||
|
||||
def Identity(inputs, C_in, C_out, stride):
|
||||
if C_in == C_out and stride == 1:
|
||||
return inputs
|
||||
elif stride == 1:
|
||||
return ReLUConvBN(inputs, C_in, C_out, 1, 1, 0)
|
||||
else:
|
||||
temp1 = fluid.layers.relu(inputs)
|
||||
temp2 = fluid.layers.pad2d(input=temp1, paddings=[0, 1, 0, 1], mode='reflect')
|
||||
temp2 = fluid.layers.slice(temp2, axes=[0, 1, 2, 3], starts=[0, 0, 1, 1], ends=[999, 999, 999, 999])
|
||||
temp1 = fluid.layers.conv2d(temp1, filter_size=1, num_filters=C_out//2, stride=stride, padding=0, act=None, bias_attr=False)
|
||||
temp2 = fluid.layers.conv2d(temp2, filter_size=1, num_filters=C_out-C_out//2, stride=stride, padding=0, act=None, bias_attr=False)
|
||||
temp = fluid.layers.concat([temp1,temp2], axis=1)
|
||||
return fluid.layers.batch_norm(input=temp, act=None, bias_attr=None)
|
||||
|
||||
|
||||
def POOL_3x3(inputs, C_in, C_out, stride, mode):
|
||||
if C_in == C_out:
|
||||
xinputs = inputs
|
||||
else:
|
||||
xinputs = ReLUConvBN(inputs, C_in, C_out, 1, 1, 0)
|
||||
return fluid.layers.pool2d(xinputs, pool_size=3, pool_stride=stride, pool_padding=1, pool_type=mode)
|
||||
|
||||
|
||||
def SepConv(inputs, C_in, C_out, kernel, stride, padding):
|
||||
temp = fluid.layers.relu(inputs)
|
||||
temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_in , stride=stride, padding=padding, act=None, bias_attr=False)
|
||||
temp = fluid.layers.conv2d(temp, filter_size= 1, num_filters=C_in , stride= 1, padding= 0, act=None, bias_attr=False)
|
||||
temp = fluid.layers.batch_norm(input=temp, act='relu', bias_attr=None)
|
||||
temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_in , stride= 1, padding=padding, act=None, bias_attr=False)
|
||||
temp = fluid.layers.conv2d(temp, filter_size= 1, num_filters=C_out, stride= 1, padding= 0, act=None, bias_attr=False)
|
||||
temp = fluid.layers.batch_norm(input=temp, act=None , bias_attr=None)
|
||||
return temp
|
||||
|
||||
|
||||
def DilConv(inputs, C_in, C_out, kernel, stride, padding, dilation):
|
||||
temp = fluid.layers.relu(inputs)
|
||||
temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_in , stride=stride, padding=padding, dilation=dilation, act=None, bias_attr=False)
|
||||
temp = fluid.layers.conv2d(temp, filter_size= 1, num_filters=C_out, stride= 1, padding= 0, act=None, bias_attr=False)
|
||||
temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None)
|
||||
return temp
|
||||
|
||||
|
||||
def Conv313(inputs, C_in, C_out, stride):
|
||||
temp = fluid.layers.relu(inputs)
|
||||
temp = fluid.layers.conv2d(temp, filter_size=(1,3), num_filters=C_out, stride=(1,stride), padding=(0,1), act=None, bias_attr=False)
|
||||
temp = fluid.layers.conv2d(temp, filter_size=(3,1), num_filters=C_out, stride=(stride,1), padding=(1,0), act=None, bias_attr=False)
|
||||
temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None)
|
||||
return temp
|
||||
|
||||
|
||||
def Conv717(inputs, C_in, C_out, stride):
|
||||
temp = fluid.layers.relu(inputs)
|
||||
temp = fluid.layers.conv2d(temp, filter_size=(1,7), num_filters=C_out, stride=(1,stride), padding=(0,3), act=None, bias_attr=False)
|
||||
temp = fluid.layers.conv2d(temp, filter_size=(7,1), num_filters=C_out, stride=(stride,1), padding=(3,0), act=None, bias_attr=False)
|
||||
temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None)
|
||||
return temp
|
65
others/GDAS/paddlepaddle/lib/models/resnet.py
Normal file
65
others/GDAS/paddlepaddle/lib/models/resnet.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
def conv_bn_layer(input,
|
||||
ch_out,
|
||||
filter_size,
|
||||
stride,
|
||||
padding,
|
||||
act='relu',
|
||||
bias_attr=False):
|
||||
tmp = fluid.layers.conv2d(
|
||||
input=input,
|
||||
filter_size=filter_size,
|
||||
num_filters=ch_out,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
act=None,
|
||||
bias_attr=bias_attr)
|
||||
return fluid.layers.batch_norm(input=tmp, act=act)
|
||||
|
||||
|
||||
def shortcut(input, ch_in, ch_out, stride):
|
||||
if stride == 2:
|
||||
temp = fluid.layers.pool2d(input, pool_size=2, pool_type='avg', pool_stride=2)
|
||||
temp = fluid.layers.conv2d(temp , filter_size=1, num_filters=ch_out, stride=1, padding=0, act=None, bias_attr=None)
|
||||
return temp
|
||||
elif ch_in != ch_out:
|
||||
return conv_bn_layer(input, ch_out, 1, stride, 0, None, None)
|
||||
else:
|
||||
return input
|
||||
|
||||
|
||||
def basicblock(input, ch_in, ch_out, stride):
|
||||
tmp = conv_bn_layer(input, ch_out, 3, stride, 1)
|
||||
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True)
|
||||
short = shortcut(input, ch_in, ch_out, stride)
|
||||
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu')
|
||||
|
||||
|
||||
def layer_warp(block_func, input, ch_in, ch_out, count, stride):
|
||||
tmp = block_func(input, ch_in, ch_out, stride)
|
||||
for i in range(1, count):
|
||||
tmp = block_func(tmp, ch_out, ch_out, 1)
|
||||
return tmp
|
||||
|
||||
|
||||
def resnet_cifar(ipt, depth, class_num):
|
||||
# depth should be one of 20, 32, 44, 56, 110, 1202
|
||||
assert (depth - 2) % 6 == 0
|
||||
n = (depth - 2) // 6
|
||||
print('[resnet] depth : {:}, class_num : {:}'.format(depth, class_num))
|
||||
conv1 = conv_bn_layer(ipt, ch_out=16, filter_size=3, stride=1, padding=1)
|
||||
print('conv-1 : shape = {:}'.format(conv1.shape))
|
||||
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
|
||||
print('res--1 : shape = {:}'.format(res1.shape))
|
||||
res2 = layer_warp(basicblock, res1 , 16, 32, n, 2)
|
||||
print('res--2 : shape = {:}'.format(res2.shape))
|
||||
res3 = layer_warp(basicblock, res2 , 32, 64, n, 2)
|
||||
print('res--3 : shape = {:}'.format(res3.shape))
|
||||
pool = fluid.layers.pool2d(input=res3, pool_size=8, pool_type='avg', pool_stride=1)
|
||||
print('pool : shape = {:}'.format(pool.shape))
|
||||
predict = fluid.layers.fc(input=pool, size=class_num, act='softmax')
|
||||
print('predict: shape = {:}'.format(predict.shape))
|
||||
return predict
|
Reference in New Issue
Block a user