Initial commit

This commit is contained in:
jack-willturner
2020-06-03 12:59:01 +01:00
commit 357e877e8d
68 changed files with 7189 additions and 0 deletions

View File

@@ -0,0 +1,24 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# The macro structure is defined in NAS-Bench-201
from .search_model_darts import TinyNetworkDarts
from .search_model_gdas import TinyNetworkGDAS
from .search_model_setn import TinyNetworkSETN
from .search_model_enas import TinyNetworkENAS
from .search_model_random import TinyNetworkRANDOM
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
# NASNet-based macro structure
from .search_model_gdas_nasnet import NASNetworkGDAS
from .search_model_darts_nasnet import NASNetworkDARTS
nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
"DARTS-V2": TinyNetworkDarts,
"GDAS": TinyNetworkGDAS,
"SETN": TinyNetworkSETN,
"ENAS": TinyNetworkENAS,
"RANDOM": TinyNetworkRANDOM}
nasnet_super_nets = {"GDAS": NASNetworkGDAS,
"DARTS": NASNetworkDARTS}

View File

@@ -0,0 +1,12 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
from search_model_enas_utils import Controller
def main():
controller = Controller(6, 4)
predictions = controller()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,199 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from copy import deepcopy
def get_combination(space, num):
combs = []
for i in range(num):
if i == 0:
for func in space:
combs.append( [(func, i)] )
else:
new_combs = []
for string in combs:
for func in space:
xstring = string + [(func, i)]
new_combs.append( xstring )
combs = new_combs
return combs
class Structure:
def __init__(self, genotype):
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
self.node_num = len(genotype) + 1
self.nodes = []
self.node_N = []
for idx, node_info in enumerate(genotype):
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
for node_in in node_info:
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
self.node_N.append( len(node_info) )
self.nodes.append( tuple(deepcopy(node_info)) )
def tolist(self, remove_str):
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
# note that we re-order the input node in this function
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
genotypes = []
for node_info in self.nodes:
node_info = list( node_info )
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
if len(node_info) == 0: return None, False
genotypes.append( node_info )
return genotypes, True
def node(self, index):
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
return self.nodes[index]
def tostr(self):
strings = []
for node_info in self.nodes:
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
string = '|{:}|'.format(string)
strings.append( string )
return '+'.join(strings)
def check_valid(self):
nodes = {0: True}
for i, node_info in enumerate(self.nodes):
sums = []
for op, xin in node_info:
if op == 'none' or nodes[xin] is False: x = False
else: x = True
sums.append( x )
nodes[i+1] = sum(sums) > 0
return nodes[len(self.nodes)]
def to_unique_str(self, consider_zero=False):
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
# two operations are special, i.e., none and skip_connect
nodes = {0: '0'}
for i_node, node_info in enumerate(self.nodes):
cur_node = []
for op, xin in node_info:
if consider_zero is None:
x = '('+nodes[xin]+')' + '@{:}'.format(op)
elif consider_zero:
if op == 'none' or nodes[xin] == '#': x = '#' # zero
elif op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
else:
if op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
cur_node.append(x)
nodes[i_node+1] = '+'.join( sorted(cur_node) )
return nodes[ len(self.nodes) ]
def check_valid_op(self, op_names):
for node_info in self.nodes:
for inode_edge in node_info:
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
if inode_edge[0] not in op_names: return False
return True
def __repr__(self):
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
def __len__(self):
return len(self.nodes) + 1
def __getitem__(self, index):
return self.nodes[index]
@staticmethod
def str2structure(xstr):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
genotypes.append( input_infos )
return Structure( genotypes )
@staticmethod
def str2fullstructure(xstr, default_name='none'):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
all_in_nodes= list(x[1] for x in input_infos)
for j in range(i):
if j not in all_in_nodes: input_infos.append((default_name, j))
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
genotypes.append( tuple(node_info) )
return Structure( genotypes )
@staticmethod
def gen_all(search_space, num, return_ori):
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
all_archs = get_combination(search_space, 1)
for i, arch in enumerate(all_archs):
all_archs[i] = [ tuple(arch) ]
for inode in range(2, num):
cur_nodes = get_combination(search_space, inode)
new_all_archs = []
for previous_arch in all_archs:
for cur_node in cur_nodes:
new_all_archs.append( previous_arch + [tuple(cur_node)] )
all_archs = new_all_archs
if return_ori:
return all_archs
else:
return [Structure(x) for x in all_archs]
ResNet_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 1), ), # node-2
(('skip_connect', 0), ('skip_connect', 2))] # node-3
)
AllConv3x3_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
)
AllFull_CODE = Structure(
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
)
AllConv1x1_CODE = Structure(
[(('nor_conv_1x1', 0), ), # node-1
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
)
AllIdentity_CODE = Structure(
[(('skip_connect', 0), ), # node-1
(('skip_connect', 0), ('skip_connect', 1)), # node-2
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
)
architectures = {'resnet' : ResNet_CODE,
'all_c3x3': AllConv3x3_CODE,
'all_c1x1': AllConv1x1_CODE,
'all_idnt': AllIdentity_CODE,
'all_full': AllFull_CODE}

View File

@@ -0,0 +1,197 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, random, torch
import warnings
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from ..cell_operations import OPS
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
class NAS201SearchCell(nn.Module):
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
super(NAS201SearchCell, self).__init__()
self.op_names = deepcopy(op_names)
self.edges = nn.ModuleDict()
self.max_nodes = max_nodes
self.in_dim = C_in
self.out_dim = C_out
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if j == 0:
xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names]
else:
xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names]
self.edges[ node_str ] = nn.ModuleList( xlists )
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
def extra_repr(self):
string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
return string
def forward(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# GDAS
def forward_gdas(self, inputs, hardwts, index):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = hardwts[ self.edge2index[node_str] ]
argmaxs = index[ self.edge2index[node_str] ].item()
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
inter_nodes.append( weigsum )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# joint
def forward_joint(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
#aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) )
inter_nodes.append( aggregation )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# uniform random sampling per iteration, SETN
def forward_urs(self, inputs):
nodes = [inputs]
for i in range(1, self.max_nodes):
while True: # to avoid select zero for all ops
sops, has_non_zero = [], False
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
candidates = self.edges[node_str]
select_op = random.choice(candidates)
sops.append( select_op )
if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True
if has_non_zero: break
inter_nodes = []
for j, select_op in enumerate(sops):
inter_nodes.append( select_op(nodes[j]) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# select the argmax
def forward_select(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) )
#inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# forward with a specific structure
def forward_dynamic(self, inputs, structure):
nodes = [inputs]
for i in range(1, self.max_nodes):
cur_op_node = structure.nodes[i-1]
inter_nodes = []
for op_name, j in cur_op_node:
node_str = '{:}<-{:}'.format(i, j)
op_index = self.op_names.index( op_name )
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
class MixedOp(nn.Module):
def __init__(self, space, C, stride, affine, track_running_stats):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in space:
op = OPS[primitive](C, C, stride, affine, track_running_stats)
self._ops.append(op)
def forward_gdas(self, x, weights, index):
return self._ops[index](x) * weights[index]
def forward_darts(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops))
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetSearchCell(nn.Module):
def __init__(self, space, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
super(NASNetSearchCell, self).__init__()
self.reduction = reduction
self.op_names = deepcopy(space)
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
self._steps = steps
self._multiplier = multiplier
self._ops = nn.ModuleList()
self.edges = nn.ModuleDict()
for i in range(self._steps):
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j) # indicate the edge from node-(j) to node-(i+2)
stride = 2 if reduction and j < 2 else 1
op = MixedOp(space, C, stride, affine, track_running_stats)
self.edges[ node_str ] = op
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
def forward_gdas(self, s0, s1, weightss, indexs):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = '{:}<-{:}'.format(i, j)
op = self.edges[ node_str ]
weights = weightss[ self.edge2index[node_str] ]
index = indexs[ self.edge2index[node_str] ].item()
clist.append( op.forward_gdas(h, weights, index) )
states.append( sum(clist) )
return torch.cat(states[-self._multiplier:], dim=1)
def forward_darts(self, s0, s1, weightss):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = '{:}<-{:}'.format(i, j)
op = self.edges[ node_str ]
weights = weightss[ self.edge2index[node_str] ]
clist.append( op.forward_darts(h, weights) )
states.append( sum(clist) )
return torch.cat(states[-self._multiplier:], dim=1)

View File

@@ -0,0 +1,97 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
########################################################
# DARTS: Differentiable Architecture Search, ICLR 2019 #
########################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkDarts(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkDarts, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def get_alphas(self):
return [self.arch_parameters]
def show_alphas(self):
with torch.no_grad():
return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() )
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def forward(self, inputs):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell(feature, alphas)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,108 @@
####################
# DARTS, ICLR 2019 #
####################
import torch
import torch.nn as nn
from copy import deepcopy
from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkDARTS(nn.Module):
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int,
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
super(NASNetworkDARTS, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
def get_weights(self) -> List[torch.nn.Parameter]:
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def get_alphas(self) -> List[torch.nn.Parameter]:
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self) -> Text:
with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() )
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() )
return '{:}\n{:}'.format(A, B)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self) -> Text:
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self) -> Dict[Text, List]:
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j)
ws = weights[ self.edge2index[node_str] ]
for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue
edges.append( (op_name, j, ws[k]) )
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append( tuple(selected_edges) )
return gene
with torch.no_grad():
gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy())
gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)),
'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}
def forward(self, inputs):
normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1)
reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction: ww = reduce_w
else : ww = normal_w
s0, s1 = s1, cell.forward_darts(s0, s1, ww)
out = self.lastact(s1)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,94 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##########################################################################
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
##########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
from .search_model_enas_utils import Controller
class TinyNetworkENAS(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkENAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# to maintain the sampled architecture
self.sampled_arch = None
def update_arch(self, _arch):
if _arch is None:
self.sampled_arch = None
elif isinstance(_arch, Structure):
self.sampled_arch = _arch
elif isinstance(_arch, (list, tuple)):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_index = _arch[ self.edge2index[node_str] ]
op_name = self.op_names[ op_index ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
self.sampled_arch = Structure(genotypes)
else:
raise ValueError('invalid type of input architecture : {:}'.format(_arch))
return self.sampled_arch
def create_controller(self):
return Controller(len(self.edge2index), len(self.op_names))
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_dynamic(feature, self.sampled_arch)
else: feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,55 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##########################################################################
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
##########################################################################
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
class Controller(nn.Module):
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0):
super(Controller, self).__init__()
# assign the attributes
self.num_edge = num_edge
self.num_ops = num_ops
self.lstm_size = lstm_size
self.lstm_N = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
# create parameters
self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size)))
self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N)
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
nn.init.uniform_(self.input_vars , -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
nn.init.uniform_(self.w_embd.weight , -0.1, 0.1)
nn.init.uniform_(self.w_pred.weight , -0.1, 0.1)
def forward(self):
inputs, h0 = self.input_vars, None
log_probs, entropys, sampled_arch = [], [], []
for iedge in range(self.num_edge):
outputs, h0 = self.w_lstm(inputs, h0)
logits = self.w_pred(outputs)
logits = logits / self.temperature
logits = self.tanh_constant * torch.tanh(logits)
# distribution
op_distribution = Categorical(logits=logits)
op_index = op_distribution.sample()
sampled_arch.append( op_index.item() )
op_log_prob = op_distribution.log_prob(op_index)
log_probs.append( op_log_prob.view(-1) )
op_entropy = op_distribution.entropy()
entropys.append( op_entropy.view(-1) )
# obtain the input embedding for the next step
inputs = self.w_embd(op_index)
return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch

View File

@@ -0,0 +1,111 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkGDAS(nn.Module):
#def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.tau = 10
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def get_alphas(self):
return [self.arch_parameters]
def show_alphas(self):
with torch.no_grad():
return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() )
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def forward(self, inputs):
while True:
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
continue
else: break
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_gdas(feature, hardwts, index)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,125 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkGDAS(nn.Module):
def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats):
super(NASNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.tau = 10
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def get_alphas(self):
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self):
with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() )
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() )
return '{:}\n{:}'.format(A, B)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j)
ws = weights[ self.edge2index[node_str] ]
for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue
edges.append( (op_name, j, ws[k]) )
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append( tuple(selected_edges) )
return gene
with torch.no_grad():
gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy())
gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)),
'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}
def forward(self, inputs):
def get_gumbel_prob(xins):
while True:
gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
continue
else: break
return hardwts, index
normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters)
reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction: hardwts, index = reduce_hardwts, reduce_index
else : hardwts, index = normal_hardwts, normal_index
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,81 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##############################################################################
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
##############################################################################
import torch, random
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkRANDOM(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkRANDOM, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_cache = None
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def random_genotype(self, set_cache):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = random.choice( self.op_names )
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
arch = Structure( genotypes )
if set_cache: self.arch_cache = arch
return arch
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_dynamic(feature, self.arch_cache)
else: feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,152 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import torch, random
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkSETN(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkSETN, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.mode = 'urs'
self.dynamic_cell = None
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ['urs', 'joint', 'select', 'dynamic']
self.mode = mode
if mode == 'dynamic': self.dynamic_cell = deepcopy( dynamic_cell )
else : self.dynamic_cell = None
def get_cal_mode(self):
return self.mode
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def get_alphas(self):
return [self.arch_parameters]
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[ self.edge2index[node_str] ]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[ op_index ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def get_log_prob(self, arch):
with torch.no_grad():
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = '{:}<-{:}'.format(i+1, xin)
op_index = self.op_names.index(op)
select_logits.append( logits[self.edge2index[node_str], op_index] )
return sum(select_logits).item()
def return_topK(self, K):
archs = Structure.gen_all(self.op_names, self.max_nodes, False)
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
if K < 0 or K >= len(archs): K = len(archs)
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
return return_pairs
def forward(self, inputs):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
with torch.no_grad():
alphas_cpu = alphas.detach().cpu()
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
if self.mode == 'urs':
feature = cell.forward_urs(feature)
elif self.mode == 'select':
feature = cell.forward_select(feature, alphas_cpu)
elif self.mode == 'joint':
feature = cell.forward_joint(feature, alphas)
elif self.mode == 'dynamic':
feature = cell.forward_dynamic(feature, self.dynamic_cell)
else: raise ValueError('invalid mode={:}'.format(self.mode))
else: feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,139 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkSETN(nn.Module):
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int,
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
super(NASNetworkSETN, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.mode = 'urs'
self.dynamic_cell = None
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ['urs', 'joint', 'select', 'dynamic']
self.mode = mode
if mode == 'dynamic':
self.dynamic_cell = deepcopy(dynamic_cell)
else:
self.dynamic_cell = None
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def get_alphas(self):
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self):
with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() )
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() )
return '{:}\n{:}'.format(A, B)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[ self.edge2index[node_str] ]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[ op_index ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j)
ws = weights[ self.edge2index[node_str] ]
for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue
edges.append( (op_name, j, ws[k]) )
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append( tuple(selected_edges) )
return gene
with torch.no_grad():
gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy())
gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)),
'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}
def forward(self, inputs):
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1)
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
# [TODO]
raise NotImplementedError
if cell.reduction: hardwts, index = reduce_hardwts, reduce_index
else : hardwts, index = normal_hardwts, normal_index
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits