add autodl

This commit is contained in:
mhz
2024-08-25 18:02:31 +02:00
parent 192f286cfb
commit a0a25f291c
431 changed files with 50646 additions and 8 deletions

View File

@@ -0,0 +1,33 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# The macro structure is defined in NAS-Bench-201
from .search_model_darts import TinyNetworkDarts
from .search_model_gdas import TinyNetworkGDAS
from .search_model_setn import TinyNetworkSETN
from .search_model_enas import TinyNetworkENAS
from .search_model_random import TinyNetworkRANDOM
from .generic_model import GenericNAS201Model
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
# NASNet-based macro structure
from .search_model_gdas_nasnet import NASNetworkGDAS
from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC
from .search_model_darts_nasnet import NASNetworkDARTS
nas201_super_nets = {
"DARTS-V1": TinyNetworkDarts,
"DARTS-V2": TinyNetworkDarts,
"GDAS": TinyNetworkGDAS,
"SETN": TinyNetworkSETN,
"ENAS": TinyNetworkENAS,
"RANDOM": TinyNetworkRANDOM,
"generic": GenericNAS201Model,
}
nasnet_super_nets = {
"GDAS": NASNetworkGDAS,
"GDAS_FRC": NASNetworkGDAS_FRC,
"DARTS": NASNetworkDARTS,
}

View File

@@ -0,0 +1,14 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
from search_model_enas_utils import Controller
def main():
controller = Controller(6, 4)
predictions = controller()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,366 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
#####################################################
import torch, random
import torch.nn as nn
from copy import deepcopy
from typing import Text
from torch.distributions.categorical import Categorical
from ..cell_operations import ResNetBasicblock, drop_path
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class Controller(nn.Module):
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
def __init__(
self,
edge2index,
op_names,
max_nodes,
lstm_size=32,
lstm_num_layers=2,
tanh_constant=2.5,
temperature=5.0,
):
super(Controller, self).__init__()
# assign the attributes
self.max_nodes = max_nodes
self.num_edge = len(edge2index)
self.edge2index = edge2index
self.num_ops = len(op_names)
self.op_names = op_names
self.lstm_size = lstm_size
self.lstm_N = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
# create parameters
self.register_parameter(
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
)
self.w_lstm = nn.LSTM(
input_size=self.lstm_size,
hidden_size=self.lstm_size,
num_layers=self.lstm_N,
)
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
nn.init.uniform_(self.input_vars, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
def convert_structure(self, _arch):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_index = _arch[self.edge2index[node_str]]
op_name = self.op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def forward(self):
inputs, h0 = self.input_vars, None
log_probs, entropys, sampled_arch = [], [], []
for iedge in range(self.num_edge):
outputs, h0 = self.w_lstm(inputs, h0)
logits = self.w_pred(outputs)
logits = logits / self.temperature
logits = self.tanh_constant * torch.tanh(logits)
# distribution
op_distribution = Categorical(logits=logits)
op_index = op_distribution.sample()
sampled_arch.append(op_index.item())
op_log_prob = op_distribution.log_prob(op_index)
log_probs.append(op_log_prob.view(-1))
op_entropy = op_distribution.entropy()
entropys.append(op_entropy.view(-1))
# obtain the input embedding for the next step
inputs = self.w_embd(op_index)
return (
torch.sum(torch.cat(log_probs)),
torch.sum(torch.cat(entropys)),
self.convert_structure(sampled_arch),
)
class GenericNAS201Model(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(GenericNAS201Model, self).__init__()
self._C = C
self._layerN = N
self._max_nodes = max_nodes
self._stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self._cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self._cells.append(cell)
C_prev = cell.out_dim
self._op_names = deepcopy(search_space)
self._Layer = len(self._cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(
nn.BatchNorm2d(
C_prev, affine=affine, track_running_stats=track_running_stats
),
nn.ReLU(inplace=True),
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self._num_edge = num_edge
# algorithm related
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self._mode = None
self.dynamic_cell = None
self._tau = None
self._algo = None
self._drop_path = None
self.verbose = False
def set_algo(self, algo: Text):
# used for searching
assert self._algo is None, "This functioin can only be called once."
self._algo = algo
if algo == "enas":
self.controller = Controller(
self.edge2index, self._op_names, self._max_nodes
)
else:
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(self._num_edge, len(self._op_names))
)
if algo == "gdas":
self._tau = 10
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ["gdas", "enas", "urs", "joint", "select", "dynamic"]
self._mode = mode
if mode == "dynamic":
self.dynamic_cell = deepcopy(dynamic_cell)
else:
self.dynamic_cell = None
def set_drop_path(self, progress, drop_path_rate):
if drop_path_rate is None:
self._drop_path = None
elif progress is None:
self._drop_path = drop_path_rate
else:
self._drop_path = progress * drop_path_rate
@property
def mode(self):
return self._mode
@property
def drop_path(self):
return self._drop_path
@property
def weights(self):
xlist = list(self._stem.parameters())
xlist += list(self._cells.parameters())
xlist += list(self.lastact.parameters())
xlist += list(self.global_pooling.parameters())
xlist += list(self.classifier.parameters())
return xlist
def set_tau(self, tau):
self._tau = tau
@property
def tau(self):
return self._tau
@property
def alphas(self):
if self._algo == "enas":
return list(self.controller.parameters())
else:
return [self.arch_parameters]
@property
def message(self):
string = self.extra_repr()
for i, cell in enumerate(self._cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self._cells), cell.extra_repr()
)
return string
def show_alphas(self):
with torch.no_grad():
if self._algo == "enas":
return "w_pred :\n{:}".format(self.controller.w_pred.weight)
else:
return "arch-parameters :\n{:}".format(
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
)
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})".format(
name=self.__class__.__name__, **self.__dict__
)
@property
def genotype(self):
genotypes = []
for i in range(1, self._max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self._op_names[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self._max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if use_random:
op_name = random.choice(self._op_names)
else:
weights = alphas_cpu[self.edge2index[node_str]]
op_index = torch.multinomial(weights, 1).item()
op_name = self._op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def get_log_prob(self, arch):
with torch.no_grad():
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = "{:}<-{:}".format(i + 1, xin)
op_index = self._op_names.index(op)
select_logits.append(logits[self.edge2index[node_str], op_index])
return sum(select_logits).item()
def return_topK(self, K, use_random=False):
archs = Structure.gen_all(self._op_names, self._max_nodes, False)
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
if K < 0 or K >= len(archs):
K = len(archs)
if use_random:
return random.sample(archs, K)
else:
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
return return_pairs
def normalize_archp(self):
if self.mode == "gdas":
while True:
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
with torch.no_grad():
hardwts_cpu = hardwts.detach().cpu()
return hardwts, hardwts_cpu, index, "GUMBEL"
else:
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
index = alphas.max(-1, keepdim=True)[1]
with torch.no_grad():
alphas_cpu = alphas.detach().cpu()
return alphas, alphas_cpu, index, "SOFTMAX"
def forward(self, inputs):
alphas, alphas_cpu, index, verbose_str = self.normalize_archp()
feature = self._stem(inputs)
for i, cell in enumerate(self._cells):
if isinstance(cell, SearchCell):
if self.mode == "urs":
feature = cell.forward_urs(feature)
if self.verbose:
verbose_str += "-forward_urs"
elif self.mode == "select":
feature = cell.forward_select(feature, alphas_cpu)
if self.verbose:
verbose_str += "-forward_select"
elif self.mode == "joint":
feature = cell.forward_joint(feature, alphas)
if self.verbose:
verbose_str += "-forward_joint"
elif self.mode == "dynamic":
feature = cell.forward_dynamic(feature, self.dynamic_cell)
if self.verbose:
verbose_str += "-forward_dynamic"
elif self.mode == "gdas":
feature = cell.forward_gdas(feature, alphas, index)
if self.verbose:
verbose_str += "-forward_gdas"
elif self.mode == "gdas_v1":
feature = cell.forward_gdas_v1(feature, alphas, index)
if self.verbose:
verbose_str += "-forward_gdas_v1"
else:
raise ValueError("invalid mode={:}".format(self.mode))
else:
feature = cell(feature)
if self.drop_path is not None:
feature = drop_path(feature, self.drop_path)
if self.verbose and random.random() < 0.001:
print(verbose_str)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,274 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from copy import deepcopy
def get_combination(space, num):
combs = []
for i in range(num):
if i == 0:
for func in space:
combs.append([(func, i)])
else:
new_combs = []
for string in combs:
for func in space:
xstring = string + [(func, i)]
new_combs.append(xstring)
combs = new_combs
return combs
class Structure:
def __init__(self, genotype):
assert isinstance(genotype, list) or isinstance(
genotype, tuple
), "invalid class of genotype : {:}".format(type(genotype))
self.node_num = len(genotype) + 1
self.nodes = []
self.node_N = []
for idx, node_info in enumerate(genotype):
assert isinstance(node_info, list) or isinstance(
node_info, tuple
), "invalid class of node_info : {:}".format(type(node_info))
assert len(node_info) >= 1, "invalid length : {:}".format(len(node_info))
for node_in in node_info:
assert isinstance(node_in, list) or isinstance(
node_in, tuple
), "invalid class of in-node : {:}".format(type(node_in))
assert (
len(node_in) == 2 and node_in[1] <= idx
), "invalid in-node : {:}".format(node_in)
self.node_N.append(len(node_info))
self.nodes.append(tuple(deepcopy(node_info)))
def tolist(self, remove_str):
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
# note that we re-order the input node in this function
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
genotypes = []
for node_info in self.nodes:
node_info = list(node_info)
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
if len(node_info) == 0:
return None, False
genotypes.append(node_info)
return genotypes, True
def node(self, index):
assert index > 0 and index <= len(self), "invalid index={:} < {:}".format(
index, len(self)
)
return self.nodes[index]
def tostr(self):
strings = []
for node_info in self.nodes:
string = "|".join([x[0] + "~{:}".format(x[1]) for x in node_info])
string = "|{:}|".format(string)
strings.append(string)
return "+".join(strings)
def check_valid(self):
nodes = {0: True}
for i, node_info in enumerate(self.nodes):
sums = []
for op, xin in node_info:
if op == "none" or nodes[xin] is False:
x = False
else:
x = True
sums.append(x)
nodes[i + 1] = sum(sums) > 0
return nodes[len(self.nodes)]
def to_unique_str(self, consider_zero=False):
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
# two operations are special, i.e., none and skip_connect
nodes = {0: "0"}
for i_node, node_info in enumerate(self.nodes):
cur_node = []
for op, xin in node_info:
if consider_zero is None:
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
elif consider_zero:
if op == "none" or nodes[xin] == "#":
x = "#" # zero
elif op == "skip_connect":
x = nodes[xin]
else:
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
else:
if op == "skip_connect":
x = nodes[xin]
else:
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
cur_node.append(x)
nodes[i_node + 1] = "+".join(sorted(cur_node))
return nodes[len(self.nodes)]
def check_valid_op(self, op_names):
for node_info in self.nodes:
for inode_edge in node_info:
# assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
if inode_edge[0] not in op_names:
return False
return True
def __repr__(self):
return "{name}({node_num} nodes with {node_info})".format(
name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__
)
def __len__(self):
return len(self.nodes) + 1
def __getitem__(self, index):
return self.nodes[index]
@staticmethod
def str2structure(xstr):
if isinstance(xstr, Structure):
return xstr
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
type(xstr)
)
nodestrs = xstr.split("+")
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != "", node_str.split("|")))
for xinput in inputs:
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
xinput
)
inputs = (xi.split("~") for xi in inputs)
input_infos = tuple((op, int(IDX)) for (op, IDX) in inputs)
genotypes.append(input_infos)
return Structure(genotypes)
@staticmethod
def str2fullstructure(xstr, default_name="none"):
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
type(xstr)
)
nodestrs = xstr.split("+")
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != "", node_str.split("|")))
for xinput in inputs:
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
xinput
)
inputs = (xi.split("~") for xi in inputs)
input_infos = list((op, int(IDX)) for (op, IDX) in inputs)
all_in_nodes = list(x[1] for x in input_infos)
for j in range(i):
if j not in all_in_nodes:
input_infos.append((default_name, j))
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
genotypes.append(tuple(node_info))
return Structure(genotypes)
@staticmethod
def gen_all(search_space, num, return_ori):
assert isinstance(search_space, list) or isinstance(
search_space, tuple
), "invalid class of search-space : {:}".format(type(search_space))
assert (
num >= 2
), "There should be at least two nodes in a neural cell instead of {:}".format(
num
)
all_archs = get_combination(search_space, 1)
for i, arch in enumerate(all_archs):
all_archs[i] = [tuple(arch)]
for inode in range(2, num):
cur_nodes = get_combination(search_space, inode)
new_all_archs = []
for previous_arch in all_archs:
for cur_node in cur_nodes:
new_all_archs.append(previous_arch + [tuple(cur_node)])
all_archs = new_all_archs
if return_ori:
return all_archs
else:
return [Structure(x) for x in all_archs]
ResNet_CODE = Structure(
[
(("nor_conv_3x3", 0),), # node-1
(("nor_conv_3x3", 1),), # node-2
(("skip_connect", 0), ("skip_connect", 2)),
] # node-3
)
AllConv3x3_CODE = Structure(
[
(("nor_conv_3x3", 0),), # node-1
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1)), # node-2
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
] # node-3
)
AllFull_CODE = Structure(
[
(
("skip_connect", 0),
("nor_conv_1x1", 0),
("nor_conv_3x3", 0),
("avg_pool_3x3", 0),
), # node-1
(
("skip_connect", 0),
("nor_conv_1x1", 0),
("nor_conv_3x3", 0),
("avg_pool_3x3", 0),
("skip_connect", 1),
("nor_conv_1x1", 1),
("nor_conv_3x3", 1),
("avg_pool_3x3", 1),
), # node-2
(
("skip_connect", 0),
("nor_conv_1x1", 0),
("nor_conv_3x3", 0),
("avg_pool_3x3", 0),
("skip_connect", 1),
("nor_conv_1x1", 1),
("nor_conv_3x3", 1),
("avg_pool_3x3", 1),
("skip_connect", 2),
("nor_conv_1x1", 2),
("nor_conv_3x3", 2),
("avg_pool_3x3", 2),
),
] # node-3
)
AllConv1x1_CODE = Structure(
[
(("nor_conv_1x1", 0),), # node-1
(("nor_conv_1x1", 0), ("nor_conv_1x1", 1)), # node-2
(("nor_conv_1x1", 0), ("nor_conv_1x1", 1), ("nor_conv_1x1", 2)),
] # node-3
)
AllIdentity_CODE = Structure(
[
(("skip_connect", 0),), # node-1
(("skip_connect", 0), ("skip_connect", 1)), # node-2
(("skip_connect", 0), ("skip_connect", 1), ("skip_connect", 2)),
] # node-3
)
architectures = {
"resnet": ResNet_CODE,
"all_c3x3": AllConv3x3_CODE,
"all_c1x1": AllConv1x1_CODE,
"all_idnt": AllIdentity_CODE,
"all_full": AllFull_CODE,
}

