update TF models (beta version)

This commit is contained in:
D-X-Y
2020-01-05 22:19:38 +11:00
parent e6ca3628ce
commit 5ac5060a33
18 changed files with 1253 additions and 44 deletions

32
lib/tf_models/__init__.py Normal file
View File

@@ -0,0 +1,32 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
from os import path as osp
__all__ = ['get_cell_based_tiny_net', 'get_search_spaces']
# the cell-based NAS models
def get_cell_based_tiny_net(config):
group_names = ['GDAS']
if config.name in group_names:
from .cell_searchs import nas_super_nets
from .cell_operations import SearchSpaceNames
if isinstance(config.space, str): search_space = SearchSpaceNames[config.space]
else: search_space = config.space
return nas_super_nets[config.name](
config.C, config.N, config.max_nodes,
config.num_classes, search_space, config.affine)
else:
raise ValueError('invalid network name : {:}'.format(config.name))
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
def get_search_spaces(xtype, name):
if xtype == 'cell':
from .cell_operations import SearchSpaceNames
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
return SearchSpaceNames[name]
else:
raise ValueError('invalid search-space type is {:}'.format(xtype))

View File

@@ -0,0 +1,120 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import tensorflow as tf
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = {
'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride),
'avg_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg', affine),
'nor_conv_1x1': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 1, stride, affine),
'nor_conv_3x3': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 3, stride, affine),
'nor_conv_5x5': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 5, stride, affine),
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride)
}
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
SearchSpaceNames = {
'nas-bench-102': NAS_BENCH_102,
}
class POOLING(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride, mode, affine):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, affine)
if mode == 'avg' : self.op = tf.keras.layers.AvgPool2D((3,3), strides=stride, padding='same')
elif mode == 'max': self.op = tf.keras.layers.MaxPool2D((3,3), strides=stride, padding='same')
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
def call(self, inputs, training):
if self.preprocess: x = self.preprocess(inputs)
else : x = inputs
return self.op(x)
class Identity(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride):
super(Identity, self).__init__()
if C_in != C_out or stride != 1:
self.layer = tf.keras.layers.Conv2D(C_out, 3, stride, padding='same', use_bias=False)
else:
self.layer = None
def call(self, inputs, training):
x = inputs
if self.layer is not None:
x = self.layer(x)
return x
class Zero(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
if C_in != C_out:
self.layer = tf.keras.layers.Conv2D(C_out, 1, stride, padding='same', use_bias=False)
elif stride != 1:
self.layer = tf.keras.layers.AvgPool2D((stride,stride), None, padding="same")
else:
self.layer = None
def call(self, inputs, training):
x = tf.zeros_like(inputs)
if self.layer is not None:
x = self.layer(x)
return x
class ReLUConvBN(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, kernel_size, strides, affine):
super(ReLUConvBN, self).__init__()
self.C_in = C_in
self.relu = tf.keras.activations.relu
self.conv = tf.keras.layers.Conv2D(C_out, kernel_size, strides, padding='same', use_bias=False)
self.bn = tf.keras.layers.BatchNormalization(center=affine, scale=affine)
def call(self, inputs, training):
x = self.relu(inputs)
x = self.conv(x)
x = self.bn(x, training)
return x
class ResNetBasicblock(tf.keras.layers.Layer):
def __init__(self, inplanes, planes, stride, affine=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, affine)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, affine)
if stride == 2:
self.downsample = tf.keras.Sequential([
tf.keras.layers.AvgPool2D((stride,stride), None, padding="same"),
tf.keras.layers.Conv2D(planes, 1, 1, padding='same', use_bias=False)])
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, stride, affine)
else:
self.downsample = None
self.addition = tf.keras.layers.Add()
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def call(self, inputs, training):
basicblock = self.conv_a(inputs, training)
basicblock = self.conv_b(basicblock, training)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
return self.addition([residual, basicblock])

View File

@@ -0,0 +1,6 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .search_model_gdas import TinyNetworkGDAS
nas_super_nets = {'GDAS': TinyNetworkGDAS}

