Use black for lib/models

This commit is contained in:
D-X-Y
2021-05-12 16:28:05 +08:00
parent d51e5fdc7f
commit f1c47af5fa
42 changed files with 7552 additions and 4688 deletions

View File

@@ -0,0 +1,179 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from log_utils import time_string
from log_utils import AverageMeter, convert_secs2time
from utils import split_str2indexes
from procedures.advanced_main import basic_train_fn, basic_eval_fn
from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from datasets.synthetic_core import get_synthetic_env
from models.xcore import get_model
from xlayers import super_core, trunc_normal_
from lfna_utils import lfna_setup, train_model, TimeData
# from lfna_models import HyperNet_VX as HyperNet
from lfna_models import HyperNet
def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
dynamic_env = env_info["dynamic_env"]
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
criterion = torch.nn.MSELoss()
logger.log("There are {:} weights.".format(model.get_w_container().numel()))
shape_container = model.get_w_container().to_shape_container()
hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim)
total_bar = env_info["total"] - 1
task_embeds = []
for i in range(total_bar):
task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim)))
for task_embed in task_embeds:
trunc_normal_(task_embed, std=0.02)
parameters = list(hypernet.parameters()) + task_embeds
optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.8),
int(args.epochs * 0.9),
],
gamma=0.1,
)
# LFNA meta-training
loss_meter = AverageMeter()
per_epoch_time, start_time = AverageMeter(), time.time()
for iepoch in range(args.epochs):
need_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
head_str = (
"[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs)
+ need_time
)
limit_bar = float(iepoch + 1) / args.epochs * total_bar
limit_bar = min(max(0, int(limit_bar)), total_bar)
losses = []
for ibatch in range(args.meta_batch):
cur_time = random.randint(0, limit_bar)
cur_task_embed = task_embeds[cur_time]
cur_container = hypernet(cur_task_embed)
cur_x = env_info["{:}-x".format(cur_time)]
cur_y = env_info["{:}-y".format(cur_time)]
cur_dataset = TimeData(cur_time, cur_x, cur_y)
preds = model.forward_with_container(cur_dataset.x, cur_container)
optimizer.zero_grad()
loss = criterion(preds, cur_dataset.y)
losses.append(loss)
final_loss = torch.stack(losses).mean()
final_loss.backward()
optimizer.step()
lr_scheduler.step()
loss_meter.update(final_loss.item())
if iepoch % 200 == 0:
logger.log(
head_str
+ "meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}, limit={:}".format(
loss_meter.avg,
loss_meter.val,
min(lr_scheduler.get_last_lr()),
len(losses),
limit_bar,
)
)
save_checkpoint(
{
"hypernet": hypernet.state_dict(),
"task_embed": task_embed,
"lr_scheduler": lr_scheduler.state_dict(),
"iepoch": iepoch,
},
logger.path("model"),
logger,
)
loss_meter.reset()
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
print(model)
print(hypernet)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the past.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/lfna-tall-hpnet",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
#####
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--meta_batch",
type=int,
default=64,
help="The batch size for the meta-model",
)
parser.add_argument(
"--epochs",
type=int,
default=2000,
help="The total number of epochs.",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.task_dim = args.hidden_dim
args.save_dir = "{:}-{:}-d{:}".format(
args.save_dir, args.env_version, args.hidden_dim
)
main(args)

View File

@@ -41,10 +41,14 @@ def main(args):
shape_container = model.get_w_container().to_shape_container() shape_container = model.get_w_container().to_shape_container()
hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim)
# task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim))
task_embed = torch.nn.Parameter(torch.Tensor(1, args.task_dim)) total_bar = 10
trunc_normal_(task_embed, std=0.02) task_embeds = []
for i in range(total_bar):
task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim)))
for task_embed in task_embeds:
trunc_normal_(task_embed, std=0.02)
parameters = list(hypernet.parameters()) + [task_embed] parameters = list(hypernet.parameters()) + task_embeds
optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True) optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, optimizer,
@@ -56,7 +60,6 @@ def main(args):
) )
# total_bar = env_info["total"] - 1 # total_bar = env_info["total"] - 1
total_bar = 1
# LFNA meta-training # LFNA meta-training
loss_meter = AverageMeter() loss_meter = AverageMeter()
per_epoch_time, start_time = AverageMeter(), time.time() per_epoch_time, start_time = AverageMeter(), time.time()
@@ -74,7 +77,7 @@ def main(args):
# for ibatch in range(args.meta_batch): # for ibatch in range(args.meta_batch):
for cur_time in range(total_bar): for cur_time in range(total_bar):
# cur_time = random.randint(0, total_bar) # cur_time = random.randint(0, total_bar)
cur_task_embed = task_embed cur_task_embed = task_embeds[cur_time]
cur_container = hypernet(cur_task_embed) cur_container = hypernet(cur_task_embed)
cur_x = env_info["{:}-x".format(cur_time)] cur_x = env_info["{:}-x".format(cur_time)]
cur_y = env_info["{:}-y".format(cur_time)] cur_y = env_info["{:}-y".format(cur_time)]
@@ -98,7 +101,7 @@ def main(args):
+ "meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( + "meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format(
loss_meter.avg, loss_meter.avg,
loss_meter.val, loss_meter.val,
min(lr_scheduler.get_lr()), min(lr_scheduler.get_last_lr()),
len(losses), len(losses),
) )
) )

View File

@@ -28,6 +28,15 @@ class HyperNet(super_core.SuperModule):
) )
trunc_normal_(self._super_layer_embed, std=0.02) trunc_normal_(self._super_layer_embed, std=0.02)
model_kwargs = dict(
input_dim=layer_embeding + task_embedding,
output_dim=max(self._numel_per_layer),
hidden_dims=[layer_embeding * 4] * 4,
act_cls="gelu",
norm_cls="layer_norm_1d",
)
self._generator = get_model(dict(model_type="norm_mlp"), **model_kwargs)
"""
model_kwargs = dict( model_kwargs = dict(
input_dim=layer_embeding + task_embedding, input_dim=layer_embeding + task_embedding,
output_dim=max(self._numel_per_layer), output_dim=max(self._numel_per_layer),
@@ -36,6 +45,7 @@ class HyperNet(super_core.SuperModule):
norm_cls="identity", norm_cls="identity",
) )
self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs)
"""
self._return_container = return_container self._return_container = return_container
print("generator: {:}".format(self._generator)) print("generator: {:}".format(self._generator))

View File

@@ -11,111 +11,145 @@ from models.cell_operations import OPS
# Cell for NAS-Bench-201 # Cell for NAS-Bench-201
class InferCell(nn.Module): class InferCell(nn.Module):
def __init__(
self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True
):
super(InferCell, self).__init__()
def __init__(self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True): self.layers = nn.ModuleList()
super(InferCell, self).__init__() self.node_IN = []
self.node_IX = []
self.genotype = deepcopy(genotype)
for i in range(1, len(genotype)):
node_info = genotype[i - 1]
cur_index = []
cur_innod = []
for (op_name, op_in) in node_info:
if op_in == 0:
layer = OPS[op_name](
C_in, C_out, stride, affine, track_running_stats
)
else:
layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats)
cur_index.append(len(self.layers))
cur_innod.append(op_in)
self.layers.append(layer)
self.node_IX.append(cur_index)
self.node_IN.append(cur_innod)
self.nodes = len(genotype)
self.in_dim = C_in
self.out_dim = C_out
self.layers = nn.ModuleList() def extra_repr(self):
self.node_IN = [] string = "info :: nodes={nodes}, inC={in_dim}, outC={out_dim}".format(
self.node_IX = [] **self.__dict__
self.genotype = deepcopy(genotype) )
for i in range(1, len(genotype)): laystr = []
node_info = genotype[i-1] for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)):
cur_index = [] y = [
cur_innod = [] "I{:}-L{:}".format(_ii, _il)
for (op_name, op_in) in node_info: for _il, _ii in zip(node_layers, node_innods)
if op_in == 0: ]
layer = OPS[op_name](C_in , C_out, stride, affine, track_running_stats) x = "{:}<-({:})".format(i + 1, ",".join(y))
else: laystr.append(x)
layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats) return (
cur_index.append( len(self.layers) ) string
cur_innod.append( op_in ) + ", [{:}]".format(" | ".join(laystr))
self.layers.append( layer ) + ", {:}".format(self.genotype.tostr())
self.node_IX.append( cur_index ) )
self.node_IN.append( cur_innod )
self.nodes = len(genotype)
self.in_dim = C_in
self.out_dim = C_out
def extra_repr(self):
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
laystr = []
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
x = '{:}<-({:})'.format(i+1, ','.join(y))
laystr.append( x )
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
def forward(self, inputs):
nodes = [inputs]
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
nodes.append( node_feature )
return nodes[-1]
def forward(self, inputs):
nodes = [inputs]
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)):
node_feature = sum(
self.layers[_il](nodes[_ii])
for _il, _ii in zip(node_layers, node_innods)
)
nodes.append(node_feature)
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetInferCell(nn.Module): class NASNetInferCell(nn.Module):
def __init__(
self,
genotype,
C_prev_prev,
C_prev,
C,
reduction,
reduction_prev,
affine,
track_running_stats,
):
super(NASNetInferCell, self).__init__()
self.reduction = reduction
if reduction_prev:
self.preprocess0 = OPS["skip_connect"](
C_prev_prev, C, 2, affine, track_running_stats
)
else:
self.preprocess0 = OPS["nor_conv_1x1"](
C_prev_prev, C, 1, affine, track_running_stats
)
self.preprocess1 = OPS["nor_conv_1x1"](
C_prev, C, 1, affine, track_running_stats
)
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): if not reduction:
super(NASNetInferCell, self).__init__() nodes, concats = genotype["normal"], genotype["normal_concat"]
self.reduction = reduction else:
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) nodes, concats = genotype["reduce"], genotype["reduce_concat"]
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) self._multiplier = len(concats)
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) self._concats = concats
self._steps = len(nodes)
self._nodes = nodes
self.edges = nn.ModuleDict()
for i, node in enumerate(nodes):
for in_node in node:
name, j = in_node[0], in_node[1]
stride = 2 if reduction and j < 2 else 1
node_str = "{:}<-{:}".format(i + 2, j)
self.edges[node_str] = OPS[name](
C, C, stride, affine, track_running_stats
)
if not reduction: # [TODO] to support drop_prob in this function..
nodes, concats = genotype['normal'], genotype['normal_concat'] def forward(self, s0, s1, unused_drop_prob):
else: s0 = self.preprocess0(s0)
nodes, concats = genotype['reduce'], genotype['reduce_concat'] s1 = self.preprocess1(s1)
self._multiplier = len(concats)
self._concats = concats
self._steps = len(nodes)
self._nodes = nodes
self.edges = nn.ModuleDict()
for i, node in enumerate(nodes):
for in_node in node:
name, j = in_node[0], in_node[1]
stride = 2 if reduction and j < 2 else 1
node_str = '{:}<-{:}'.format(i+2, j)
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
# [TODO] to support drop_prob in this function.. states = [s0, s1]
def forward(self, s0, s1, unused_drop_prob): for i, node in enumerate(self._nodes):
s0 = self.preprocess0(s0) clist = []
s1 = self.preprocess1(s1) for in_node in node:
name, j = in_node[0], in_node[1]
states = [s0, s1] node_str = "{:}<-{:}".format(i + 2, j)
for i, node in enumerate(self._nodes): op = self.edges[node_str]
clist = [] clist.append(op(states[j]))
for in_node in node: states.append(sum(clist))
name, j = in_node[0], in_node[1] return torch.cat([states[x] for x in self._concats], dim=1)
node_str = '{:}<-{:}'.format(i+2, j)
op = self.edges[ node_str ]
clist.append( op(states[j]) )
states.append( sum(clist) )
return torch.cat([states[x] for x in self._concats], dim=1)
class AuxiliaryHeadCIFAR(nn.Module): class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(
5, stride=3, padding=0, count_include_pad=False
), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True),
)
self.classifier = nn.Linear(768, num_classes)
def __init__(self, C, num_classes): def forward(self, x):
"""assuming input size 8x8""" x = self.features(x)
super(AuxiliaryHeadCIFAR, self).__init__() x = self.classifier(x.view(x.size(0), -1))
self.features = nn.Sequential( return x
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x

View File

@@ -9,63 +9,109 @@ from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
# The macro structure is based on NASNet # The macro structure is based on NASNet
class NASNetonCIFAR(nn.Module): class NASNetonCIFAR(nn.Module):
def __init__(
self,
C,
N,
stem_multiplier,
num_classes,
genotype,
auxiliary,
affine=True,
track_running_stats=True,
):
super(NASNetonCIFAR, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
def __init__(self, C, N, stem_multiplier, num_classes, genotype, auxiliary, affine=True, track_running_stats=True): # config for each layer
super(NASNetonCIFAR, self).__init__() layer_channels = (
self._C = C [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
self._layerN = N )
self.stem = nn.Sequential( layer_reductions = (
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
nn.BatchNorm2d(C*stem_multiplier)) )
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False C_prev_prev, C_prev, C_curr, reduction_prev = (
self.auxiliary_index = None C * stem_multiplier,
self.auxiliary_head = None C * stem_multiplier,
self.cells = nn.ModuleList() C,
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): False,
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) )
self.cells.append( cell ) self.auxiliary_index = None
C_prev_prev, C_prev, reduction_prev = C_prev, cell._multiplier*C_curr, reduction self.auxiliary_head = None
if reduction and C_curr == C*4 and auxiliary: self.cells = nn.ModuleList()
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) for index, (C_curr, reduction) in enumerate(
self.auxiliary_index = index zip(layer_channels, layer_reductions)
self._Layer = len(self.cells) ):
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) cell = InferCell(
self.global_pooling = nn.AdaptiveAvgPool2d(1) genotype,
self.classifier = nn.Linear(C_prev, num_classes) C_prev_prev,
self.drop_path_prob = -1 C_prev,
C_curr,
reduction,
reduction_prev,
affine,
track_running_stats,
)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = (
C_prev,
cell._multiplier * C_curr,
reduction,
)
if reduction and C_curr == C * 4 and auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
self.auxiliary_index = index
self._Layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.drop_path_prob = -1
def update_drop_path(self, drop_path_prob): def update_drop_path(self, drop_path_prob):
self.drop_path_prob = drop_path_prob self.drop_path_prob = drop_path_prob
def auxiliary_param(self): def auxiliary_param(self):
if self.auxiliary_head is None: return [] if self.auxiliary_head is None:
else: return list( self.auxiliary_head.parameters() ) return []
else:
return list(self.auxiliary_head.parameters())
def get_message(self): def get_message(self):
string = self.extra_repr() string = self.extra_repr()
for i, cell in enumerate(self.cells): for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) string += "\n {:02d}/{:02d} :: {:}".format(
return string i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self): def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) return "{name}(C={_C}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def forward(self, inputs): def forward(self, inputs):
stem_feature, logits_aux = self.stem(inputs), None stem_feature, logits_aux = self.stem(inputs), None
cell_results = [stem_feature, stem_feature] cell_results = [stem_feature, stem_feature]
for i, cell in enumerate(self.cells): for i, cell in enumerate(self.cells):
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
cell_results.append( cell_feature ) cell_results.append(cell_feature)
if self.auxiliary_index is not None and i == self.auxiliary_index and self.training: if (
logits_aux = self.auxiliary_head( cell_results[-1] ) self.auxiliary_index is not None
out = self.lastact(cell_results[-1]) and i == self.auxiliary_index
out = self.global_pooling( out ) and self.training
out = out.view(out.size(0), -1) ):
logits = self.classifier(out) logits_aux = self.auxiliary_head(cell_results[-1])
if logits_aux is None: return out, logits out = self.lastact(cell_results[-1])
else: return out, [logits, logits_aux] out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
if logits_aux is None:
return out, logits
else:
return out, [logits, logits_aux]

View File

@@ -8,51 +8,56 @@ from .cells import InferCell
# The macro structure for architectures in NAS-Bench-201 # The macro structure for architectures in NAS-Bench-201
class TinyNetwork(nn.Module): class TinyNetwork(nn.Module):
def __init__(self, C, N, genotype, num_classes):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
def __init__(self, C, N, genotype, num_classes): self.stem = nn.Sequential(
super(TinyNetwork, self).__init__() nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
self._C = C )
self._layerN = N
self.stem = nn.Sequential( layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C C_prev = C
self.cells = nn.ModuleList() self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): for index, (C_curr, reduction) in enumerate(
if reduction: zip(layer_channels, layer_reductions)
cell = ResNetBasicblock(C_prev, C_curr, 2, True) ):
else: if reduction:
cell = InferCell(genotype, C_prev, C_curr, 1) cell = ResNetBasicblock(C_prev, C_curr, 2, True)
self.cells.append( cell ) else:
C_prev = cell.out_dim cell = InferCell(genotype, C_prev, C_curr, 1)
self._Layer= len(self.cells) self.cells.append(cell)
C_prev = cell.out_dim
self._Layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1) self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes) self.classifier = nn.Linear(C_prev, num_classes)
def get_message(self): def get_message(self):
string = self.extra_repr() string = self.extra_repr()
for i, cell in enumerate(self.cells): for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) string += "\n {:02d}/{:02d} :: {:}".format(
return string i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self): def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) return "{name}(C={_C}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def forward(self, inputs): def forward(self, inputs):
feature = self.stem(inputs) feature = self.stem(inputs)
for i, cell in enumerate(self.cells): for i, cell in enumerate(self.cells):
feature = cell(feature) feature = cell(feature)
out = self.lastact(feature) out = self.lastact(feature)
out = self.global_pooling( out ) out = self.global_pooling(out)
out = out.view(out.size(0), -1) out = out.view(out.size(0), -1)
logits = self.classifier(out) logits = self.classifier(out)
return out, logits return out, logits

View File

