update TF models (beta version)
This commit is contained in:
50
lib/tf_models/cell_searchs/search_cells.py
Normal file
50
lib/tf_models/cell_searchs/search_cells.py
Normal file
@@ -0,0 +1,50 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, random
|
||||
import tensorflow as tf
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
|
||||
class SearchCell(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False):
|
||||
super(SearchCell, self).__init__()
|
||||
|
||||
self.op_names = deepcopy(op_names)
|
||||
self.max_nodes = max_nodes
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
self.edge_keys = []
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
if j == 0:
|
||||
xlists = [OPS[op_name](C_in , C_out, stride, affine) for op_name in op_names]
|
||||
else:
|
||||
xlists = [OPS[op_name](C_in , C_out, 1, affine) for op_name in op_names]
|
||||
for k, op in enumerate(xlists):
|
||||
setattr(self, '{:}.{:}'.format(node_str, k), op)
|
||||
self.edge_keys.append( node_str )
|
||||
self.edge_keys = sorted(self.edge_keys)
|
||||
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||
self.num_edges = len(self.edge_keys)
|
||||
|
||||
def call(self, inputs, weightss, training):
|
||||
w_lst = tf.split(weightss, self.num_edges, 0)
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
edge_idx = self.edge2index[node_str]
|
||||
op_outps = []
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
op = getattr(self, '{:}.{:}'.format(node_str, k))
|
||||
op_outps.append( op(nodes[j], training) )
|
||||
stack_op_outs = tf.stack(op_outps, axis=-1)
|
||||
weighted_sums = tf.math.multiply(stack_op_outs, w_lst[edge_idx])
|
||||
inter_nodes.append( tf.math.reduce_sum(weighted_sums, axis=-1) )
|
||||
nodes.append( tf.math.add_n(inter_nodes) )
|
||||
return nodes[-1]
|
Reference in New Issue
Block a user