add autodl
This commit is contained in:
33
AutoDL-Projects/xautodl/models/cell_searchs/__init__.py
Normal file
33
AutoDL-Projects/xautodl/models/cell_searchs/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# The macro structure is defined in NAS-Bench-201
|
||||
from .search_model_darts import TinyNetworkDarts
|
||||
from .search_model_gdas import TinyNetworkGDAS
|
||||
from .search_model_setn import TinyNetworkSETN
|
||||
from .search_model_enas import TinyNetworkENAS
|
||||
from .search_model_random import TinyNetworkRANDOM
|
||||
from .generic_model import GenericNAS201Model
|
||||
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
|
||||
|
||||
# NASNet-based macro structure
|
||||
from .search_model_gdas_nasnet import NASNetworkGDAS
|
||||
from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC
|
||||
from .search_model_darts_nasnet import NASNetworkDARTS
|
||||
|
||||
|
||||
nas201_super_nets = {
|
||||
"DARTS-V1": TinyNetworkDarts,
|
||||
"DARTS-V2": TinyNetworkDarts,
|
||||
"GDAS": TinyNetworkGDAS,
|
||||
"SETN": TinyNetworkSETN,
|
||||
"ENAS": TinyNetworkENAS,
|
||||
"RANDOM": TinyNetworkRANDOM,
|
||||
"generic": GenericNAS201Model,
|
||||
}
|
||||
|
||||
nasnet_super_nets = {
|
||||
"GDAS": NASNetworkGDAS,
|
||||
"GDAS_FRC": NASNetworkGDAS_FRC,
|
||||
"DARTS": NASNetworkDARTS,
|
||||
}
|
14
AutoDL-Projects/xautodl/models/cell_searchs/_test_module.py
Normal file
14
AutoDL-Projects/xautodl/models/cell_searchs/_test_module.py
Normal file
@@ -0,0 +1,14 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
from search_model_enas_utils import Controller
|
||||
|
||||
|
||||
def main():
|
||||
controller = Controller(6, 4)
|
||||
predictions = controller()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
366
AutoDL-Projects/xautodl/models/cell_searchs/generic_model.py
Normal file
366
AutoDL-Projects/xautodl/models/cell_searchs/generic_model.py
Normal file
@@ -0,0 +1,366 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
|
||||
#####################################################
|
||||
import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import Text
|
||||
from torch.distributions.categorical import Categorical
|
||||
|
||||
from ..cell_operations import ResNetBasicblock, drop_path
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class Controller(nn.Module):
|
||||
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
|
||||
def __init__(
|
||||
self,
|
||||
edge2index,
|
||||
op_names,
|
||||
max_nodes,
|
||||
lstm_size=32,
|
||||
lstm_num_layers=2,
|
||||
tanh_constant=2.5,
|
||||
temperature=5.0,
|
||||
):
|
||||
super(Controller, self).__init__()
|
||||
# assign the attributes
|
||||
self.max_nodes = max_nodes
|
||||
self.num_edge = len(edge2index)
|
||||
self.edge2index = edge2index
|
||||
self.num_ops = len(op_names)
|
||||
self.op_names = op_names
|
||||
self.lstm_size = lstm_size
|
||||
self.lstm_N = lstm_num_layers
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
# create parameters
|
||||
self.register_parameter(
|
||||
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
|
||||
)
|
||||
self.w_lstm = nn.LSTM(
|
||||
input_size=self.lstm_size,
|
||||
hidden_size=self.lstm_size,
|
||||
num_layers=self.lstm_N,
|
||||
)
|
||||
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
|
||||
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
|
||||
|
||||
nn.init.uniform_(self.input_vars, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
|
||||
|
||||
def convert_structure(self, _arch):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_index = _arch[self.edge2index[node_str]]
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def forward(self):
|
||||
|
||||
inputs, h0 = self.input_vars, None
|
||||
log_probs, entropys, sampled_arch = [], [], []
|
||||
for iedge in range(self.num_edge):
|
||||
outputs, h0 = self.w_lstm(inputs, h0)
|
||||
|
||||
logits = self.w_pred(outputs)
|
||||
logits = logits / self.temperature
|
||||
logits = self.tanh_constant * torch.tanh(logits)
|
||||
# distribution
|
||||
op_distribution = Categorical(logits=logits)
|
||||
op_index = op_distribution.sample()
|
||||
sampled_arch.append(op_index.item())
|
||||
|
||||
op_log_prob = op_distribution.log_prob(op_index)
|
||||
log_probs.append(op_log_prob.view(-1))
|
||||
op_entropy = op_distribution.entropy()
|
||||
entropys.append(op_entropy.view(-1))
|
||||
|
||||
# obtain the input embedding for the next step
|
||||
inputs = self.w_embd(op_index)
|
||||
return (
|
||||
torch.sum(torch.cat(log_probs)),
|
||||
torch.sum(torch.cat(entropys)),
|
||||
self.convert_structure(sampled_arch),
|
||||
)
|
||||
|
||||
|
||||
class GenericNAS201Model(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(GenericNAS201Model, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._max_nodes = max_nodes
|
||||
self._stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self._cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self._cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self._op_names = deepcopy(search_space)
|
||||
self._Layer = len(self._cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(
|
||||
nn.BatchNorm2d(
|
||||
C_prev, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self._num_edge = num_edge
|
||||
# algorithm related
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self._mode = None
|
||||
self.dynamic_cell = None
|
||||
self._tau = None
|
||||
self._algo = None
|
||||
self._drop_path = None
|
||||
self.verbose = False
|
||||
|
||||
def set_algo(self, algo: Text):
|
||||
# used for searching
|
||||
assert self._algo is None, "This functioin can only be called once."
|
||||
self._algo = algo
|
||||
if algo == "enas":
|
||||
self.controller = Controller(
|
||||
self.edge2index, self._op_names, self._max_nodes
|
||||
)
|
||||
else:
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(self._num_edge, len(self._op_names))
|
||||
)
|
||||
if algo == "gdas":
|
||||
self._tau = 10
|
||||
|
||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||
assert mode in ["gdas", "enas", "urs", "joint", "select", "dynamic"]
|
||||
self._mode = mode
|
||||
if mode == "dynamic":
|
||||
self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else:
|
||||
self.dynamic_cell = None
|
||||
|
||||
def set_drop_path(self, progress, drop_path_rate):
|
||||
if drop_path_rate is None:
|
||||
self._drop_path = None
|
||||
elif progress is None:
|
||||
self._drop_path = drop_path_rate
|
||||
else:
|
||||
self._drop_path = progress * drop_path_rate
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
return self._mode
|
||||
|
||||
@property
|
||||
def drop_path(self):
|
||||
return self._drop_path
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
xlist = list(self._stem.parameters())
|
||||
xlist += list(self._cells.parameters())
|
||||
xlist += list(self.lastact.parameters())
|
||||
xlist += list(self.global_pooling.parameters())
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self._tau = tau
|
||||
|
||||
@property
|
||||
def tau(self):
|
||||
return self._tau
|
||||
|
||||
@property
|
||||
def alphas(self):
|
||||
if self._algo == "enas":
|
||||
return list(self.controller.parameters())
|
||||
else:
|
||||
return [self.arch_parameters]
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self._cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self._cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
if self._algo == "enas":
|
||||
return "w_pred :\n{:}".format(self.controller.w_pred.weight)
|
||||
else:
|
||||
return "arch-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
@property
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self._max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self._op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
for i in range(1, self._max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if use_random:
|
||||
op_name = random.choice(self._op_names)
|
||||
else:
|
||||
weights = alphas_cpu[self.edge2index[node_str]]
|
||||
op_index = torch.multinomial(weights, 1).item()
|
||||
op_name = self._op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def get_log_prob(self, arch):
|
||||
with torch.no_grad():
|
||||
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
|
||||
select_logits = []
|
||||
for i, node_info in enumerate(arch.nodes):
|
||||
for op, xin in node_info:
|
||||
node_str = "{:}<-{:}".format(i + 1, xin)
|
||||
op_index = self._op_names.index(op)
|
||||
select_logits.append(logits[self.edge2index[node_str], op_index])
|
||||
return sum(select_logits).item()
|
||||
|
||||
def return_topK(self, K, use_random=False):
|
||||
archs = Structure.gen_all(self._op_names, self._max_nodes, False)
|
||||
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
|
||||
if K < 0 or K >= len(archs):
|
||||
K = len(archs)
|
||||
if use_random:
|
||||
return random.sample(archs, K)
|
||||
else:
|
||||
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
|
||||
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
|
||||
return return_pairs
|
||||
|
||||
def normalize_archp(self):
|
||||
if self.mode == "gdas":
|
||||
while True:
|
||||
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
|
||||
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
with torch.no_grad():
|
||||
hardwts_cpu = hardwts.detach().cpu()
|
||||
return hardwts, hardwts_cpu, index, "GUMBEL"
|
||||
else:
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
index = alphas.max(-1, keepdim=True)[1]
|
||||
with torch.no_grad():
|
||||
alphas_cpu = alphas.detach().cpu()
|
||||
return alphas, alphas_cpu, index, "SOFTMAX"
|
||||
|
||||
def forward(self, inputs):
|
||||
alphas, alphas_cpu, index, verbose_str = self.normalize_archp()
|
||||
feature = self._stem(inputs)
|
||||
for i, cell in enumerate(self._cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
if self.mode == "urs":
|
||||
feature = cell.forward_urs(feature)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_urs"
|
||||
elif self.mode == "select":
|
||||
feature = cell.forward_select(feature, alphas_cpu)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_select"
|
||||
elif self.mode == "joint":
|
||||
feature = cell.forward_joint(feature, alphas)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_joint"
|
||||
elif self.mode == "dynamic":
|
||||
feature = cell.forward_dynamic(feature, self.dynamic_cell)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_dynamic"
|
||||
elif self.mode == "gdas":
|
||||
feature = cell.forward_gdas(feature, alphas, index)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_gdas"
|
||||
elif self.mode == "gdas_v1":
|
||||
feature = cell.forward_gdas_v1(feature, alphas, index)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_gdas_v1"
|
||||
else:
|
||||
raise ValueError("invalid mode={:}".format(self.mode))
|
||||
else:
|
||||
feature = cell(feature)
|
||||
if self.drop_path is not None:
|
||||
feature = drop_path(feature, self.drop_path)
|
||||
if self.verbose and random.random() < 0.001:
|
||||
print(verbose_str)
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
return out, logits
|
274
AutoDL-Projects/xautodl/models/cell_searchs/genotypes.py
Normal file
274
AutoDL-Projects/xautodl/models/cell_searchs/genotypes.py
Normal file
@@ -0,0 +1,274 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def get_combination(space, num):
|
||||
combs = []
|
||||
for i in range(num):
|
||||
if i == 0:
|
||||
for func in space:
|
||||
combs.append([(func, i)])
|
||||
else:
|
||||
new_combs = []
|
||||
for string in combs:
|
||||
for func in space:
|
||||
xstring = string + [(func, i)]
|
||||
new_combs.append(xstring)
|
||||
combs = new_combs
|
||||
return combs
|
||||
|
||||
|
||||
class Structure:
|
||||
def __init__(self, genotype):
|
||||
assert isinstance(genotype, list) or isinstance(
|
||||
genotype, tuple
|
||||
), "invalid class of genotype : {:}".format(type(genotype))
|
||||
self.node_num = len(genotype) + 1
|
||||
self.nodes = []
|
||||
self.node_N = []
|
||||
for idx, node_info in enumerate(genotype):
|
||||
assert isinstance(node_info, list) or isinstance(
|
||||
node_info, tuple
|
||||
), "invalid class of node_info : {:}".format(type(node_info))
|
||||
assert len(node_info) >= 1, "invalid length : {:}".format(len(node_info))
|
||||
for node_in in node_info:
|
||||
assert isinstance(node_in, list) or isinstance(
|
||||
node_in, tuple
|
||||
), "invalid class of in-node : {:}".format(type(node_in))
|
||||
assert (
|
||||
len(node_in) == 2 and node_in[1] <= idx
|
||||
), "invalid in-node : {:}".format(node_in)
|
||||
self.node_N.append(len(node_info))
|
||||
self.nodes.append(tuple(deepcopy(node_info)))
|
||||
|
||||
def tolist(self, remove_str):
|
||||
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
|
||||
# note that we re-order the input node in this function
|
||||
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
|
||||
genotypes = []
|
||||
for node_info in self.nodes:
|
||||
node_info = list(node_info)
|
||||
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
|
||||
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
|
||||
if len(node_info) == 0:
|
||||
return None, False
|
||||
genotypes.append(node_info)
|
||||
return genotypes, True
|
||||
|
||||
def node(self, index):
|
||||
assert index > 0 and index <= len(self), "invalid index={:} < {:}".format(
|
||||
index, len(self)
|
||||
)
|
||||
return self.nodes[index]
|
||||
|
||||
def tostr(self):
|
||||
strings = []
|
||||
for node_info in self.nodes:
|
||||
string = "|".join([x[0] + "~{:}".format(x[1]) for x in node_info])
|
||||
string = "|{:}|".format(string)
|
||||
strings.append(string)
|
||||
return "+".join(strings)
|
||||
|
||||
def check_valid(self):
|
||||
nodes = {0: True}
|
||||
for i, node_info in enumerate(self.nodes):
|
||||
sums = []
|
||||
for op, xin in node_info:
|
||||
if op == "none" or nodes[xin] is False:
|
||||
x = False
|
||||
else:
|
||||
x = True
|
||||
sums.append(x)
|
||||
nodes[i + 1] = sum(sums) > 0
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def to_unique_str(self, consider_zero=False):
|
||||
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
||||
# two operations are special, i.e., none and skip_connect
|
||||
nodes = {0: "0"}
|
||||
for i_node, node_info in enumerate(self.nodes):
|
||||
cur_node = []
|
||||
for op, xin in node_info:
|
||||
if consider_zero is None:
|
||||
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
|
||||
elif consider_zero:
|
||||
if op == "none" or nodes[xin] == "#":
|
||||
x = "#" # zero
|
||||
elif op == "skip_connect":
|
||||
x = nodes[xin]
|
||||
else:
|
||||
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
|
||||
else:
|
||||
if op == "skip_connect":
|
||||
x = nodes[xin]
|
||||
else:
|
||||
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
|
||||
cur_node.append(x)
|
||||
nodes[i_node + 1] = "+".join(sorted(cur_node))
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def check_valid_op(self, op_names):
|
||||
for node_info in self.nodes:
|
||||
for inode_edge in node_info:
|
||||
# assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
|
||||
if inode_edge[0] not in op_names:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({node_num} nodes with {node_info})".format(
|
||||
name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes) + 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.nodes[index]
|
||||
|
||||
@staticmethod
|
||||
def str2structure(xstr):
|
||||
if isinstance(xstr, Structure):
|
||||
return xstr
|
||||
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
|
||||
type(xstr)
|
||||
)
|
||||
nodestrs = xstr.split("+")
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != "", node_str.split("|")))
|
||||
for xinput in inputs:
|
||||
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
|
||||
xinput
|
||||
)
|
||||
inputs = (xi.split("~") for xi in inputs)
|
||||
input_infos = tuple((op, int(IDX)) for (op, IDX) in inputs)
|
||||
genotypes.append(input_infos)
|
||||
return Structure(genotypes)
|
||||
|
||||
@staticmethod
|
||||
def str2fullstructure(xstr, default_name="none"):
|
||||
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
|
||||
type(xstr)
|
||||
)
|
||||
nodestrs = xstr.split("+")
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != "", node_str.split("|")))
|
||||
for xinput in inputs:
|
||||
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
|
||||
xinput
|
||||
)
|
||||
inputs = (xi.split("~") for xi in inputs)
|
||||
input_infos = list((op, int(IDX)) for (op, IDX) in inputs)
|
||||
all_in_nodes = list(x[1] for x in input_infos)
|
||||
for j in range(i):
|
||||
if j not in all_in_nodes:
|
||||
input_infos.append((default_name, j))
|
||||
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
|
||||
genotypes.append(tuple(node_info))
|
||||
return Structure(genotypes)
|
||||
|
||||
@staticmethod
|
||||
def gen_all(search_space, num, return_ori):
|
||||
assert isinstance(search_space, list) or isinstance(
|
||||
search_space, tuple
|
||||
), "invalid class of search-space : {:}".format(type(search_space))
|
||||
assert (
|
||||
num >= 2
|
||||
), "There should be at least two nodes in a neural cell instead of {:}".format(
|
||||
num
|
||||
)
|
||||
all_archs = get_combination(search_space, 1)
|
||||
for i, arch in enumerate(all_archs):
|
||||
all_archs[i] = [tuple(arch)]
|
||||
|
||||
for inode in range(2, num):
|
||||
cur_nodes = get_combination(search_space, inode)
|
||||
new_all_archs = []
|
||||
for previous_arch in all_archs:
|
||||
for cur_node in cur_nodes:
|
||||
new_all_archs.append(previous_arch + [tuple(cur_node)])
|
||||
all_archs = new_all_archs
|
||||
if return_ori:
|
||||
return all_archs
|
||||
else:
|
||||
return [Structure(x) for x in all_archs]
|
||||
|
||||
|
||||
ResNet_CODE = Structure(
|
||||
[
|
||||
(("nor_conv_3x3", 0),), # node-1
|
||||
(("nor_conv_3x3", 1),), # node-2
|
||||
(("skip_connect", 0), ("skip_connect", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllConv3x3_CODE = Structure(
|
||||
[
|
||||
(("nor_conv_3x3", 0),), # node-1
|
||||
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1)), # node-2
|
||||
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllFull_CODE = Structure(
|
||||
[
|
||||
(
|
||||
("skip_connect", 0),
|
||||
("nor_conv_1x1", 0),
|
||||
("nor_conv_3x3", 0),
|
||||
("avg_pool_3x3", 0),
|
||||
), # node-1
|
||||
(
|
||||
("skip_connect", 0),
|
||||
("nor_conv_1x1", 0),
|
||||
("nor_conv_3x3", 0),
|
||||
("avg_pool_3x3", 0),
|
||||
("skip_connect", 1),
|
||||
("nor_conv_1x1", 1),
|
||||
("nor_conv_3x3", 1),
|
||||
("avg_pool_3x3", 1),
|
||||
), # node-2
|
||||
(
|
||||
("skip_connect", 0),
|
||||
("nor_conv_1x1", 0),
|
||||
("nor_conv_3x3", 0),
|
||||
("avg_pool_3x3", 0),
|
||||
("skip_connect", 1),
|
||||
("nor_conv_1x1", 1),
|
||||
("nor_conv_3x3", 1),
|
||||
("avg_pool_3x3", 1),
|
||||
("skip_connect", 2),
|
||||
("nor_conv_1x1", 2),
|
||||
("nor_conv_3x3", 2),
|
||||
("avg_pool_3x3", 2),
|
||||
),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllConv1x1_CODE = Structure(
|
||||
[
|
||||
(("nor_conv_1x1", 0),), # node-1
|
||||
(("nor_conv_1x1", 0), ("nor_conv_1x1", 1)), # node-2
|
||||
(("nor_conv_1x1", 0), ("nor_conv_1x1", 1), ("nor_conv_1x1", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllIdentity_CODE = Structure(
|
||||
[
|
||||
(("skip_connect", 0),), # node-1
|
||||
(("skip_connect", 0), ("skip_connect", 1)), # node-2
|
||||
(("skip_connect", 0), ("skip_connect", 1), ("skip_connect", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
architectures = {
|
||||
"resnet": ResNet_CODE,
|
||||
"all_c3x3": AllConv3x3_CODE,
|
||||
"all_c1x1": AllConv1x1_CODE,
|
||||
"all_idnt": AllIdentity_CODE,
|
||||
"all_full": AllFull_CODE,
|
||||
}
|
267
AutoDL-Projects/xautodl/models/cell_searchs/search_cells.py
Normal file
267
AutoDL-Projects/xautodl/models/cell_searchs/search_cells.py
Normal file
@@ -0,0 +1,267 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, random, torch
|
||||
import warnings
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
|
||||
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
|
||||
class NAS201SearchCell(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C_in,
|
||||
C_out,
|
||||
stride,
|
||||
max_nodes,
|
||||
op_names,
|
||||
affine=False,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(NAS201SearchCell, self).__init__()
|
||||
|
||||
self.op_names = deepcopy(op_names)
|
||||
self.edges = nn.ModuleDict()
|
||||
self.max_nodes = max_nodes
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if j == 0:
|
||||
xlists = [
|
||||
OPS[op_name](C_in, C_out, stride, affine, track_running_stats)
|
||||
for op_name in op_names
|
||||
]
|
||||
else:
|
||||
xlists = [
|
||||
OPS[op_name](C_in, C_out, 1, affine, track_running_stats)
|
||||
for op_name in op_names
|
||||
]
|
||||
self.edges[node_str] = nn.ModuleList(xlists)
|
||||
self.edge_keys = sorted(list(self.edges.keys()))
|
||||
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
|
||||
self.num_edges = len(self.edges)
|
||||
|
||||
def extra_repr(self):
|
||||
string = "info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}".format(
|
||||
**self.__dict__
|
||||
)
|
||||
return string
|
||||
|
||||
def forward(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
inter_nodes.append(
|
||||
sum(
|
||||
layer(nodes[j]) * w
|
||||
for layer, w in zip(self.edges[node_str], weights)
|
||||
)
|
||||
)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# GDAS
|
||||
def forward_gdas(self, inputs, hardwts, index):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = hardwts[self.edge2index[node_str]]
|
||||
argmaxs = index[self.edge2index[node_str]].item()
|
||||
weigsum = sum(
|
||||
weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie]
|
||||
for _ie, edge in enumerate(self.edges[node_str])
|
||||
)
|
||||
inter_nodes.append(weigsum)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119
|
||||
def forward_gdas_v1(self, inputs, hardwts, index):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = hardwts[self.edge2index[node_str]]
|
||||
argmaxs = index[self.edge2index[node_str]].item()
|
||||
weigsum = weights[argmaxs] * self.edges[node_str](nodes[j])
|
||||
inter_nodes.append(weigsum)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# joint
|
||||
def forward_joint(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
# aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
|
||||
aggregation = sum(
|
||||
layer(nodes[j]) * w
|
||||
for layer, w in zip(self.edges[node_str], weights)
|
||||
)
|
||||
inter_nodes.append(aggregation)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# uniform random sampling per iteration, SETN
|
||||
def forward_urs(self, inputs):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
while True: # to avoid select zero for all ops
|
||||
sops, has_non_zero = [], False
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
candidates = self.edges[node_str]
|
||||
select_op = random.choice(candidates)
|
||||
sops.append(select_op)
|
||||
if not hasattr(select_op, "is_zero") or select_op.is_zero is False:
|
||||
has_non_zero = True
|
||||
if has_non_zero:
|
||||
break
|
||||
inter_nodes = []
|
||||
for j, select_op in enumerate(sops):
|
||||
inter_nodes.append(select_op(nodes[j]))
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# select the argmax
|
||||
def forward_select(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
inter_nodes.append(
|
||||
self.edges[node_str][weights.argmax().item()](nodes[j])
|
||||
)
|
||||
# inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# forward with a specific structure
|
||||
def forward_dynamic(self, inputs, structure):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
cur_op_node = structure.nodes[i - 1]
|
||||
inter_nodes = []
|
||||
for op_name, j in cur_op_node:
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_index = self.op_names.index(op_name)
|
||||
inter_nodes.append(self.edges[node_str][op_index](nodes[j]))
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
def __init__(self, space, C, stride, affine, track_running_stats):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in space:
|
||||
op = OPS[primitive](C, C, stride, affine, track_running_stats)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward_gdas(self, x, weights, index):
|
||||
return self._ops[index](x) * weights[index]
|
||||
|
||||
def forward_darts(self, x, weights):
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
|
||||
|
||||
class NASNetSearchCell(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetSearchCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.op_names = deepcopy(space)
|
||||
if reduction_prev:
|
||||
self.preprocess0 = OPS["skip_connect"](
|
||||
C_prev_prev, C, 2, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
self.preprocess0 = OPS["nor_conv_1x1"](
|
||||
C_prev_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
self.preprocess1 = OPS["nor_conv_1x1"](
|
||||
C_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
self.edges = nn.ModuleDict()
|
||||
for i in range(self._steps):
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(
|
||||
i, j
|
||||
) # indicate the edge from node-(j) to node-(i+2)
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(space, C, stride, affine, track_running_stats)
|
||||
self.edges[node_str] = op
|
||||
self.edge_keys = sorted(list(self.edges.keys()))
|
||||
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
|
||||
self.num_edges = len(self.edges)
|
||||
|
||||
@property
|
||||
def multiplier(self):
|
||||
return self._multiplier
|
||||
|
||||
def forward_gdas(self, s0, s1, weightss, indexs):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op = self.edges[node_str]
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
index = indexs[self.edge2index[node_str]].item()
|
||||
clist.append(op.forward_gdas(h, weights, index))
|
||||
states.append(sum(clist))
|
||||
|
||||
return torch.cat(states[-self._multiplier :], dim=1)
|
||||
|
||||
def forward_darts(self, s0, s1, weightss):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op = self.edges[node_str]
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
clist.append(op.forward_darts(h, weights))
|
||||
states.append(sum(clist))
|
||||
|
||||
return torch.cat(states[-self._multiplier :], dim=1)
|
@@ -0,0 +1,122 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
########################################################
|
||||
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
||||
########################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkDarts(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkDarts, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
return "arch-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self.op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def forward(self, inputs):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell(feature, alphas)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
@@ -0,0 +1,178 @@
|
||||
####################
|
||||
# DARTS, ICLR 2019 #
|
||||
####################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import List, Text, Dict
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkDARTS(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C: int,
|
||||
N: int,
|
||||
steps: int,
|
||||
multiplier: int,
|
||||
stem_multiplier: int,
|
||||
num_classes: int,
|
||||
search_space: List[Text],
|
||||
affine: bool,
|
||||
track_running_stats: bool,
|
||||
):
|
||||
super(NASNetworkDARTS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_normal_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.arch_reduce_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
|
||||
def get_weights(self) -> List[torch.nn.Parameter]:
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self) -> List[torch.nn.Parameter]:
|
||||
return [self.arch_normal_parameters, self.arch_reduce_parameters]
|
||||
|
||||
def show_alphas(self) -> Text:
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
|
||||
)
|
||||
B = "arch-reduce-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}\n{:}".format(A, B)
|
||||
|
||||
def get_message(self) -> Text:
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self) -> Text:
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self) -> Dict[Text, List]:
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
# (TODO) xuanyidong:
|
||||
# Here the selected two edges might come from the same input node.
|
||||
# And this case could be a problem that two edges will collapse into a single one
|
||||
# due to our assumption -- at most one edge from an input node during evaluation.
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
gene_reduce = _parse(
|
||||
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
"reduce": gene_reduce,
|
||||
"reduce_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1)
|
||||
reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
ww = reduce_w
|
||||
else:
|
||||
ww = normal_w
|
||||
s0, s1 = s1, cell.forward_darts(s0, s1, ww)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
114
AutoDL-Projects/xautodl/models/cell_searchs/search_model_enas.py
Normal file
114
AutoDL-Projects/xautodl/models/cell_searchs/search_model_enas.py
Normal file
@@ -0,0 +1,114 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##########################################################################
|
||||
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
|
||||
##########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
from .search_model_enas_utils import Controller
|
||||
|
||||
|
||||
class TinyNetworkENAS(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkENAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
# to maintain the sampled architecture
|
||||
self.sampled_arch = None
|
||||
|
||||
def update_arch(self, _arch):
|
||||
if _arch is None:
|
||||
self.sampled_arch = None
|
||||
elif isinstance(_arch, Structure):
|
||||
self.sampled_arch = _arch
|
||||
elif isinstance(_arch, (list, tuple)):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_index = _arch[self.edge2index[node_str]]
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
self.sampled_arch = Structure(genotypes)
|
||||
else:
|
||||
raise ValueError("invalid type of input architecture : {:}".format(_arch))
|
||||
return self.sampled_arch
|
||||
|
||||
def create_controller(self):
|
||||
return Controller(len(self.edge2index), len(self.op_names))
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_dynamic(feature, self.sampled_arch)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
@@ -0,0 +1,74 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##########################################################################
|
||||
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
|
||||
##########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions.categorical import Categorical
|
||||
|
||||
|
||||
class Controller(nn.Module):
|
||||
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
|
||||
def __init__(
|
||||
self,
|
||||
num_edge,
|
||||
num_ops,
|
||||
lstm_size=32,
|
||||
lstm_num_layers=2,
|
||||
tanh_constant=2.5,
|
||||
temperature=5.0,
|
||||
):
|
||||
super(Controller, self).__init__()
|
||||
# assign the attributes
|
||||
self.num_edge = num_edge
|
||||
self.num_ops = num_ops
|
||||
self.lstm_size = lstm_size
|
||||
self.lstm_N = lstm_num_layers
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
# create parameters
|
||||
self.register_parameter(
|
||||
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
|
||||
)
|
||||
self.w_lstm = nn.LSTM(
|
||||
input_size=self.lstm_size,
|
||||
hidden_size=self.lstm_size,
|
||||
num_layers=self.lstm_N,
|
||||
)
|
||||
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
|
||||
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
|
||||
|
||||
nn.init.uniform_(self.input_vars, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
|
||||
|
||||
def forward(self):
|
||||
|
||||
inputs, h0 = self.input_vars, None
|
||||
log_probs, entropys, sampled_arch = [], [], []
|
||||
for iedge in range(self.num_edge):
|
||||
outputs, h0 = self.w_lstm(inputs, h0)
|
||||
|
||||
logits = self.w_pred(outputs)
|
||||
logits = logits / self.temperature
|
||||
logits = self.tanh_constant * torch.tanh(logits)
|
||||
# distribution
|
||||
op_distribution = Categorical(logits=logits)
|
||||
op_index = op_distribution.sample()
|
||||
sampled_arch.append(op_index.item())
|
||||
|
||||
op_log_prob = op_distribution.log_prob(op_index)
|
||||
log_probs.append(op_log_prob.view(-1))
|
||||
op_entropy = op_distribution.entropy()
|
||||
entropys.append(op_entropy.view(-1))
|
||||
|
||||
# obtain the input embedding for the next step
|
||||
inputs = self.w_embd(op_index)
|
||||
return (
|
||||
torch.sum(torch.cat(log_probs)),
|
||||
torch.sum(torch.cat(entropys)),
|
||||
sampled_arch,
|
||||
)
|
142
AutoDL-Projects/xautodl/models/cell_searchs/search_model_gdas.py
Normal file
142
AutoDL-Projects/xautodl/models/cell_searchs/search_model_gdas.py
Normal file
@@ -0,0 +1,142 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkGDAS(nn.Module):
|
||||
|
||||
# def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkGDAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.tau = 10
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
return "arch-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self.op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def forward(self, inputs):
|
||||
while True:
|
||||
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
|
||||
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_gdas(feature, hardwts, index)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
@@ -0,0 +1,200 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
from ..cell_operations import RAW_OP_CLASSES
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkGDAS_FRC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C,
|
||||
N,
|
||||
steps,
|
||||
multiplier,
|
||||
stem_multiplier,
|
||||
num_classes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetworkGDAS_FRC, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = RAW_OP_CLASSES["gdas_reduction"](
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
reduction
|
||||
or num_edge == cell.num_edges
|
||||
and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = (
|
||||
C_prev,
|
||||
cell.multiplier * C_curr,
|
||||
reduction,
|
||||
)
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.tau = 10
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}".format(A)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
def get_gumbel_prob(xins):
|
||||
while True:
|
||||
gumbels = -torch.empty_like(xins).exponential_().log()
|
||||
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
return hardwts, index
|
||||
|
||||
hardwts, index = get_gumbel_prob(self.arch_parameters)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
s0, s1 = s1, cell(s0, s1)
|
||||
else:
|
||||
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
@@ -0,0 +1,197 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkGDAS(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C,
|
||||
N,
|
||||
steps,
|
||||
multiplier,
|
||||
stem_multiplier,
|
||||
num_classes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetworkGDAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_normal_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.arch_reduce_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.tau = 10
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_normal_parameters, self.arch_reduce_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
|
||||
)
|
||||
B = "arch-reduce-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}\n{:}".format(A, B)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
gene_reduce = _parse(
|
||||
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
"reduce": gene_reduce,
|
||||
"reduce_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
def get_gumbel_prob(xins):
|
||||
while True:
|
||||
gumbels = -torch.empty_like(xins).exponential_().log()
|
||||
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
return hardwts, index
|
||||
|
||||
normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters)
|
||||
reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
hardwts, index = reduce_hardwts, reduce_index
|
||||
else:
|
||||
hardwts, index = normal_hardwts, normal_index
|
||||
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
@@ -0,0 +1,102 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##############################################################################
|
||||
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
|
||||
##############################################################################
|
||||
import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkRANDOM(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkRANDOM, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_cache = None
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def random_genotype(self, set_cache):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_name = random.choice(self.op_names)
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
arch = Structure(genotypes)
|
||||
if set_cache:
|
||||
self.arch_cache = arch
|
||||
return arch
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_dynamic(feature, self.arch_cache)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
return out, logits
|
178
AutoDL-Projects/xautodl/models/cell_searchs/search_model_setn.py
Normal file
178
AutoDL-Projects/xautodl/models/cell_searchs/search_model_setn.py
Normal file
@@ -0,0 +1,178 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
######################################################################################
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
||||
######################################################################################
|
||||
import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkSETN(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkSETN, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.mode = "urs"
|
||||
self.dynamic_cell = None
|
||||
|
||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||
assert mode in ["urs", "joint", "select", "dynamic"]
|
||||
self.mode = mode
|
||||
if mode == "dynamic":
|
||||
self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else:
|
||||
self.dynamic_cell = None
|
||||
|
||||
def get_cal_mode(self):
|
||||
return self.mode
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self.op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if use_random:
|
||||
op_name = random.choice(self.op_names)
|
||||
else:
|
||||
weights = alphas_cpu[self.edge2index[node_str]]
|
||||
op_index = torch.multinomial(weights, 1).item()
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def get_log_prob(self, arch):
|
||||
with torch.no_grad():
|
||||
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
|
||||
select_logits = []
|
||||
for i, node_info in enumerate(arch.nodes):
|
||||
for op, xin in node_info:
|
||||
node_str = "{:}<-{:}".format(i + 1, xin)
|
||||
op_index = self.op_names.index(op)
|
||||
select_logits.append(logits[self.edge2index[node_str], op_index])
|
||||
return sum(select_logits).item()
|
||||
|
||||
def return_topK(self, K):
|
||||
archs = Structure.gen_all(self.op_names, self.max_nodes, False)
|
||||
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
|
||||
if K < 0 or K >= len(archs):
|
||||
K = len(archs)
|
||||
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
|
||||
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
|
||||
return return_pairs
|
||||
|
||||
def forward(self, inputs):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
with torch.no_grad():
|
||||
alphas_cpu = alphas.detach().cpu()
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
if self.mode == "urs":
|
||||
feature = cell.forward_urs(feature)
|
||||
elif self.mode == "select":
|
||||
feature = cell.forward_select(feature, alphas_cpu)
|
||||
elif self.mode == "joint":
|
||||
feature = cell.forward_joint(feature, alphas)
|
||||
elif self.mode == "dynamic":
|
||||
feature = cell.forward_dynamic(feature, self.dynamic_cell)
|
||||
else:
|
||||
raise ValueError("invalid mode={:}".format(self.mode))
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
@@ -0,0 +1,205 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
######################################################################################
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
||||
######################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import List, Text, Dict
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkSETN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C: int,
|
||||
N: int,
|
||||
steps: int,
|
||||
multiplier: int,
|
||||
stem_multiplier: int,
|
||||
num_classes: int,
|
||||
search_space: List[Text],
|
||||
affine: bool,
|
||||
track_running_stats: bool,
|
||||
):
|
||||
super(NASNetworkSETN, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_normal_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.arch_reduce_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.mode = "urs"
|
||||
self.dynamic_cell = None
|
||||
|
||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||
assert mode in ["urs", "joint", "select", "dynamic"]
|
||||
self.mode = mode
|
||||
if mode == "dynamic":
|
||||
self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else:
|
||||
self.dynamic_cell = None
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_normal_parameters, self.arch_reduce_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
|
||||
)
|
||||
B = "arch-reduce-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}\n{:}".format(A, B)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if use_random:
|
||||
op_name = random.choice(self.op_names)
|
||||
else:
|
||||
weights = alphas_cpu[self.edge2index[node_str]]
|
||||
op_index = torch.multinomial(weights, 1).item()
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
gene_reduce = _parse(
|
||||
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
"reduce": gene_reduce,
|
||||
"reduce_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1)
|
||||
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
# [TODO]
|
||||
raise NotImplementedError
|
||||
if cell.reduction:
|
||||
hardwts, index = reduce_hardwts, reduce_index
|
||||
else:
|
||||
hardwts, index = normal_hardwts, normal_index
|
||||
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
Reference in New Issue
Block a user