@@ -4,315 +4,550 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
__all__ = ['OPS', 'RAW_OP_CLASSES', 'ResNetBasicblock', 'SearchSpaceNames'] __all__ = ["OPS", "RAW_OP_CLASSES", "ResNetBasicblock", "SearchSpaceNames"]
OPS = { OPS = {
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride), "none": lambda C_in, C_out, stride, affine, track_running_stats: Zero(
'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats), C_in, C_out, stride
'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats), ),
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats), "avg_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING(
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), C_in, C_out, stride, "avg", affine, track_running_stats
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats), ),
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), "max_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING(
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats), C_in, C_out, stride, "max", affine, track_running_stats
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats), ),
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats), "nor_conv_7x7": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), C_in,
C_out,
(7, 7),
(stride, stride),
(3, 3),
(1, 1),
affine,
track_running_stats,
),
"nor_conv_3x3": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
C_in,
C_out,
(3, 3),
(stride, stride),
(1, 1),
(1, 1),
affine,
track_running_stats,
),
"nor_conv_1x1": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
C_in,
C_out,
(1, 1),
(stride, stride),
(0, 0),
(1, 1),
affine,
track_running_stats,
),
"dua_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(
C_in,
C_out,
(3, 3),
(stride, stride),
(1, 1),
(1, 1),
affine,
track_running_stats,
),
"dua_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(
C_in,
C_out,
(5, 5),
(stride, stride),
(2, 2),
(1, 1),
affine,
track_running_stats,
),
"dil_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: SepConv(
C_in,
C_out,
(3, 3),
(stride, stride),
(2, 2),
(2, 2),
affine,
track_running_stats,
),
"dil_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: SepConv(
C_in,
C_out,
(5, 5),
(stride, stride),
(4, 4),
(2, 2),
affine,
track_running_stats,
),
"skip_connect": lambda C_in, C_out, stride, affine, track_running_stats: Identity()
if stride == 1 and C_in == C_out
else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
} }
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] CONNECT_NAS_BENCHMARK = ["none", "skip_connect", "nor_conv_3x3"]
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] NAS_BENCH_201 = ["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"]
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3'] DARTS_SPACE = [
"none",
"skip_connect",
"dua_sepc_3x3",
"dua_sepc_5x5",
"dil_sepc_3x3",
"dil_sepc_5x5",
"avg_pool_3x3",
"max_pool_3x3",
]
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK, SearchSpaceNames = {
'nats-bench' : NAS_BENCH_201, "connect-nas": CONNECT_NAS_BENCHMARK,
'nas-bench-201': NAS_BENCH_201, "nats-bench": NAS_BENCH_201,
'darts' : DARTS_SPACE} "nas-bench-201": NAS_BENCH_201,
"darts": DARTS_SPACE,
}
class ReLUConvBN(nn.Module): class ReLUConvBN(nn.Module):
def __init__(
self,
C_in,
C_out,
kernel_size,
stride,
padding,
dilation,
affine,
track_running_stats=True,
):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in,
C_out,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=not affine,
),
nn.BatchNorm2d(
C_out, affine=affine, track_running_stats=track_running_stats
),
)
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): def forward(self, x):
super(ReLUConvBN, self).__init__() return self.op(x)
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module): class SepConv(nn.Module):
def __init__(
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): self,
super(SepConv, self).__init__() C_in,
self.op = nn.Sequential( C_out,
nn.ReLU(inplace=False), kernel_size,
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), stride,
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine), padding,
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats), dilation,
) affine,
track_running_stats=True,
):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in,
C_in,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=C_in,
bias=False,
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine),
nn.BatchNorm2d(
C_out, affine=affine, track_running_stats=track_running_stats
),
)
def forward(self, x): def forward(self, x):
return self.op(x) return self.op(x)
class DualSepConv(nn.Module): class DualSepConv(nn.Module):
def __init__(
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): self,
super(DualSepConv, self).__init__() C_in,
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats) C_out,
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats) kernel_size,
stride,
padding,
dilation,
affine,
track_running_stats=True,
):
super(DualSepConv, self).__init__()
self.op_a = SepConv(
C_in,
C_in,
kernel_size,
stride,
padding,
dilation,
affine,
track_running_stats,
)
self.op_b = SepConv(
C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats
)
def forward(self, x): def forward(self, x):
x = self.op_a(x) x = self.op_a(x)
x = self.op_b(x) x = self.op_b(x)
return x return x
class ResNetBasicblock(nn.Module): class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, affine=True, track_running_stats=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_a = ReLUConvBN(
inplanes, planes, 3, stride, 1, 1, affine, track_running_stats
)
self.conv_b = ReLUConvBN(
planes, planes, 3, 1, 1, 1, affine, track_running_stats
)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(
inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False
),
)
elif inplanes != planes:
self.downsample = ReLUConvBN(
inplanes, planes, 1, 1, 0, 1, affine, track_running_stats
)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def __init__(self, inplanes, planes, stride, affine=True, track_running_stats=True): def extra_repr(self):
super(ResNetBasicblock, self).__init__() string = "{name}(inC={in_dim}, outC={out_dim}, stride={stride})".format(
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) name=self.__class__.__name__, **self.__dict__
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine, track_running_stats) )
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine, track_running_stats) return string
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine, track_running_stats)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def extra_repr(self): def forward(self, inputs):
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
return string
def forward(self, inputs): basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
basicblock = self.conv_a(inputs) if self.downsample is not None:
basicblock = self.conv_b(basicblock) residual = self.downsample(inputs)
else:
if self.downsample is not None: residual = inputs
residual = self.downsample(inputs) return residual + basicblock
else:
residual = inputs
return residual + basicblock
class POOLING(nn.Module): class POOLING(nn.Module):
def __init__(
self, C_in, C_out, stride, mode, affine=True, track_running_stats=True
):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(
C_in, C_out, 1, 1, 0, 1, affine, track_running_stats
)
if mode == "avg":
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == "max":
self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else:
raise ValueError("Invalid mode={:} in POOLING".format(mode))
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True): def forward(self, inputs):
super(POOLING, self).__init__() if self.preprocess:
if C_in == C_out: x = self.preprocess(inputs)
self.preprocess = None else:
else: x = inputs
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1, affine, track_running_stats) return self.op(x)
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
def forward(self, inputs):
if self.preprocess: x = self.preprocess(inputs)
else : x = inputs
return self.op(x)
class Identity(nn.Module): class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def __init__(self): def forward(self, x):
super(Identity, self).__init__() return x
def forward(self, x):
return x
class Zero(nn.Module): class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def __init__(self, C_in, C_out, stride): def forward(self, x):
super(Zero, self).__init__() if self.C_in == self.C_out:
self.C_in = C_in if self.stride == 1:
self.C_out = C_out return x.mul(0.0)
self.stride = stride else:
self.is_zero = True return x[:, :, :: self.stride, :: self.stride].mul(0.0)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
def forward(self, x): def extra_repr(self):
if self.C_in == self.C_out: return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__)
if self.stride == 1: return x.mul(0.)
else : return x[:,:,::self.stride,::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
class FactorizedReduce(nn.Module): class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
# assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(
nn.Conv2d(
C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine
)
)
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 1:
self.conv = nn.Conv2d(
C_in, C_out, 1, stride=stride, padding=0, bias=not affine
)
else:
raise ValueError("Invalid stride : {:}".format(stride))
self.bn = nn.BatchNorm2d(
C_out, affine=affine, track_running_stats=track_running_stats
)
def __init__(self, C_in, C_out, stride, affine, track_running_stats): def forward(self, x):
super(FactorizedReduce, self).__init__() if self.stride == 2:
self.stride = stride x = self.relu(x)
self.C_in = C_in y = self.pad(x)
self.C_out = C_out out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
self.relu = nn.ReLU(inplace=False) else:
if stride == 2: out = self.conv(x)
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out) out = self.bn(out)
C_outs = [C_out // 2, C_out - C_out // 2] return out
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 1:
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=not affine)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
def forward(self, x): def extra_repr(self):
if self.stride == 2: return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__)
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
else:
out = self.conv(x)
out = self.bn(out)
return out
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019 # Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
class PartAwareOp(nn.Module): class PartAwareOp(nn.Module):
def __init__(self, C_in, C_out, stride, part=4):
def __init__(self, C_in, C_out, stride, part=4): super().__init__()
super().__init__() self.part = 4
self.part = 4 self.hidden = C_in // 3
self.hidden = C_in // 3 self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.local_conv_list = nn.ModuleList()
self.local_conv_list = nn.ModuleList() for i in range(self.part):
for i in range(self.part): self.local_conv_list.append(
self.local_conv_list.append( nn.Sequential(
nn.Sequential(nn.ReLU(), nn.Conv2d(C_in, self.hidden, 1), nn.BatchNorm2d(self.hidden, affine=True)) nn.ReLU(),
nn.Conv2d(C_in, self.hidden, 1),
nn.BatchNorm2d(self.hidden, affine=True),
)
) )
self.W_K = nn.Linear(self.hidden, self.hidden) self.W_K = nn.Linear(self.hidden, self.hidden)
self.W_Q = nn.Linear(self.hidden, self.hidden) self.W_Q = nn.Linear(self.hidden, self.hidden)
if stride == 2 : self.last = FactorizedReduce(C_in + self.hidden, C_out, 2) if stride == 2:
elif stride == 1: self.last = FactorizedReduce(C_in + self.hidden, C_out, 1) self.last = FactorizedReduce(C_in + self.hidden, C_out, 2)
else: raise ValueError('Invalid Stride : {:}'.format(stride)) elif stride == 1:
self.last = FactorizedReduce(C_in + self.hidden, C_out, 1)
else:
raise ValueError("Invalid Stride : {:}".format(stride))
def forward(self, x): def forward(self, x):
batch, C, H, W = x.size() batch, C, H, W = x.size()
assert H >= self.part, 'input size too small : {:} vs {:}'.format(x.shape, self.part) assert H >= self.part, "input size too small : {:} vs {:}".format(
IHs = [0] x.shape, self.part
for i in range(self.part): IHs.append( min(H, int((i+1)*(float(H)/self.part))) ) )
local_feat_list = [] IHs = [0]
for i in range(self.part): for i in range(self.part):
feature = x[:, :, IHs[i]:IHs[i+1], :] IHs.append(min(H, int((i + 1) * (float(H) / self.part))))
xfeax = self.avg_pool(feature) local_feat_list = []
xfea = self.local_conv_list[i]( xfeax ) for i in range(self.part):
local_feat_list.append( xfea ) feature = x[:, :, IHs[i] : IHs[i + 1], :]
part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part) xfeax = self.avg_pool(feature)
part_feature = part_feature.transpose(1,2).contiguous() xfea = self.local_conv_list[i](xfeax)
part_K = self.W_K(part_feature) local_feat_list.append(xfea)
part_Q = self.W_Q(part_feature).transpose(1,2).contiguous() part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part)
weight_att = torch.bmm(part_K, part_Q) part_feature = part_feature.transpose(1, 2).contiguous()
attention = torch.softmax(weight_att, dim=2) part_K = self.W_K(part_feature)
aggreateF = torch.bmm(attention, part_feature).transpose(1,2).contiguous() part_Q = self.W_Q(part_feature).transpose(1, 2).contiguous()
features = [] weight_att = torch.bmm(part_K, part_Q)
for i in range(self.part): attention = torch.softmax(weight_att, dim=2)
feature = aggreateF[:, :, i:i+1].expand(batch, self.hidden, IHs[i+1]-IHs[i]) aggreateF = torch.bmm(attention, part_feature).transpose(1, 2).contiguous()
feature = feature.view(batch, self.hidden, IHs[i+1]-IHs[i], 1) features = []
features.append( feature ) for i in range(self.part):
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W) feature = aggreateF[:, :, i : i + 1].expand(
final_fea = torch.cat((x,features), dim=1) batch, self.hidden, IHs[i + 1] - IHs[i]
outputs = self.last( final_fea ) )
return outputs feature = feature.view(batch, self.hidden, IHs[i + 1] - IHs[i], 1)
features.append(feature)
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W)
final_fea = torch.cat((x, features), dim=1)
outputs = self.last(final_fea)
return outputs
def drop_path(x, drop_prob): def drop_path(x, drop_prob):
if drop_prob > 0.: if drop_prob > 0.0:
keep_prob = 1. - drop_prob keep_prob = 1.0 - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1) mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob) mask = mask.bernoulli_(keep_prob)
x = torch.div(x, keep_prob) x = torch.div(x, keep_prob)
x.mul_(mask) x.mul_(mask)
return x return x
# Searching for A Robust Neural Architecture in Four GPU Hours # Searching for A Robust Neural Architecture in Four GPU Hours
class GDAS_Reduction_Cell(nn.Module): class GDAS_Reduction_Cell(nn.Module):
def __init__(
self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats
):
super(GDAS_Reduction_Cell, self).__init__()
if reduction_prev:
self.preprocess0 = FactorizedReduce(
C_prev_prev, C, 2, affine, track_running_stats
)
else:
self.preprocess0 = ReLUConvBN(
C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats
)
self.preprocess1 = ReLUConvBN(
C_prev, C, 1, 1, 0, 1, affine, track_running_stats
)
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats): self.reduction = True
super(GDAS_Reduction_Cell, self).__init__() self.ops1 = nn.ModuleList(
if reduction_prev: [
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine, track_running_stats) nn.Sequential(
else: nn.ReLU(inplace=False),
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats) nn.Conv2d(
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, affine, track_running_stats) C,
C,
(1, 3),
stride=(1, 2),
padding=(0, 1),
groups=8,
bias=not affine,
),
nn.Conv2d(
C,
C,
(3, 1),
stride=(2, 1),
padding=(1, 0),
groups=8,
bias=not affine,
),
nn.BatchNorm2d(
C, affine=affine, track_running_stats=track_running_stats
),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
nn.BatchNorm2d(
C, affine=affine, track_running_stats=track_running_stats
),
),
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C,
C,
(1, 3),
stride=(1, 2),
padding=(0, 1),
groups=8,
bias=not affine,
),
nn.Conv2d(
C,
C,
(3, 1),
stride=(2, 1),
padding=(1, 0),
groups=8,
bias=not affine,
),
nn.BatchNorm2d(
C, affine=affine, track_running_stats=track_running_stats
),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
nn.BatchNorm2d(
C, affine=affine, track_running_stats=track_running_stats
),
),
]
)
self.reduction = True self.ops2 = nn.ModuleList(
self.ops1 = nn.ModuleList( [
[nn.Sequential( nn.Sequential(
nn.ReLU(inplace=False), nn.MaxPool2d(3, stride=2, padding=1),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=not affine), nn.BatchNorm2d(
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=not affine), C, affine=affine, track_running_stats=track_running_stats
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats), ),
nn.ReLU(inplace=False), ),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine), nn.Sequential(
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats)), nn.MaxPool2d(3, stride=2, padding=1),
nn.Sequential( nn.BatchNorm2d(
nn.ReLU(inplace=False), C, affine=affine, track_running_stats=track_running_stats
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=not affine), ),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=not affine), ),
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats), ]
nn.ReLU(inplace=False), )
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats))])
self.ops2 = nn.ModuleList( @property
[nn.Sequential( def multiplier(self):
nn.MaxPool2d(3, stride=2, padding=1), return 4
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats)),
nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats))])
@property def forward(self, s0, s1, drop_prob=-1):
def multiplier(self): s0 = self.preprocess0(s0)
return 4 s1 = self.preprocess1(s1)
def forward(self, s0, s1, drop_prob = -1): X0 = self.ops1[0](s0)
s0 = self.preprocess0(s0) X1 = self.ops1[1](s1)
s1 = self.preprocess1(s1) if self.training and drop_prob > 0.0:
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
X0 = self.ops1[0] (s0) # X2 = self.ops2[0] (X0+X1)
X1 = self.ops1[1] (s1) X2 = self.ops2[0](s0)
if self.training and drop_prob > 0.: X3 = self.ops2[1](s1)
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob) if self.training and drop_prob > 0.0:
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
#X2 = self.ops2[0] (X0+X1) return torch.cat([X0, X1, X2, X3], dim=1)
X2 = self.ops2[0] (s0)
X3 = self.ops2[1] (s1)
if self.training and drop_prob > 0.:
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
return torch.cat([X0, X1, X2, X3], dim=1)
# To manage the useful classes in this file. # To manage the useful classes in this file.
RAW_OP_CLASSES = { RAW_OP_CLASSES = {"gdas_reduction": GDAS_Reduction_Cell}
'gdas_reduction': GDAS_Reduction_Cell
}

View File

@@ -2,27 +2,32 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
################################################## ##################################################
# The macro structure is defined in NAS-Bench-201 # The macro structure is defined in NAS-Bench-201
from .search_model_darts import TinyNetworkDarts from .search_model_darts import TinyNetworkDarts
from .search_model_gdas import TinyNetworkGDAS from .search_model_gdas import TinyNetworkGDAS
from .search_model_setn import TinyNetworkSETN from .search_model_setn import TinyNetworkSETN
from .search_model_enas import TinyNetworkENAS from .search_model_enas import TinyNetworkENAS
from .search_model_random import TinyNetworkRANDOM from .search_model_random import TinyNetworkRANDOM
from .generic_model import GenericNAS201Model from .generic_model import GenericNAS201Model
from .genotypes import Structure as CellStructure, architectures as CellArchitectures from .genotypes import Structure as CellStructure, architectures as CellArchitectures
# NASNet-based macro structure # NASNet-based macro structure
from .search_model_gdas_nasnet import NASNetworkGDAS from .search_model_gdas_nasnet import NASNetworkGDAS
from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC
from .search_model_darts_nasnet import NASNetworkDARTS from .search_model_darts_nasnet import NASNetworkDARTS
nas201_super_nets = {'DARTS-V1': TinyNetworkDarts, nas201_super_nets = {
"DARTS-V2": TinyNetworkDarts, "DARTS-V1": TinyNetworkDarts,
"GDAS": TinyNetworkGDAS, "DARTS-V2": TinyNetworkDarts,
"SETN": TinyNetworkSETN, "GDAS": TinyNetworkGDAS,
"ENAS": TinyNetworkENAS, "SETN": TinyNetworkSETN,
"RANDOM": TinyNetworkRANDOM, "ENAS": TinyNetworkENAS,
"generic": GenericNAS201Model} "RANDOM": TinyNetworkRANDOM,
"generic": GenericNAS201Model,
}
nasnet_super_nets = {"GDAS": NASNetworkGDAS, nasnet_super_nets = {
"GDAS_FRC": NASNetworkGDAS_FRC, "GDAS": NASNetworkGDAS,
"DARTS": NASNetworkDARTS} "GDAS_FRC": NASNetworkGDAS_FRC,
"DARTS": NASNetworkDARTS,
}

View File

@@ -4,9 +4,11 @@
import torch import torch
from search_model_enas_utils import Controller from search_model_enas_utils import Controller
def main():
controller = Controller(6, 4)
predictions = controller()
if __name__ == '__main__': def main():
main() controller = Controller(6, 4)
predictions = controller()
if __name__ == "__main__":
main()

View File

@@ -8,296 +8,355 @@ from typing import Text
from torch.distributions.categorical import Categorical from torch.distributions.categorical import Categorical
from ..cell_operations import ResNetBasicblock, drop_path from ..cell_operations import ResNetBasicblock, drop_path
from .search_cells import NAS201SearchCell as SearchCell from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure from .genotypes import Structure
class Controller(nn.Module): class Controller(nn.Module):
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
def __init__(self, edge2index, op_names, max_nodes, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0): def __init__(
super(Controller, self).__init__() self,
# assign the attributes edge2index,
self.max_nodes = max_nodes op_names,
self.num_edge = len(edge2index) max_nodes,
self.edge2index = edge2index lstm_size=32,
self.num_ops = len(op_names) lstm_num_layers=2,
self.op_names = op_names tanh_constant=2.5,
self.lstm_size = lstm_size temperature=5.0,
self.lstm_N = lstm_num_layers ):
self.tanh_constant = tanh_constant super(Controller, self).__init__()
self.temperature = temperature # assign the attributes
# create parameters self.max_nodes = max_nodes
self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size))) self.num_edge = len(edge2index)
self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N) self.edge2index = edge2index
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) self.num_ops = len(op_names)
self.w_pred = nn.Linear(self.lstm_size, self.num_ops) self.op_names = op_names
self.lstm_size = lstm_size
self.lstm_N = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
# create parameters
self.register_parameter(
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
)
self.w_lstm = nn.LSTM(
input_size=self.lstm_size,
hidden_size=self.lstm_size,
num_layers=self.lstm_N,
)
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
nn.init.uniform_(self.input_vars , -0.1, 0.1) nn.init.uniform_(self.input_vars, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
nn.init.uniform_(self.w_embd.weight , -0.1, 0.1) nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
nn.init.uniform_(self.w_pred.weight , -0.1, 0.1) nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
def convert_structure(self, _arch): def convert_structure(self, _arch):
genotypes = [] genotypes = []
for i in range(1, self.max_nodes): for i in range(1, self.max_nodes):
xlist = [] xlist = []
for j in range(i): for j in range(i):
node_str = '{:}<-{:}'.format(i, j) node_str = "{:}<-{:}".format(i, j)
op_index = _arch[self.edge2index[node_str]] op_index = _arch[self.edge2index[node_str]]
op_name = self.op_names[op_index] op_name = self.op_names[op_index]
xlist.append((op_name, j)) xlist.append((op_name, j))
genotypes.append( tuple(xlist) ) genotypes.append(tuple(xlist))
return Structure(genotypes) return Structure(genotypes)
def forward(self): def forward(self):
inputs, h0 = self.input_vars, None inputs, h0 = self.input_vars, None
log_probs, entropys, sampled_arch = [], [], [] log_probs, entropys, sampled_arch = [], [], []
for iedge in range(self.num_edge): for iedge in range(self.num_edge):
outputs, h0 = self.w_lstm(inputs, h0) outputs, h0 = self.w_lstm(inputs, h0)
logits = self.w_pred(outputs)
logits = logits / self.temperature
logits = self.tanh_constant * torch.tanh(logits)
# distribution
op_distribution = Categorical(logits=logits)
op_index = op_distribution.sample()
sampled_arch.append( op_index.item() )
op_log_prob = op_distribution.log_prob(op_index) logits = self.w_pred(outputs)
log_probs.append( op_log_prob.view(-1) ) logits = logits / self.temperature
op_entropy = op_distribution.entropy() logits = self.tanh_constant * torch.tanh(logits)
entropys.append( op_entropy.view(-1) ) # distribution
op_distribution = Categorical(logits=logits)
# obtain the input embedding for the next step op_index = op_distribution.sample()
inputs = self.w_embd(op_index) sampled_arch.append(op_index.item())
return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), self.convert_structure(sampled_arch)
op_log_prob = op_distribution.log_prob(op_index)
log_probs.append(op_log_prob.view(-1))
op_entropy = op_distribution.entropy()
entropys.append(op_entropy.view(-1))
# obtain the input embedding for the next step
inputs = self.w_embd(op_index)
return (
torch.sum(torch.cat(log_probs)),
torch.sum(torch.cat(entropys)),
self.convert_structure(sampled_arch),
)
class GenericNAS201Model(nn.Module): class GenericNAS201Model(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(GenericNAS201Model, self).__init__()
self._C = C
self._layerN = N
self._max_nodes = max_nodes
self._stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self._cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self._cells.append(cell)
C_prev = cell.out_dim
self._op_names = deepcopy(search_space)
self._Layer = len(self._cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(
nn.BatchNorm2d(
C_prev, affine=affine, track_running_stats=track_running_stats
),
nn.ReLU(inplace=True),
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self._num_edge = num_edge
# algorithm related
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self._mode = None
self.dynamic_cell = None
self._tau = None
self._algo = None
self._drop_path = None
self.verbose = False
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): def set_algo(self, algo: Text):
super(GenericNAS201Model, self).__init__() # used for searching
self._C = C assert self._algo is None, "This functioin can only be called once."
self._layerN = N self._algo = algo
self._max_nodes = max_nodes if algo == "enas":
self._stem = nn.Sequential( self.controller = Controller(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), self.edge2index, self._op_names, self._max_nodes
nn.BatchNorm2d(C)) )
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self._cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self._cells.append(cell)
C_prev = cell.out_dim
self._op_names = deepcopy(search_space)
self._Layer = len(self._cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev, affine=affine, track_running_stats=track_running_stats), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self._num_edge = num_edge
# algorithm related
self.arch_parameters = nn.Parameter(1e-3*torch.randn(num_edge, len(search_space)))
self._mode = None
self.dynamic_cell = None
self._tau = None
self._algo = None
self._drop_path = None
self.verbose = False
def set_algo(self, algo: Text):
# used for searching
assert self._algo is None, 'This functioin can only be called once.'
self._algo = algo
if algo == 'enas':
self.controller = Controller(self.edge2index, self._op_names, self._max_nodes)
else:
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(self._num_edge, len(self._op_names)) )
if algo == 'gdas':
self._tau = 10
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ['gdas', 'enas', 'urs', 'joint', 'select', 'dynamic']
self._mode = mode
if mode == 'dynamic': self.dynamic_cell = deepcopy(dynamic_cell)
else : self.dynamic_cell = None
def set_drop_path(self, progress, drop_path_rate):
if drop_path_rate is None:
self._drop_path = None
elif progress is None:
self._drop_path = drop_path_rate
else:
self._drop_path = progress * drop_path_rate
@property
def mode(self):
return self._mode
@property
def drop_path(self):
return self._drop_path
@property
def weights(self):
xlist = list(self._stem.parameters())
xlist+= list(self._cells.parameters())
xlist+= list(self.lastact.parameters())
xlist+= list(self.global_pooling.parameters())
xlist+= list(self.classifier.parameters())
return xlist
def set_tau(self, tau):
self._tau = tau
@property
def tau(self):
return self._tau
@property
def alphas(self):
if self._algo == 'enas':
return list(self.controller.parameters())
else:
return [self.arch_parameters]
@property
def message(self):
string = self.extra_repr()
for i, cell in enumerate(self._cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self._cells), cell.extra_repr())
return string
def show_alphas(self):
with torch.no_grad():
if self._algo == 'enas':
return 'w_pred :\n{:}'.format(self.controller.w_pred.weight)
else:
return 'arch-parameters :\n{:}'.format(nn.functional.softmax(self.arch_parameters, dim=-1).cpu())
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})'.format(name=self.__class__.__name__, **self.__dict__))
@property
def genotype(self):
genotypes = []
for i in range(1, self._max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self._op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self._max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if use_random:
op_name = random.choice(self._op_names)
else: else:
weights = alphas_cpu[ self.edge2index[node_str] ] self.arch_parameters = nn.Parameter(
op_index = torch.multinomial(weights, 1).item() 1e-3 * torch.randn(self._num_edge, len(self._op_names))
op_name = self._op_names[ op_index ] )
xlist.append((op_name, j)) if algo == "gdas":
genotypes.append(tuple(xlist)) self._tau = 10
return Structure(genotypes)
def get_log_prob(self, arch): def set_cal_mode(self, mode, dynamic_cell=None):
with torch.no_grad(): assert mode in ["gdas", "enas", "urs", "joint", "select", "dynamic"]
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) self._mode = mode
select_logits = [] if mode == "dynamic":
for i, node_info in enumerate(arch.nodes): self.dynamic_cell = deepcopy(dynamic_cell)
for op, xin in node_info: else:
node_str = '{:}<-{:}'.format(i+1, xin) self.dynamic_cell = None
op_index = self._op_names.index(op)
select_logits.append( logits[self.edge2index[node_str], op_index] )
return sum(select_logits).item()
def return_topK(self, K, use_random=False): def set_drop_path(self, progress, drop_path_rate):
archs = Structure.gen_all(self._op_names, self._max_nodes, False) if drop_path_rate is None:
pairs = [(self.get_log_prob(arch), arch) for arch in archs] self._drop_path = None
if K < 0 or K >= len(archs): K = len(archs) elif progress is None:
if use_random: self._drop_path = drop_path_rate
return random.sample(archs, K) else:
else: self._drop_path = progress * drop_path_rate
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
return return_pairs
def normalize_archp(self): @property
if self.mode == 'gdas': def mode(self):
while True: return self._mode
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
continue
else: break
with torch.no_grad():
hardwts_cpu = hardwts.detach().cpu()
return hardwts, hardwts_cpu, index, 'GUMBEL'
else:
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
index = alphas.max(-1, keepdim=True)[1]
with torch.no_grad():
alphas_cpu = alphas.detach().cpu()
return alphas, alphas_cpu, index, 'SOFTMAX'
def forward(self, inputs): @property
alphas, alphas_cpu, index, verbose_str = self.normalize_archp() def drop_path(self):
feature = self._stem(inputs) return self._drop_path
for i, cell in enumerate(self._cells):
if isinstance(cell, SearchCell): @property
if self.mode == 'urs': def weights(self):
feature = cell.forward_urs(feature) xlist = list(self._stem.parameters())
if self.verbose: xlist += list(self._cells.parameters())
verbose_str += '-forward_urs' xlist += list(self.lastact.parameters())
elif self.mode == 'select': xlist += list(self.global_pooling.parameters())
feature = cell.forward_select(feature, alphas_cpu) xlist += list(self.classifier.parameters())
if self.verbose: return xlist
verbose_str += '-forward_select'
elif self.mode == 'joint': def set_tau(self, tau):
feature = cell.forward_joint(feature, alphas) self._tau = tau
if self.verbose:
verbose_str += '-forward_joint' @property
elif self.mode == 'dynamic': def tau(self):
feature = cell.forward_dynamic(feature, self.dynamic_cell) return self._tau
if self.verbose:
verbose_str += '-forward_dynamic' @property
elif self.mode == 'gdas': def alphas(self):
feature = cell.forward_gdas(feature, alphas, index) if self._algo == "enas":
if self.verbose: return list(self.controller.parameters())
verbose_str += '-forward_gdas' else:
else: raise ValueError('invalid mode={:}'.format(self.mode)) return [self.arch_parameters]
else: feature = cell(feature)
if self.drop_path is not None: @property
feature = drop_path(feature, self.drop_path) def message(self):
if self.verbose and random.random() < 0.001: string = self.extra_repr()
print(verbose_str) for i, cell in enumerate(self._cells):
out = self.lastact(feature) string += "\n {:02d}/{:02d} :: {:}".format(
out = self.global_pooling(out) i, len(self._cells), cell.extra_repr()
out = out.view(out.size(0), -1) )
logits = self.classifier(out) return string
return out, logits
def show_alphas(self):
with torch.no_grad():
if self._algo == "enas":
return "w_pred :\n{:}".format(self.controller.w_pred.weight)
else:
return "arch-parameters :\n{:}".format(
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
)
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})".format(
name=self.__class__.__name__, **self.__dict__
)
@property
def genotype(self):
genotypes = []
for i in range(1, self._max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self._op_names[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self._max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if use_random:
op_name = random.choice(self._op_names)
else:
weights = alphas_cpu[self.edge2index[node_str]]
op_index = torch.multinomial(weights, 1).item()
op_name = self._op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def get_log_prob(self, arch):
with torch.no_grad():
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = "{:}<-{:}".format(i + 1, xin)
op_index = self._op_names.index(op)
select_logits.append(logits[self.edge2index[node_str], op_index])
return sum(select_logits).item()
def return_topK(self, K, use_random=False):
archs = Structure.gen_all(self._op_names, self._max_nodes, False)
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
if K < 0 or K >= len(archs):
K = len(archs)
if use_random:
return random.sample(archs, K)
else:
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
return return_pairs
def normalize_archp(self):
if self.mode == "gdas":
while True:
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
with torch.no_grad():
hardwts_cpu = hardwts.detach().cpu()
return hardwts, hardwts_cpu, index, "GUMBEL"
else:
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
index = alphas.max(-1, keepdim=True)[1]
with torch.no_grad():
alphas_cpu = alphas.detach().cpu()
return alphas, alphas_cpu, index, "SOFTMAX"
def forward(self, inputs):
alphas, alphas_cpu, index, verbose_str = self.normalize_archp()
feature = self._stem(inputs)
for i, cell in enumerate(self._cells):
if isinstance(cell, SearchCell):
if self.mode == "urs":
feature = cell.forward_urs(feature)
if self.verbose:
verbose_str += "-forward_urs"
elif self.mode == "select":
feature = cell.forward_select(feature, alphas_cpu)
if self.verbose:
verbose_str += "-forward_select"
elif self.mode == "joint":
feature = cell.forward_joint(feature, alphas)
if self.verbose:
verbose_str += "-forward_joint"
elif self.mode == "dynamic":
feature = cell.forward_dynamic(feature, self.dynamic_cell)
if self.verbose:
verbose_str += "-forward_dynamic"
elif self.mode == "gdas":
feature = cell.forward_gdas(feature, alphas, index)
if self.verbose:
verbose_str += "-forward_gdas"
else:
raise ValueError("invalid mode={:}".format(self.mode))
else:
feature = cell(feature)
if self.drop_path is not None:
feature = drop_path(feature, self.drop_path)
if self.verbose and random.random() < 0.001:
print(verbose_str)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -5,194 +5,270 @@ from copy import deepcopy
def get_combination(space, num): def get_combination(space, num):
combs = [] combs = []
for i in range(num): for i in range(num):
if i == 0: if i == 0:
for func in space: for func in space:
combs.append( [(func, i)] ) combs.append([(func, i)])
else: else:
new_combs = [] new_combs = []
for string in combs: for string in combs:
for func in space: for func in space:
xstring = string + [(func, i)] xstring = string + [(func, i)]
new_combs.append( xstring ) new_combs.append(xstring)
combs = new_combs combs = new_combs
return combs return combs
class Structure: class Structure:
def __init__(self, genotype):
assert isinstance(genotype, list) or isinstance(
genotype, tuple
), "invalid class of genotype : {:}".format(type(genotype))
self.node_num = len(genotype) + 1
self.nodes = []
self.node_N = []
for idx, node_info in enumerate(genotype):
assert isinstance(node_info, list) or isinstance(
node_info, tuple
), "invalid class of node_info : {:}".format(type(node_info))
assert len(node_info) >= 1, "invalid length : {:}".format(len(node_info))
for node_in in node_info:
assert isinstance(node_in, list) or isinstance(
node_in, tuple
), "invalid class of in-node : {:}".format(type(node_in))
assert (
len(node_in) == 2 and node_in[1] <= idx
), "invalid in-node : {:}".format(node_in)
self.node_N.append(len(node_info))
self.nodes.append(tuple(deepcopy(node_info)))
def __init__(self, genotype): def tolist(self, remove_str):
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype)) # convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
self.node_num = len(genotype) + 1 # note that we re-order the input node in this function
self.nodes = [] # return the-genotype-list and success [if unsuccess, it is not a connectivity]
self.node_N = [] genotypes = []
for idx, node_info in enumerate(genotype): for node_info in self.nodes:
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info)) node_info = list(node_info)
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info)) node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
for node_in in node_info: node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in)) if len(node_info) == 0:
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in) return None, False
self.node_N.append( len(node_info) ) genotypes.append(node_info)
self.nodes.append( tuple(deepcopy(node_info)) ) return genotypes, True
def tolist(self, remove_str): def node(self, index):
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation. assert index > 0 and index <= len(self), "invalid index={:} < {:}".format(
# note that we re-order the input node in this function index, len(self)
# return the-genotype-list and success [if unsuccess, it is not a connectivity] )
genotypes = [] return self.nodes[index]
for node_info in self.nodes:
node_info = list( node_info )
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
if len(node_info) == 0: return None, False
genotypes.append( node_info )
return genotypes, True
def node(self, index): def tostr(self):
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self)) strings = []
return self.nodes[index] for node_info in self.nodes:
string = "|".join([x[0] + "~{:}".format(x[1]) for x in node_info])
string = "|{:}|".format(string)
strings.append(string)
return "+".join(strings)
def tostr(self): def check_valid(self):
strings = [] nodes = {0: True}
for node_info in self.nodes: for i, node_info in enumerate(self.nodes):
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info]) sums = []
string = '|{:}|'.format(string) for op, xin in node_info:
strings.append( string ) if op == "none" or nodes[xin] is False:
return '+'.join(strings) x = False
else:
x = True
sums.append(x)
nodes[i + 1] = sum(sums) > 0
return nodes[len(self.nodes)]
def check_valid(self): def to_unique_str(self, consider_zero=False):
nodes = {0: True} # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
for i, node_info in enumerate(self.nodes): # two operations are special, i.e., none and skip_connect
sums = [] nodes = {0: "0"}
for op, xin in node_info: for i_node, node_info in enumerate(self.nodes):
if op == 'none' or nodes[xin] is False: x = False cur_node = []
else: x = True for op, xin in node_info:
sums.append( x ) if consider_zero is None:
nodes[i+1] = sum(sums) > 0 x = "(" + nodes[xin] + ")" + "@{:}".format(op)
return nodes[len(self.nodes)] elif consider_zero:
if op == "none" or nodes[xin] == "#":
x = "#" # zero
elif op == "skip_connect":
x = nodes[xin]
else:
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
else:
if op == "skip_connect":
x = nodes[xin]
else:
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
cur_node.append(x)
nodes[i_node + 1] = "+".join(sorted(cur_node))
return nodes[len(self.nodes)]
def to_unique_str(self, consider_zero=False): def check_valid_op(self, op_names):
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation for node_info in self.nodes:
# two operations are special, i.e., none and skip_connect for inode_edge in node_info:
nodes = {0: '0'} # assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
for i_node, node_info in enumerate(self.nodes): if inode_edge[0] not in op_names:
cur_node = [] return False
for op, xin in node_info: return True
if consider_zero is None:
x = '('+nodes[xin]+')' + '@{:}'.format(op) def __repr__(self):
elif consider_zero: return "{name}({node_num} nodes with {node_info})".format(
if op == 'none' or nodes[xin] == '#': x = '#' # zero name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__
elif op == 'skip_connect': x = nodes[xin] )
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
def __len__(self):
return len(self.nodes) + 1
def __getitem__(self, index):
return self.nodes[index]
@staticmethod
def str2structure(xstr):
if isinstance(xstr, Structure):
return xstr
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
type(xstr)
)
nodestrs = xstr.split("+")
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != "", node_str.split("|")))
for xinput in inputs:
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
xinput
)
inputs = (xi.split("~") for xi in inputs)
input_infos = tuple((op, int(IDX)) for (op, IDX) in inputs)
genotypes.append(input_infos)
return Structure(genotypes)
@staticmethod
def str2fullstructure(xstr, default_name="none"):
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
type(xstr)
)
nodestrs = xstr.split("+")
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != "", node_str.split("|")))
for xinput in inputs:
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
xinput
)
inputs = (xi.split("~") for xi in inputs)
input_infos = list((op, int(IDX)) for (op, IDX) in inputs)
all_in_nodes = list(x[1] for x in input_infos)
for j in range(i):
if j not in all_in_nodes:
input_infos.append((default_name, j))
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
genotypes.append(tuple(node_info))
return Structure(genotypes)
@staticmethod
def gen_all(search_space, num, return_ori):
assert isinstance(search_space, list) or isinstance(
search_space, tuple
), "invalid class of search-space : {:}".format(type(search_space))
assert (
num >= 2
), "There should be at least two nodes in a neural cell instead of {:}".format(
num
)
all_archs = get_combination(search_space, 1)
for i, arch in enumerate(all_archs):
all_archs[i] = [tuple(arch)]
for inode in range(2, num):
cur_nodes = get_combination(search_space, inode)
new_all_archs = []
for previous_arch in all_archs:
for cur_node in cur_nodes:
new_all_archs.append(previous_arch + [tuple(cur_node)])
all_archs = new_all_archs
if return_ori:
return all_archs
else: else:
if op == 'skip_connect': x = nodes[xin] return [Structure(x) for x in all_archs]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
cur_node.append(x)
nodes[i_node+1] = '+'.join( sorted(cur_node) )
return nodes[ len(self.nodes) ]
def check_valid_op(self, op_names):
for node_info in self.nodes:
for inode_edge in node_info:
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
if inode_edge[0] not in op_names: return False
return True
def __repr__(self):
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
def __len__(self):
return len(self.nodes) + 1
def __getitem__(self, index):
return self.nodes[index]
@staticmethod
def str2structure(xstr):
if isinstance(xstr, Structure): return xstr
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
genotypes.append( input_infos )
return Structure( genotypes )
@staticmethod
def str2fullstructure(xstr, default_name='none'):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
all_in_nodes= list(x[1] for x in input_infos)
for j in range(i):
if j not in all_in_nodes: input_infos.append((default_name, j))
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
genotypes.append( tuple(node_info) )
return Structure( genotypes )
@staticmethod
def gen_all(search_space, num, return_ori):
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
all_archs = get_combination(search_space, 1)
for i, arch in enumerate(all_archs):
all_archs[i] = [ tuple(arch) ]
for inode in range(2, num):
cur_nodes = get_combination(search_space, inode)
new_all_archs = []
for previous_arch in all_archs:
for cur_node in cur_nodes:
new_all_archs.append( previous_arch + [tuple(cur_node)] )
all_archs = new_all_archs
if return_ori:
return all_archs
else:
return [Structure(x) for x in all_archs]
ResNet_CODE = Structure( ResNet_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1 [
(('nor_conv_3x3', 1), ), # node-2 (("nor_conv_3x3", 0),), # node-1
(('skip_connect', 0), ('skip_connect', 2))] # node-3 (("nor_conv_3x3", 1),), # node-2
) (("skip_connect", 0), ("skip_connect", 2)),
] # node-3
)
AllConv3x3_CODE = Structure( AllConv3x3_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1 [
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2 (("nor_conv_3x3", 0),), # node-1
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3 (("nor_conv_3x3", 0), ("nor_conv_3x3", 1)), # node-2
) (("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
] # node-3
)
AllFull_CODE = Structure( AllFull_CODE = Structure(
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1 [
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2 (
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3 ("skip_connect", 0),
) ("nor_conv_1x1", 0),
("nor_conv_3x3", 0),
("avg_pool_3x3", 0),
), # node-1
(
("skip_connect", 0),
("nor_conv_1x1", 0),
("nor_conv_3x3", 0),
("avg_pool_3x3", 0),
("skip_connect", 1),
("nor_conv_1x1", 1),
("nor_conv_3x3", 1),
("avg_pool_3x3", 1),
), # node-2
(
("skip_connect", 0),
("nor_conv_1x1", 0),
("nor_conv_3x3", 0),
("avg_pool_3x3", 0),
("skip_connect", 1),
("nor_conv_1x1", 1),
("nor_conv_3x3", 1),
("avg_pool_3x3", 1),
("skip_connect", 2),
("nor_conv_1x1", 2),
("nor_conv_3x3", 2),
("avg_pool_3x3", 2),
),
] # node-3
)
AllConv1x1_CODE = Structure( AllConv1x1_CODE = Structure(
[(('nor_conv_1x1', 0), ), # node-1 [
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2 (("nor_conv_1x1", 0),), # node-1
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3 (("nor_conv_1x1", 0), ("nor_conv_1x1", 1)), # node-2
) (("nor_conv_1x1", 0), ("nor_conv_1x1", 1), ("nor_conv_1x1", 2)),
] # node-3
)
AllIdentity_CODE = Structure( AllIdentity_CODE = Structure(
[(('skip_connect', 0), ), # node-1 [
(('skip_connect', 0), ('skip_connect', 1)), # node-2 (("skip_connect", 0),), # node-1
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3 (("skip_connect", 0), ("skip_connect", 1)), # node-2
) (("skip_connect", 0), ("skip_connect", 1), ("skip_connect", 2)),
] # node-3
)
architectures = {'resnet' : ResNet_CODE, architectures = {
'all_c3x3': AllConv3x3_CODE, "resnet": ResNet_CODE,
'all_c1x1': AllConv1x1_CODE, "all_c3x3": AllConv3x3_CODE,
'all_idnt': AllIdentity_CODE, "all_c1x1": AllConv1x1_CODE,
'all_full': AllFull_CODE} "all_idnt": AllIdentity_CODE,
"all_full": AllFull_CODE,
}

View File

@@ -11,191 +11,241 @@ from ..cell_operations import OPS
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG # This module is used for NAS-Bench-201, represents a small search space with a complete DAG
class NAS201SearchCell(nn.Module): class NAS201SearchCell(nn.Module):
def __init__(
self,
C_in,
C_out,
stride,
max_nodes,
op_names,
affine=False,
track_running_stats=True,
):
super(NAS201SearchCell, self).__init__()
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): self.op_names = deepcopy(op_names)
super(NAS201SearchCell, self).__init__() self.edges = nn.ModuleDict()
self.max_nodes = max_nodes
self.in_dim = C_in
self.out_dim = C_out
for i in range(1, max_nodes):
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if j == 0:
xlists = [
OPS[op_name](C_in, C_out, stride, affine, track_running_stats)
for op_name in op_names
]
else:
xlists = [
OPS[op_name](C_in, C_out, 1, affine, track_running_stats)
for op_name in op_names
]
self.edges[node_str] = nn.ModuleList(xlists)
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
self.op_names = deepcopy(op_names) def extra_repr(self):
self.edges = nn.ModuleDict() string = "info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}".format(
self.max_nodes = max_nodes **self.__dict__
self.in_dim = C_in )
self.out_dim = C_out return string
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if j == 0:
xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names]
else:
xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names]
self.edges[ node_str ] = nn.ModuleList( xlists )
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
def extra_repr(self): def forward(self, inputs, weightss):
string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) nodes = [inputs]
return string for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = weightss[self.edge2index[node_str]]
inter_nodes.append(
sum(
layer(nodes[j]) * w
for layer, w in zip(self.edges[node_str], weights)
)
)
nodes.append(sum(inter_nodes))
return nodes[-1]
def forward(self, inputs, weightss): # GDAS
nodes = [inputs] def forward_gdas(self, inputs, hardwts, index):
for i in range(1, self.max_nodes): nodes = [inputs]
inter_nodes = [] for i in range(1, self.max_nodes):
for j in range(i): inter_nodes = []
node_str = '{:}<-{:}'.format(i, j) for j in range(i):
weights = weightss[ self.edge2index[node_str] ] node_str = "{:}<-{:}".format(i, j)
inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) weights = hardwts[self.edge2index[node_str]]
nodes.append( sum(inter_nodes) ) argmaxs = index[self.edge2index[node_str]].item()
return nodes[-1] weigsum = sum(
weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie]
for _ie, edge in enumerate(self.edges[node_str])
)
inter_nodes.append(weigsum)
nodes.append(sum(inter_nodes))
return nodes[-1]
# GDAS # joint
def forward_gdas(self, inputs, hardwts, index): def forward_joint(self, inputs, weightss):
nodes = [inputs] nodes = [inputs]
for i in range(1, self.max_nodes): for i in range(1, self.max_nodes):
inter_nodes = [] inter_nodes = []
for j in range(i): for j in range(i):
node_str = '{:}<-{:}'.format(i, j) node_str = "{:}<-{:}".format(i, j)
weights = hardwts[ self.edge2index[node_str] ] weights = weightss[self.edge2index[node_str]]
argmaxs = index[ self.edge2index[node_str] ].item() # aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) aggregation = sum(
inter_nodes.append( weigsum ) layer(nodes[j]) * w
nodes.append( sum(inter_nodes) ) for layer, w in zip(self.edges[node_str], weights)
return nodes[-1] )
inter_nodes.append(aggregation)
nodes.append(sum(inter_nodes))
return nodes[-1]
# joint # uniform random sampling per iteration, SETN
def forward_joint(self, inputs, weightss): def forward_urs(self, inputs):
nodes = [inputs] nodes = [inputs]
for i in range(1, self.max_nodes): for i in range(1, self.max_nodes):
inter_nodes = [] while True: # to avoid select zero for all ops
for j in range(i): sops, has_non_zero = [], False
node_str = '{:}<-{:}'.format(i, j) for j in range(i):
weights = weightss[ self.edge2index[node_str] ] node_str = "{:}<-{:}".format(i, j)
#aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() candidates = self.edges[node_str]
aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) select_op = random.choice(candidates)
inter_nodes.append( aggregation ) sops.append(select_op)
nodes.append( sum(inter_nodes) ) if not hasattr(select_op, "is_zero") or select_op.is_zero is False:
return nodes[-1] has_non_zero = True
if has_non_zero:
break
inter_nodes = []
for j, select_op in enumerate(sops):
inter_nodes.append(select_op(nodes[j]))
nodes.append(sum(inter_nodes))
return nodes[-1]
# uniform random sampling per iteration, SETN # select the argmax
def forward_urs(self, inputs): def forward_select(self, inputs, weightss):
nodes = [inputs] nodes = [inputs]
for i in range(1, self.max_nodes): for i in range(1, self.max_nodes):
while True: # to avoid select zero for all ops inter_nodes = []
sops, has_non_zero = [], False for j in range(i):
for j in range(i): node_str = "{:}<-{:}".format(i, j)
node_str = '{:}<-{:}'.format(i, j) weights = weightss[self.edge2index[node_str]]
candidates = self.edges[node_str] inter_nodes.append(
select_op = random.choice(candidates) self.edges[node_str][weights.argmax().item()](nodes[j])
sops.append( select_op ) )
if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True # inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
if has_non_zero: break nodes.append(sum(inter_nodes))
inter_nodes = [] return nodes[-1]
for j, select_op in enumerate(sops):
inter_nodes.append( select_op(nodes[j]) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# select the argmax
def forward_select(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) )
#inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# forward with a specific structure
def forward_dynamic(self, inputs, structure):
nodes = [inputs]
for i in range(1, self.max_nodes):
cur_op_node = structure.nodes[i-1]
inter_nodes = []
for op_name, j in cur_op_node:
node_str = '{:}<-{:}'.format(i, j)
op_index = self.op_names.index( op_name )
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# forward with a specific structure
def forward_dynamic(self, inputs, structure):
nodes = [inputs]
for i in range(1, self.max_nodes):
cur_op_node = structure.nodes[i - 1]
inter_nodes = []
for op_name, j in cur_op_node:
node_str = "{:}<-{:}".format(i, j)
op_index = self.op_names.index(op_name)
inter_nodes.append(self.edges[node_str][op_index](nodes[j]))
nodes.append(sum(inter_nodes))
return nodes[-1]
class MixedOp(nn.Module): class MixedOp(nn.Module):
def __init__(self, space, C, stride, affine, track_running_stats):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in space:
op = OPS[primitive](C, C, stride, affine, track_running_stats)
self._ops.append(op)
def __init__(self, space, C, stride, affine, track_running_stats): def forward_gdas(self, x, weights, index):
super(MixedOp, self).__init__() return self._ops[index](x) * weights[index]
self._ops = nn.ModuleList()
for primitive in space:
op = OPS[primitive](C, C, stride, affine, track_running_stats)
self._ops.append(op)
def forward_gdas(self, x, weights, index): def forward_darts(self, x, weights):
return self._ops[index](x) * weights[index] return sum(w * op(x) for w, op in zip(weights, self._ops))
def forward_darts(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops))
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetSearchCell(nn.Module): class NASNetSearchCell(nn.Module):
def __init__(
self,
space,
steps,
multiplier,
C_prev_prev,
C_prev,
C,
reduction,
reduction_prev,
affine,
track_running_stats,
):
super(NASNetSearchCell, self).__init__()
self.reduction = reduction
self.op_names = deepcopy(space)
if reduction_prev:
self.preprocess0 = OPS["skip_connect"](
C_prev_prev, C, 2, affine, track_running_stats
)
else:
self.preprocess0 = OPS["nor_conv_1x1"](
C_prev_prev, C, 1, affine, track_running_stats
)
self.preprocess1 = OPS["nor_conv_1x1"](
C_prev, C, 1, affine, track_running_stats
)
self._steps = steps
self._multiplier = multiplier
def __init__(self, space, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): self._ops = nn.ModuleList()
super(NASNetSearchCell, self).__init__() self.edges = nn.ModuleDict()
self.reduction = reduction for i in range(self._steps):
self.op_names = deepcopy(space) for j in range(2 + i):
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) node_str = "{:}<-{:}".format(
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) i, j
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) ) # indicate the edge from node-(j) to node-(i+2)
self._steps = steps stride = 2 if reduction and j < 2 else 1
self._multiplier = multiplier op = MixedOp(space, C, stride, affine, track_running_stats)
self.edges[node_str] = op
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
self._ops = nn.ModuleList() @property
self.edges = nn.ModuleDict() def multiplier(self):
for i in range(self._steps): return self._multiplier
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j) # indicate the edge from node-(j) to node-(i+2)
stride = 2 if reduction and j < 2 else 1
op = MixedOp(space, C, stride, affine, track_running_stats)
self.edges[ node_str ] = op
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
@property def forward_gdas(self, s0, s1, weightss, indexs):
def multiplier(self): s0 = self.preprocess0(s0)
return self._multiplier s1 = self.preprocess1(s1)
def forward_gdas(self, s0, s1, weightss, indexs): states = [s0, s1]
s0 = self.preprocess0(s0) for i in range(self._steps):
s1 = self.preprocess1(s1) clist = []
for j, h in enumerate(states):
node_str = "{:}<-{:}".format(i, j)
op = self.edges[node_str]
weights = weightss[self.edge2index[node_str]]
index = indexs[self.edge2index[node_str]].item()
clist.append(op.forward_gdas(h, weights, index))
states.append(sum(clist))
states = [s0, s1] return torch.cat(states[-self._multiplier :], dim=1)
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = '{:}<-{:}'.format(i, j)
op = self.edges[ node_str ]
weights = weightss[ self.edge2index[node_str] ]
index = indexs[ self.edge2index[node_str] ].item()
clist.append( op.forward_gdas(h, weights, index) )
states.append( sum(clist) )
return torch.cat(states[-self._multiplier:], dim=1) def forward_darts(self, s0, s1, weightss):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
def forward_darts(self, s0, s1, weightss): states = [s0, s1]
s0 = self.preprocess0(s0) for i in range(self._steps):
s1 = self.preprocess1(s1) clist = []
for j, h in enumerate(states):
node_str = "{:}<-{:}".format(i, j)
op = self.edges[node_str]
weights = weightss[self.edge2index[node_str]]
clist.append(op.forward_darts(h, weights))
states.append(sum(clist))
states = [s0, s1] return torch.cat(states[-self._multiplier :], dim=1)
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = '{:}<-{:}'.format(i, j)
op = self.edges[ node_str ]
weights = weightss[ self.edge2index[node_str] ]
clist.append( op.forward_darts(h, weights) )
states.append( sum(clist) )
return torch.cat(states[-self._multiplier:], dim=1)

View File

@@ -7,91 +7,116 @@ import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from ..cell_operations import ResNetBasicblock from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure from .genotypes import Structure
class TinyNetworkDarts(nn.Module): class TinyNetworkDarts(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkDarts, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
super(TinyNetworkDarts, self).__init__() layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList() self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): for index, (C_curr, reduction) in enumerate(
if reduction: zip(layer_channels, layer_reductions)
cell = ResNetBasicblock(C_prev, C_curr, 2) ):
else: if reduction:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) cell = ResNetBasicblock(C_prev, C_curr, 2)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index else:
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) cell = SearchCell(
self.cells.append( cell ) C_prev,
C_prev = cell.out_dim C_curr,
self.op_names = deepcopy( search_space ) 1,
self._Layer = len(self.cells) max_nodes,
self.edge2index = edge2index search_space,
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) affine,
self.global_pooling = nn.AdaptiveAvgPool2d(1) track_running_stats,
self.classifier = nn.Linear(C_prev, num_classes) )
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
def get_weights(self): def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) xlist += list(self.lastact.parameters()) + list(
xlist+= list( self.classifier.parameters() ) self.global_pooling.parameters()
return xlist )
xlist += list(self.classifier.parameters())
return xlist
def get_alphas(self): def get_alphas(self):
return [self.arch_parameters] return [self.arch_parameters]
def show_alphas(self): def show_alphas(self):
with torch.no_grad():
return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() )
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad(): with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ] return "arch-parameters :\n{:}".format(
op_name = self.op_names[ weights.argmax().item() ] nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
xlist.append((op_name, j)) )
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def forward(self, inputs): def get_message(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1) string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
feature = self.stem(inputs) def extra_repr(self):
for i, cell in enumerate(self.cells): return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
if isinstance(cell, SearchCell): name=self.__class__.__name__, **self.__dict__
feature = cell(feature, alphas) )
else:
feature = cell(feature)
out = self.lastact(feature) def genotype(self):
out = self.global_pooling( out ) genotypes = []
out = out.view(out.size(0), -1) for i in range(1, self.max_nodes):
logits = self.classifier(out) xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self.op_names[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
return out, logits def forward(self, inputs):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell(feature, alphas)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -10,103 +10,169 @@ from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet # The macro structure is based on NASNet
class NASNetworkDARTS(nn.Module): class NASNetworkDARTS(nn.Module):
def __init__(
self,
C: int,
N: int,
steps: int,
multiplier: int,
stem_multiplier: int,
num_classes: int,
search_space: List[Text],
affine: bool,
track_running_stats: bool,
):
super(NASNetworkDARTS, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, # config for each layer
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): layer_channels = (
super(NASNetworkDARTS, self).__init__() [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
self._C = C )
self._layerN = N layer_reductions = (
self._steps = steps [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
self._multiplier = multiplier )
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
num_edge, edge2index = None, None num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList() self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): for index, (C_curr, reduction) in enumerate(
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) zip(layer_channels, layer_reductions)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index ):
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) cell = SearchCell(
self.cells.append( cell ) search_space,
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction steps,
self.op_names = deepcopy( search_space ) multiplier,
self._Layer = len(self.cells) C_prev_prev,
self.edge2index = edge2index C_prev,
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) C_curr,
self.global_pooling = nn.AdaptiveAvgPool2d(1) reduction,
self.classifier = nn.Linear(C_prev, num_classes) reduction_prev,
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) affine,
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.arch_reduce_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
def get_weights(self) -> List[torch.nn.Parameter]: def get_weights(self) -> List[torch.nn.Parameter]:
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) xlist += list(self.lastact.parameters()) + list(
xlist+= list( self.classifier.parameters() ) self.global_pooling.parameters()
return xlist )
xlist += list(self.classifier.parameters())
return xlist
def get_alphas(self) -> List[torch.nn.Parameter]: def get_alphas(self) -> List[torch.nn.Parameter]:
return [self.arch_normal_parameters, self.arch_reduce_parameters] return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self) -> Text: def show_alphas(self) -> Text:
with torch.no_grad(): with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) A = "arch-normal-parameters :\n{:}".format(
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
return '{:}\n{:}'.format(A, B) )
B = "arch-reduce-parameters :\n{:}".format(
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
)
return "{:}\n{:}".format(A, B)
def get_message(self) -> Text: def get_message(self) -> Text:
string = self.extra_repr() string = self.extra_repr()
for i, cell in enumerate(self.cells): for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) string += "\n {:02d}/{:02d} :: {:}".format(
return string i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self) -> Text: def extra_repr(self) -> Text:
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self) -> Dict[Text, List]: def genotype(self) -> Dict[Text, List]:
def _parse(weights): def _parse(weights):
gene = [] gene = []
for i in range(self._steps): for i in range(self._steps):
edges = [] edges = []
for j in range(2+i): for j in range(2 + i):
node_str = '{:}<-{:}'.format(i, j) node_str = "{:}<-{:}".format(i, j)
ws = weights[ self.edge2index[node_str] ] ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names): for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue if op_name == "none":
edges.append( (op_name, j, ws[k]) ) continue
# (TODO) xuanyidong: edges.append((op_name, j, ws[k]))
# Here the selected two edges might come from the same input node. # (TODO) xuanyidong:
# And this case could be a problem that two edges will collapse into a single one # Here the selected two edges might come from the same input node.
# due to our assumption -- at most one edge from an input node during evaluation. # And this case could be a problem that two edges will collapse into a single one
edges = sorted(edges, key=lambda x: -x[-1]) # due to our assumption -- at most one edge from an input node during evaluation.
selected_edges = edges[:2] edges = sorted(edges, key=lambda x: -x[-1])
gene.append( tuple(selected_edges) ) selected_edges = edges[:2]
return gene gene.append(tuple(selected_edges))
with torch.no_grad(): return gene
gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy())
gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)),
'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}
def forward(self, inputs): with torch.no_grad():
gene_normal = _parse(
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
)
gene_reduce = _parse(
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
)
return {
"normal": gene_normal,
"normal_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
"reduce": gene_reduce,
"reduce_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
}
normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1) def forward(self, inputs):
reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1)
s0 = s1 = self.stem(inputs) normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1)
for i, cell in enumerate(self.cells): reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1)
if cell.reduction: ww = reduce_w
else : ww = normal_w
s0, s1 = s1, cell.forward_darts(s0, s1, ww)
out = self.lastact(s1)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction:
ww = reduce_w
else:
ww = normal_w
s0, s1 = s1, cell.forward_darts(s0, s1, ww)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -7,88 +7,108 @@ import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from ..cell_operations import ResNetBasicblock from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure from .genotypes import Structure
from .search_model_enas_utils import Controller from .search_model_enas_utils import Controller
class TinyNetworkENAS(nn.Module): class TinyNetworkENAS(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkENAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
super(TinyNetworkENAS, self).__init__() layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList() self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): for index, (C_curr, reduction) in enumerate(
if reduction: zip(layer_channels, layer_reductions)
cell = ResNetBasicblock(C_prev, C_curr, 2) ):
else: if reduction:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) cell = ResNetBasicblock(C_prev, C_curr, 2)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index else:
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) cell = SearchCell(
self.cells.append( cell ) C_prev,
C_prev = cell.out_dim C_curr,
self.op_names = deepcopy( search_space ) 1,
self._Layer = len(self.cells) max_nodes,
self.edge2index = edge2index search_space,
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) affine,
self.global_pooling = nn.AdaptiveAvgPool2d(1) track_running_stats,
self.classifier = nn.Linear(C_prev, num_classes) )
# to maintain the sampled architecture if num_edge is None:
self.sampled_arch = None num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# to maintain the sampled architecture
self.sampled_arch = None
def update_arch(self, _arch): def update_arch(self, _arch):
if _arch is None: if _arch is None:
self.sampled_arch = None self.sampled_arch = None
elif isinstance(_arch, Structure): elif isinstance(_arch, Structure):
self.sampled_arch = _arch self.sampled_arch = _arch
elif isinstance(_arch, (list, tuple)): elif isinstance(_arch, (list, tuple)):
genotypes = [] genotypes = []
for i in range(1, self.max_nodes): for i in range(1, self.max_nodes):
xlist = [] xlist = []
for j in range(i): for j in range(i):
node_str = '{:}<-{:}'.format(i, j) node_str = "{:}<-{:}".format(i, j)
op_index = _arch[ self.edge2index[node_str] ] op_index = _arch[self.edge2index[node_str]]
op_name = self.op_names[ op_index ] op_name = self.op_names[op_index]
xlist.append((op_name, j)) xlist.append((op_name, j))
genotypes.append( tuple(xlist) ) genotypes.append(tuple(xlist))
self.sampled_arch = Structure(genotypes) self.sampled_arch = Structure(genotypes)
else: else:
raise ValueError('invalid type of input architecture : {:}'.format(_arch)) raise ValueError("invalid type of input architecture : {:}".format(_arch))
return self.sampled_arch return self.sampled_arch
def create_controller(self):
return Controller(len(self.edge2index), len(self.op_names))
def get_message(self): def create_controller(self):
string = self.extra_repr() return Controller(len(self.edge2index), len(self.op_names))
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self): def get_message(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def forward(self, inputs): def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
feature = self.stem(inputs) def forward(self, inputs):
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_dynamic(feature, self.sampled_arch)
else: feature = cell(feature)
out = self.lastact(feature) feature = self.stem(inputs)
out = self.global_pooling( out ) for i, cell in enumerate(self.cells):
out = out.view(out.size(0), -1) if isinstance(cell, SearchCell):
logits = self.classifier(out) feature = cell.forward_dynamic(feature, self.sampled_arch)
else:
feature = cell(feature)
return out, logits out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -7,49 +7,68 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.distributions.categorical import Categorical from torch.distributions.categorical import Categorical
class Controller(nn.Module): class Controller(nn.Module):
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0): def __init__(
super(Controller, self).__init__() self,
# assign the attributes num_edge,
self.num_edge = num_edge num_ops,
self.num_ops = num_ops lstm_size=32,
self.lstm_size = lstm_size lstm_num_layers=2,
self.lstm_N = lstm_num_layers tanh_constant=2.5,
self.tanh_constant = tanh_constant temperature=5.0,
self.temperature = temperature ):
# create parameters super(Controller, self).__init__()
self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size))) # assign the attributes
self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N) self.num_edge = num_edge
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) self.num_ops = num_ops
self.w_pred = nn.Linear(self.lstm_size, self.num_ops) self.lstm_size = lstm_size
self.lstm_N = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
# create parameters
self.register_parameter(
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
)
self.w_lstm = nn.LSTM(
input_size=self.lstm_size,
hidden_size=self.lstm_size,
num_layers=self.lstm_N,
)
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
nn.init.uniform_(self.input_vars , -0.1, 0.1) nn.init.uniform_(self.input_vars, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
nn.init.uniform_(self.w_embd.weight , -0.1, 0.1) nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
nn.init.uniform_(self.w_pred.weight , -0.1, 0.1) nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
def forward(self): def forward(self):
inputs, h0 = self.input_vars, None inputs, h0 = self.input_vars, None
log_probs, entropys, sampled_arch = [], [], [] log_probs, entropys, sampled_arch = [], [], []
for iedge in range(self.num_edge): for iedge in range(self.num_edge):
outputs, h0 = self.w_lstm(inputs, h0) outputs, h0 = self.w_lstm(inputs, h0)
logits = self.w_pred(outputs)
logits = logits / self.temperature
logits = self.tanh_constant * torch.tanh(logits)
# distribution
op_distribution = Categorical(logits=logits)
op_index = op_distribution.sample()
sampled_arch.append( op_index.item() )
op_log_prob = op_distribution.log_prob(op_index) logits = self.w_pred(outputs)
log_probs.append( op_log_prob.view(-1) ) logits = logits / self.temperature
op_entropy = op_distribution.entropy() logits = self.tanh_constant * torch.tanh(logits)
entropys.append( op_entropy.view(-1) ) # distribution
op_distribution = Categorical(logits=logits)
# obtain the input embedding for the next step op_index = op_distribution.sample()
inputs = self.w_embd(op_index) sampled_arch.append(op_index.item())
return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch
op_log_prob = op_distribution.log_prob(op_index)
log_probs.append(op_log_prob.view(-1))
op_entropy = op_distribution.entropy()
entropys.append(op_entropy.view(-1))
# obtain the input embedding for the next step
inputs = self.w_embd(op_index)
return (
torch.sum(torch.cat(log_probs)),
torch.sum(torch.cat(entropys)),
sampled_arch,
)

View File

@@ -5,107 +5,138 @@ import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from ..cell_operations import ResNetBasicblock from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure from .genotypes import Structure
class TinyNetworkGDAS(nn.Module): class TinyNetworkGDAS(nn.Module):
#def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True): # def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): def __init__(
super(TinyNetworkGDAS, self).__init__() self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
self._C = C ):
self._layerN = N super(TinyNetworkGDAS, self).__init__()
self.max_nodes = max_nodes self._C = C
self.stem = nn.Sequential( self._layerN = N
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), self.max_nodes = max_nodes
nn.BatchNorm2d(C)) self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N )
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
self.cells = nn.ModuleList() layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.tau = 10
def get_weights(self): C_prev, num_edge, edge2index = C, None, None
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) self.cells = nn.ModuleList()
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) for index, (C_curr, reduction) in enumerate(
xlist+= list( self.classifier.parameters() ) zip(layer_channels, layer_reductions)
return xlist ):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.tau = 10
def set_tau(self, tau): def get_weights(self):
self.tau = tau xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def get_tau(self): def set_tau(self, tau):
return self.tau self.tau = tau
def get_alphas(self): def get_tau(self):
return [self.arch_parameters] return self.tau
def show_alphas(self): def get_alphas(self):
with torch.no_grad(): return [self.arch_parameters]
return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() )
def get_message(self): def show_alphas(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad(): with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ] return "arch-parameters :\n{:}".format(
op_name = self.op_names[ weights.argmax().item() ] nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
xlist.append((op_name, j)) )
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def forward(self, inputs): def get_message(self):
while True: string = self.extra_repr()
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() for i, cell in enumerate(self.cells):
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau string += "\n {:02d}/{:02d} :: {:}".format(
probs = nn.functional.softmax(logits, dim=1) i, len(self.cells), cell.extra_repr()
index = probs.max(-1, keepdim=True)[1] )
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) return string
hardwts = one_h - probs.detach() + probs
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
continue
else: break
feature = self.stem(inputs) def extra_repr(self):
for i, cell in enumerate(self.cells): return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
if isinstance(cell, SearchCell): name=self.__class__.__name__, **self.__dict__
feature = cell.forward_gdas(feature, hardwts, index) )
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self.op_names[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def forward(self, inputs):
while True:
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_gdas(feature, hardwts, index)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -10,116 +10,190 @@ from models.cell_operations import RAW_OP_CLASSES
# The macro structure is based on NASNet # The macro structure is based on NASNet
class NASNetworkGDAS_FRC(nn.Module): class NASNetworkGDAS_FRC(nn.Module):
def __init__(
self,
C,
N,
steps,
multiplier,
stem_multiplier,
num_classes,
search_space,
affine,
track_running_stats,
):
super(NASNetworkGDAS_FRC, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): # config for each layer
super(NASNetworkGDAS_FRC, self).__init__() layer_channels = (
self._C = C [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
self._layerN = N )
self._steps = steps layer_reductions = (
self._multiplier = multiplier [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
self.stem = nn.Sequential( )
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
num_edge, edge2index = None, None num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList() self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): for index, (C_curr, reduction) in enumerate(
if reduction: zip(layer_channels, layer_reductions)
cell = RAW_OP_CLASSES['gdas_reduction'](C_prev_prev, C_prev, C_curr, reduction_prev, affine, track_running_stats) ):
else: if reduction:
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) cell = RAW_OP_CLASSES["gdas_reduction"](
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index C_prev_prev,
else: assert reduction or num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) C_prev,
self.cells.append( cell ) C_curr,
C_prev_prev, C_prev, reduction_prev = C_prev, cell.multiplier * C_curr, reduction reduction_prev,
self.op_names = deepcopy( search_space ) affine,
self._Layer = len(self.cells) track_running_stats,
self.edge2index = edge2index )
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) else:
self.global_pooling = nn.AdaptiveAvgPool2d(1) cell = SearchCell(
self.classifier = nn.Linear(C_prev, num_classes) search_space,
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) steps,
self.tau = 10 multiplier,
C_prev_prev,
C_prev,
C_curr,
reduction,
reduction_prev,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
reduction
or num_edge == cell.num_edges
and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = (
C_prev,
cell.multiplier * C_curr,
reduction,
)
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.tau = 10
def get_weights(self): def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) xlist += list(self.lastact.parameters()) + list(
xlist+= list( self.classifier.parameters() ) self.global_pooling.parameters()
return xlist )
xlist += list(self.classifier.parameters())
return xlist
def set_tau(self, tau): def set_tau(self, tau):
self.tau = tau self.tau = tau
def get_tau(self): def get_tau(self):
return self.tau return self.tau
def get_alphas(self): def get_alphas(self):
return [self.arch_parameters] return [self.arch_parameters]
def show_alphas(self): def show_alphas(self):
with torch.no_grad(): with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format(nn.functional.softmax(self.arch_parameters, dim=-1).cpu()) A = "arch-normal-parameters :\n{:}".format(
return '{:}'.format(A) nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
)
return "{:}".format(A)
def get_message(self): def get_message(self):
string = self.extra_repr() string = self.extra_repr()
for i, cell in enumerate(self.cells): for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) string += "\n {:02d}/{:02d} :: {:}".format(
return string i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self): def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self): def genotype(self):
def _parse(weights): def _parse(weights):
gene = [] gene = []
for i in range(self._steps): for i in range(self._steps):
edges = [] edges = []
for j in range(2+i): for j in range(2 + i):
node_str = '{:}<-{:}'.format(i, j) node_str = "{:}<-{:}".format(i, j)
ws = weights[ self.edge2index[node_str] ] ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names): for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue if op_name == "none":
edges.append( (op_name, j, ws[k]) ) continue
edges = sorted(edges, key=lambda x: -x[-1]) edges.append((op_name, j, ws[k]))
selected_edges = edges[:2] edges = sorted(edges, key=lambda x: -x[-1])
gene.append( tuple(selected_edges) ) selected_edges = edges[:2]
return gene gene.append(tuple(selected_edges))
with torch.no_grad(): return gene
gene_normal = _parse(torch.softmax(self.arch_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}
def forward(self, inputs): with torch.no_grad():
def get_gumbel_prob(xins): gene_normal = _parse(
while True: torch.softmax(self.arch_parameters, dim=-1).cpu().numpy()
gumbels = -torch.empty_like(xins).exponential_().log() )
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau return {
probs = nn.functional.softmax(logits, dim=1) "normal": gene_normal,
index = probs.max(-1, keepdim=True)[1] "normal_concat": list(
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) range(2 + self._steps - self._multiplier, self._steps + 2)
hardwts = one_h - probs.detach() + probs ),
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): }
continue
else: break
return hardwts, index
hardwts, index = get_gumbel_prob(self.arch_parameters) def forward(self, inputs):
def get_gumbel_prob(xins):
while True:
gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
return hardwts, index
s0 = s1 = self.stem(inputs) hardwts, index = get_gumbel_prob(self.arch_parameters)
for i, cell in enumerate(self.cells):
if cell.reduction:
s0, s1 = s1, cell(s0, s1)
else:
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction:
s0, s1 = s1, cell(s0, s1)
else:
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -9,117 +9,189 @@ from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet # The macro structure is based on NASNet
class NASNetworkGDAS(nn.Module): class NASNetworkGDAS(nn.Module):
def __init__(
self,
C,
N,
steps,
multiplier,
stem_multiplier,
num_classes,
search_space,
affine,
track_running_stats,
):
super(NASNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): # config for each layer
super(NASNetworkGDAS, self).__init__() layer_channels = (
self._C = C [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
self._layerN = N )
self._steps = steps layer_reductions = (
self._multiplier = multiplier [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
self.stem = nn.Sequential( )
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
num_edge, edge2index = None, None num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList() self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): for index, (C_curr, reduction) in enumerate(
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) zip(layer_channels, layer_reductions)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index ):
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) cell = SearchCell(
self.cells.append( cell ) search_space,
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction steps,
self.op_names = deepcopy( search_space ) multiplier,
self._Layer = len(self.cells) C_prev_prev,
self.edge2index = edge2index C_prev,
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) C_curr,
self.global_pooling = nn.AdaptiveAvgPool2d(1) reduction,
self.classifier = nn.Linear(C_prev, num_classes) reduction_prev,
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) affine,
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) track_running_stats,
self.tau = 10 )
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.arch_reduce_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.tau = 10
def get_weights(self): def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) xlist += list(self.lastact.parameters()) + list(
xlist+= list( self.classifier.parameters() ) self.global_pooling.parameters()
return xlist )
xlist += list(self.classifier.parameters())
return xlist
def set_tau(self, tau): def set_tau(self, tau):
self.tau = tau self.tau = tau
def get_tau(self): def get_tau(self):
return self.tau return self.tau
def get_alphas(self): def get_alphas(self):
return [self.arch_normal_parameters, self.arch_reduce_parameters] return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self): def show_alphas(self):
with torch.no_grad(): with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) A = "arch-normal-parameters :\n{:}".format(
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
return '{:}\n{:}'.format(A, B) )
B = "arch-reduce-parameters :\n{:}".format(
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
)
return "{:}\n{:}".format(A, B)
def get_message(self): def get_message(self):
string = self.extra_repr() string = self.extra_repr()
for i, cell in enumerate(self.cells): for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) string += "\n {:02d}/{:02d} :: {:}".format(
return string i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self): def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self): def genotype(self):
def _parse(weights): def _parse(weights):
gene = [] gene = []
for i in range(self._steps): for i in range(self._steps):
edges = [] edges = []
for j in range(2+i): for j in range(2 + i):
node_str = '{:}<-{:}'.format(i, j) node_str = "{:}<-{:}".format(i, j)
ws = weights[ self.edge2index[node_str] ] ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names): for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue if op_name == "none":
edges.append( (op_name, j, ws[k]) ) continue
edges = sorted(edges, key=lambda x: -x[-1]) edges.append((op_name, j, ws[k]))
selected_edges = edges[:2] edges = sorted(edges, key=lambda x: -x[-1])
gene.append( tuple(selected_edges) ) selected_edges = edges[:2]
return gene gene.append(tuple(selected_edges))
with torch.no_grad(): return gene
gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy())
gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)),
'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}
def forward(self, inputs): with torch.no_grad():
def get_gumbel_prob(xins): gene_normal = _parse(
while True: torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
gumbels = -torch.empty_like(xins).exponential_().log() )
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau gene_reduce = _parse(
probs = nn.functional.softmax(logits, dim=1) torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
index = probs.max(-1, keepdim=True)[1] )
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) return {
hardwts = one_h - probs.detach() + probs "normal": gene_normal,
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): "normal_concat": list(
continue range(2 + self._steps - self._multiplier, self._steps + 2)
else: break ),
return hardwts, index "reduce": gene_reduce,
"reduce_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
}
normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters) def forward(self, inputs):
reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters) def get_gumbel_prob(xins):
while True:
gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
return hardwts, index
s0 = s1 = self.stem(inputs) normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters)
for i, cell in enumerate(self.cells): reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters)
if cell.reduction: hardwts, index = reduce_hardwts, reduce_index
else : hardwts, index = normal_hardwts, normal_index
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction:
hardwts, index = reduce_hardwts, reduce_index
else:
hardwts, index = normal_hardwts, normal_index
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -1,81 +1,102 @@
################################################## ##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
############################################################################## ##############################################################################
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 # # Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
############################################################################## ##############################################################################
import torch, random import torch, random
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from ..cell_operations import ResNetBasicblock from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure from .genotypes import Structure
class TinyNetworkRANDOM(nn.Module): class TinyNetworkRANDOM(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkRANDOM, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
super(TinyNetworkRANDOM, self).__init__() layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList() self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): for index, (C_curr, reduction) in enumerate(
if reduction: zip(layer_channels, layer_reductions)
cell = ResNetBasicblock(C_prev, C_curr, 2) ):
else: if reduction:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) cell = ResNetBasicblock(C_prev, C_curr, 2)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index else:
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) cell = SearchCell(
self.cells.append( cell ) C_prev,
C_prev = cell.out_dim C_curr,
self.op_names = deepcopy( search_space ) 1,
self._Layer = len(self.cells) max_nodes,
self.edge2index = edge2index search_space,
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) affine,
self.global_pooling = nn.AdaptiveAvgPool2d(1) track_running_stats,
self.classifier = nn.Linear(C_prev, num_classes) )
self.arch_cache = None if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
def get_message(self): else:
string = self.extra_repr() assert (
for i, cell in enumerate(self.cells): num_edge == cell.num_edges and edge2index == cell.edge2index
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
return string self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_cache = None
def extra_repr(self): def get_message(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def random_genotype(self, set_cache): def extra_repr(self):
genotypes = [] return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
for i in range(1, self.max_nodes): name=self.__class__.__name__, **self.__dict__
xlist = [] )
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = random.choice( self.op_names )
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
arch = Structure( genotypes )
if set_cache: self.arch_cache = arch
return arch
def forward(self, inputs): def random_genotype(self, set_cache):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = random.choice(self.op_names)
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
arch = Structure(genotypes)
if set_cache:
self.arch_cache = arch
return arch
feature = self.stem(inputs) def forward(self, inputs):
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_dynamic(feature, self.arch_cache)
else: feature = cell(feature)
out = self.lastact(feature) feature = self.stem(inputs)
out = self.global_pooling( out ) for i, cell in enumerate(self.cells):
out = out.view(out.size(0), -1) if isinstance(cell, SearchCell):
logits = self.classifier(out) feature = cell.forward_dynamic(feature, self.arch_cache)
return out, logits else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -7,146 +7,172 @@ import torch, random
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from ..cell_operations import ResNetBasicblock from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure from .genotypes import Structure
class TinyNetworkSETN(nn.Module): class TinyNetworkSETN(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkSETN, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
super(TinyNetworkSETN, self).__init__() layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList() self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): for index, (C_curr, reduction) in enumerate(
if reduction: zip(layer_channels, layer_reductions)
cell = ResNetBasicblock(C_prev, C_curr, 2) ):
else: if reduction:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) cell = ResNetBasicblock(C_prev, C_curr, 2)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index else:
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) cell = SearchCell(
self.cells.append( cell ) C_prev,
C_prev = cell.out_dim C_curr,
self.op_names = deepcopy( search_space ) 1,
self._Layer = len(self.cells) max_nodes,
self.edge2index = edge2index search_space,
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) affine,
self.global_pooling = nn.AdaptiveAvgPool2d(1) track_running_stats,
self.classifier = nn.Linear(C_prev, num_classes) )
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) if num_edge is None:
self.mode = 'urs' num_edge, edge2index = cell.num_edges, cell.edge2index
self.dynamic_cell = None else:
assert (
def set_cal_mode(self, mode, dynamic_cell=None): num_edge == cell.num_edges and edge2index == cell.edge2index
assert mode in ['urs', 'joint', 'select', 'dynamic'] ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.mode = mode self.cells.append(cell)
if mode == 'dynamic': self.dynamic_cell = deepcopy( dynamic_cell ) C_prev = cell.out_dim
else : self.dynamic_cell = None self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.mode = "urs"
self.dynamic_cell = None
def get_cal_mode(self): def set_cal_mode(self, mode, dynamic_cell=None):
return self.mode assert mode in ["urs", "joint", "select", "dynamic"]
self.mode = mode
def get_weights(self): if mode == "dynamic":
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) self.dynamic_cell = deepcopy(dynamic_cell)
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def get_alphas(self):
return [self.arch_parameters]
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else: else:
weights = alphas_cpu[ self.edge2index[node_str] ] self.dynamic_cell = None
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[ op_index ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def get_log_prob(self, arch): def get_cal_mode(self):
with torch.no_grad(): return self.mode
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = '{:}<-{:}'.format(i+1, xin)
op_index = self.op_names.index(op)
select_logits.append( logits[self.edge2index[node_str], op_index] )
return sum(select_logits).item()
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def return_topK(self, K): def get_alphas(self):
archs = Structure.gen_all(self.op_names, self.max_nodes, False) return [self.arch_parameters]
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
if K < 0 or K >= len(archs): K = len(archs)
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
return return_pairs
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def forward(self, inputs): def extra_repr(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1) return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
with torch.no_grad(): name=self.__class__.__name__, **self.__dict__
alphas_cpu = alphas.detach().cpu() )
feature = self.stem(inputs) def genotype(self):
for i, cell in enumerate(self.cells): genotypes = []
if isinstance(cell, SearchCell): for i in range(1, self.max_nodes):
if self.mode == 'urs': xlist = []
feature = cell.forward_urs(feature) for j in range(i):
elif self.mode == 'select': node_str = "{:}<-{:}".format(i, j)
feature = cell.forward_select(feature, alphas_cpu) with torch.no_grad():
elif self.mode == 'joint': weights = self.arch_parameters[self.edge2index[node_str]]
feature = cell.forward_joint(feature, alphas) op_name = self.op_names[weights.argmax().item()]
elif self.mode == 'dynamic': xlist.append((op_name, j))
feature = cell.forward_dynamic(feature, self.dynamic_cell) genotypes.append(tuple(xlist))
else: raise ValueError('invalid mode={:}'.format(self.mode)) return Structure(genotypes)
else: feature = cell(feature)
out = self.lastact(feature) def dync_genotype(self, use_random=False):
out = self.global_pooling( out ) genotypes = []
out = out.view(out.size(0), -1) with torch.no_grad():
logits = self.classifier(out) alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[self.edge2index[node_str]]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
return out, logits def get_log_prob(self, arch):
with torch.no_grad():
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = "{:}<-{:}".format(i + 1, xin)
op_index = self.op_names.index(op)
select_logits.append(logits[self.edge2index[node_str], op_index])
return sum(select_logits).item()
def return_topK(self, K):
archs = Structure.gen_all(self.op_names, self.max_nodes, False)
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
if K < 0 or K >= len(archs):
K = len(archs)
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
return return_pairs
def forward(self, inputs):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
with torch.no_grad():
alphas_cpu = alphas.detach().cpu()
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
if self.mode == "urs":
feature = cell.forward_urs(feature)
elif self.mode == "select":
feature = cell.forward_select(feature, alphas_cpu)
elif self.mode == "joint":
feature = cell.forward_joint(feature, alphas)
elif self.mode == "dynamic":
feature = cell.forward_dynamic(feature, self.dynamic_cell)
else:
raise ValueError("invalid mode={:}".format(self.mode))
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -7,133 +7,199 @@ import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from typing import List, Text, Dict from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet # The macro structure is based on NASNet
class NASNetworkSETN(nn.Module): class NASNetworkSETN(nn.Module):
def __init__(
self,
C: int,
N: int,
steps: int,
multiplier: int,
stem_multiplier: int,
num_classes: int,
search_space: List[Text],
affine: bool,
track_running_stats: bool,
):
super(NASNetworkSETN, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, # config for each layer
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): layer_channels = (
super(NASNetworkSETN, self).__init__() [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
self._C = C )
self._layerN = N layer_reductions = (
self._steps = steps [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
self._multiplier = multiplier )
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
num_edge, edge2index = None, None num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList() self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): for index, (C_curr, reduction) in enumerate(
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) zip(layer_channels, layer_reductions)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index ):
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) cell = SearchCell(
self.cells.append( cell ) search_space,
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction steps,
self.op_names = deepcopy( search_space ) multiplier,
self._Layer = len(self.cells) C_prev_prev,
self.edge2index = edge2index C_prev,
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) C_curr,
self.global_pooling = nn.AdaptiveAvgPool2d(1) reduction,
self.classifier = nn.Linear(C_prev, num_classes) reduction_prev,
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) affine,
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) track_running_stats,
self.mode = 'urs' )
self.dynamic_cell = None if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.arch_reduce_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.mode = "urs"
self.dynamic_cell = None
def set_cal_mode(self, mode, dynamic_cell=None): def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ['urs', 'joint', 'select', 'dynamic'] assert mode in ["urs", "joint", "select", "dynamic"]
self.mode = mode self.mode = mode
if mode == 'dynamic': if mode == "dynamic":
self.dynamic_cell = deepcopy(dynamic_cell) self.dynamic_cell = deepcopy(dynamic_cell)
else:
self.dynamic_cell = None
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def get_alphas(self):
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self):
with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() )
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() )
return '{:}\n{:}'.format(A, B)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else: else:
weights = alphas_cpu[ self.edge2index[node_str] ] self.dynamic_cell = None
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[ op_index ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def genotype(self): def get_weights(self):
def _parse(weights): xlist = list(self.stem.parameters()) + list(self.cells.parameters())
gene = [] xlist += list(self.lastact.parameters()) + list(
for i in range(self._steps): self.global_pooling.parameters()
edges = [] )
for j in range(2+i): xlist += list(self.classifier.parameters())
node_str = '{:}<-{:}'.format(i, j) return xlist
ws = weights[ self.edge2index[node_str] ]
for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue
edges.append( (op_name, j, ws[k]) )
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append( tuple(selected_edges) )
return gene
with torch.no_grad():
gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy())
gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)),
'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}
def forward(self, inputs): def get_alphas(self):
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1) return [self.arch_normal_parameters, self.arch_reduce_parameters]
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
s0 = s1 = self.stem(inputs) def show_alphas(self):
for i, cell in enumerate(self.cells): with torch.no_grad():
# [TODO] A = "arch-normal-parameters :\n{:}".format(
raise NotImplementedError nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
if cell.reduction: hardwts, index = reduce_hardwts, reduce_index )
else : hardwts, index = normal_hardwts, normal_index B = "arch-reduce-parameters :\n{:}".format(
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
out = self.lastact(s1) )
out = self.global_pooling( out ) return "{:}\n{:}".format(A, B)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[self.edge2index[node_str]]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2 + i):
node_str = "{:}<-{:}".format(i, j)
ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names):
if op_name == "none":
continue
edges.append((op_name, j, ws[k]))
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append(tuple(selected_edges))
return gene
with torch.no_grad():
gene_normal = _parse(
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
)
gene_reduce = _parse(
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
)
return {
"normal": gene_normal,
"normal_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
"reduce": gene_reduce,
"reduce_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
}
def forward(self, inputs):
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1)
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
# [TODO]
raise NotImplementedError
if cell.reduction:
hardwts, index = reduce_hardwts, reduce_index
else:
hardwts, index = normal_hardwts, normal_index
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -3,60 +3,72 @@ import torch.nn as nn
def copy_conv(module, init): def copy_conv(module, init):
assert isinstance(module, nn.Conv2d), 'invalid module : {:}'.format(module) assert isinstance(module, nn.Conv2d), "invalid module : {:}".format(module)
assert isinstance(init , nn.Conv2d), 'invalid module : {:}'.format(init) assert isinstance(init, nn.Conv2d), "invalid module : {:}".format(init)
new_i, new_o = module.in_channels, module.out_channels new_i, new_o = module.in_channels, module.out_channels
module.weight.copy_( init.weight.detach()[:new_o, :new_i] ) module.weight.copy_(init.weight.detach()[:new_o, :new_i])
if module.bias is not None: if module.bias is not None:
module.bias.copy_( init.bias.detach()[:new_o] ) module.bias.copy_(init.bias.detach()[:new_o])
def copy_bn (module, init):
assert isinstance(module, nn.BatchNorm2d), 'invalid module : {:}'.format(module)
assert isinstance(init , nn.BatchNorm2d), 'invalid module : {:}'.format(init)
num_features = module.num_features
if module.weight is not None:
module.weight.copy_( init.weight.detach()[:num_features] )
if module.bias is not None:
module.bias.copy_( init.bias.detach()[:num_features] )
if module.running_mean is not None:
module.running_mean.copy_( init.running_mean.detach()[:num_features] )
if module.running_var is not None:
module.running_var.copy_( init.running_var.detach()[:num_features] )
def copy_fc (module, init): def copy_bn(module, init):
assert isinstance(module, nn.Linear), 'invalid module : {:}'.format(module) assert isinstance(module, nn.BatchNorm2d), "invalid module : {:}".format(module)
assert isinstance(init , nn.Linear), 'invalid module : {:}'.format(init) assert isinstance(init, nn.BatchNorm2d), "invalid module : {:}".format(init)
new_i, new_o = module.in_features, module.out_features num_features = module.num_features
module.weight.copy_( init.weight.detach()[:new_o, :new_i] ) if module.weight is not None:
if module.bias is not None: module.weight.copy_(init.weight.detach()[:num_features])
module.bias.copy_( init.bias.detach()[:new_o] ) if module.bias is not None:
module.bias.copy_(init.bias.detach()[:num_features])
if module.running_mean is not None:
module.running_mean.copy_(init.running_mean.detach()[:num_features])
if module.running_var is not None:
module.running_var.copy_(init.running_var.detach()[:num_features])
def copy_fc(module, init):
assert isinstance(module, nn.Linear), "invalid module : {:}".format(module)
assert isinstance(init, nn.Linear), "invalid module : {:}".format(init)
new_i, new_o = module.in_features, module.out_features
module.weight.copy_(init.weight.detach()[:new_o, :new_i])
if module.bias is not None:
module.bias.copy_(init.bias.detach()[:new_o])
def copy_base(module, init): def copy_base(module, init):
assert type(module).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format(module) assert type(module).__name__ in [
assert type( init).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format( init) "ConvBNReLU",
if module.conv is not None: "Downsample",
copy_conv(module.conv, init.conv) ], "invalid module : {:}".format(module)
if module.bn is not None: assert type(init).__name__ in [
copy_bn (module.bn, init.bn) "ConvBNReLU",
"Downsample",
], "invalid module : {:}".format(init)
if module.conv is not None:
copy_conv(module.conv, init.conv)
if module.bn is not None:
copy_bn(module.bn, init.bn)
def copy_basic(module, init): def copy_basic(module, init):
copy_base(module.conv_a, init.conv_a) copy_base(module.conv_a, init.conv_a)
copy_base(module.conv_b, init.conv_b) copy_base(module.conv_b, init.conv_b)
if module.downsample is not None: if module.downsample is not None:
if init.downsample is not None: if init.downsample is not None:
copy_base(module.downsample, init.downsample) copy_base(module.downsample, init.downsample)
#else: # else:
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
def init_from_model(network, init_model): def init_from_model(network, init_model):
with torch.no_grad(): with torch.no_grad():
copy_fc(network.classifier, init_model.classifier) copy_fc(network.classifier, init_model.classifier)
for base, target in zip(init_model.layers, network.layers): for base, target in zip(init_model.layers, network.layers):
assert type(base).__name__ == type(target).__name__, 'invalid type : {:} vs {:}'.format(base, target) assert (
if type(base).__name__ == 'ConvBNReLU': type(base).__name__ == type(target).__name__
copy_base(target, base) ), "invalid type : {:} vs {:}".format(base, target)
elif type(base).__name__ == 'ResNetBasicblock': if type(base).__name__ == "ConvBNReLU":
copy_basic(target, base) copy_base(target, base)
else: elif type(base).__name__ == "ResNetBasicblock":
raise ValueError('unknown type name : {:}'.format( type(base).__name__ )) copy_basic(target, base)
else:
raise ValueError("unknown type name : {:}".format(type(base).__name__))

