update tf-GDAS
This commit is contained in:
@@ -11,7 +11,7 @@ OPS = {
|
||||
'nor_conv_1x1': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 1, stride, affine),
|
||||
'nor_conv_3x3': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 3, stride, affine),
|
||||
'nor_conv_5x5': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 5, stride, affine),
|
||||
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride)
|
||||
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride) if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine)
|
||||
}
|
||||
|
||||
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
@@ -87,6 +87,36 @@ class ReLUConvBN(tf.keras.layers.Layer):
|
||||
return x
|
||||
|
||||
|
||||
class FactorizedReduce(tf.keras.layers.Layer):
|
||||
def __init__(self, C_in, C_out, stride, affine):
|
||||
assert output_filters % 2 == 0, ('Need even number of filters when using this factorized reduction.')
|
||||
self.stride == stride
|
||||
self.relu = tf.keras.activations.relu
|
||||
if stride == 1:
|
||||
self.layer = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2D(C_out, 1, strides, padding='same', use_bias=False),
|
||||
tf.keras.layers.BatchNormalization(center=affine, scale=affine)])
|
||||
elif stride == 2:
|
||||
stride_spec = [1, stride, stride, 1] # data_format == 'NHWC'
|
||||
self.layer1 = tf.keras.layers.Conv2D(C_out//2, 1, strides, padding='same', use_bias=False)
|
||||
self.layer2 = tf.keras.layers.Conv2D(C_out//2, 1, strides, padding='same', use_bias=False)
|
||||
self.bn = tf.keras.layers.BatchNormalization(center=affine, scale=affine)
|
||||
else:
|
||||
raise ValueError('invalid stride={:}'.format(stride))
|
||||
|
||||
def call(self, inputs, training):
|
||||
x = self.relu(inputs)
|
||||
if self.stride == 1:
|
||||
return self.layer(x, training)
|
||||
else:
|
||||
path1 = x
|
||||
path2 = tf.pad(x, [[0, 0], [0, 1], [0, 1], [0, 0]])[:, 1:, 1:, :] # data_format == 'NHWC'
|
||||
x1 = self.layer1(path1)
|
||||
x2 = self.layer2(path2)
|
||||
final_path = tf.concat(values=[x1, x2], axis=3)
|
||||
return self.bn(final_path)
|
||||
|
||||
|
||||
class ResNetBasicblock(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, inplanes, planes, stride, affine=True):
|
||||
|
Reference in New Issue
Block a user