add autodl
This commit is contained in:
5
AutoDL-Projects/xautodl/models/cell_infers/__init__.py
Normal file
5
AutoDL-Projects/xautodl/models/cell_infers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from .tiny_network import TinyNetwork
|
||||
from .nasnet_cifar import NASNetonCIFAR
|
155
AutoDL-Projects/xautodl/models/cell_infers/cells.py
Normal file
155
AutoDL-Projects/xautodl/models/cell_infers/cells.py
Normal file
@@ -0,0 +1,155 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
from xautodl.models.cell_operations import OPS
|
||||
|
||||
|
||||
# Cell for NAS-Bench-201
|
||||
class InferCell(nn.Module):
|
||||
def __init__(
|
||||
self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True
|
||||
):
|
||||
super(InferCell, self).__init__()
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.node_IN = []
|
||||
self.node_IX = []
|
||||
self.genotype = deepcopy(genotype)
|
||||
for i in range(1, len(genotype)):
|
||||
node_info = genotype[i - 1]
|
||||
cur_index = []
|
||||
cur_innod = []
|
||||
for (op_name, op_in) in node_info:
|
||||
if op_in == 0:
|
||||
layer = OPS[op_name](
|
||||
C_in, C_out, stride, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats)
|
||||
cur_index.append(len(self.layers))
|
||||
cur_innod.append(op_in)
|
||||
self.layers.append(layer)
|
||||
self.node_IX.append(cur_index)
|
||||
self.node_IN.append(cur_innod)
|
||||
self.nodes = len(genotype)
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
|
||||
def extra_repr(self):
|
||||
string = "info :: nodes={nodes}, inC={in_dim}, outC={out_dim}".format(
|
||||
**self.__dict__
|
||||
)
|
||||
laystr = []
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)):
|
||||
y = [
|
||||
"I{:}-L{:}".format(_ii, _il)
|
||||
for _il, _ii in zip(node_layers, node_innods)
|
||||
]
|
||||
x = "{:}<-({:})".format(i + 1, ",".join(y))
|
||||
laystr.append(x)
|
||||
return (
|
||||
string
|
||||
+ ", [{:}]".format(" | ".join(laystr))
|
||||
+ ", {:}".format(self.genotype.tostr())
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
nodes = [inputs]
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)):
|
||||
node_feature = sum(
|
||||
self.layers[_il](nodes[_ii])
|
||||
for _il, _ii in zip(node_layers, node_innods)
|
||||
)
|
||||
nodes.append(node_feature)
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
class NASNetInferCell(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
genotype,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetInferCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
if reduction_prev:
|
||||
self.preprocess0 = OPS["skip_connect"](
|
||||
C_prev_prev, C, 2, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
self.preprocess0 = OPS["nor_conv_1x1"](
|
||||
C_prev_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
self.preprocess1 = OPS["nor_conv_1x1"](
|
||||
C_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
|
||||
if not reduction:
|
||||
nodes, concats = genotype["normal"], genotype["normal_concat"]
|
||||
else:
|
||||
nodes, concats = genotype["reduce"], genotype["reduce_concat"]
|
||||
self._multiplier = len(concats)
|
||||
self._concats = concats
|
||||
self._steps = len(nodes)
|
||||
self._nodes = nodes
|
||||
self.edges = nn.ModuleDict()
|
||||
for i, node in enumerate(nodes):
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
node_str = "{:}<-{:}".format(i + 2, j)
|
||||
self.edges[node_str] = OPS[name](
|
||||
C, C, stride, affine, track_running_stats
|
||||
)
|
||||
|
||||
# [TODO] to support drop_prob in this function..
|
||||
def forward(self, s0, s1, unused_drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i, node in enumerate(self._nodes):
|
||||
clist = []
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
node_str = "{:}<-{:}".format(i + 2, j)
|
||||
op = self.edges[node_str]
|
||||
clist.append(op(states[j]))
|
||||
states.append(sum(clist))
|
||||
return torch.cat([states[x] for x in self._concats], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(
|
||||
5, stride=3, padding=0, count_include_pad=False
|
||||
), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
118
AutoDL-Projects/xautodl/models/cell_infers/nasnet_cifar.py
Normal file
118
AutoDL-Projects/xautodl/models/cell_infers/nasnet_cifar.py
Normal file
@@ -0,0 +1,118 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetonCIFAR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C,
|
||||
N,
|
||||
stem_multiplier,
|
||||
num_classes,
|
||||
genotype,
|
||||
auxiliary,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(NASNetonCIFAR, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
self.auxiliary_index = None
|
||||
self.auxiliary_head = None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = InferCell(
|
||||
genotype,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = (
|
||||
C_prev,
|
||||
cell._multiplier * C_curr,
|
||||
reduction,
|
||||
)
|
||||
if reduction and C_curr == C * 4 and auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
|
||||
self.auxiliary_index = index
|
||||
self._Layer = len(self.cells)
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.drop_path_prob = -1
|
||||
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None:
|
||||
return []
|
||||
else:
|
||||
return list(self.auxiliary_head.parameters())
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
stem_feature, logits_aux = self.stem(inputs), None
|
||||
cell_results = [stem_feature, stem_feature]
|
||||
for i, cell in enumerate(self.cells):
|
||||
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
|
||||
cell_results.append(cell_feature)
|
||||
if (
|
||||
self.auxiliary_index is not None
|
||||
and i == self.auxiliary_index
|
||||
and self.training
|
||||
):
|
||||
logits_aux = self.auxiliary_head(cell_results[-1])
|
||||
out = self.lastact(cell_results[-1])
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
if logits_aux is None:
|
||||
return out, logits
|
||||
else:
|
||||
return out, [logits, logits_aux]
|
63
AutoDL-Projects/xautodl/models/cell_infers/tiny_network.py
Normal file
63
AutoDL-Projects/xautodl/models/cell_infers/tiny_network.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .cells import InferCell
|
||||
|
||||
|
||||
# The macro structure for architectures in NAS-Bench-201
|
||||
class TinyNetwork(nn.Module):
|
||||
def __init__(self, C, N, genotype, num_classes):
|
||||
super(TinyNetwork, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev = C
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
|
||||
else:
|
||||
cell = InferCell(genotype, C_prev, C_curr, 1)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self._Layer = len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
Reference in New Issue
Block a user