View File

@@ -3,16 +3,14 @@ import torch.nn as nn
def initialize_resnet(m): def initialize_resnet(m):
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01) nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)

View File

@@ -7,161 +7,280 @@ from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):
def __init__(
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
super(ConvBNReLU, self).__init__() ):
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) super(ConvBNReLU, self).__init__()
else : self.avg = None if has_avg:
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
if has_bn : self.bn = nn.BatchNorm2d(nOut) else:
else : self.bn = None self.avg = None
if has_relu: self.relu = nn.ReLU(inplace=True) self.conv = nn.Conv2d(
else : self.relu = None nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
if has_bn:
self.bn = nn.BatchNorm2d(nOut)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
def forward(self, inputs): def forward(self, inputs):
if self.avg : out = self.avg( inputs ) if self.avg:
else : out = inputs out = self.avg(inputs)
conv = self.conv( out ) else:
if self.bn : out = self.bn( conv ) out = inputs
else : out = conv conv = self.conv(out)
if self.relu: out = self.relu( out ) if self.bn:
else : out = out out = self.bn(conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
return out return out
class ResNetBasicblock(nn.Module): class ResNetBasicblock(nn.Module):
num_conv = 2 num_conv = 2
expansion = 1 expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs): def __init__(self, iCs, stride):
basicblock = self.conv_a(inputs) super(ResNetBasicblock, self).__init__()
basicblock = self.conv_b(basicblock) assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
if self.downsample is not None: self.conv_a = ConvBNReLU(
residual = self.downsample(inputs) iCs[0],
else: iCs[1],
residual = inputs 3,
out = residual + basicblock stride,
return F.relu(out, inplace=True) 1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module): class ResNetBottleneck(nn.Module):
expansion = 4 expansion = 4
num_conv = 3 num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs): def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
self.conv_1x1 = ConvBNReLU(
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
iCs[1],
iCs[2],
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=False,
has_bn=False,
has_relu=False,
)
residual_in = iCs[3]
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
bottleneck = self.conv_1x1(inputs) def forward(self, inputs):
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: bottleneck = self.conv_1x1(inputs)
residual = self.downsample(inputs) bottleneck = self.conv_3x3(bottleneck)
else: bottleneck = self.conv_1x4(bottleneck)
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferCifarResNet(nn.Module): class InferCifarResNet(nn.Module):
def __init__(
self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual
):
super(InferCifarResNet, self).__init__()
def __init__(self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual): # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
super(InferCifarResNet, self).__init__() if block_name == "ResNetBasicblock":
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 2) // 6
elif block_name == "ResNetBottleneck":
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, "depth should be one of 164"
layer_blocks = (depth - 2) // 9
else:
raise ValueError("invalid block : {:}".format(block_name))
assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks)
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model self.message = (
if block_name == 'ResNetBasicblock': "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
block = ResNetBasicblock depth, layer_blocks
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' )
layer_blocks = (depth - 2) // 6 )
elif block_name == 'ResNetBottleneck': self.num_classes = num_classes
block = ResNetBottleneck self.xchannels = xchannels
assert (depth - 2) % 9 == 0, 'depth should be one of 164' self.layers = nn.ModuleList(
layer_blocks = (depth - 2) // 9 [
else: ConvBNReLU(
raise ValueError('invalid block : {:}'.format(block_name)) xchannels[0],
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks) xchannels[1],
3,
1,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
]
)
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iCs,
module.out_dim,
stride,
)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL + 1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) self.avgpool = nn.AvgPool2d(8)
self.num_classes = num_classes self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self): self.apply(initialize_resnet)
return self.message if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def forward(self, inputs): def get_message(self):
x = inputs return self.message
for i, layer in enumerate(self.layers):
x = layer( x ) def forward(self, inputs):
features = self.avgpool(x) x = inputs
features = features.view(features.size(0), -1) for i, layer in enumerate(self.layers):
logits = self.classifier(features) x = layer(x)
return features, logits features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -7,144 +7,257 @@ from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):
def __init__(
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
super(ConvBNReLU, self).__init__() ):
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) super(ConvBNReLU, self).__init__()
else : self.avg = None if has_avg:
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
if has_bn : self.bn = nn.BatchNorm2d(nOut) else:
else : self.bn = None self.avg = None
if has_relu: self.relu = nn.ReLU(inplace=True) self.conv = nn.Conv2d(
else : self.relu = None nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
if has_bn:
self.bn = nn.BatchNorm2d(nOut)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
def forward(self, inputs): def forward(self, inputs):
if self.avg : out = self.avg( inputs ) if self.avg:
else : out = inputs out = self.avg(inputs)
conv = self.conv( out ) else:
if self.bn : out = self.bn( conv ) out = inputs
else : out = conv conv = self.conv(out)
if self.relu: out = self.relu( out ) if self.bn:
else : out = out out = self.bn(conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
return out return out
class ResNetBasicblock(nn.Module): class ResNetBasicblock(nn.Module):
num_conv = 2 num_conv = 2
expansion = 1 expansion = 1
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
def forward(self, inputs): def __init__(self, inplanes, planes, stride):
basicblock = self.conv_a(inputs) super(ResNetBasicblock, self).__init__()
basicblock = self.conv_b(basicblock) assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
if self.downsample is not None: self.conv_a = ConvBNReLU(
residual = self.downsample(inputs) inplanes,
else: planes,
residual = inputs 3,
out = residual + basicblock stride,
return F.relu(out, inplace=True) 1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module): class ResNetBottleneck(nn.Module):
expansion = 4 expansion = 4
num_conv = 3 num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
else:
self.downsample = None
self.out_dim = planes*self.expansion
def forward(self, inputs): def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_1x1 = ConvBNReLU(
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
planes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
planes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes * self.expansion:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=False,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes * self.expansion
bottleneck = self.conv_1x1(inputs) def forward(self, inputs):
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: bottleneck = self.conv_1x1(inputs)
residual = self.downsample(inputs) bottleneck = self.conv_3x3(bottleneck)
else: bottleneck = self.conv_1x4(bottleneck)
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferDepthCifarResNet(nn.Module): class InferDepthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
super(InferDepthCifarResNet, self).__init__()
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual): # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
super(InferDepthCifarResNet, self).__init__() if block_name == "ResNetBasicblock":
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 2) // 6
elif block_name == "ResNetBottleneck":
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, "depth should be one of 164"
layer_blocks = (depth - 2) // 9
else:
raise ValueError("invalid block : {:}".format(block_name))
assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks)
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model self.message = (
if block_name == 'ResNetBasicblock': "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
block = ResNetBasicblock depth, layer_blocks
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' )
layer_blocks = (depth - 2) // 6 )
elif block_name == 'ResNetBottleneck': self.num_classes = num_classes
block = ResNetBottleneck self.layers = nn.ModuleList(
assert (depth - 2) % 9 == 0, 'depth should be one of 164' [
layer_blocks = (depth - 2) // 9 ConvBNReLU(
else: 3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
raise ValueError('invalid block : {:}'.format(block_name)) )
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks) ]
)
self.channels = [16]
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2 ** stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append(module.out_dim)
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
planes,
module.out_dim,
stride,
)
if iL + 1 == xblocks[stage]: # reach the maximum depth
break
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) self.avgpool = nn.AvgPool2d(8)
self.num_classes = num_classes self.classifier = nn.Linear(self.channels[-1], num_classes)
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.channels = [16]
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, planes, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.channels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self): self.apply(initialize_resnet)
return self.message if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def forward(self, inputs): def get_message(self):
x = inputs return self.message
for i, layer in enumerate(self.layers):
x = layer( x ) def forward(self, inputs):
features = self.avgpool(x) x = inputs
features = features.view(features.size(0), -1) for i, layer in enumerate(self.layers):
logits = self.classifier(features) x = layer(x)
return features, logits features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -7,154 +7,271 @@ from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):
def __init__(
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
super(ConvBNReLU, self).__init__() ):
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) super(ConvBNReLU, self).__init__()
else : self.avg = None if has_avg:
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
if has_bn : self.bn = nn.BatchNorm2d(nOut) else:
else : self.bn = None self.avg = None
if has_relu: self.relu = nn.ReLU(inplace=True) self.conv = nn.Conv2d(
else : self.relu = None nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
if has_bn:
self.bn = nn.BatchNorm2d(nOut)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
def forward(self, inputs): def forward(self, inputs):
if self.avg : out = self.avg( inputs ) if self.avg:
else : out = inputs out = self.avg(inputs)
conv = self.conv( out ) else:
if self.bn : out = self.bn( conv ) out = inputs
else : out = conv conv = self.conv(out)
if self.relu: out = self.relu( out ) if self.bn:
else : out = out out = self.bn(conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
return out return out
class ResNetBasicblock(nn.Module): class ResNetBasicblock(nn.Module):
num_conv = 2 num_conv = 2
expansion = 1 expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs): def __init__(self, iCs, stride):
basicblock = self.conv_a(inputs) super(ResNetBasicblock, self).__init__()
basicblock = self.conv_b(basicblock) assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
if self.downsample is not None: self.conv_a = ConvBNReLU(
residual = self.downsample(inputs) iCs[0],
else: iCs[1],
residual = inputs 3,
out = residual + basicblock stride,
return F.relu(out, inplace=True) 1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module): class ResNetBottleneck(nn.Module):
expansion = 4 expansion = 4
num_conv = 3 num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs): def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
self.conv_1x1 = ConvBNReLU(
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
iCs[1],
iCs[2],
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=False,
has_bn=False,
has_relu=False,
)
residual_in = iCs[3]
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
bottleneck = self.conv_1x1(inputs) def forward(self, inputs):
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: bottleneck = self.conv_1x1(inputs)
residual = self.downsample(inputs) bottleneck = self.conv_3x3(bottleneck)
else: bottleneck = self.conv_1x4(bottleneck)
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferWidthCifarResNet(nn.Module): class InferWidthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
super(InferWidthCifarResNet, self).__init__()
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual): # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
super(InferWidthCifarResNet, self).__init__() if block_name == "ResNetBasicblock":
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 2) // 6
elif block_name == "ResNetBottleneck":
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, "depth should be one of 164"
layer_blocks = (depth - 2) // 9
else:
raise ValueError("invalid block : {:}".format(block_name))
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model self.message = (
if block_name == 'ResNetBasicblock': "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
block = ResNetBasicblock depth, layer_blocks
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' )
layer_blocks = (depth - 2) // 6 )
elif block_name == 'ResNetBottleneck': self.num_classes = num_classes
block = ResNetBottleneck self.xchannels = xchannels
assert (depth - 2) % 9 == 0, 'depth should be one of 164' self.layers = nn.ModuleList(
layer_blocks = (depth - 2) // 9 [
else: ConvBNReLU(
raise ValueError('invalid block : {:}'.format(block_name)) xchannels[0],
xchannels[1],
3,
1,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
]
)
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iCs,
module.out_dim,
stride,
)
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) self.avgpool = nn.AvgPool2d(8)
self.num_classes = num_classes self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self): self.apply(initialize_resnet)
return self.message if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def forward(self, inputs): def get_message(self):
x = inputs return self.message
for i, layer in enumerate(self.layers):
x = layer( x ) def forward(self, inputs):
features = self.avgpool(x) x = inputs
features = features.view(features.size(0), -1) for i, layer in enumerate(self.layers):
logits = self.classifier(features) x = layer(x)
return features, logits features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -7,164 +7,318 @@ from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs): num_conv = 1
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out def __init__(
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
):
super(ConvBNReLU, self).__init__()
if has_avg:
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.avg = None
self.conv = nn.Conv2d(
nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
if has_bn:
self.bn = nn.BatchNorm2d(nOut)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
def forward(self, inputs):
if self.avg:
out = self.avg(inputs)
else:
out = inputs
conv = self.conv(out)
if self.bn:
out = self.bn(conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
return out
class ResNetBasicblock(nn.Module): class ResNetBasicblock(nn.Module):
num_conv = 2 num_conv = 2
expansion = 1 expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs): def __init__(self, iCs, stride):
basicblock = self.conv_a(inputs) super(ResNetBasicblock, self).__init__()
basicblock = self.conv_b(basicblock) assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
if self.downsample is not None: self.conv_a = ConvBNReLU(
residual = self.downsample(inputs) iCs[0],
else: iCs[1],
residual = inputs 3,
out = residual + basicblock stride,
return F.relu(out, inplace=True) 1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=True,
has_bn=True,
has_relu=False,
)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module): class ResNetBottleneck(nn.Module):
expansion = 4 expansion = 4
num_conv = 3 num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=True, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs): def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
self.conv_1x1 = ConvBNReLU(
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
iCs[1],
iCs[2],
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=True,
has_bn=True,
has_relu=False,
)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
residual_in = iCs[3]
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
bottleneck = self.conv_1x1(inputs) def forward(self, inputs):
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: bottleneck = self.conv_1x1(inputs)
residual = self.downsample(inputs) bottleneck = self.conv_3x3(bottleneck)
else: bottleneck = self.conv_1x4(bottleneck)
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferImagenetResNet(nn.Module): class InferImagenetResNet(nn.Module):
def __init__(
self,
block_name,
layers,
xblocks,
xchannels,
deep_stem,
num_classes,
zero_init_residual,
):
super(InferImagenetResNet, self).__init__()
def __init__(self, block_name, layers, xblocks, xchannels, deep_stem, num_classes, zero_init_residual): # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
super(InferImagenetResNet, self).__init__() if block_name == "BasicBlock":
block = ResNetBasicblock
elif block_name == "Bottleneck":
block = ResNetBottleneck
else:
raise ValueError("invalid block : {:}".format(block_name))
assert len(xblocks) == len(
layers
), "invalid layers : {:} vs xblocks : {:}".format(layers, xblocks)
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model self.message = "InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}".format(
if block_name == 'BasicBlock': sum(layers) * block.num_conv, sum(xblocks) * block.num_conv, xblocks
block = ResNetBasicblock )
elif block_name == 'Bottleneck': self.num_classes = num_classes
block = ResNetBottleneck self.xchannels = xchannels
else: if not deep_stem:
raise ValueError('invalid block : {:}'.format(block_name)) self.layers = nn.ModuleList(
assert len(xblocks) == len(layers), 'invalid layers : {:} vs xblocks : {:}'.format(layers, xblocks) [
ConvBNReLU(
xchannels[0],
xchannels[1],
7,
2,
3,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
]
)
last_channel_idx = 1
else:
self.layers = nn.ModuleList(
[
ConvBNReLU(
xchannels[0],
xchannels[1],
3,
2,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
),
ConvBNReLU(
xchannels[1],
xchannels[2],
3,
1,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
),
]
)
last_channel_idx = 2
self.layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
for stage, layer_blocks in enumerate(layers):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iCs,
module.out_dim,
stride,
)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL + 1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
assert last_channel_idx + 1 == len(self.xchannels), "{:} vs {:}".format(
last_channel_idx, len(self.xchannels)
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.message = 'InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}'.format(sum(layers)*block.num_conv, sum(xblocks)*block.num_conv, xblocks) self.apply(initialize_resnet)
self.num_classes = num_classes if zero_init_residual:
self.xchannels = xchannels for m in self.modules():
if not deep_stem: if isinstance(m, ResNetBasicblock):
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True) ] ) nn.init.constant_(m.conv_b.bn.weight, 0)
last_channel_idx = 1 elif isinstance(m, ResNetBottleneck):
else: nn.init.constant_(m.conv_1x4.bn.weight, 0)
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True)
,ConvBNReLU(xchannels[1], xchannels[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 2
self.layers.append( nn.MaxPool2d(kernel_size=3, stride=2, padding=1) )
for stage, layer_blocks in enumerate(layers):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
assert last_channel_idx + 1 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self): def get_message(self):
return self.message return self.message
def forward(self, inputs): def forward(self, inputs):
x = inputs x = inputs
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
x = layer( x ) x = layer(x)
features = self.avgpool(x) features = self.avgpool(x)
features = features.view(features.size(0), -1) features = features.view(features.size(0), -1)
logits = self.classifier(features) logits = self.classifier(features)
return features, logits return features, logits

View File

@@ -4,119 +4,171 @@
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
from torch import nn from torch import nn
from ..initialization import initialize_resnet from ..initialization import initialize_resnet
from ..SharedUtils import parse_channel_info from ..SharedUtils import parse_channel_info
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, groups, has_bn=True, has_relu=True): def __init__(
super(ConvBNReLU, self).__init__() self,
padding = (kernel_size - 1) // 2 in_planes,
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False) out_planes,
if has_bn: self.bn = nn.BatchNorm2d(out_planes) kernel_size,
else : self.bn = None stride,
if has_relu: self.relu = nn.ReLU6(inplace=True) groups,
else : self.relu = None has_bn=True,
has_relu=True,
def forward(self, x): ):
out = self.conv( x ) super(ConvBNReLU, self).__init__()
if self.bn: out = self.bn ( out ) padding = (kernel_size - 1) // 2
if self.relu: out = self.relu( out ) self.conv = nn.Conv2d(
return out in_planes,
out_planes,
kernel_size,
stride,
padding,
groups=groups,
bias=False,
)
if has_bn:
self.bn = nn.BatchNorm2d(out_planes)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU6(inplace=True)
else:
self.relu = None
def forward(self, x):
out = self.conv(x)
if self.bn:
out = self.bn(out)
if self.relu:
out = self.relu(out)
return out
class InvertedResidual(nn.Module): class InvertedResidual(nn.Module):
def __init__(self, channels, stride, expand_ratio, additive): def __init__(self, channels, stride, expand_ratio, additive):
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
self.stride = stride self.stride = stride
assert stride in [1, 2], 'invalid stride : {:}'.format(stride) assert stride in [1, 2], "invalid stride : {:}".format(stride)
assert len(channels) in [2, 3], 'invalid channels : {:}'.format(channels) assert len(channels) in [2, 3], "invalid channels : {:}".format(channels)
if len(channels) == 2: if len(channels) == 2:
layers = [] layers = []
else: else:
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)] layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
layers.extend([ layers.extend(
# dw [
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]), # dw
# pw-linear ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False), # pw-linear
]) ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
self.conv = nn.Sequential(*layers) ]
self.additive = additive )
if self.additive and channels[0] != channels[-1]: self.conv = nn.Sequential(*layers)
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False) self.additive = additive
else: if self.additive and channels[0] != channels[-1]:
self.shortcut = None self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
self.out_dim = channels[-1] else:
self.shortcut = None
self.out_dim = channels[-1]
def forward(self, x): def forward(self, x):
out = self.conv(x) out = self.conv(x)
# if self.additive: return additive_func(out, x) # if self.additive: return additive_func(out, x)
if self.shortcut: return out + self.shortcut(x) if self.shortcut:
else : return out return out + self.shortcut(x)
else:
return out
class InferMobileNetV2(nn.Module): class InferMobileNetV2(nn.Module):
def __init__(self, num_classes, xchannels, xblocks, dropout): def __init__(self, num_classes, xchannels, xblocks, dropout):
super(InferMobileNetV2, self).__init__() super(InferMobileNetV2, self).__init__()
block = InvertedResidual block = InvertedResidual
inverted_residual_setting = [ inverted_residual_setting = [
# t, c, n, s # t, c, n, s
[1, 16 , 1, 1], [1, 16, 1, 1],
[6, 24 , 2, 2], [6, 24, 2, 2],
[6, 32 , 3, 2], [6, 32, 3, 2],
[6, 64 , 4, 2], [6, 64, 4, 2],
[6, 96 , 3, 1], [6, 96, 3, 1],
[6, 160, 3, 2], [6, 160, 3, 2],
[6, 320, 1, 1], [6, 320, 1, 1],
] ]
assert len(inverted_residual_setting) == len(xblocks), 'invalid number of layers : {:} vs {:}'.format(len(inverted_residual_setting), len(xblocks)) assert len(inverted_residual_setting) == len(
for block_num, ir_setting in zip(xblocks, inverted_residual_setting): xblocks
assert block_num <= ir_setting[2], '{:} vs {:}'.format(block_num, ir_setting) ), "invalid number of layers : {:} vs {:}".format(
xchannels = parse_channel_info(xchannels) len(inverted_residual_setting), len(xblocks)
#for i, chs in enumerate(xchannels): )
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs) for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
self.xchannels = xchannels assert block_num <= ir_setting[2], "{:} vs {:}".format(
self.message = 'InferMobileNetV2 : xblocks={:}'.format(xblocks) block_num, ir_setting
# building first layer )
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)] xchannels = parse_channel_info(xchannels)
last_channel_idx = 1 # for i, chs in enumerate(xchannels):
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
self.xchannels = xchannels
self.message = "InferMobileNetV2 : xblocks={:}".format(xblocks)
# building first layer
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
last_channel_idx = 1
# building inverted residual blocks # building inverted residual blocks
for stage, (t, c, n, s) in enumerate(inverted_residual_setting): for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
for i in range(n): for i in range(n):
stride = s if i == 0 else 1 stride = s if i == 0 else 1
additv = True if i > 0 else False additv = True if i > 0 else False
module = block(self.xchannels[last_channel_idx], stride, t, additv) module = block(self.xchannels[last_channel_idx], stride, t, additv)
features.append(module) features.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(stage, i, n, len(features), self.xchannels[last_channel_idx], stride, t, c) self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(
last_channel_idx += 1 stage,
if i + 1 == xblocks[stage]: i,
out_channel = module.out_dim n,
for iiL in range(i+1, n): len(features),
last_channel_idx += 1 self.xchannels[last_channel_idx],
self.xchannels[last_channel_idx][0] = module.out_dim stride,
break t,
# building last several layers c,
features.append(ConvBNReLU(self.xchannels[last_channel_idx][0], self.xchannels[last_channel_idx][1], 1, 1, 1)) )
assert last_channel_idx + 2 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels)) last_channel_idx += 1
# make it nn.Sequential if i + 1 == xblocks[stage]:
self.features = nn.Sequential(*features) out_channel = module.out_dim
for iiL in range(i + 1, n):
last_channel_idx += 1
self.xchannels[last_channel_idx][0] = module.out_dim
break
# building last several layers
features.append(
ConvBNReLU(
self.xchannels[last_channel_idx][0],
self.xchannels[last_channel_idx][1],
1,
1,
1,
)
)
assert last_channel_idx + 2 == len(self.xchannels), "{:} vs {:}".format(
last_channel_idx, len(self.xchannels)
)
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier # building classifier
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(self.xchannels[last_channel_idx][1], num_classes), nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
) )
# weight initialization # weight initialization
self.apply( initialize_resnet ) self.apply(initialize_resnet)
def get_message(self): def get_message(self):
return self.message return self.message
def forward(self, inputs): def forward(self, inputs):
features = self.features(inputs) features = self.features(inputs)
vectors = features.mean([2, 3]) vectors = features.mean([2, 3])
predicts = self.classifier(vectors) predicts = self.classifier(vectors)
return features, predicts return features, predicts