View File

@@ -0,0 +1,267 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, random, torch
import warnings
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from ..cell_operations import OPS
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
class NAS201SearchCell(nn.Module):
def __init__(
self,
C_in,
C_out,
stride,
max_nodes,
op_names,
affine=False,
track_running_stats=True,
):
super(NAS201SearchCell, self).__init__()
self.op_names = deepcopy(op_names)
self.edges = nn.ModuleDict()
self.max_nodes = max_nodes
self.in_dim = C_in
self.out_dim = C_out
for i in range(1, max_nodes):
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if j == 0:
xlists = [
OPS[op_name](C_in, C_out, stride, affine, track_running_stats)
for op_name in op_names
]
else:
xlists = [
OPS[op_name](C_in, C_out, 1, affine, track_running_stats)
for op_name in op_names
]
self.edges[node_str] = nn.ModuleList(xlists)
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
def extra_repr(self):
string = "info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}".format(
**self.__dict__
)
return string
def forward(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = weightss[self.edge2index[node_str]]
inter_nodes.append(
sum(
layer(nodes[j]) * w
for layer, w in zip(self.edges[node_str], weights)
)
)
nodes.append(sum(inter_nodes))
return nodes[-1]
# GDAS
def forward_gdas(self, inputs, hardwts, index):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = hardwts[self.edge2index[node_str]]
argmaxs = index[self.edge2index[node_str]].item()
weigsum = sum(
weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie]
for _ie, edge in enumerate(self.edges[node_str])
)
inter_nodes.append(weigsum)
nodes.append(sum(inter_nodes))
return nodes[-1]
# GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119
def forward_gdas_v1(self, inputs, hardwts, index):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = hardwts[self.edge2index[node_str]]
argmaxs = index[self.edge2index[node_str]].item()
weigsum = weights[argmaxs] * self.edges[node_str](nodes[j])
inter_nodes.append(weigsum)
nodes.append(sum(inter_nodes))
return nodes[-1]
# joint
def forward_joint(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = weightss[self.edge2index[node_str]]
# aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
aggregation = sum(
layer(nodes[j]) * w
for layer, w in zip(self.edges[node_str], weights)
)
inter_nodes.append(aggregation)
nodes.append(sum(inter_nodes))
return nodes[-1]
# uniform random sampling per iteration, SETN
def forward_urs(self, inputs):
nodes = [inputs]
for i in range(1, self.max_nodes):
while True: # to avoid select zero for all ops
sops, has_non_zero = [], False
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
candidates = self.edges[node_str]
select_op = random.choice(candidates)
sops.append(select_op)
if not hasattr(select_op, "is_zero") or select_op.is_zero is False:
has_non_zero = True
if has_non_zero:
break
inter_nodes = []
for j, select_op in enumerate(sops):
inter_nodes.append(select_op(nodes[j]))
nodes.append(sum(inter_nodes))
return nodes[-1]
# select the argmax
def forward_select(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = weightss[self.edge2index[node_str]]
inter_nodes.append(
self.edges[node_str][weights.argmax().item()](nodes[j])
)
# inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
nodes.append(sum(inter_nodes))
return nodes[-1]
# forward with a specific structure
def forward_dynamic(self, inputs, structure):
nodes = [inputs]
for i in range(1, self.max_nodes):
cur_op_node = structure.nodes[i - 1]
inter_nodes = []
for op_name, j in cur_op_node:
node_str = "{:}<-{:}".format(i, j)
op_index = self.op_names.index(op_name)
inter_nodes.append(self.edges[node_str][op_index](nodes[j]))
nodes.append(sum(inter_nodes))
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class MixedOp(nn.Module):
def __init__(self, space, C, stride, affine, track_running_stats):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in space:
op = OPS[primitive](C, C, stride, affine, track_running_stats)
self._ops.append(op)
def forward_gdas(self, x, weights, index):
return self._ops[index](x) * weights[index]
def forward_darts(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops))
class NASNetSearchCell(nn.Module):
def __init__(
self,
space,
steps,
multiplier,
C_prev_prev,
C_prev,
C,
reduction,
reduction_prev,
affine,
track_running_stats,
):
super(NASNetSearchCell, self).__init__()
self.reduction = reduction
self.op_names = deepcopy(space)
if reduction_prev:
self.preprocess0 = OPS["skip_connect"](
C_prev_prev, C, 2, affine, track_running_stats
)
else:
self.preprocess0 = OPS["nor_conv_1x1"](
C_prev_prev, C, 1, affine, track_running_stats
)
self.preprocess1 = OPS["nor_conv_1x1"](
C_prev, C, 1, affine, track_running_stats
)
self._steps = steps
self._multiplier = multiplier
self._ops = nn.ModuleList()
self.edges = nn.ModuleDict()
for i in range(self._steps):
for j in range(2 + i):
node_str = "{:}<-{:}".format(
i, j
) # indicate the edge from node-(j) to node-(i+2)
stride = 2 if reduction and j < 2 else 1
op = MixedOp(space, C, stride, affine, track_running_stats)
self.edges[node_str] = op
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
@property
def multiplier(self):
return self._multiplier
def forward_gdas(self, s0, s1, weightss, indexs):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = "{:}<-{:}".format(i, j)
op = self.edges[node_str]
weights = weightss[self.edge2index[node_str]]
index = indexs[self.edge2index[node_str]].item()
clist.append(op.forward_gdas(h, weights, index))
states.append(sum(clist))
return torch.cat(states[-self._multiplier :], dim=1)
def forward_darts(self, s0, s1, weightss):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = "{:}<-{:}".format(i, j)
op = self.edges[node_str]
weights = weightss[self.edge2index[node_str]]
clist.append(op.forward_darts(h, weights))
states.append(sum(clist))
return torch.cat(states[-self._multiplier :], dim=1)

View File

@@ -0,0 +1,122 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
########################################################
# DARTS: Differentiable Architecture Search, ICLR 2019 #
########################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkDarts(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkDarts, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def get_alphas(self):
return [self.arch_parameters]
def show_alphas(self):
with torch.no_grad():
return "arch-parameters :\n{:}".format(
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self.op_names[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def forward(self, inputs):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell(feature, alphas)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,178 @@
####################
# DARTS, ICLR 2019 #
####################
import torch
import torch.nn as nn
from copy import deepcopy
from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkDARTS(nn.Module):
def __init__(
self,
C: int,
N: int,
steps: int,
multiplier: int,
stem_multiplier: int,
num_classes: int,
search_space: List[Text],
affine: bool,
track_running_stats: bool,
):
super(NASNetworkDARTS, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
# config for each layer
layer_channels = (
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
)
layer_reductions = (
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
cell = SearchCell(
search_space,
steps,
multiplier,
C_prev_prev,
C_prev,
C_curr,
reduction,
reduction_prev,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.arch_reduce_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
def get_weights(self) -> List[torch.nn.Parameter]:
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def get_alphas(self) -> List[torch.nn.Parameter]:
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self) -> Text:
with torch.no_grad():
A = "arch-normal-parameters :\n{:}".format(
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
)
B = "arch-reduce-parameters :\n{:}".format(
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
)
return "{:}\n{:}".format(A, B)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self) -> Text:
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self) -> Dict[Text, List]:
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2 + i):
node_str = "{:}<-{:}".format(i, j)
ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names):
if op_name == "none":
continue
edges.append((op_name, j, ws[k]))
# (TODO) xuanyidong:
# Here the selected two edges might come from the same input node.
# And this case could be a problem that two edges will collapse into a single one
# due to our assumption -- at most one edge from an input node during evaluation.
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append(tuple(selected_edges))
return gene
with torch.no_grad():
gene_normal = _parse(
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
)
gene_reduce = _parse(
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
)
return {
"normal": gene_normal,
"normal_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
"reduce": gene_reduce,
"reduce_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
}
def forward(self, inputs):
normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1)
reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction:
ww = reduce_w
else:
ww = normal_w
s0, s1 = s1, cell.forward_darts(s0, s1, ww)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,114 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##########################################################################
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
##########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
from .search_model_enas_utils import Controller
class TinyNetworkENAS(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkENAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# to maintain the sampled architecture
self.sampled_arch = None
def update_arch(self, _arch):
if _arch is None:
self.sampled_arch = None
elif isinstance(_arch, Structure):
self.sampled_arch = _arch
elif isinstance(_arch, (list, tuple)):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_index = _arch[self.edge2index[node_str]]
op_name = self.op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
self.sampled_arch = Structure(genotypes)
else:
raise ValueError("invalid type of input architecture : {:}".format(_arch))
return self.sampled_arch
def create_controller(self):
return Controller(len(self.edge2index), len(self.op_names))
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_dynamic(feature, self.sampled_arch)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,74 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##########################################################################
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
##########################################################################
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
class Controller(nn.Module):
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
def __init__(
self,
num_edge,
num_ops,
lstm_size=32,
lstm_num_layers=2,
tanh_constant=2.5,
temperature=5.0,
):
super(Controller, self).__init__()
# assign the attributes
self.num_edge = num_edge
self.num_ops = num_ops
self.lstm_size = lstm_size
self.lstm_N = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
# create parameters
self.register_parameter(
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
)
self.w_lstm = nn.LSTM(
input_size=self.lstm_size,
hidden_size=self.lstm_size,
num_layers=self.lstm_N,
)
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
nn.init.uniform_(self.input_vars, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
def forward(self):
inputs, h0 = self.input_vars, None
log_probs, entropys, sampled_arch = [], [], []
for iedge in range(self.num_edge):
outputs, h0 = self.w_lstm(inputs, h0)
logits = self.w_pred(outputs)
logits = logits / self.temperature
logits = self.tanh_constant * torch.tanh(logits)
# distribution
op_distribution = Categorical(logits=logits)
op_index = op_distribution.sample()
sampled_arch.append(op_index.item())
op_log_prob = op_distribution.log_prob(op_index)
log_probs.append(op_log_prob.view(-1))
op_entropy = op_distribution.entropy()
entropys.append(op_entropy.view(-1))
# obtain the input embedding for the next step
inputs = self.w_embd(op_index)
return (
torch.sum(torch.cat(log_probs)),
torch.sum(torch.cat(entropys)),
sampled_arch,
)

View File

@@ -0,0 +1,142 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkGDAS(nn.Module):
# def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.tau = 10
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def get_alphas(self):
return [self.arch_parameters]
def show_alphas(self):
with torch.no_grad():
return "arch-parameters :\n{:}".format(
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self.op_names[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def forward(self, inputs):
while True:
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_gdas(feature, hardwts, index)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,200 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from .search_cells import NASNetSearchCell as SearchCell
from ..cell_operations import RAW_OP_CLASSES
# The macro structure is based on NASNet
class NASNetworkGDAS_FRC(nn.Module):
def __init__(
self,
C,
N,
steps,
multiplier,
stem_multiplier,
num_classes,
search_space,
affine,
track_running_stats,
):
super(NASNetworkGDAS_FRC, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
# config for each layer
layer_channels = (
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
)
layer_reductions = (
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = RAW_OP_CLASSES["gdas_reduction"](
C_prev_prev,
C_prev,
C_curr,
reduction_prev,
affine,
track_running_stats,
)
else:
cell = SearchCell(
search_space,
steps,
multiplier,
C_prev_prev,
C_prev,
C_curr,
reduction,
reduction_prev,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
reduction
or num_edge == cell.num_edges
and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = (
C_prev,
cell.multiplier * C_curr,
reduction,
)
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.tau = 10
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def get_alphas(self):
return [self.arch_parameters]
def show_alphas(self):
with torch.no_grad():
A = "arch-normal-parameters :\n{:}".format(
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
)
return "{:}".format(A)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2 + i):
node_str = "{:}<-{:}".format(i, j)
ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names):
if op_name == "none":
continue
edges.append((op_name, j, ws[k]))
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append(tuple(selected_edges))
return gene
with torch.no_grad():
gene_normal = _parse(
torch.softmax(self.arch_parameters, dim=-1).cpu().numpy()
)
return {
"normal": gene_normal,
"normal_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
}
def forward(self, inputs):
def get_gumbel_prob(xins):
while True:
gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
return hardwts, index
hardwts, index = get_gumbel_prob(self.arch_parameters)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction:
s0, s1 = s1, cell(s0, s1)
else:
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,197 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkGDAS(nn.Module):
def __init__(
self,
C,
N,
steps,
multiplier,
stem_multiplier,
num_classes,
search_space,
affine,
track_running_stats,
):
super(NASNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
# config for each layer
layer_channels = (
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
)
layer_reductions = (
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
cell = SearchCell(
search_space,
steps,
multiplier,
C_prev_prev,
C_prev,
C_curr,
reduction,
reduction_prev,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.arch_reduce_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.tau = 10
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def get_alphas(self):
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self):
with torch.no_grad():
A = "arch-normal-parameters :\n{:}".format(
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
)
B = "arch-reduce-parameters :\n{:}".format(
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
)
return "{:}\n{:}".format(A, B)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2 + i):
node_str = "{:}<-{:}".format(i, j)
ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names):
if op_name == "none":
continue
edges.append((op_name, j, ws[k]))
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append(tuple(selected_edges))
return gene
with torch.no_grad():
gene_normal = _parse(
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
)
gene_reduce = _parse(
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
)
return {
"normal": gene_normal,
"normal_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
"reduce": gene_reduce,
"reduce_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
}
def forward(self, inputs):
def get_gumbel_prob(xins):
while True:
gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
return hardwts, index
normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters)
reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction:
hardwts, index = reduce_hardwts, reduce_index
else:
hardwts, index = normal_hardwts, normal_index
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,102 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##############################################################################
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
##############################################################################
import torch, random
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkRANDOM(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkRANDOM, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_cache = None
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def random_genotype(self, set_cache):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = random.choice(self.op_names)
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
arch = Structure(genotypes)
if set_cache:
self.arch_cache = arch
return arch
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_dynamic(feature, self.arch_cache)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,178 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import torch, random
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkSETN(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkSETN, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.mode = "urs"
self.dynamic_cell = None
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ["urs", "joint", "select", "dynamic"]
self.mode = mode
if mode == "dynamic":
self.dynamic_cell = deepcopy(dynamic_cell)
else:
self.dynamic_cell = None
def get_cal_mode(self):
return self.mode
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def get_alphas(self):
return [self.arch_parameters]
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self.op_names[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[self.edge2index[node_str]]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def get_log_prob(self, arch):
with torch.no_grad():
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = "{:}<-{:}".format(i + 1, xin)
op_index = self.op_names.index(op)
select_logits.append(logits[self.edge2index[node_str], op_index])
return sum(select_logits).item()
def return_topK(self, K):
archs = Structure.gen_all(self.op_names, self.max_nodes, False)
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
if K < 0 or K >= len(archs):
K = len(archs)
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
return return_pairs
def forward(self, inputs):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
with torch.no_grad():
alphas_cpu = alphas.detach().cpu()
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
if self.mode == "urs":
feature = cell.forward_urs(feature)
elif self.mode == "select":
feature = cell.forward_select(feature, alphas_cpu)
elif self.mode == "joint":
feature = cell.forward_joint(feature, alphas)
elif self.mode == "dynamic":
feature = cell.forward_dynamic(feature, self.dynamic_cell)
else:
raise ValueError("invalid mode={:}".format(self.mode))
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,205 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkSETN(nn.Module):
def __init__(
self,
C: int,
N: int,
steps: int,
multiplier: int,
stem_multiplier: int,
num_classes: int,
search_space: List[Text],
affine: bool,
track_running_stats: bool,
):
super(NASNetworkSETN, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
# config for each layer
layer_channels = (
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
)
layer_reductions = (
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
cell = SearchCell(
search_space,
steps,
multiplier,
C_prev_prev,
C_prev,
C_curr,
reduction,
reduction_prev,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.arch_reduce_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.mode = "urs"
self.dynamic_cell = None
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ["urs", "joint", "select", "dynamic"]
self.mode = mode
if mode == "dynamic":
self.dynamic_cell = deepcopy(dynamic_cell)
else:
self.dynamic_cell = None
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def get_alphas(self):
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self):
with torch.no_grad():
A = "arch-normal-parameters :\n{:}".format(
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
)
B = "arch-reduce-parameters :\n{:}".format(
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
)
return "{:}\n{:}".format(A, B)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[self.edge2index[node_str]]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2 + i):
node_str = "{:}<-{:}".format(i, j)
ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names):
if op_name == "none":
continue
edges.append((op_name, j, ws[k]))
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append(tuple(selected_edges))
return gene
with torch.no_grad():
gene_normal = _parse(
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
)
gene_reduce = _parse(
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
)
return {
"normal": gene_normal,
"normal_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
"reduce": gene_reduce,
"reduce_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
}
def forward(self, inputs):
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1)
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
# [TODO]
raise NotImplementedError
if cell.reduction:
hardwts, index = reduce_hardwts, reduce_index
else:
hardwts, index = normal_hardwts, normal_index
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits