update TF models (beta version)
This commit is contained in:
120
lib/tf_models/cell_operations.py
Normal file
120
lib/tf_models/cell_operations.py
Normal file
@@ -0,0 +1,120 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import tensorflow as tf
|
||||
|
||||
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||
|
||||
OPS = {
|
||||
'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride),
|
||||
'avg_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg', affine),
|
||||
'nor_conv_1x1': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 1, stride, affine),
|
||||
'nor_conv_3x3': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 3, stride, affine),
|
||||
'nor_conv_5x5': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 5, stride, affine),
|
||||
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride)
|
||||
}
|
||||
|
||||
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
|
||||
SearchSpaceNames = {
|
||||
'nas-bench-102': NAS_BENCH_102,
|
||||
}
|
||||
|
||||
|
||||
class POOLING(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, mode, affine):
|
||||
super(POOLING, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, affine)
|
||||
if mode == 'avg' : self.op = tf.keras.layers.AvgPool2D((3,3), strides=stride, padding='same')
|
||||
elif mode == 'max': self.op = tf.keras.layers.MaxPool2D((3,3), strides=stride, padding='same')
|
||||
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
||||
|
||||
def call(self, inputs, training):
|
||||
if self.preprocess: x = self.preprocess(inputs)
|
||||
else : x = inputs
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(tf.keras.layers.Layer):
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(Identity, self).__init__()
|
||||
if C_in != C_out or stride != 1:
|
||||
self.layer = tf.keras.layers.Conv2D(C_out, 3, stride, padding='same', use_bias=False)
|
||||
else:
|
||||
self.layer = None
|
||||
|
||||
def call(self, inputs, training):
|
||||
x = inputs
|
||||
if self.layer is not None:
|
||||
x = self.layer(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class Zero(tf.keras.layers.Layer):
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(Zero, self).__init__()
|
||||
if C_in != C_out:
|
||||
self.layer = tf.keras.layers.Conv2D(C_out, 1, stride, padding='same', use_bias=False)
|
||||
elif stride != 1:
|
||||
self.layer = tf.keras.layers.AvgPool2D((stride,stride), None, padding="same")
|
||||
else:
|
||||
self.layer = None
|
||||
|
||||
def call(self, inputs, training):
|
||||
x = tf.zeros_like(inputs)
|
||||
if self.layer is not None:
|
||||
x = self.layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class ReLUConvBN(tf.keras.layers.Layer):
|
||||
def __init__(self, C_in, C_out, kernel_size, strides, affine):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.C_in = C_in
|
||||
self.relu = tf.keras.activations.relu
|
||||
self.conv = tf.keras.layers.Conv2D(C_out, kernel_size, strides, padding='same', use_bias=False)
|
||||
self.bn = tf.keras.layers.BatchNormalization(center=affine, scale=affine)
|
||||
|
||||
def call(self, inputs, training):
|
||||
x = self.relu(inputs)
|
||||
x = self.conv(x)
|
||||
x = self.bn(x, training)
|
||||
return x
|
||||
|
||||
|
||||
class ResNetBasicblock(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, inplanes, planes, stride, affine=True):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, affine)
|
||||
self.conv_b = ReLUConvBN( planes, planes, 3, 1, affine)
|
||||
if stride == 2:
|
||||
self.downsample = tf.keras.Sequential([
|
||||
tf.keras.layers.AvgPool2D((stride,stride), None, padding="same"),
|
||||
tf.keras.layers.Conv2D(planes, 1, 1, padding='same', use_bias=False)])
|
||||
elif inplanes != planes:
|
||||
self.downsample = ReLUConvBN(inplanes, planes, 1, stride, affine)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.addition = tf.keras.layers.Add()
|
||||
self.in_dim = inplanes
|
||||
self.out_dim = planes
|
||||
self.stride = stride
|
||||
self.num_conv = 2
|
||||
|
||||
def call(self, inputs, training):
|
||||
|
||||
basicblock = self.conv_a(inputs, training)
|
||||
basicblock = self.conv_b(basicblock, training)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
return self.addition([residual, basicblock])
|
Reference in New Issue
Block a user