View File

@@ -8,51 +8,57 @@ from models.cell_infers.cells import InferCell
class DynamicShapeTinyNet(nn.Module): class DynamicShapeTinyNet(nn.Module):
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
super(DynamicShapeTinyNet, self).__init__()
self._channels = channels
if len(channels) % 3 != 2:
raise ValueError("invalid number of layers : {:}".format(len(channels)))
self._num_stage = N = len(channels) // 3
def __init__(self, channels: List[int], genotype: Any, num_classes: int): self.stem = nn.Sequential(
super(DynamicShapeTinyNet, self).__init__() nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
self._channels = channels nn.BatchNorm2d(channels[0]),
if len(channels) % 3 != 2: )
raise ValueError('invalid number of layers : {:}'.format(len(channels)))
self._num_stage = N = len(channels) // 3
self.stem = nn.Sequential( # layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False), layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
nn.BatchNorm2d(channels[0]))
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N c_prev = channels[0]
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N self.cells = nn.ModuleList()
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(c_prev, c_curr, 2, True)
else:
cell = InferCell(genotype, c_prev, c_curr, 1)
self.cells.append(cell)
c_prev = cell.out_dim
self._num_layer = len(self.cells)
c_prev = channels[0] self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
self.cells = nn.ModuleList() self.global_pooling = nn.AdaptiveAvgPool2d(1)
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)): self.classifier = nn.Linear(c_prev, num_classes)
if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True)
else : cell = InferCell(genotype, c_prev, c_curr, 1)
self.cells.append( cell )
c_prev = cell.out_dim
self._num_layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True)) def get_message(self) -> Text:
self.global_pooling = nn.AdaptiveAvgPool2d(1) string = self.extra_repr()
self.classifier = nn.Linear(c_prev, num_classes) for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def get_message(self) -> Text: def extra_repr(self):
string = self.extra_repr() return "{name}(C={_channels}, N={_num_stage}, L={_num_layer})".format(
for i, cell in enumerate(self.cells): name=self.__class__.__name__, **self.__dict__
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) )
return string
def extra_repr(self): def forward(self, inputs):
return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__)) feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
def forward(self, inputs): out = self.lastact(feature)
feature = self.stem(inputs) out = self.global_pooling(out)
for i, cell in enumerate(self.cells): out = out.view(out.size(0), -1)
feature = cell(feature) logits = self.classifier(out)
out = self.lastact(feature) return out, logits
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -6,4 +6,4 @@ from .InferImagenetResNet import InferImagenetResNet
from .InferCifarResNet_depth import InferDepthCifarResNet from .InferCifarResNet_depth import InferDepthCifarResNet
from .InferCifarResNet import InferCifarResNet from .InferCifarResNet import InferCifarResNet
from .InferMobileNetV2 import InferMobileNetV2 from .InferMobileNetV2 import InferMobileNetV2
from .InferTinyCellNet import DynamicShapeTinyNet from .InferTinyCellNet import DynamicShapeTinyNet

View File

@@ -1,5 +1,5 @@
def parse_channel_info(xstring): def parse_channel_info(xstring):
blocks = xstring.split(' ') blocks = xstring.split(" ")
blocks = [x.split('-') for x in blocks] blocks = [x.split("-") for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks] blocks = [[int(_) for _ in x] for x in blocks]
return blocks return blocks

File diff suppressed because it is too large Load Diff

View File