View File

@@ -0,0 +1,50 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, random
import tensorflow as tf
from copy import deepcopy
from ..cell_operations import OPS
class SearchCell(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False):
super(SearchCell, self).__init__()
self.op_names = deepcopy(op_names)
self.max_nodes = max_nodes
self.in_dim = C_in
self.out_dim = C_out
self.edge_keys = []
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if j == 0:
xlists = [OPS[op_name](C_in , C_out, stride, affine) for op_name in op_names]
else:
xlists = [OPS[op_name](C_in , C_out, 1, affine) for op_name in op_names]
for k, op in enumerate(xlists):
setattr(self, '{:}.{:}'.format(node_str, k), op)
self.edge_keys.append( node_str )
self.edge_keys = sorted(self.edge_keys)
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edge_keys)
def call(self, inputs, weightss, training):
w_lst = tf.split(weightss, self.num_edges, 0)
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
edge_idx = self.edge2index[node_str]
op_outps = []
for k, op_name in enumerate(self.op_names):
op = getattr(self, '{:}.{:}'.format(node_str, k))
op_outps.append( op(nodes[j], training) )
stack_op_outs = tf.stack(op_outps, axis=-1)
weighted_sums = tf.math.multiply(stack_op_outs, w_lst[edge_idx])
inter_nodes.append( tf.math.reduce_sum(weighted_sums, axis=-1) )
nodes.append( tf.math.add_n(inter_nodes) )
return nodes[-1]

View File

@@ -0,0 +1,99 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import tensorflow as tf
import numpy as np
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import SearchCell
def sample_gumbel(shape, eps=1e-20):
U = tf.random.uniform(shape, minval=0, maxval=1)
return -tf.math.log(-tf.math.log(U + eps) + eps)
def gumbel_softmax(logits, temperature):
gumbel_softmax_sample = logits + sample_gumbel(tf.shape(logits))
y = tf.nn.softmax(gumbel_softmax_sample / temperature)
return y
class TinyNetworkGDAS(tf.keras.Model):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine):
super(TinyNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, 3, 1, padding='same', use_bias=False),
tf.keras.layers.BatchNormalization()], name='stem')
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell_prefix = 'cell-{:03d}'.format(index)
#with tf.name_scope(cell_prefix) as scope:
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
C_prev = cell.out_dim
setattr(self, cell_prefix, cell)
self.num_layers = len(layer_reductions)
self.op_names = deepcopy( search_space )
self.edge2index = edge2index
self.num_edge = num_edge
self.lastact = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.GlobalAvgPool2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(num_classes, activation='softmax')], name='lastact')
#self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
arch_init = tf.random_normal_initializer(mean=0, stddev=0.001)
self.arch_parameters = tf.Variable(initial_value=arch_init(shape=(num_edge, len(search_space)), dtype='float32'), trainable=True, name='arch-encoding')
def get_alphas(self):
xlist = self.trainable_variables
return [x for x in xlist if 'arch-encoding' in x.name]
def get_weights(self):
xlist = self.trainable_variables
return [x for x in xlist if 'arch-encoding' not in x.name]
def get_np_alphas(self):
arch_nps = self.arch_parameters.numpy()
arch_ops = np.exp(arch_nps) / np.sum(np.exp(arch_nps), axis=-1, keepdims=True)
return arch_ops
def genotype(self):
genotypes, arch_nps = [], self.arch_parameters.numpy()
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = arch_nps[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return genotypes
#
def call(self, inputs, tau, training):
weightss = tf.cond(tau < 0, lambda: tf.nn.softmax(self.arch_parameters, axis=1),
lambda: gumbel_softmax(tf.math.log_softmax(self.arch_parameters, axis=1), tau))
feature = self.stem(inputs, training)
for idx in range(self.num_layers):
cell = getattr(self, 'cell-{:03d}'.format(idx))
if isinstance(cell, SearchCell):
feature = cell.call(feature, weightss, training)
else:
feature = cell(feature, training)
logits = self.lastact(feature, training)
return logits