@@ -6,335 +6,510 @@ from collections import OrderedDict
from bisect import bisect_right from bisect import bisect_right
import torch.nn as nn import torch.nn as nn
from ..initialization import initialize_resnet from ..initialization import initialize_resnet
from ..SharedUtils import additive_func from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices from .SoftSelect import get_width_choices
def get_depth_choices(nDepth, return_num): def get_depth_choices(nDepth, return_num):
if nDepth == 2: if nDepth == 2:
choices = (1, 2) choices = (1, 2)
elif nDepth == 3: elif nDepth == 3:
choices = (1, 2, 3) choices = (1, 2, 3)
elif nDepth > 3: elif nDepth > 3:
choices = list(range(1, nDepth+1, 2)) choices = list(range(1, nDepth + 1, 2))
if choices[-1] < nDepth: choices.append(nDepth) if choices[-1] < nDepth:
else: choices.append(nDepth)
raise ValueError('invalid nDepth : {:}'.format(nDepth)) else:
if return_num: return len(choices) raise ValueError("invalid nDepth : {:}".format(nDepth))
else : return choices if return_num:
return len(choices)
else:
return choices
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):
num_conv = 1 num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_width_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) def __init__(
else : self.avg = None self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) ):
if has_bn : self.bn = nn.BatchNorm2d(nOut) super(ConvBNReLU, self).__init__()
else : self.bn = None self.InShape = None
if has_relu: self.relu = nn.ReLU(inplace=False) self.OutShape = None
else : self.relu = None self.choices = get_width_choices(nOut)
self.in_dim = nIn self.register_buffer("choices_tensor", torch.Tensor(self.choices))
self.out_dim = nOut
def get_flops(self, divide=1): if has_avg:
iC, oC = self.in_dim, self.out_dim self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels) else:
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape) self.avg = None
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape) self.conv = nn.Conv2d(
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups nIn,
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups) nOut,
all_positions = self.OutShape[0] * self.OutShape[1] kernel_size=kernel,
flops = (conv_per_position_flops * all_positions / divide) * iC * oC stride=stride,
if self.conv.bias is not None: flops += all_positions / divide padding=padding,
return flops dilation=1,
groups=1,
bias=bias,
)
if has_bn:
self.bn = nn.BatchNorm2d(nOut)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU(inplace=False)
else:
self.relu = None
self.in_dim = nIn
self.out_dim = nOut
def forward(self, inputs): def get_flops(self, divide=1):
if self.avg : out = self.avg( inputs ) iC, oC = self.in_dim, self.out_dim
else : out = inputs assert (
conv = self.conv( out ) iC <= self.conv.in_channels and oC <= self.conv.out_channels
if self.bn : out = self.bn( conv ) ), "{:} vs {:} | {:} vs {:}".format(
else : out = conv iC, self.conv.in_channels, oC, self.conv.out_channels
if self.relu: out = self.relu( out ) )
else : out = out assert (
if self.InShape is None: isinstance(self.InShape, tuple) and len(self.InShape) == 2
self.InShape = (inputs.size(-2), inputs.size(-1)) ), "invalid in-shape : {:}".format(self.InShape)
self.OutShape = (out.size(-2) , out.size(-1)) assert (
return out isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
), "invalid out-shape : {:}".format(self.OutShape)
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None:
flops += all_positions / divide
return flops
def forward(self, inputs):
if self.avg:
out = self.avg(inputs)
else:
out = inputs
conv = self.conv(out)
if self.bn:
out = self.bn(conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2), out.size(-1))
return out
class ResNetBasicblock(nn.Module): class ResNetBasicblock(nn.Module):
expansion = 1 expansion = 1
num_conv = 2 num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_flops(self, divide=1): def __init__(self, inplanes, planes, stride):
flop_A = self.conv_a.get_flops(divide) super(ResNetBasicblock, self).__init__()
flop_B = self.conv_b.get_flops(divide) assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
if hasattr(self.downsample, 'get_flops'): self.conv_a = ConvBNReLU(
flop_C = self.downsample.get_flops(divide) inplanes,
else: planes,
flop_C = 0 3,
return flop_A + flop_B + flop_C stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = "basic"
def forward(self, inputs): def get_flops(self, divide=1):
basicblock = self.conv_a(inputs) flop_A = self.conv_a.get_flops(divide)
basicblock = self.conv_b(basicblock) flop_B = self.conv_b.get_flops(divide)
if self.downsample is not None: residual = self.downsample(inputs) if hasattr(self.downsample, "get_flops"):
else : residual = inputs flop_C = self.downsample.get_flops(divide)
out = additive_func(residual, basicblock) else:
return nn.functional.relu(out, inplace=True) flop_C = 0
return flop_A + flop_B + flop_C
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module): class ResNetBottleneck(nn.Module):
expansion = 4 expansion = 4
num_conv = 3 num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = 'basic'
def get_range(self): def __init__(self, inplanes, planes, stride):
return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range() super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_1x1 = ConvBNReLU(
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
planes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
planes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes * self.expansion:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = "basic"
def get_flops(self, divide): def get_range(self):
flop_A = self.conv_1x1.get_flops(divide) return (
flop_B = self.conv_3x3.get_flops(divide) self.conv_1x1.get_range()
flop_C = self.conv_1x4.get_flops(divide) + self.conv_3x3.get_range()
if hasattr(self.downsample, 'get_flops'): + self.conv_1x4.get_range()
flop_D = self.downsample.get_flops(divide) )
else:
flop_D = 0
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs): def get_flops(self, divide):
bottleneck = self.conv_1x1(inputs) flop_A = self.conv_1x1.get_flops(divide)
bottleneck = self.conv_3x3(bottleneck) flop_B = self.conv_3x3.get_flops(divide)
bottleneck = self.conv_1x4(bottleneck) flop_C = self.conv_1x4.get_flops(divide)
if self.downsample is not None: residual = self.downsample(inputs) if hasattr(self.downsample, "get_flops"):
else : residual = inputs flop_D = self.downsample.get_flops(divide)
out = additive_func(residual, bottleneck) else:
return nn.functional.relu(out, inplace=True) flop_D = 0
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
class SearchDepthCifarResNet(nn.Module): class SearchDepthCifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes):
super(SearchDepthCifarResNet, self).__init__()
def __init__(self, block_name, depth, num_classes): # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
super(SearchDepthCifarResNet, self).__init__() if block_name == "ResNetBasicblock":
block = ResNetBasicblock
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
if block_name == 'ResNetBasicblock': layer_blocks = (depth - 2) // 6
block = ResNetBasicblock elif block_name == "ResNetBottleneck":
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' block = ResNetBottleneck
layer_blocks = (depth - 2) // 6 assert (depth - 2) % 9 == 0, "depth should be one of 164"
elif block_name == 'ResNetBottleneck': layer_blocks = (depth - 2) // 9
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.InShape = None
self.depth_info = OrderedDict()
self.depth_at_i = OrderedDict()
for stage in range(3):
cur_block_choices = get_depth_choices(layer_blocks, False)
assert cur_block_choices[-1] == layer_blocks, 'stage={:}, {:} vs {:}'.format(stage, cur_block_choices, layer_blocks)
self.message += "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(stage, cur_block_choices, layer_blocks)
block_choices, xstart = [], len(self.layers)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
# added for depth
layer_index = len(self.layers) - 1
if iL + 1 in cur_block_choices: block_choices.append( layer_index )
if iL + 1 == layer_blocks:
self.depth_info[layer_index] = {'choices': block_choices,
'stage' : stage,
'xstart' : xstart}
self.depth_info_list = []
for xend, info in self.depth_info.items():
self.depth_info_list.append( (xend, info) )
xstart, xstage = info['xstart'], info['stage']
for ilayer in range(xstart, xend+1):
idx = bisect_right(info['choices'], ilayer-1)
self.depth_at_i[ilayer] = (xstage, idx)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
self.register_parameter('depth_attentions', nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))))
nn.init.normal_(self.depth_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self):
return [self.depth_attentions]
def base_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None: config_dict = config_dict.copy()
# select depth
if mode == 'genotype':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
elif mode == 'max':
choices = [depth_probs.size(1)-1 for _ in range(depth_probs.size(0))]
elif mode == 'random':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
else:
raise ValueError('invalid mode : {:}'.format(mode))
selected_layers = []
for choice, xvalue in zip(choices, self.depth_info_list):
xtemp = xvalue[1]['choices'][choice] - xvalue[1]['xstart'] + 1
selected_layers.append(xtemp)
flop = 0
for i, layer in enumerate(self.layers):
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
if xatti <= choices[xstagei]: # leave this depth
flop+= layer.get_flops()
else: else:
flop+= 0 # do not use this layer raise ValueError("invalid block : {:}".format(block_name))
else:
flop+= layer.get_flops()
# the last fc layer
flop += self.classifier.in_features * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xblocks'] = selected_layers
config_dict['super_type'] = 'infer-depth'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self): self.message = (
string = "for depth, there are {:} attention probabilities.".format(len(self.depth_attentions)) "SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format(
string+= '\n{:}'.format(self.depth_info) depth, layer_blocks
discrepancy = [] )
with torch.no_grad(): )
for i, att in enumerate(self.depth_attentions): self.num_classes = num_classes
prob = nn.functional.softmax(att, dim=0) self.channels = [16]
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist() self.layers = nn.ModuleList(
prob = ['{:.3f}'.format(x) for x in prob] [
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.depth_attentions), ' '.join(prob)) ConvBNReLU(
logt = ['{:.4f}'.format(x) for x in att.cpu().tolist()] 3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
xstring += ' || {:17s}'.format(' '.join(logt)) )
prob = sorted( [float(x) for x in prob] ) ]
disc = prob[-1] - prob[-2] )
xstring += ' || discrepancy={:.2f} || select={:}/{:}'.format(disc, selc, len(prob)) self.InShape = None
discrepancy.append( disc ) self.depth_info = OrderedDict()
string += '\n{:}'.format(xstring) self.depth_at_i = OrderedDict()
return string, discrepancy for stage in range(3):
cur_block_choices = get_depth_choices(layer_blocks, False)
assert (
cur_block_choices[-1] == layer_blocks
), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks)
self.message += (
"\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(
stage, cur_block_choices, layer_blocks
)
)
block_choices, xstart = [], len(self.layers)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2 ** stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append(module.out_dim)
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iC,
module.out_dim,
stride,
)
# added for depth
layer_index = len(self.layers) - 1
if iL + 1 in cur_block_choices:
block_choices.append(layer_index)
if iL + 1 == layer_blocks:
self.depth_info[layer_index] = {
"choices": block_choices,
"stage": stage,
"xstart": xstart,
}
self.depth_info_list = []
for xend, info in self.depth_info.items():
self.depth_info_list.append((xend, info))
xstart, xstage = info["xstart"], info["stage"]
for ilayer in range(xstart, xend + 1):
idx = bisect_right(info["choices"], ilayer - 1)
self.depth_at_i[ilayer] = (xstage, idx)
def set_tau(self, tau_max, tau_min, epoch_ratio): self.avgpool = nn.AvgPool2d(8)
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio) self.classifier = nn.Linear(module.out_dim, num_classes)
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 self.InShape = None
self.tau = tau self.tau = -1
self.search_mode = "basic"
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
def get_message(self): self.register_parameter(
return self.message "depth_attentions",
nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))),
)
nn.init.normal_(self.depth_attentions, 0, 0.01)
self.apply(initialize_resnet)
def forward(self, inputs): def arch_parameters(self):
if self.search_mode == 'basic': return [self.depth_attentions]
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs): def base_parameters(self):
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) return (
flop_depth_probs = torch.flip( torch.cumsum( torch.flip(flop_depth_probs, [1]), 1 ), [1] ) list(self.layers.parameters())
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) + list(self.avgpool.parameters())
+ list(self.classifier.parameters())
)
x, flops = inputs, [] def get_flop(self, mode, config_dict, extra_info):
feature_maps = [] if config_dict is not None:
for i, layer in enumerate(self.layers): config_dict = config_dict.copy()
layer_i = layer( x ) # select depth
feature_maps.append( layer_i ) if mode == "genotype":
if i in self.depth_info: # aggregate the information with torch.no_grad():
choices = self.depth_info[i]['choices'] depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
xstagei = self.depth_info[i]['stage'] choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
possible_tensors = [] elif mode == "max":
for tempi, A in enumerate(choices): choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))]
xtensor = feature_maps[A] elif mode == "random":
possible_tensors.append( xtensor ) with torch.no_grad():
weighted_sum = sum( xtensor * W for xtensor, W in zip(possible_tensors, selected_depth_probs[xstagei]) ) depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
x = weighted_sum choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
else: else:
x = layer_i raise ValueError("invalid mode : {:}".format(mode))
selected_layers = []
if i in self.depth_at_i: for choice, xvalue in zip(choices, self.depth_info_list):
xstagei, xatti = self.depth_at_i[i] xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1
#print ('layer-{:03d}, stage={:}, att={:}, prob={:}, flop={:}'.format(i, xstagei, xatti, flop_depth_probs[xstagei, xatti].item(), layer.get_flops(1e6))) selected_layers.append(xtemp)
x_expected_flop = flop_depth_probs[xstagei, xatti] * layer.get_flops(1e6) flop = 0
else: for i, layer in enumerate(self.layers):
x_expected_flop = layer.get_flops(1e6) if i in self.depth_at_i:
flops.append( x_expected_flop ) xstagei, xatti = self.depth_at_i[i]
flops.append( (self.classifier.in_features * self.classifier.out_features*1.0/1e6) ) if xatti <= choices[xstagei]: # leave this depth
flop += layer.get_flops()
else:
flop += 0 # do not use this layer
else:
flop += layer.get_flops()
# the last fc layer
flop += self.classifier.in_features * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict["xblocks"] = selected_layers
config_dict["super_type"] = "infer-depth"
config_dict["estimated_FLOP"] = flop / 1e6
return flop / 1e6, config_dict
features = self.avgpool(x) def get_arch_info(self):
features = features.view(features.size(0), -1) string = "for depth, there are {:} attention probabilities.".format(
logits = linear_forward(features, self.classifier) len(self.depth_attentions)
return logits, torch.stack( [sum(flops)] ) )
string += "\n{:}".format(self.depth_info)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.depth_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu()
selc = prob.argmax().item()
prob = prob.tolist()
prob = ["{:.3f}".format(x) for x in prob]
xstring = "{:03d}/{:03d}-th : {:}".format(
i, len(self.depth_attentions), " ".join(prob)
)
logt = ["{:.4f}".format(x) for x in att.cpu().tolist()]
xstring += " || {:17s}".format(" ".join(logt))
prob = sorted([float(x) for x in prob])
disc = prob[-1] - prob[-2]
xstring += " || discrepancy={:.2f} || select={:}/{:}".format(
disc, selc, len(prob)
)
discrepancy.append(disc)
string += "\n{:}".format(xstring)
return string, discrepancy
def basic_forward(self, inputs): def set_tau(self, tau_max, tau_min, epoch_ratio):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1)) assert (
x = inputs epoch_ratio >= 0 and epoch_ratio <= 1
for i, layer in enumerate(self.layers): ), "invalid epoch-ratio : {:}".format(epoch_ratio)
x = layer( x ) tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
features = self.avgpool(x) self.tau = tau
features = features.view(features.size(0), -1)
logits = self.classifier(features) def get_message(self):
return features, logits return self.message
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, inputs):
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
flop_depth_probs = torch.flip(
torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1]
)
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
x, flops = inputs, []
feature_maps = []
for i, layer in enumerate(self.layers):
layer_i = layer(x)
feature_maps.append(layer_i)
if i in self.depth_info: # aggregate the information
choices = self.depth_info[i]["choices"]
xstagei = self.depth_info[i]["stage"]
possible_tensors = []
for tempi, A in enumerate(choices):
xtensor = feature_maps[A]
possible_tensors.append(xtensor)
weighted_sum = sum(
xtensor * W
for xtensor, W in zip(
possible_tensors, selected_depth_probs[xstagei]
)
)
x = weighted_sum
else:
x = layer_i
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
# print ('layer-{:03d}, stage={:}, att={:}, prob={:}, flop={:}'.format(i, xstagei, xatti, flop_depth_probs[xstagei, xatti].item(), layer.get_flops(1e6)))
x_expected_flop = flop_depth_probs[xstagei, xatti] * layer.get_flops(
1e6
)
else:
x_expected_flop = layer.get_flops(1e6)
flops.append(x_expected_flop)
flops.append(
(self.classifier.in_features * self.classifier.out_features * 1.0 / 1e6)
)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack([sum(flops)])
def basic_forward(self, inputs):
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -4,390 +4,616 @@
import math, torch import math, torch
import torch.nn as nn import torch.nn as nn
from ..initialization import initialize_resnet from ..initialization import initialize_resnet
from ..SharedUtils import additive_func from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices as get_choices from .SoftSelect import get_width_choices as get_choices
def conv_forward(inputs, conv, choices): def conv_forward(inputs, conv, choices):
iC = conv.in_channels iC = conv.in_channels
fill_size = list(inputs.size()) fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1] fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device) filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1) xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs) outputs = conv(xinputs)
selecteds = [outputs[:,:oC] for oC in choices] selecteds = [outputs[:, :oC] for oC in choices]
return selecteds return selecteds
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):
num_conv = 1 num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) def __init__(
else : self.avg = None self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) ):
#if has_bn : self.bn = nn.BatchNorm2d(nOut) super(ConvBNReLU, self).__init__()
#else : self.bn = None self.InShape = None
self.has_bn = has_bn self.OutShape = None
self.BNs = nn.ModuleList() self.choices = get_choices(nOut)
for i, _out in enumerate(self.choices): self.register_buffer("choices_tensor", torch.Tensor(self.choices))
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = 'basic'
def get_flops(self, channels, check_range=True, divide=1): if has_avg:
iC, oC = channels self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels) else:
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape) self.avg = None
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape) self.conv = nn.Conv2d(
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups nIn,
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups) nOut,
all_positions = self.OutShape[0] * self.OutShape[1] kernel_size=kernel,
flops = (conv_per_position_flops * all_positions / divide) * iC * oC stride=stride,
if self.conv.bias is not None: flops += all_positions / divide padding=padding,
return flops dilation=1,
groups=1,
bias=bias,
)
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
# else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = "basic"
def get_range(self): def get_flops(self, channels, check_range=True, divide=1):
return [self.choices] iC, oC = channels
if check_range:
assert (
iC <= self.conv.in_channels and oC <= self.conv.out_channels
), "{:} vs {:} | {:} vs {:}".format(
iC, self.conv.in_channels, oC, self.conv.out_channels
)
assert (
isinstance(self.InShape, tuple) and len(self.InShape) == 2
), "invalid in-shape : {:}".format(self.InShape)
assert (
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
), "invalid out-shape : {:}".format(self.OutShape)
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None:
flops += all_positions / divide
return flops
def forward(self, inputs): def get_range(self):
if self.search_mode == 'basic': return [self.choices]
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs): def forward(self, inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) if self.search_mode == "basic":
inputs, expected_inC, probability, index, prob = tuple_inputs return self.basic_forward(inputs)
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) elif self.search_mode == "search":
probability = torch.squeeze(probability) return self.search_forward(inputs)
assert len(index) == 2, 'invalid length : {:}'.format(index) else:
# compute expected flop raise ValueError("invalid search_mode = {:}".format(self.search_mode))
#coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg : out = self.avg( inputs )
else : out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
#out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu: out = self.relu( out ) def search_forward(self, tuple_inputs):
else : out = out assert (
return out, expected_outC, expected_flop isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, "invalid length : {:}".format(index)
# compute expected flop
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg:
out = self.avg(inputs)
else:
out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
def basic_forward(self, inputs): if self.relu:
if self.avg : out = self.avg( inputs ) out = self.relu(out)
else : out = inputs else:
conv = self.conv( out ) out = out
if self.has_bn:out= self.BNs[-1]( conv ) return out, expected_outC, expected_flop
else : out = conv
if self.relu: out = self.relu( out ) def basic_forward(self, inputs):
else : out = out if self.avg:
if self.InShape is None: out = self.avg(inputs)
self.InShape = (inputs.size(-2), inputs.size(-1)) else:
self.OutShape = (out.size(-2) , out.size(-1)) out = inputs
return out conv = self.conv(out)
if self.has_bn:
out = self.BNs[-1](conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2), out.size(-1))
return out
class ResNetBasicblock(nn.Module): class ResNetBasicblock(nn.Module):
expansion = 1 expansion = 1
num_conv = 2 num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_range(self): def __init__(self, inplanes, planes, stride):
return self.conv_a.get_range() + self.conv_b.get_range() super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_a = ConvBNReLU(
inplanes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = "basic"
def get_flops(self, channels): def get_range(self):
assert len(channels) == 3, 'invalid channels : {:}'.format(channels) return self.conv_a.get_range() + self.conv_b.get_range()
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_C = channels[0] * channels[-1] * self.conv_b.OutShape[0] * self.conv_b.OutShape[1]
return flop_A + flop_B + flop_C
def forward(self, inputs): def get_flops(self, channels):
if self.search_mode == 'basic' : return self.basic_forward(inputs) assert len(channels) == 3, "invalid channels : {:}".format(channels)
elif self.search_mode == 'search': return self.search_forward(inputs) flop_A = self.conv_a.get_flops([channels[0], channels[1]])
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) flop_B = self.conv_b.get_flops([channels[1], channels[2]])
if hasattr(self.downsample, "get_flops"):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if (
channels[0] != channels[-1] and self.downsample is None
): # this short-cut will be added during the infer-train
flop_C = (
channels[0]
* channels[-1]
* self.conv_b.OutShape[0]
* self.conv_b.OutShape[1]
)
return flop_A + flop_B + flop_C
def search_forward(self, tuple_inputs): def forward(self, inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) if self.search_mode == "basic":
inputs, expected_inC, probability, indexes, probs = tuple_inputs return self.basic_forward(inputs)
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 elif self.search_mode == "search":
out_a, expected_inC_a, expected_flop_a = self.conv_a( (inputs, expected_inC , probability[0], indexes[0], probs[0]) ) return self.search_forward(inputs)
out_b, expected_inC_b, expected_flop_b = self.conv_b( (out_a , expected_inC_a, probability[1], indexes[1], probs[1]) ) else:
if self.downsample is not None: raise ValueError("invalid search_mode = {:}".format(self.search_mode))
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[1], indexes[1], probs[1]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_b)
return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
def basic_forward(self, inputs): def search_forward(self, tuple_inputs):
basicblock = self.conv_a(inputs) assert (
basicblock = self.conv_b(basicblock) isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
if self.downsample is not None: residual = self.downsample(inputs) ), "invalid type input : {:}".format(type(tuple_inputs))
else : residual = inputs inputs, expected_inC, probability, indexes, probs = tuple_inputs
out = additive_func(residual, basicblock) assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
return nn.functional.relu(out, inplace=True) out_a, expected_inC_a, expected_flop_a = self.conv_a(
(inputs, expected_inC, probability[0], indexes[0], probs[0])
)
out_b, expected_inC_b, expected_flop_b = self.conv_b(
(out_a, expected_inC_a, probability[1], indexes[1], probs[1])
)
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample(
(inputs, expected_inC, probability[1], indexes[1], probs[1])
)
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_b)
return (
nn.functional.relu(out, inplace=True),
expected_inC_b,
sum([expected_flop_a, expected_flop_b, expected_flop_c]),
)
def basic_forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module): class ResNetBottleneck(nn.Module):
expansion = 4 expansion = 4
num_conv = 3 num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = 'basic'
def get_range(self): def __init__(self, inplanes, planes, stride):
return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range() super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_1x1 = ConvBNReLU(
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
planes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
planes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes * self.expansion:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = "basic"
def get_flops(self, channels): def get_range(self):
assert len(channels) == 4, 'invalid channels : {:}'.format(channels) return (
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) self.conv_1x1.get_range()
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) + self.conv_3x3.get_range()
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) + self.conv_1x4.get_range()
if hasattr(self.downsample, 'get_flops'): )
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_D = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_D = channels[0] * channels[-1] * self.conv_1x4.OutShape[0] * self.conv_1x4.OutShape[1]
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs): def get_flops(self, channels):
if self.search_mode == 'basic' : return self.basic_forward(inputs) assert len(channels) == 4, "invalid channels : {:}".format(channels)
elif self.search_mode == 'search': return self.search_forward(inputs) flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
if hasattr(self.downsample, "get_flops"):
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_D = 0
if (
channels[0] != channels[-1] and self.downsample is None
): # this short-cut will be added during the infer-train
flop_D = (
channels[0]
* channels[-1]
* self.conv_1x4.OutShape[0]
* self.conv_1x4.OutShape[1]
)
return flop_A + flop_B + flop_C + flop_D
def basic_forward(self, inputs): def forward(self, inputs):
bottleneck = self.conv_1x1(inputs) if self.search_mode == "basic":
bottleneck = self.conv_3x3(bottleneck) return self.basic_forward(inputs)
bottleneck = self.conv_1x4(bottleneck) elif self.search_mode == "search":
if self.downsample is not None: residual = self.downsample(inputs) return self.search_forward(inputs)
else : residual = inputs else:
out = additive_func(residual, bottleneck) raise ValueError("invalid search_mode = {:}".format(self.search_mode))
return nn.functional.relu(out, inplace=True)
def search_forward(self, tuple_inputs): def basic_forward(self, inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) bottleneck = self.conv_1x1(inputs)
inputs, expected_inC, probability, indexes, probs = tuple_inputs bottleneck = self.conv_3x3(bottleneck)
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 bottleneck = self.conv_1x4(bottleneck)
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( (inputs, expected_inC , probability[0], indexes[0], probs[0]) ) if self.downsample is not None:
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( (out_1x1,expected_inC_1x1, probability[1], indexes[1], probs[1]) ) residual = self.downsample(inputs)
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( (out_3x3,expected_inC_3x3, probability[2], indexes[2], probs[2]) ) else:
if self.downsample is not None: residual = inputs
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[2], indexes[2], probs[2]) ) out = additive_func(residual, bottleneck)
else: return nn.functional.relu(out, inplace=True)
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_1x4) def search_forward(self, tuple_inputs):
return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c]) assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1(
(inputs, expected_inC, probability[0], indexes[0], probs[0])
)
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3(
(out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1])
)
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4(
(out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2])
)
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample(
(inputs, expected_inC, probability[2], indexes[2], probs[2])
)
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_1x4)
return (
nn.functional.relu(out, inplace=True),
expected_inC_1x4,
sum(
[
expected_flop_1x1,
expected_flop_3x3,
expected_flop_1x4,
expected_flop_c,
]
),
)
class SearchWidthCifarResNet(nn.Module): class SearchWidthCifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes):
super(SearchWidthCifarResNet, self).__init__()
def __init__(self, block_name, depth, num_classes): # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
super(SearchWidthCifarResNet, self).__init__() if block_name == "ResNetBasicblock":
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 2) // 6
elif block_name == "ResNetBottleneck":
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, "depth should be one of 164"
layer_blocks = (depth - 2) // 9
else:
raise ValueError("invalid block : {:}".format(block_name))
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model self.message = (
if block_name == 'ResNetBasicblock': "SearchWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
block = ResNetBasicblock depth, layer_blocks
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' )
layer_blocks = (depth - 2) // 6 )
elif block_name == 'ResNetBottleneck': self.num_classes = num_classes
block = ResNetBottleneck self.channels = [16]
assert (depth - 2) % 9 == 0, 'depth should be one of 164' self.layers = nn.ModuleList(
layer_blocks = (depth - 2) // 9 [
else: ConvBNReLU(
raise ValueError('invalid block : {:}'.format(block_name)) 3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
)
]
)
self.InShape = None
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2 ** stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append(module.out_dim)
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iC,
module.out_dim,
stride,
)
self.message = 'SearchWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) self.avgpool = nn.AvgPool2d(8)
self.num_classes = num_classes self.classifier = nn.Linear(module.out_dim, num_classes)
self.channels = [16] self.InShape = None
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) self.tau = -1
self.InShape = None self.search_mode = "basic"
for stage in range(3): # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth)
self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None)))) # parameters for width
nn.init.normal_(self.width_attentions, 0, 0.01) self.Ranges = []
self.apply(initialize_resnet) self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append((start_index, len(self.Ranges)))
assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format(
len(self.Ranges) + 1, depth
)
def arch_parameters(self): self.register_parameter(
return [self.width_attentions] "width_attentions",
nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))),
)
nn.init.normal_(self.width_attentions, 0, 0.01)
self.apply(initialize_resnet)
def base_parameters(self): def arch_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters()) return [self.width_attentions]
def get_flop(self, mode, config_dict, extra_info): def base_parameters(self):
if config_dict is not None: config_dict = config_dict.copy() return (
#weights = [F.softmax(x, dim=0) for x in self.width_attentions] list(self.layers.parameters())
channels = [3] + list(self.avgpool.parameters())
for i, weight in enumerate(self.width_attentions): + list(self.classifier.parameters())
if mode == 'genotype': )
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None:
config_dict = config_dict.copy()
# weights = [F.softmax(x, dim=0) for x in self.width_attentions]
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == "genotype":
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][torch.argmax(probe).item()]
elif mode == "max":
C = self.Ranges[i][-1]
elif mode == "fix":
C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
elif mode == "random":
assert isinstance(extra_info, float), "invalid extra_info : {:}".format(
extra_info
)
with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0)
approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
for j in range(prob.size(0)):
prob[j] = 1 / (
abs(j - (approximate_C - self.Ranges[i][j])) + 0.2
)
C = self.Ranges[i][torch.multinomial(prob, 1, False).item()]
else:
raise ValueError("invalid mode : {:}".format(mode))
channels.append(C)
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple(channels[s : e + 1])
flop += layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict["xchannels"] = channels
config_dict["super_type"] = "infer-width"
config_dict["estimated_FLOP"] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for width, there are {:} attention probabilities.".format(
len(self.width_attentions)
)
discrepancy = []
with torch.no_grad(): with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0) for i, att in enumerate(self.width_attentions):
C = self.Ranges[i][ torch.argmax(probe).item() ] prob = nn.functional.softmax(att, dim=0)
elif mode == 'max': prob = prob.cpu()
C = self.Ranges[i][-1] selc = prob.argmax().item()
elif mode == 'fix': prob = prob.tolist()
C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] ) prob = ["{:.3f}".format(x) for x in prob]
elif mode == 'random': xstring = "{:03d}/{:03d}-th : {:}".format(
assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info) i, len(self.width_attentions), " ".join(prob)
)
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
xstring += " || {:52s}".format(" ".join(logt))
prob = sorted([float(x) for x in prob])
disc = prob[-1] - prob[-2]
xstring += " || dis={:.2f} || select={:}/{:}".format(
disc, selc, len(prob)
)
discrepancy.append(disc)
string += "\n{:}".format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert (
epoch_ratio >= 0 and epoch_ratio <= 1
), "invalid epoch-ratio : {:}".format(epoch_ratio)
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, inputs):
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
with torch.no_grad(): with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0) selected_widths = selected_widths.cpu()
approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
for j in range(prob.size(0)):
prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2)
C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ]
else:
raise ValueError('invalid mode : {:}'.format(mode))
channels.append( C )
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple( channels[s:e+1] )
flop+= layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xchannels'] = channels
config_dict['super_type'] = 'infer-width'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self): x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
string = "for width, there are {:} attention probabilities.".format(len(self.width_attentions)) for i, layer in enumerate(self.layers):
discrepancy = [] selected_w_index = selected_widths[
with torch.no_grad(): last_channel_idx : last_channel_idx + layer.num_conv
for i, att in enumerate(self.width_attentions): ]
prob = nn.functional.softmax(att, dim=0) selected_w_probs = selected_probs[
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist() last_channel_idx : last_channel_idx + layer.num_conv
prob = ['{:.3f}'.format(x) for x in prob] ]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob)) layer_prob = flop_probs[
logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()] last_channel_idx : last_channel_idx + layer.num_conv
xstring += ' || {:52s}'.format(' '.join(logt)) ]
prob = sorted( [float(x) for x in prob] ) x, expected_inC, expected_flop = layer(
disc = prob[-1] - prob[-2] (x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob)) )
discrepancy.append( disc ) last_channel_idx += layer.num_conv
string += '\n{:}'.format(xstring) flops.append(expected_flop)
return string, discrepancy flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack([sum(flops)])
def set_tau(self, tau_max, tau_min, epoch_ratio): def basic_forward(self, inputs):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio) if self.InShape is None:
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 self.InShape = (inputs.size(-2), inputs.size(-1))
self.tau = tau x = inputs
for i, layer in enumerate(self.layers):
def get_message(self): x = layer(x)
return self.message features = self.avgpool(x)
features = features.view(features.size(0), -1)
def forward(self, inputs): logits = self.classifier(features)
if self.search_mode == 'basic': return features, logits
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv]
selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv]
layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv]
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
last_channel_idx += layer.num_conv
flops.append( expected_flop )
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

File diff suppressed because it is too large Load Diff

View File

@@ -4,313 +4,463 @@
import math, torch import math, torch
import torch.nn as nn import torch.nn as nn
from ..initialization import initialize_resnet from ..initialization import initialize_resnet
from ..SharedUtils import additive_func from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices as get_choices from .SoftSelect import get_width_choices as get_choices
def conv_forward(inputs, conv, choices): def conv_forward(inputs, conv, choices):
iC = conv.in_channels iC = conv.in_channels
fill_size = list(inputs.size()) fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1] fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device) filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1) xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs) outputs = conv(xinputs)
selecteds = [outputs[:,:oC] for oC in choices] selecteds = [outputs[:, :oC] for oC in choices]
return selecteds return selecteds
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):
num_conv = 1 num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) def __init__(
else : self.avg = None self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) ):
#if has_bn : self.bn = nn.BatchNorm2d(nOut) super(ConvBNReLU, self).__init__()
#else : self.bn = None self.InShape = None
self.has_bn = has_bn self.OutShape = None
self.BNs = nn.ModuleList() self.choices = get_choices(nOut)
for i, _out in enumerate(self.choices): self.register_buffer("choices_tensor", torch.Tensor(self.choices))
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = 'basic'
def get_flops(self, channels, check_range=True, divide=1): if has_avg:
iC, oC = channels self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels) else:
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape) self.avg = None
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape) self.conv = nn.Conv2d(
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups nIn,
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups) nOut,
all_positions = self.OutShape[0] * self.OutShape[1] kernel_size=kernel,
flops = (conv_per_position_flops * all_positions / divide) * iC * oC stride=stride,
if self.conv.bias is not None: flops += all_positions / divide padding=padding,
return flops dilation=1,
groups=1,
bias=bias,
)
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
# else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = "basic"
def get_range(self): def get_flops(self, channels, check_range=True, divide=1):
return [self.choices] iC, oC = channels
if check_range:
assert (
iC <= self.conv.in_channels and oC <= self.conv.out_channels
), "{:} vs {:} | {:} vs {:}".format(
iC, self.conv.in_channels, oC, self.conv.out_channels
)
assert (
isinstance(self.InShape, tuple) and len(self.InShape) == 2
), "invalid in-shape : {:}".format(self.InShape)
assert (
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
), "invalid out-shape : {:}".format(self.OutShape)
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None:
flops += all_positions / divide
return flops
def forward(self, inputs): def get_range(self):
if self.search_mode == 'basic': return [self.choices]
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs): def forward(self, inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) if self.search_mode == "basic":
inputs, expected_inC, probability, index, prob = tuple_inputs return self.basic_forward(inputs)
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) elif self.search_mode == "search":
probability = torch.squeeze(probability) return self.search_forward(inputs)
assert len(index) == 2, 'invalid length : {:}'.format(index) else:
# compute expected flop raise ValueError("invalid search_mode = {:}".format(self.search_mode))
#coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg : out = self.avg( inputs )
else : out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
#out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu: out = self.relu( out ) def search_forward(self, tuple_inputs):
else : out = out assert (
return out, expected_outC, expected_flop isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, "invalid length : {:}".format(index)
# compute expected flop
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg:
out = self.avg(inputs)
else:
out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
def basic_forward(self, inputs): if self.relu:
if self.avg : out = self.avg( inputs ) out = self.relu(out)
else : out = inputs else:
conv = self.conv( out ) out = out
if self.has_bn:out= self.BNs[-1]( conv ) return out, expected_outC, expected_flop
else : out = conv
if self.relu: out = self.relu( out ) def basic_forward(self, inputs):
else : out = out if self.avg:
if self.InShape is None: out = self.avg(inputs)
self.InShape = (inputs.size(-2), inputs.size(-1)) else:
self.OutShape = (out.size(-2) , out.size(-1)) out = inputs
return out conv = self.conv(out)
if self.has_bn:
out = self.BNs[-1](conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2), out.size(-1))
return out
class SimBlock(nn.Module): class SimBlock(nn.Module):
expansion = 1 expansion = 1
num_conv = 1 num_conv = 1
def __init__(self, inplanes, planes, stride):
super(SimBlock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_range(self): def __init__(self, inplanes, planes, stride):
return self.conv.get_range() super(SimBlock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv = ConvBNReLU(
inplanes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = "basic"
def get_flops(self, channels): def get_range(self):
assert len(channels) == 2, 'invalid channels : {:}'.format(channels) return self.conv.get_range()
flop_A = self.conv.get_flops([channels[0], channels[1]])
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_C = channels[0] * channels[-1] * self.conv.OutShape[0] * self.conv.OutShape[1]
return flop_A + flop_C
def forward(self, inputs): def get_flops(self, channels):
if self.search_mode == 'basic' : return self.basic_forward(inputs) assert len(channels) == 2, "invalid channels : {:}".format(channels)
elif self.search_mode == 'search': return self.search_forward(inputs) flop_A = self.conv.get_flops([channels[0], channels[1]])
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) if hasattr(self.downsample, "get_flops"):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if (
channels[0] != channels[-1] and self.downsample is None
): # this short-cut will be added during the infer-train
flop_C = (
channels[0]
* channels[-1]
* self.conv.OutShape[0]
* self.conv.OutShape[1]
)
return flop_A + flop_C
def search_forward(self, tuple_inputs): def forward(self, inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) if self.search_mode == "basic":
inputs, expected_inC, probability, indexes, probs = tuple_inputs return self.basic_forward(inputs)
assert indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1, 'invalid size : {:}, {:}, {:}'.format(indexes.size(), probs.size(), probability.size()) elif self.search_mode == "search":
out, expected_next_inC, expected_flop = self.conv( (inputs, expected_inC , probability[0], indexes[0], probs[0]) ) return self.search_forward(inputs)
if self.downsample is not None: else:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[-1], indexes[-1], probs[-1]) ) raise ValueError("invalid search_mode = {:}".format(self.search_mode))
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out)
return nn.functional.relu(out, inplace=True), expected_next_inC, sum([expected_flop, expected_flop_c])
def basic_forward(self, inputs): def search_forward(self, tuple_inputs):
basicblock = self.conv(inputs) assert (
if self.downsample is not None: residual = self.downsample(inputs) isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
else : residual = inputs ), "invalid type input : {:}".format(type(tuple_inputs))
out = additive_func(residual, basicblock) inputs, expected_inC, probability, indexes, probs = tuple_inputs
return nn.functional.relu(out, inplace=True) assert (
indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1
), "invalid size : {:}, {:}, {:}".format(
indexes.size(), probs.size(), probability.size()
)
out, expected_next_inC, expected_flop = self.conv(
(inputs, expected_inC, probability[0], indexes[0], probs[0])
)
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample(
(inputs, expected_inC, probability[-1], indexes[-1], probs[-1])
)
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out)
return (
nn.functional.relu(out, inplace=True),
expected_next_inC,
sum([expected_flop, expected_flop_c]),
)
def basic_forward(self, inputs):
basicblock = self.conv(inputs)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class SearchWidthSimResNet(nn.Module): class SearchWidthSimResNet(nn.Module):
def __init__(self, depth, num_classes):
super(SearchWidthSimResNet, self).__init__()
def __init__(self, depth, num_classes): assert (
super(SearchWidthSimResNet, self).__init__() depth - 2
) % 3 == 0, "depth should be one of 5, 8, 11, 14, ... instead of {:}".format(
depth
)
layer_blocks = (depth - 2) // 3
self.message = (
"SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}".format(
depth, layer_blocks
)
)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList(
[
ConvBNReLU(
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
)
]
)
self.InShape = None
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2 ** stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = SimBlock(iC, planes, stride)
self.channels.append(module.out_dim)
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iC,
module.out_dim,
stride,
)
assert (depth - 2) % 3 == 0, 'depth should be one of 5, 8, 11, 14, ... instead of {:}'.format(depth) self.avgpool = nn.AvgPool2d(8)
layer_blocks = (depth - 2) // 3 self.classifier = nn.Linear(module.out_dim, num_classes)
self.message = 'SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) self.InShape = None
self.num_classes = num_classes self.tau = -1
self.channels = [16] self.search_mode = "basic"
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
self.InShape = None
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = SimBlock(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth)
self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None)))) # parameters for width
nn.init.normal_(self.width_attentions, 0, 0.01) self.Ranges = []
self.apply(initialize_resnet) self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append((start_index, len(self.Ranges)))
assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format(
len(self.Ranges) + 1, depth
)
def arch_parameters(self): self.register_parameter(
return [self.width_attentions] "width_attentions",
nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))),
)
nn.init.normal_(self.width_attentions, 0, 0.01)
self.apply(initialize_resnet)
def base_parameters(self): def arch_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters()) return [self.width_attentions]
def get_flop(self, mode, config_dict, extra_info): def base_parameters(self):
if config_dict is not None: config_dict = config_dict.copy() return (
#weights = [F.softmax(x, dim=0) for x in self.width_attentions] list(self.layers.parameters())
channels = [3] + list(self.avgpool.parameters())
for i, weight in enumerate(self.width_attentions): + list(self.classifier.parameters())
if mode == 'genotype': )
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None:
config_dict = config_dict.copy()
# weights = [F.softmax(x, dim=0) for x in self.width_attentions]
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == "genotype":
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][torch.argmax(probe).item()]
elif mode == "max":
C = self.Ranges[i][-1]
elif mode == "fix":
C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
elif mode == "random":
assert isinstance(extra_info, float), "invalid extra_info : {:}".format(
extra_info
)
with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0)
approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
for j in range(prob.size(0)):
prob[j] = 1 / (
abs(j - (approximate_C - self.Ranges[i][j])) + 0.2
)
C = self.Ranges[i][torch.multinomial(prob, 1, False).item()]
else:
raise ValueError("invalid mode : {:}".format(mode))
channels.append(C)
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple(channels[s : e + 1])
flop += layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict["xchannels"] = channels
config_dict["super_type"] = "infer-width"
config_dict["estimated_FLOP"] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for width, there are {:} attention probabilities.".format(
len(self.width_attentions)
)
discrepancy = []
with torch.no_grad(): with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0) for i, att in enumerate(self.width_attentions):
C = self.Ranges[i][ torch.argmax(probe).item() ] prob = nn.functional.softmax(att, dim=0)
elif mode == 'max': prob = prob.cpu()
C = self.Ranges[i][-1] selc = prob.argmax().item()
elif mode == 'fix': prob = prob.tolist()
C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] ) prob = ["{:.3f}".format(x) for x in prob]
elif mode == 'random': xstring = "{:03d}/{:03d}-th : {:}".format(
assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info) i, len(self.width_attentions), " ".join(prob)
)
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
xstring += " || {:52s}".format(" ".join(logt))
prob = sorted([float(x) for x in prob])
disc = prob[-1] - prob[-2]
xstring += " || dis={:.2f} || select={:}/{:}".format(
disc, selc, len(prob)
)
discrepancy.append(disc)
string += "\n{:}".format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert (
epoch_ratio >= 0 and epoch_ratio <= 1
), "invalid epoch-ratio : {:}".format(epoch_ratio)
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, inputs):
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
with torch.no_grad(): with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0) selected_widths = selected_widths.cpu()
approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
for j in range(prob.size(0)):
prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2)
C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ]
else:
raise ValueError('invalid mode : {:}'.format(mode))
channels.append( C )
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple( channels[s:e+1] )
flop+= layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xchannels'] = channels
config_dict['super_type'] = 'infer-width'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self): x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
string = "for width, there are {:} attention probabilities.".format(len(self.width_attentions)) for i, layer in enumerate(self.layers):
discrepancy = [] selected_w_index = selected_widths[
with torch.no_grad(): last_channel_idx : last_channel_idx + layer.num_conv
for i, att in enumerate(self.width_attentions): ]
prob = nn.functional.softmax(att, dim=0) selected_w_probs = selected_probs[
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist() last_channel_idx : last_channel_idx + layer.num_conv
prob = ['{:.3f}'.format(x) for x in prob] ]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob)) layer_prob = flop_probs[
logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()] last_channel_idx : last_channel_idx + layer.num_conv
xstring += ' || {:52s}'.format(' '.join(logt)) ]
prob = sorted( [float(x) for x in prob] ) x, expected_inC, expected_flop = layer(
disc = prob[-1] - prob[-2] (x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob)) )
discrepancy.append( disc ) last_channel_idx += layer.num_conv
string += '\n{:}'.format(xstring) flops.append(expected_flop)
return string, discrepancy flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack([sum(flops)])
def set_tau(self, tau_max, tau_min, epoch_ratio): def basic_forward(self, inputs):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio) if self.InShape is None:
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 self.InShape = (inputs.size(-2), inputs.size(-1))
self.tau = tau x = inputs
for i, layer in enumerate(self.layers):
def get_message(self): x = layer(x)
return self.message features = self.avgpool(x)
features = features.view(features.size(0), -1)
def forward(self, inputs): logits = self.classifier(features)
if self.search_mode == 'basic': return features, logits
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv]
selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv]
layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv]
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
last_channel_idx += layer.num_conv
flops.append( expected_flop )
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -6,106 +6,123 @@ import torch.nn as nn
def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7): def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7):
if tau <= 0: if tau <= 0:
new_logits = logits new_logits = logits
probs = nn.functional.softmax(new_logits, dim=1) probs = nn.functional.softmax(new_logits, dim=1)
else : else:
while True: # a trick to avoid the gumbels bug while True: # a trick to avoid the gumbels bug
gumbels = -torch.empty_like(logits).exponential_().log() gumbels = -torch.empty_like(logits).exponential_().log()
new_logits = (logits.log_softmax(dim=1) + gumbels) / tau new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
probs = nn.functional.softmax(new_logits, dim=1) probs = nn.functional.softmax(new_logits, dim=1)
if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break if (
(not torch.isinf(gumbels).any())
and (not torch.isinf(probs).any())
and (not torch.isnan(probs).any())
):
break
if just_prob: return probs if just_prob:
return probs
#with torch.no_grad(): # add eps for unexpected torch error # with torch.no_grad(): # add eps for unexpected torch error
# probs = nn.functional.softmax(new_logits, dim=1) # probs = nn.functional.softmax(new_logits, dim=1)
# selected_index = torch.multinomial(probs + eps, 2, False) # selected_index = torch.multinomial(probs + eps, 2, False)
with torch.no_grad(): # add eps for unexpected torch error with torch.no_grad(): # add eps for unexpected torch error
probs = probs.cpu() probs = probs.cpu()
selected_index = torch.multinomial(probs + eps, num, False).to(logits.device) selected_index = torch.multinomial(probs + eps, num, False).to(logits.device)
selected_logit = torch.gather(new_logits, 1, selected_index) selected_logit = torch.gather(new_logits, 1, selected_index)
selcted_probs = nn.functional.softmax(selected_logit, dim=1) selcted_probs = nn.functional.softmax(selected_logit, dim=1)
return selected_index, selcted_probs return selected_index, selcted_probs
def ChannelWiseInter(inputs, oC, mode='v2'): def ChannelWiseInter(inputs, oC, mode="v2"):
if mode == 'v1': if mode == "v1":
return ChannelWiseInterV1(inputs, oC) return ChannelWiseInterV1(inputs, oC)
elif mode == 'v2': elif mode == "v2":
return ChannelWiseInterV2(inputs, oC) return ChannelWiseInterV2(inputs, oC)
else: else:
raise ValueError('invalid mode : {:}'.format(mode)) raise ValueError("invalid mode : {:}".format(mode))
def ChannelWiseInterV1(inputs, oC): def ChannelWiseInterV1(inputs, oC):
assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size()) assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size())
def start_index(a, b, c):
return int( math.floor(float(a * c) / b) ) def start_index(a, b, c):
def end_index(a, b, c): return int(math.floor(float(a * c) / b))
return int( math.ceil(float((a + 1) * c) / b) )
batch, iC, H, W = inputs.size() def end_index(a, b, c):
outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device) return int(math.ceil(float((a + 1) * c) / b))
if iC == oC: return inputs
for ot in range(oC): batch, iC, H, W = inputs.size()
istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC) outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device)
values = inputs[:, istartT:iendT].mean(dim=1) if iC == oC:
outputs[:, ot, :, :] = values return inputs
return outputs for ot in range(oC):
istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC)
values = inputs[:, istartT:iendT].mean(dim=1)
outputs[:, ot, :, :] = values
return outputs
def ChannelWiseInterV2(inputs, oC): def ChannelWiseInterV2(inputs, oC):
assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size()) assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size())
batch, C, H, W = inputs.size() batch, C, H, W = inputs.size()
if C == oC: return inputs if C == oC:
else : return nn.functional.adaptive_avg_pool3d(inputs, (oC,H,W)) return inputs
#inputs_5D = inputs.view(batch, 1, C, H, W) else:
#otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None) return nn.functional.adaptive_avg_pool3d(inputs, (oC, H, W))
#otputs = otputs_5D.view(batch, oC, H, W) # inputs_5D = inputs.view(batch, 1, C, H, W)
#otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False) # otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None)
#return otputs # otputs = otputs_5D.view(batch, oC, H, W)
# otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False)
# return otputs
def linear_forward(inputs, linear): def linear_forward(inputs, linear):
if linear is None: return inputs if linear is None:
iC = inputs.size(1) return inputs
weight = linear.weight[:, :iC] iC = inputs.size(1)
if linear.bias is None: bias = None weight = linear.weight[:, :iC]
else : bias = linear.bias if linear.bias is None:
return nn.functional.linear(inputs, weight, bias) bias = None
else:
bias = linear.bias
return nn.functional.linear(inputs, weight, bias)
def get_width_choices(nOut): def get_width_choices(nOut):
xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
if nOut is None: if nOut is None:
return len(xsrange) return len(xsrange)
else: else:
Xs = [int(nOut * i) for i in xsrange] Xs = [int(nOut * i) for i in xsrange]
#xs = [ int(nOut * i // 10) for i in range(2, 11)] # xs = [ int(nOut * i // 10) for i in range(2, 11)]
#Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1] # Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1]
Xs = sorted( list( set(Xs) ) ) Xs = sorted(list(set(Xs)))
return tuple(Xs) return tuple(Xs)
def get_depth_choices(nDepth): def get_depth_choices(nDepth):
if nDepth is None: if nDepth is None:
return 3 return 3
else:
assert nDepth >= 3, 'nDepth should be greater than 2 vs {:}'.format(nDepth)
if nDepth == 1 : return (1, 1, 1)
elif nDepth == 2: return (1, 1, 2)
elif nDepth >= 3:
return (nDepth//3, nDepth*2//3, nDepth)
else: else:
raise ValueError('invalid Depth : {:}'.format(nDepth)) assert nDepth >= 3, "nDepth should be greater than 2 vs {:}".format(nDepth)
if nDepth == 1:
return (1, 1, 1)
elif nDepth == 2:
return (1, 1, 2)
elif nDepth >= 3:
return (nDepth // 3, nDepth * 2 // 3, nDepth)
else:
raise ValueError("invalid Depth : {:}".format(nDepth))
def drop_path(x, drop_prob): def drop_path(x, drop_prob):
if drop_prob > 0.: if drop_prob > 0.0:
keep_prob = 1. - drop_prob keep_prob = 1.0 - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1) mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob) mask = mask.bernoulli_(keep_prob)
x = x * (mask / keep_prob) x = x * (mask / keep_prob)
#x.div_(keep_prob) # x.div_(keep_prob)
#x.mul_(mask) # x.mul_(mask)
return x return x

View File

@@ -3,7 +3,7 @@
################################################## ##################################################
from .SearchCifarResNet_width import SearchWidthCifarResNet from .SearchCifarResNet_width import SearchWidthCifarResNet
from .SearchCifarResNet_depth import SearchDepthCifarResNet from .SearchCifarResNet_depth import SearchDepthCifarResNet
from .SearchCifarResNet import SearchShapeCifarResNet from .SearchCifarResNet import SearchShapeCifarResNet
from .SearchSimResNet_width import SearchWidthSimResNet from .SearchSimResNet_width import SearchWidthSimResNet
from .SearchImagenetResNet import SearchShapeImagenetResNet from .SearchImagenetResNet import SearchShapeImagenetResNet
from .generic_size_tiny_cell_model import GenericNAS301Model from .generic_size_tiny_cell_model import GenericNAS301Model

View File

@@ -15,152 +15,195 @@ from models.shape_searchs.SoftSelect import select2withP, ChannelWiseInter
class GenericNAS301Model(nn.Module): class GenericNAS301Model(nn.Module):
def __init__(
self,
candidate_Cs: List[int],
max_num_Cs: int,
genotype: Any,
num_classes: int,
affine: bool,
track_running_stats: bool,
):
super(GenericNAS301Model, self).__init__()
self._max_num_Cs = max_num_Cs
self._candidate_Cs = candidate_Cs
if max_num_Cs % 3 != 2:
raise ValueError("invalid number of layers : {:}".format(max_num_Cs))
self._num_stage = N = max_num_Cs // 3
self._max_C = max(candidate_Cs)
def __init__(self, candidate_Cs: List[int], max_num_Cs: int, genotype: Any, num_classes: int, affine: bool, track_running_stats: bool): stem = nn.Sequential(
super(GenericNAS301Model, self).__init__() nn.Conv2d(3, self._max_C, kernel_size=3, padding=1, bias=not affine),
self._max_num_Cs = max_num_Cs nn.BatchNorm2d(
self._candidate_Cs = candidate_Cs self._max_C, affine=affine, track_running_stats=track_running_stats
if max_num_Cs % 3 != 2: ),
raise ValueError('invalid number of layers : {:}'.format(max_num_Cs)) )
self._num_stage = N = max_num_Cs // 3
self._max_C = max(candidate_Cs)
stem = nn.Sequential( layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
nn.Conv2d(3, self._max_C, kernel_size=3, padding=1, bias=not affine),
nn.BatchNorm2d(self._max_C, affine=affine, track_running_stats=track_running_stats))
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N c_prev = self._max_C
self._cells = nn.ModuleList()
self._cells.append(stem)
for index, reduction in enumerate(layer_reductions):
if reduction:
cell = ResNetBasicblock(c_prev, self._max_C, 2, True)
else:
cell = InferCell(
genotype, c_prev, self._max_C, 1, affine, track_running_stats
)
self._cells.append(cell)
c_prev = cell.out_dim
self._num_layer = len(self._cells)
c_prev = self._max_C self.lastact = nn.Sequential(
self._cells = nn.ModuleList() nn.BatchNorm2d(
self._cells.append(stem) c_prev, affine=affine, track_running_stats=track_running_stats
for index, reduction in enumerate(layer_reductions): ),
if reduction : cell = ResNetBasicblock(c_prev, self._max_C, 2, True) nn.ReLU(inplace=True),
else : cell = InferCell(genotype, c_prev, self._max_C, 1, affine, track_running_stats) )
self._cells.append(cell) self.global_pooling = nn.AdaptiveAvgPool2d(1)
c_prev = cell.out_dim self.classifier = nn.Linear(c_prev, num_classes)
self._num_layer = len(self._cells) # algorithm related
self.register_buffer("_tau", torch.zeros(1))
self._algo = None
self._warmup_ratio = None
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev, affine=affine, track_running_stats=track_running_stats), nn.ReLU(inplace=True)) def set_algo(self, algo: Text):
self.global_pooling = nn.AdaptiveAvgPool2d(1) # used for searching
self.classifier = nn.Linear(c_prev, num_classes) assert self._algo is None, "This functioin can only be called once."
# algorithm related assert algo in ["mask_gumbel", "mask_rl", "tas"], "invalid algo : {:}".format(
self.register_buffer('_tau', torch.zeros(1)) algo
self._algo = None )
self._warmup_ratio = None self._algo = algo
self._arch_parameters = nn.Parameter(
1e-3 * torch.randn(self._max_num_Cs, len(self._candidate_Cs))
)
# if algo == 'mask_gumbel' or algo == 'mask_rl':
self.register_buffer(
"_masks", torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))
)
for i in range(len(self._candidate_Cs)):
self._masks.data[i, : self._candidate_Cs[i]] = 1
def set_algo(self, algo: Text): @property
# used for searching def tau(self):
assert self._algo is None, 'This functioin can only be called once.' return self._tau
assert algo in ['mask_gumbel', 'mask_rl', 'tas'], 'invalid algo : {:}'.format(algo)
self._algo = algo
self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs)))
# if algo == 'mask_gumbel' or algo == 'mask_rl':
self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs)))
for i in range(len(self._candidate_Cs)):
self._masks.data[i, :self._candidate_Cs[i]] = 1
@property
def tau(self):
return self._tau
def set_tau(self, tau): def set_tau(self, tau):
self._tau.data[:] = tau self._tau.data[:] = tau
@property @property
def warmup_ratio(self): def warmup_ratio(self):
return self._warmup_ratio return self._warmup_ratio
def set_warmup_ratio(self, ratio: float): def set_warmup_ratio(self, ratio: float):
self._warmup_ratio = ratio self._warmup_ratio = ratio
@property @property
def weights(self): def weights(self):
xlist = list(self._cells.parameters()) xlist = list(self._cells.parameters())
xlist+= list(self.lastact.parameters()) xlist += list(self.lastact.parameters())
xlist+= list(self.global_pooling.parameters()) xlist += list(self.global_pooling.parameters())
xlist+= list(self.classifier.parameters()) xlist += list(self.classifier.parameters())
return xlist return xlist
@property @property
def alphas(self): def alphas(self):
return [self._arch_parameters] return [self._arch_parameters]
def show_alphas(self): def show_alphas(self):
with torch.no_grad():
return 'arch-parameters :\n{:}'.format(nn.functional.softmax(self._arch_parameters, dim=-1).cpu())
@property
def random(self):
cs = []
for i in range(self._max_num_Cs):
index = random.randint(0, len(self._candidate_Cs)-1)
cs.append(str(self._candidate_Cs[index]))
return ':'.join(cs)
@property
def genotype(self):
cs = []
for i in range(self._max_num_Cs):
with torch.no_grad():
index = self._arch_parameters[i].argmax().item()
cs.append(str(self._candidate_Cs[index]))
return ':'.join(cs)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self._cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self._cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(candidates={_candidate_Cs}, num={_max_num_Cs}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = inputs
log_probs = []
for i, cell in enumerate(self._cells):
feature = cell(feature)
# apply different searching algorithms
idx = max(0, i-1)
if self._warmup_ratio is not None:
if random.random() < self._warmup_ratio:
mask = self._masks[-1]
else:
mask = self._masks[random.randint(0, len(self._masks)-1)]
feature = feature * mask.view(1, -1, 1, 1)
elif self._algo == 'mask_gumbel':
weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1)
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
feature = feature * mask
elif self._algo == 'tas':
selected_cs, selected_probs = select2withP(self._arch_parameters[idx:idx+1], self.tau, num=2)
with torch.no_grad(): with torch.no_grad():
i1, i2 = selected_cs.cpu().view(-1).tolist() return "arch-parameters :\n{:}".format(
c1, c2 = self._candidate_Cs[i1], self._candidate_Cs[i2] nn.functional.softmax(self._arch_parameters, dim=-1).cpu()
out_channel = max(c1, c2) )
out1 = ChannelWiseInter(feature[:, :c1], out_channel)
out2 = ChannelWiseInter(feature[:, :c2], out_channel)
out = out1 * selected_probs[0, 0] + out2 * selected_probs[0, 1]
if feature.shape[1] == out.shape[1]:
feature = out
else:
miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device)
feature = torch.cat((out, miss), dim=1)
elif self._algo == 'mask_rl':
prob = nn.functional.softmax(self._arch_parameters[idx:idx+1], dim=-1)
dist = torch.distributions.Categorical(prob)
action = dist.sample()
log_probs.append(dist.log_prob(action))
mask = self._masks[action.item()].view(1, -1, 1, 1)
feature = feature * mask
else:
raise ValueError('invalid algorithm : {:}'.format(self._algo))
out = self.lastact(feature) @property
out = self.global_pooling(out) def random(self):
out = out.view(out.size(0), -1) cs = []
logits = self.classifier(out) for i in range(self._max_num_Cs):
index = random.randint(0, len(self._candidate_Cs) - 1)
cs.append(str(self._candidate_Cs[index]))
return ":".join(cs)
return out, logits, log_probs @property
def genotype(self):
cs = []
for i in range(self._max_num_Cs):
with torch.no_grad():
index = self._arch_parameters[i].argmax().item()
cs.append(str(self._candidate_Cs[index]))
return ":".join(cs)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self._cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self._cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(candidates={_candidate_Cs}, num={_max_num_Cs}, N={_num_stage}, L={_num_layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def forward(self, inputs):
feature = inputs
log_probs = []
for i, cell in enumerate(self._cells):
feature = cell(feature)
# apply different searching algorithms
idx = max(0, i - 1)
if self._warmup_ratio is not None:
if random.random() < self._warmup_ratio:
mask = self._masks[-1]
else:
mask = self._masks[random.randint(0, len(self._masks) - 1)]
feature = feature * mask.view(1, -1, 1, 1)
elif self._algo == "mask_gumbel":
weights = nn.functional.gumbel_softmax(
self._arch_parameters[idx : idx + 1], tau=self.tau, dim=-1
)
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
feature = feature * mask
elif self._algo == "tas":
selected_cs, selected_probs = select2withP(
self._arch_parameters[idx : idx + 1], self.tau, num=2
)
with torch.no_grad():
i1, i2 = selected_cs.cpu().view(-1).tolist()
c1, c2 = self._candidate_Cs[i1], self._candidate_Cs[i2]
out_channel = max(c1, c2)
out1 = ChannelWiseInter(feature[:, :c1], out_channel)
out2 = ChannelWiseInter(feature[:, :c2], out_channel)
out = out1 * selected_probs[0, 0] + out2 * selected_probs[0, 1]
if feature.shape[1] == out.shape[1]:
feature = out
else:
miss = torch.zeros(
feature.shape[0],
feature.shape[1] - out.shape[1],
feature.shape[2],
feature.shape[3],
device=feature.device,
)
feature = torch.cat((out, miss), dim=1)
elif self._algo == "mask_rl":
prob = nn.functional.softmax(
self._arch_parameters[idx : idx + 1], dim=-1
)
dist = torch.distributions.Categorical(prob)
action = dist.sample()
log_probs.append(dist.log_prob(action))
mask = self._masks[action.item()].view(1, -1, 1, 1)
feature = feature * mask
else:
raise ValueError("invalid algorithm : {:}".format(self._algo))
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits, log_probs

View File

@@ -6,15 +6,15 @@ import torch.nn as nn
from SoftSelect import ChannelWiseInter from SoftSelect import ChannelWiseInter
if __name__ == '__main__': if __name__ == "__main__":
tensors = torch.rand((16, 128, 7, 7)) tensors = torch.rand((16, 128, 7, 7))
for oc in range(200, 210): for oc in range(200, 210):
out_v1 = ChannelWiseInter(tensors, oc, 'v1') out_v1 = ChannelWiseInter(tensors, oc, "v1")
out_v2 = ChannelWiseInter(tensors, oc, 'v2') out_v2 = ChannelWiseInter(tensors, oc, "v2")
assert (out_v1 == out_v2).any().item() == 1 assert (out_v1 == out_v2).any().item() == 1
for oc in range(48, 160): for oc in range(48, 160):
out_v1 = ChannelWiseInter(tensors, oc, 'v1') out_v1 = ChannelWiseInter(tensors, oc, "v1")
out_v2 = ChannelWiseInter(tensors, oc, 'v2') out_v2 = ChannelWiseInter(tensors, oc, "v2")
assert (out_v1 == out_v2).any().item() == 1 assert (out_v1 == out_v2).any().item() == 1

View File

@@ -35,6 +35,22 @@ def get_model(config: Dict[Text, Any], **kwargs):
act_cls(), act_cls(),
SuperLinear(hidden_dim2, kwargs["output_dim"]), SuperLinear(hidden_dim2, kwargs["output_dim"]),
) )
elif model_type == "norm_mlp":
act_cls = super_name2activation[kwargs["act_cls"]]
norm_cls = super_name2norm[kwargs["norm_cls"]]
sub_layers, last_dim = [], kwargs["input_dim"]
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
sub_layers.extend(
[
norm_cls(last_dim, elementwise_affine=False),
SuperLinear(last_dim, hidden_dim),
act_cls(),
]
)
last_dim = hidden_dim
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
model = SuperSequential(*sub_layers)
else: else:
raise TypeError("Unkonwn model type: {:}".format(model_type)) raise TypeError("Unkonwn model type: {:}".format(model_type))
return model return model