diff --git a/exps/LFNA/lfna-tall-hpnet.py b/exps/LFNA/lfna-tall-hpnet.py new file mode 100644 index 0000000..99f8f49 --- /dev/null +++ b/exps/LFNA/lfna-tall-hpnet.py @@ -0,0 +1,179 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +# python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 +##################################################### +import sys, time, copy, torch, random, argparse +from tqdm import tqdm +from copy import deepcopy +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint +from log_utils import time_string +from log_utils import AverageMeter, convert_secs2time + +from utils import split_str2indexes + +from procedures.advanced_main import basic_train_fn, basic_eval_fn +from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric +from datasets.synthetic_core import get_synthetic_env +from models.xcore import get_model +from xlayers import super_core, trunc_normal_ + + +from lfna_utils import lfna_setup, train_model, TimeData + +# from lfna_models import HyperNet_VX as HyperNet +from lfna_models import HyperNet + + +def main(args): + logger, env_info, model_kwargs = lfna_setup(args) + dynamic_env = env_info["dynamic_env"] + model = get_model(dict(model_type="simple_mlp"), **model_kwargs) + criterion = torch.nn.MSELoss() + + logger.log("There are {:} weights.".format(model.get_w_container().numel())) + + shape_container = model.get_w_container().to_shape_container() + hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) + total_bar = env_info["total"] - 1 + task_embeds = [] + for i in range(total_bar): + task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim))) + for task_embed in task_embeds: + trunc_normal_(task_embed, std=0.02) + + parameters = list(hypernet.parameters()) + task_embeds + optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[ + int(args.epochs * 0.8), + int(args.epochs * 0.9), + ], + gamma=0.1, + ) + + # LFNA meta-training + loss_meter = AverageMeter() + per_epoch_time, start_time = AverageMeter(), time.time() + for iepoch in range(args.epochs): + + need_time = "Time Left: {:}".format( + convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) + ) + head_str = ( + "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) + + need_time + ) + + limit_bar = float(iepoch + 1) / args.epochs * total_bar + limit_bar = min(max(0, int(limit_bar)), total_bar) + losses = [] + for ibatch in range(args.meta_batch): + cur_time = random.randint(0, limit_bar) + cur_task_embed = task_embeds[cur_time] + cur_container = hypernet(cur_task_embed) + cur_x = env_info["{:}-x".format(cur_time)] + cur_y = env_info["{:}-y".format(cur_time)] + cur_dataset = TimeData(cur_time, cur_x, cur_y) + + preds = model.forward_with_container(cur_dataset.x, cur_container) + optimizer.zero_grad() + loss = criterion(preds, cur_dataset.y) + + losses.append(loss) + + final_loss = torch.stack(losses).mean() + final_loss.backward() + optimizer.step() + lr_scheduler.step() + + loss_meter.update(final_loss.item()) + if iepoch % 200 == 0: + logger.log( + head_str + + "meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}, limit={:}".format( + loss_meter.avg, + loss_meter.val, + min(lr_scheduler.get_last_lr()), + len(losses), + limit_bar, + ) + ) + + save_checkpoint( + { + "hypernet": hypernet.state_dict(), + "task_embed": task_embed, + "lr_scheduler": lr_scheduler.state_dict(), + "iepoch": iepoch, + }, + logger.path("model"), + logger, + ) + loss_meter.reset() + per_epoch_time.update(time.time() - start_time) + start_time = time.time() + + print(model) + print(hypernet) + + logger.log("-" * 200 + "\n") + logger.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Use the data in the past.") + parser.add_argument( + "--save_dir", + type=str, + default="./outputs/lfna-synthetic/lfna-tall-hpnet", + help="The checkpoint directory.", + ) + parser.add_argument( + "--env_version", + type=str, + required=True, + help="The synthetic enviornment version.", + ) + parser.add_argument( + "--hidden_dim", + type=int, + required=True, + help="The hidden dimension.", + ) + ##### + parser.add_argument( + "--init_lr", + type=float, + default=0.1, + help="The initial learning rate for the optimizer (default is Adam)", + ) + parser.add_argument( + "--meta_batch", + type=int, + default=64, + help="The batch size for the meta-model", + ) + parser.add_argument( + "--epochs", + type=int, + default=2000, + help="The total number of epochs.", + ) + # Random Seed + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "The save dir argument can not be None" + args.task_dim = args.hidden_dim + args.save_dir = "{:}-{:}-d{:}".format( + args.save_dir, args.env_version, args.hidden_dim + ) + main(args) diff --git a/exps/LFNA/lfna-test-hpnet.py b/exps/LFNA/lfna-test-hpnet.py index 3cf5552..ce7715d 100644 --- a/exps/LFNA/lfna-test-hpnet.py +++ b/exps/LFNA/lfna-test-hpnet.py @@ -41,10 +41,14 @@ def main(args): shape_container = model.get_w_container().to_shape_container() hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) - task_embed = torch.nn.Parameter(torch.Tensor(1, args.task_dim)) - trunc_normal_(task_embed, std=0.02) + total_bar = 10 + task_embeds = [] + for i in range(total_bar): + task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim))) + for task_embed in task_embeds: + trunc_normal_(task_embed, std=0.02) - parameters = list(hypernet.parameters()) + [task_embed] + parameters = list(hypernet.parameters()) + task_embeds optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, @@ -56,7 +60,6 @@ def main(args): ) # total_bar = env_info["total"] - 1 - total_bar = 1 # LFNA meta-training loss_meter = AverageMeter() per_epoch_time, start_time = AverageMeter(), time.time() @@ -74,7 +77,7 @@ def main(args): # for ibatch in range(args.meta_batch): for cur_time in range(total_bar): # cur_time = random.randint(0, total_bar) - cur_task_embed = task_embed + cur_task_embed = task_embeds[cur_time] cur_container = hypernet(cur_task_embed) cur_x = env_info["{:}-x".format(cur_time)] cur_y = env_info["{:}-y".format(cur_time)] @@ -98,7 +101,7 @@ def main(args): + "meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( loss_meter.avg, loss_meter.val, - min(lr_scheduler.get_lr()), + min(lr_scheduler.get_last_lr()), len(losses), ) ) diff --git a/exps/LFNA/lfna_models.py b/exps/LFNA/lfna_models.py index c85f4fa..063d5b6 100644 --- a/exps/LFNA/lfna_models.py +++ b/exps/LFNA/lfna_models.py @@ -28,6 +28,15 @@ class HyperNet(super_core.SuperModule): ) trunc_normal_(self._super_layer_embed, std=0.02) + model_kwargs = dict( + input_dim=layer_embeding + task_embedding, + output_dim=max(self._numel_per_layer), + hidden_dims=[layer_embeding * 4] * 4, + act_cls="gelu", + norm_cls="layer_norm_1d", + ) + self._generator = get_model(dict(model_type="norm_mlp"), **model_kwargs) + """ model_kwargs = dict( input_dim=layer_embeding + task_embedding, output_dim=max(self._numel_per_layer), @@ -36,6 +45,7 @@ class HyperNet(super_core.SuperModule): norm_cls="identity", ) self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) + """ self._return_container = return_container print("generator: {:}".format(self._generator)) diff --git a/lib/models/cell_infers/cells.py b/lib/models/cell_infers/cells.py index 0e9aae4..40df57b 100644 --- a/lib/models/cell_infers/cells.py +++ b/lib/models/cell_infers/cells.py @@ -11,111 +11,145 @@ from models.cell_operations import OPS # Cell for NAS-Bench-201 class InferCell(nn.Module): + def __init__( + self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True + ): + super(InferCell, self).__init__() - def __init__(self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True): - super(InferCell, self).__init__() + self.layers = nn.ModuleList() + self.node_IN = [] + self.node_IX = [] + self.genotype = deepcopy(genotype) + for i in range(1, len(genotype)): + node_info = genotype[i - 1] + cur_index = [] + cur_innod = [] + for (op_name, op_in) in node_info: + if op_in == 0: + layer = OPS[op_name]( + C_in, C_out, stride, affine, track_running_stats + ) + else: + layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats) + cur_index.append(len(self.layers)) + cur_innod.append(op_in) + self.layers.append(layer) + self.node_IX.append(cur_index) + self.node_IN.append(cur_innod) + self.nodes = len(genotype) + self.in_dim = C_in + self.out_dim = C_out - self.layers = nn.ModuleList() - self.node_IN = [] - self.node_IX = [] - self.genotype = deepcopy(genotype) - for i in range(1, len(genotype)): - node_info = genotype[i-1] - cur_index = [] - cur_innod = [] - for (op_name, op_in) in node_info: - if op_in == 0: - layer = OPS[op_name](C_in , C_out, stride, affine, track_running_stats) - else: - layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats) - cur_index.append( len(self.layers) ) - cur_innod.append( op_in ) - self.layers.append( layer ) - self.node_IX.append( cur_index ) - self.node_IN.append( cur_innod ) - self.nodes = len(genotype) - self.in_dim = C_in - self.out_dim = C_out - - def extra_repr(self): - string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) - laystr = [] - for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): - y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)] - x = '{:}<-({:})'.format(i+1, ','.join(y)) - laystr.append( x ) - return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr()) - - def forward(self, inputs): - nodes = [inputs] - for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): - node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) - nodes.append( node_feature ) - return nodes[-1] + def extra_repr(self): + string = "info :: nodes={nodes}, inC={in_dim}, outC={out_dim}".format( + **self.__dict__ + ) + laystr = [] + for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)): + y = [ + "I{:}-L{:}".format(_ii, _il) + for _il, _ii in zip(node_layers, node_innods) + ] + x = "{:}<-({:})".format(i + 1, ",".join(y)) + laystr.append(x) + return ( + string + + ", [{:}]".format(" | ".join(laystr)) + + ", {:}".format(self.genotype.tostr()) + ) + def forward(self, inputs): + nodes = [inputs] + for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)): + node_feature = sum( + self.layers[_il](nodes[_ii]) + for _il, _ii in zip(node_layers, node_innods) + ) + nodes.append(node_feature) + return nodes[-1] # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 class NASNetInferCell(nn.Module): + def __init__( + self, + genotype, + C_prev_prev, + C_prev, + C, + reduction, + reduction_prev, + affine, + track_running_stats, + ): + super(NASNetInferCell, self).__init__() + self.reduction = reduction + if reduction_prev: + self.preprocess0 = OPS["skip_connect"]( + C_prev_prev, C, 2, affine, track_running_stats + ) + else: + self.preprocess0 = OPS["nor_conv_1x1"]( + C_prev_prev, C, 1, affine, track_running_stats + ) + self.preprocess1 = OPS["nor_conv_1x1"]( + C_prev, C, 1, affine, track_running_stats + ) - def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): - super(NASNetInferCell, self).__init__() - self.reduction = reduction - if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) - else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) - self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) + if not reduction: + nodes, concats = genotype["normal"], genotype["normal_concat"] + else: + nodes, concats = genotype["reduce"], genotype["reduce_concat"] + self._multiplier = len(concats) + self._concats = concats + self._steps = len(nodes) + self._nodes = nodes + self.edges = nn.ModuleDict() + for i, node in enumerate(nodes): + for in_node in node: + name, j = in_node[0], in_node[1] + stride = 2 if reduction and j < 2 else 1 + node_str = "{:}<-{:}".format(i + 2, j) + self.edges[node_str] = OPS[name]( + C, C, stride, affine, track_running_stats + ) - if not reduction: - nodes, concats = genotype['normal'], genotype['normal_concat'] - else: - nodes, concats = genotype['reduce'], genotype['reduce_concat'] - self._multiplier = len(concats) - self._concats = concats - self._steps = len(nodes) - self._nodes = nodes - self.edges = nn.ModuleDict() - for i, node in enumerate(nodes): - for in_node in node: - name, j = in_node[0], in_node[1] - stride = 2 if reduction and j < 2 else 1 - node_str = '{:}<-{:}'.format(i+2, j) - self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats) + # [TODO] to support drop_prob in this function.. + def forward(self, s0, s1, unused_drop_prob): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) - # [TODO] to support drop_prob in this function.. - def forward(self, s0, s1, unused_drop_prob): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) - - states = [s0, s1] - for i, node in enumerate(self._nodes): - clist = [] - for in_node in node: - name, j = in_node[0], in_node[1] - node_str = '{:}<-{:}'.format(i+2, j) - op = self.edges[ node_str ] - clist.append( op(states[j]) ) - states.append( sum(clist) ) - return torch.cat([states[x] for x in self._concats], dim=1) + states = [s0, s1] + for i, node in enumerate(self._nodes): + clist = [] + for in_node in node: + name, j = in_node[0], in_node[1] + node_str = "{:}<-{:}".format(i + 2, j) + op = self.edges[node_str] + clist.append(op(states[j])) + states.append(sum(clist)) + return torch.cat([states[x] for x in self._concats], dim=1) class AuxiliaryHeadCIFAR(nn.Module): + def __init__(self, C, num_classes): + """assuming input size 8x8""" + super(AuxiliaryHeadCIFAR, self).__init__() + self.features = nn.Sequential( + nn.ReLU(inplace=True), + nn.AvgPool2d( + 5, stride=3, padding=0, count_include_pad=False + ), # image size = 2 x 2 + nn.Conv2d(C, 128, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, 2, bias=False), + nn.BatchNorm2d(768), + nn.ReLU(inplace=True), + ) + self.classifier = nn.Linear(768, num_classes) - def __init__(self, C, num_classes): - """assuming input size 8x8""" - super(AuxiliaryHeadCIFAR, self).__init__() - self.features = nn.Sequential( - nn.ReLU(inplace=True), - nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 - nn.Conv2d(C, 128, 1, bias=False), - nn.BatchNorm2d(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 768, 2, bias=False), - nn.BatchNorm2d(768), - nn.ReLU(inplace=True) - ) - self.classifier = nn.Linear(768, num_classes) - - def forward(self, x): - x = self.features(x) - x = self.classifier(x.view(x.size(0),-1)) - return x + def forward(self, x): + x = self.features(x) + x = self.classifier(x.view(x.size(0), -1)) + return x diff --git a/lib/models/cell_infers/nasnet_cifar.py b/lib/models/cell_infers/nasnet_cifar.py index 20b0f82..bdef399 100644 --- a/lib/models/cell_infers/nasnet_cifar.py +++ b/lib/models/cell_infers/nasnet_cifar.py @@ -9,63 +9,109 @@ from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR # The macro structure is based on NASNet class NASNetonCIFAR(nn.Module): + def __init__( + self, + C, + N, + stem_multiplier, + num_classes, + genotype, + auxiliary, + affine=True, + track_running_stats=True, + ): + super(NASNetonCIFAR, self).__init__() + self._C = C + self._layerN = N + self.stem = nn.Sequential( + nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C * stem_multiplier), + ) - def __init__(self, C, N, stem_multiplier, num_classes, genotype, auxiliary, affine=True, track_running_stats=True): - super(NASNetonCIFAR, self).__init__() - self._C = C - self._layerN = N - self.stem = nn.Sequential( - nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C*stem_multiplier)) - - # config for each layer - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) - layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) + # config for each layer + layer_channels = ( + [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) + ) + layer_reductions = ( + [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) + ) - C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False - self.auxiliary_index = None - self.auxiliary_head = None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) - self.cells.append( cell ) - C_prev_prev, C_prev, reduction_prev = C_prev, cell._multiplier*C_curr, reduction - if reduction and C_curr == C*4 and auxiliary: - self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) - self.auxiliary_index = index - self._Layer = len(self.cells) - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.drop_path_prob = -1 + C_prev_prev, C_prev, C_curr, reduction_prev = ( + C * stem_multiplier, + C * stem_multiplier, + C, + False, + ) + self.auxiliary_index = None + self.auxiliary_head = None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + cell = InferCell( + genotype, + C_prev_prev, + C_prev, + C_curr, + reduction, + reduction_prev, + affine, + track_running_stats, + ) + self.cells.append(cell) + C_prev_prev, C_prev, reduction_prev = ( + C_prev, + cell._multiplier * C_curr, + reduction, + ) + if reduction and C_curr == C * 4 and auxiliary: + self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) + self.auxiliary_index = index + self._Layer = len(self.cells) + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.drop_path_prob = -1 - def update_drop_path(self, drop_path_prob): - self.drop_path_prob = drop_path_prob + def update_drop_path(self, drop_path_prob): + self.drop_path_prob = drop_path_prob - def auxiliary_param(self): - if self.auxiliary_head is None: return [] - else: return list( self.auxiliary_head.parameters() ) + def auxiliary_param(self): + if self.auxiliary_head is None: + return [] + else: + return list(self.auxiliary_head.parameters()) - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string - def extra_repr(self): - return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + def extra_repr(self): + return "{name}(C={_C}, N={_layerN}, L={_Layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) - def forward(self, inputs): - stem_feature, logits_aux = self.stem(inputs), None - cell_results = [stem_feature, stem_feature] - for i, cell in enumerate(self.cells): - cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) - cell_results.append( cell_feature ) - if self.auxiliary_index is not None and i == self.auxiliary_index and self.training: - logits_aux = self.auxiliary_head( cell_results[-1] ) - out = self.lastact(cell_results[-1]) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) - if logits_aux is None: return out, logits - else: return out, [logits, logits_aux] + def forward(self, inputs): + stem_feature, logits_aux = self.stem(inputs), None + cell_results = [stem_feature, stem_feature] + for i, cell in enumerate(self.cells): + cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) + cell_results.append(cell_feature) + if ( + self.auxiliary_index is not None + and i == self.auxiliary_index + and self.training + ): + logits_aux = self.auxiliary_head(cell_results[-1]) + out = self.lastact(cell_results[-1]) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + if logits_aux is None: + return out, logits + else: + return out, [logits, logits_aux] diff --git a/lib/models/cell_infers/tiny_network.py b/lib/models/cell_infers/tiny_network.py index 6dd72b3..e8da1e4 100644 --- a/lib/models/cell_infers/tiny_network.py +++ b/lib/models/cell_infers/tiny_network.py @@ -8,51 +8,56 @@ from .cells import InferCell # The macro structure for architectures in NAS-Bench-201 class TinyNetwork(nn.Module): + def __init__(self, C, N, genotype, num_classes): + super(TinyNetwork, self).__init__() + self._C = C + self._layerN = N - def __init__(self, C, N, genotype, num_classes): - super(TinyNetwork, self).__init__() - self._C = C - self._layerN = N + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) + ) - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - C_prev = C - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2, True) - else: - cell = InferCell(genotype, C_prev, C_curr, 1) - self.cells.append( cell ) - C_prev = cell.out_dim - self._Layer= len(self.cells) + C_prev = C + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2, True) + else: + cell = InferCell(genotype, C_prev, C_curr, 1) + self.cells.append(cell) + C_prev = cell.out_dim + self._Layer = len(self.cells) - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string - def extra_repr(self): - return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + def extra_repr(self): + return "{name}(C={_C}, N={_layerN}, L={_Layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) - def forward(self, inputs): - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - feature = cell(feature) + def forward(self, inputs): + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + feature = cell(feature) - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) + out = self.lastact(feature) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) - return out, logits + return out, logits diff --git a/lib/models/cell_operations.py b/lib/models/cell_operations.py index 465f516..051539c 100644 --- a/lib/models/cell_operations.py +++ b/lib/models/cell_operations.py @@ -4,315 +4,550 @@ import torch import torch.nn as nn -__all__ = ['OPS', 'RAW_OP_CLASSES', 'ResNetBasicblock', 'SearchSpaceNames'] +__all__ = ["OPS", "RAW_OP_CLASSES", "ResNetBasicblock", "SearchSpaceNames"] OPS = { - 'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride), - 'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats), - 'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats), - 'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats), - 'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), - 'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats), - 'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), - 'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats), - 'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats), - 'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats), - 'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), + "none": lambda C_in, C_out, stride, affine, track_running_stats: Zero( + C_in, C_out, stride + ), + "avg_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING( + C_in, C_out, stride, "avg", affine, track_running_stats + ), + "max_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING( + C_in, C_out, stride, "max", affine, track_running_stats + ), + "nor_conv_7x7": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN( + C_in, + C_out, + (7, 7), + (stride, stride), + (3, 3), + (1, 1), + affine, + track_running_stats, + ), + "nor_conv_3x3": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN( + C_in, + C_out, + (3, 3), + (stride, stride), + (1, 1), + (1, 1), + affine, + track_running_stats, + ), + "nor_conv_1x1": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN( + C_in, + C_out, + (1, 1), + (stride, stride), + (0, 0), + (1, 1), + affine, + track_running_stats, + ), + "dua_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv( + C_in, + C_out, + (3, 3), + (stride, stride), + (1, 1), + (1, 1), + affine, + track_running_stats, + ), + "dua_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv( + C_in, + C_out, + (5, 5), + (stride, stride), + (2, 2), + (1, 1), + affine, + track_running_stats, + ), + "dil_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: SepConv( + C_in, + C_out, + (3, 3), + (stride, stride), + (2, 2), + (2, 2), + affine, + track_running_stats, + ), + "dil_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: SepConv( + C_in, + C_out, + (5, 5), + (stride, stride), + (4, 4), + (2, 2), + affine, + track_running_stats, + ), + "skip_connect": lambda C_in, C_out, stride, affine, track_running_stats: Identity() + if stride == 1 and C_in == C_out + else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), } -CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] -NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] -DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3'] +CONNECT_NAS_BENCHMARK = ["none", "skip_connect", "nor_conv_3x3"] +NAS_BENCH_201 = ["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"] +DARTS_SPACE = [ + "none", + "skip_connect", + "dua_sepc_3x3", + "dua_sepc_5x5", + "dil_sepc_3x3", + "dil_sepc_5x5", + "avg_pool_3x3", + "max_pool_3x3", +] -SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK, - 'nats-bench' : NAS_BENCH_201, - 'nas-bench-201': NAS_BENCH_201, - 'darts' : DARTS_SPACE} +SearchSpaceNames = { + "connect-nas": CONNECT_NAS_BENCHMARK, + "nats-bench": NAS_BENCH_201, + "nas-bench-201": NAS_BENCH_201, + "darts": DARTS_SPACE, +} class ReLUConvBN(nn.Module): + def __init__( + self, + C_in, + C_out, + kernel_size, + stride, + padding, + dilation, + affine, + track_running_stats=True, + ): + super(ReLUConvBN, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d( + C_in, + C_out, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=not affine, + ), + nn.BatchNorm2d( + C_out, affine=affine, track_running_stats=track_running_stats + ), + ) - def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): - super(ReLUConvBN, self).__init__() - self.op = nn.Sequential( - nn.ReLU(inplace=False), - nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine), - nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) - ) - - def forward(self, x): - return self.op(x) + def forward(self, x): + return self.op(x) class SepConv(nn.Module): - - def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): - super(SepConv, self).__init__() - self.op = nn.Sequential( - nn.ReLU(inplace=False), - nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), - nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine), - nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats), - ) + def __init__( + self, + C_in, + C_out, + kernel_size, + stride, + padding, + dilation, + affine, + track_running_stats=True, + ): + super(SepConv, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d( + C_in, + C_in, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=C_in, + bias=False, + ), + nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine), + nn.BatchNorm2d( + C_out, affine=affine, track_running_stats=track_running_stats + ), + ) - def forward(self, x): - return self.op(x) + def forward(self, x): + return self.op(x) class DualSepConv(nn.Module): - - def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): - super(DualSepConv, self).__init__() - self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats) - self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats) + def __init__( + self, + C_in, + C_out, + kernel_size, + stride, + padding, + dilation, + affine, + track_running_stats=True, + ): + super(DualSepConv, self).__init__() + self.op_a = SepConv( + C_in, + C_in, + kernel_size, + stride, + padding, + dilation, + affine, + track_running_stats, + ) + self.op_b = SepConv( + C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats + ) - def forward(self, x): - x = self.op_a(x) - x = self.op_b(x) - return x + def forward(self, x): + x = self.op_a(x) + x = self.op_b(x) + return x class ResNetBasicblock(nn.Module): + def __init__(self, inplanes, planes, stride, affine=True, track_running_stats=True): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_a = ReLUConvBN( + inplanes, planes, 3, stride, 1, 1, affine, track_running_stats + ) + self.conv_b = ReLUConvBN( + planes, planes, 3, 1, 1, 1, affine, track_running_stats + ) + if stride == 2: + self.downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=2, padding=0), + nn.Conv2d( + inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False + ), + ) + elif inplanes != planes: + self.downsample = ReLUConvBN( + inplanes, planes, 1, 1, 0, 1, affine, track_running_stats + ) + else: + self.downsample = None + self.in_dim = inplanes + self.out_dim = planes + self.stride = stride + self.num_conv = 2 - def __init__(self, inplanes, planes, stride, affine=True, track_running_stats=True): - super(ResNetBasicblock, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine, track_running_stats) - self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine, track_running_stats) - if stride == 2: - self.downsample = nn.Sequential( - nn.AvgPool2d(kernel_size=2, stride=2, padding=0), - nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False)) - elif inplanes != planes: - self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine, track_running_stats) - else: - self.downsample = None - self.in_dim = inplanes - self.out_dim = planes - self.stride = stride - self.num_conv = 2 + def extra_repr(self): + string = "{name}(inC={in_dim}, outC={out_dim}, stride={stride})".format( + name=self.__class__.__name__, **self.__dict__ + ) + return string - def extra_repr(self): - string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__) - return string + def forward(self, inputs): - def forward(self, inputs): + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) - basicblock = self.conv_a(inputs) - basicblock = self.conv_b(basicblock) - - if self.downsample is not None: - residual = self.downsample(inputs) - else: - residual = inputs - return residual + basicblock + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + return residual + basicblock class POOLING(nn.Module): + def __init__( + self, C_in, C_out, stride, mode, affine=True, track_running_stats=True + ): + super(POOLING, self).__init__() + if C_in == C_out: + self.preprocess = None + else: + self.preprocess = ReLUConvBN( + C_in, C_out, 1, 1, 0, 1, affine, track_running_stats + ) + if mode == "avg": + self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) + elif mode == "max": + self.op = nn.MaxPool2d(3, stride=stride, padding=1) + else: + raise ValueError("Invalid mode={:} in POOLING".format(mode)) - def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True): - super(POOLING, self).__init__() - if C_in == C_out: - self.preprocess = None - else: - self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1, affine, track_running_stats) - if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) - elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) - else : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) - - def forward(self, inputs): - if self.preprocess: x = self.preprocess(inputs) - else : x = inputs - return self.op(x) + def forward(self, inputs): + if self.preprocess: + x = self.preprocess(inputs) + else: + x = inputs + return self.op(x) class Identity(nn.Module): + def __init__(self): + super(Identity, self).__init__() - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return x + def forward(self, x): + return x class Zero(nn.Module): + def __init__(self, C_in, C_out, stride): + super(Zero, self).__init__() + self.C_in = C_in + self.C_out = C_out + self.stride = stride + self.is_zero = True - def __init__(self, C_in, C_out, stride): - super(Zero, self).__init__() - self.C_in = C_in - self.C_out = C_out - self.stride = stride - self.is_zero = True + def forward(self, x): + if self.C_in == self.C_out: + if self.stride == 1: + return x.mul(0.0) + else: + return x[:, :, :: self.stride, :: self.stride].mul(0.0) + else: + shape = list(x.shape) + shape[1] = self.C_out + zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device) + return zeros - def forward(self, x): - if self.C_in == self.C_out: - if self.stride == 1: return x.mul(0.) - else : return x[:,:,::self.stride,::self.stride].mul(0.) - else: - shape = list(x.shape) - shape[1] = self.C_out - zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device) - return zeros - - def extra_repr(self): - return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) + def extra_repr(self): + return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__) class FactorizedReduce(nn.Module): + def __init__(self, C_in, C_out, stride, affine, track_running_stats): + super(FactorizedReduce, self).__init__() + self.stride = stride + self.C_in = C_in + self.C_out = C_out + self.relu = nn.ReLU(inplace=False) + if stride == 2: + # assert C_out % 2 == 0, 'C_out : {:}'.format(C_out) + C_outs = [C_out // 2, C_out - C_out // 2] + self.convs = nn.ModuleList() + for i in range(2): + self.convs.append( + nn.Conv2d( + C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine + ) + ) + self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) + elif stride == 1: + self.conv = nn.Conv2d( + C_in, C_out, 1, stride=stride, padding=0, bias=not affine + ) + else: + raise ValueError("Invalid stride : {:}".format(stride)) + self.bn = nn.BatchNorm2d( + C_out, affine=affine, track_running_stats=track_running_stats + ) - def __init__(self, C_in, C_out, stride, affine, track_running_stats): - super(FactorizedReduce, self).__init__() - self.stride = stride - self.C_in = C_in - self.C_out = C_out - self.relu = nn.ReLU(inplace=False) - if stride == 2: - #assert C_out % 2 == 0, 'C_out : {:}'.format(C_out) - C_outs = [C_out // 2, C_out - C_out // 2] - self.convs = nn.ModuleList() - for i in range(2): - self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine)) - self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) - elif stride == 1: - self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=not affine) - else: - raise ValueError('Invalid stride : {:}'.format(stride)) - self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) + def forward(self, x): + if self.stride == 2: + x = self.relu(x) + y = self.pad(x) + out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1) + else: + out = self.conv(x) + out = self.bn(out) + return out - def forward(self, x): - if self.stride == 2: - x = self.relu(x) - y = self.pad(x) - out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1) - else: - out = self.conv(x) - out = self.bn(out) - return out - - def extra_repr(self): - return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) + def extra_repr(self): + return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__) # Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019 class PartAwareOp(nn.Module): - - def __init__(self, C_in, C_out, stride, part=4): - super().__init__() - self.part = 4 - self.hidden = C_in // 3 - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.local_conv_list = nn.ModuleList() - for i in range(self.part): - self.local_conv_list.append( - nn.Sequential(nn.ReLU(), nn.Conv2d(C_in, self.hidden, 1), nn.BatchNorm2d(self.hidden, affine=True)) + def __init__(self, C_in, C_out, stride, part=4): + super().__init__() + self.part = 4 + self.hidden = C_in // 3 + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.local_conv_list = nn.ModuleList() + for i in range(self.part): + self.local_conv_list.append( + nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, self.hidden, 1), + nn.BatchNorm2d(self.hidden, affine=True), + ) ) - self.W_K = nn.Linear(self.hidden, self.hidden) - self.W_Q = nn.Linear(self.hidden, self.hidden) + self.W_K = nn.Linear(self.hidden, self.hidden) + self.W_Q = nn.Linear(self.hidden, self.hidden) - if stride == 2 : self.last = FactorizedReduce(C_in + self.hidden, C_out, 2) - elif stride == 1: self.last = FactorizedReduce(C_in + self.hidden, C_out, 1) - else: raise ValueError('Invalid Stride : {:}'.format(stride)) + if stride == 2: + self.last = FactorizedReduce(C_in + self.hidden, C_out, 2) + elif stride == 1: + self.last = FactorizedReduce(C_in + self.hidden, C_out, 1) + else: + raise ValueError("Invalid Stride : {:}".format(stride)) - def forward(self, x): - batch, C, H, W = x.size() - assert H >= self.part, 'input size too small : {:} vs {:}'.format(x.shape, self.part) - IHs = [0] - for i in range(self.part): IHs.append( min(H, int((i+1)*(float(H)/self.part))) ) - local_feat_list = [] - for i in range(self.part): - feature = x[:, :, IHs[i]:IHs[i+1], :] - xfeax = self.avg_pool(feature) - xfea = self.local_conv_list[i]( xfeax ) - local_feat_list.append( xfea ) - part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part) - part_feature = part_feature.transpose(1,2).contiguous() - part_K = self.W_K(part_feature) - part_Q = self.W_Q(part_feature).transpose(1,2).contiguous() - weight_att = torch.bmm(part_K, part_Q) - attention = torch.softmax(weight_att, dim=2) - aggreateF = torch.bmm(attention, part_feature).transpose(1,2).contiguous() - features = [] - for i in range(self.part): - feature = aggreateF[:, :, i:i+1].expand(batch, self.hidden, IHs[i+1]-IHs[i]) - feature = feature.view(batch, self.hidden, IHs[i+1]-IHs[i], 1) - features.append( feature ) - features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W) - final_fea = torch.cat((x,features), dim=1) - outputs = self.last( final_fea ) - return outputs + def forward(self, x): + batch, C, H, W = x.size() + assert H >= self.part, "input size too small : {:} vs {:}".format( + x.shape, self.part + ) + IHs = [0] + for i in range(self.part): + IHs.append(min(H, int((i + 1) * (float(H) / self.part)))) + local_feat_list = [] + for i in range(self.part): + feature = x[:, :, IHs[i] : IHs[i + 1], :] + xfeax = self.avg_pool(feature) + xfea = self.local_conv_list[i](xfeax) + local_feat_list.append(xfea) + part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part) + part_feature = part_feature.transpose(1, 2).contiguous() + part_K = self.W_K(part_feature) + part_Q = self.W_Q(part_feature).transpose(1, 2).contiguous() + weight_att = torch.bmm(part_K, part_Q) + attention = torch.softmax(weight_att, dim=2) + aggreateF = torch.bmm(attention, part_feature).transpose(1, 2).contiguous() + features = [] + for i in range(self.part): + feature = aggreateF[:, :, i : i + 1].expand( + batch, self.hidden, IHs[i + 1] - IHs[i] + ) + feature = feature.view(batch, self.hidden, IHs[i + 1] - IHs[i], 1) + features.append(feature) + features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W) + final_fea = torch.cat((x, features), dim=1) + outputs = self.last(final_fea) + return outputs def drop_path(x, drop_prob): - if drop_prob > 0.: - keep_prob = 1. - drop_prob - mask = x.new_zeros(x.size(0), 1, 1, 1) - mask = mask.bernoulli_(keep_prob) - x = torch.div(x, keep_prob) - x.mul_(mask) - return x + if drop_prob > 0.0: + keep_prob = 1.0 - drop_prob + mask = x.new_zeros(x.size(0), 1, 1, 1) + mask = mask.bernoulli_(keep_prob) + x = torch.div(x, keep_prob) + x.mul_(mask) + return x # Searching for A Robust Neural Architecture in Four GPU Hours class GDAS_Reduction_Cell(nn.Module): + def __init__( + self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats + ): + super(GDAS_Reduction_Cell, self).__init__() + if reduction_prev: + self.preprocess0 = FactorizedReduce( + C_prev_prev, C, 2, affine, track_running_stats + ) + else: + self.preprocess0 = ReLUConvBN( + C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats + ) + self.preprocess1 = ReLUConvBN( + C_prev, C, 1, 1, 0, 1, affine, track_running_stats + ) - def __init__(self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats): - super(GDAS_Reduction_Cell, self).__init__() - if reduction_prev: - self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine, track_running_stats) - else: - self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats) - self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, affine, track_running_stats) + self.reduction = True + self.ops1 = nn.ModuleList( + [ + nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d( + C, + C, + (1, 3), + stride=(1, 2), + padding=(0, 1), + groups=8, + bias=not affine, + ), + nn.Conv2d( + C, + C, + (3, 1), + stride=(2, 1), + padding=(1, 0), + groups=8, + bias=not affine, + ), + nn.BatchNorm2d( + C, affine=affine, track_running_stats=track_running_stats + ), + nn.ReLU(inplace=False), + nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine), + nn.BatchNorm2d( + C, affine=affine, track_running_stats=track_running_stats + ), + ), + nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d( + C, + C, + (1, 3), + stride=(1, 2), + padding=(0, 1), + groups=8, + bias=not affine, + ), + nn.Conv2d( + C, + C, + (3, 1), + stride=(2, 1), + padding=(1, 0), + groups=8, + bias=not affine, + ), + nn.BatchNorm2d( + C, affine=affine, track_running_stats=track_running_stats + ), + nn.ReLU(inplace=False), + nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine), + nn.BatchNorm2d( + C, affine=affine, track_running_stats=track_running_stats + ), + ), + ] + ) - self.reduction = True - self.ops1 = nn.ModuleList( - [nn.Sequential( - nn.ReLU(inplace=False), - nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=not affine), - nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=not affine), - nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats), - nn.ReLU(inplace=False), - nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine), - nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats)), - nn.Sequential( - nn.ReLU(inplace=False), - nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=not affine), - nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=not affine), - nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats), - nn.ReLU(inplace=False), - nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine), - nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats))]) + self.ops2 = nn.ModuleList( + [ + nn.Sequential( + nn.MaxPool2d(3, stride=2, padding=1), + nn.BatchNorm2d( + C, affine=affine, track_running_stats=track_running_stats + ), + ), + nn.Sequential( + nn.MaxPool2d(3, stride=2, padding=1), + nn.BatchNorm2d( + C, affine=affine, track_running_stats=track_running_stats + ), + ), + ] + ) - self.ops2 = nn.ModuleList( - [nn.Sequential( - nn.MaxPool2d(3, stride=2, padding=1), - nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats)), - nn.Sequential( - nn.MaxPool2d(3, stride=2, padding=1), - nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats))]) + @property + def multiplier(self): + return 4 - @property - def multiplier(self): - return 4 + def forward(self, s0, s1, drop_prob=-1): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) - def forward(self, s0, s1, drop_prob = -1): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) + X0 = self.ops1[0](s0) + X1 = self.ops1[1](s1) + if self.training and drop_prob > 0.0: + X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob) - X0 = self.ops1[0] (s0) - X1 = self.ops1[1] (s1) - if self.training and drop_prob > 0.: - X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob) - - #X2 = self.ops2[0] (X0+X1) - X2 = self.ops2[0] (s0) - X3 = self.ops2[1] (s1) - if self.training and drop_prob > 0.: - X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob) - return torch.cat([X0, X1, X2, X3], dim=1) + # X2 = self.ops2[0] (X0+X1) + X2 = self.ops2[0](s0) + X3 = self.ops2[1](s1) + if self.training and drop_prob > 0.0: + X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob) + return torch.cat([X0, X1, X2, X3], dim=1) # To manage the useful classes in this file. -RAW_OP_CLASSES = { - 'gdas_reduction': GDAS_Reduction_Cell -} - +RAW_OP_CLASSES = {"gdas_reduction": GDAS_Reduction_Cell} diff --git a/lib/models/cell_searchs/__init__.py b/lib/models/cell_searchs/__init__.py index 05a315c..0d770cb 100644 --- a/lib/models/cell_searchs/__init__.py +++ b/lib/models/cell_searchs/__init__.py @@ -2,27 +2,32 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## # The macro structure is defined in NAS-Bench-201 -from .search_model_darts import TinyNetworkDarts -from .search_model_gdas import TinyNetworkGDAS -from .search_model_setn import TinyNetworkSETN -from .search_model_enas import TinyNetworkENAS -from .search_model_random import TinyNetworkRANDOM -from .generic_model import GenericNAS201Model -from .genotypes import Structure as CellStructure, architectures as CellArchitectures +from .search_model_darts import TinyNetworkDarts +from .search_model_gdas import TinyNetworkGDAS +from .search_model_setn import TinyNetworkSETN +from .search_model_enas import TinyNetworkENAS +from .search_model_random import TinyNetworkRANDOM +from .generic_model import GenericNAS201Model +from .genotypes import Structure as CellStructure, architectures as CellArchitectures + # NASNet-based macro structure from .search_model_gdas_nasnet import NASNetworkGDAS from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC from .search_model_darts_nasnet import NASNetworkDARTS -nas201_super_nets = {'DARTS-V1': TinyNetworkDarts, - "DARTS-V2": TinyNetworkDarts, - "GDAS": TinyNetworkGDAS, - "SETN": TinyNetworkSETN, - "ENAS": TinyNetworkENAS, - "RANDOM": TinyNetworkRANDOM, - "generic": GenericNAS201Model} +nas201_super_nets = { + "DARTS-V1": TinyNetworkDarts, + "DARTS-V2": TinyNetworkDarts, + "GDAS": TinyNetworkGDAS, + "SETN": TinyNetworkSETN, + "ENAS": TinyNetworkENAS, + "RANDOM": TinyNetworkRANDOM, + "generic": GenericNAS201Model, +} -nasnet_super_nets = {"GDAS": NASNetworkGDAS, - "GDAS_FRC": NASNetworkGDAS_FRC, - "DARTS": NASNetworkDARTS} +nasnet_super_nets = { + "GDAS": NASNetworkGDAS, + "GDAS_FRC": NASNetworkGDAS_FRC, + "DARTS": NASNetworkDARTS, +} diff --git a/lib/models/cell_searchs/_test_module.py b/lib/models/cell_searchs/_test_module.py index c603ba6..cd6fbfb 100644 --- a/lib/models/cell_searchs/_test_module.py +++ b/lib/models/cell_searchs/_test_module.py @@ -4,9 +4,11 @@ import torch from search_model_enas_utils import Controller -def main(): - controller = Controller(6, 4) - predictions = controller() -if __name__ == '__main__': - main() +def main(): + controller = Controller(6, 4) + predictions = controller() + + +if __name__ == "__main__": + main() diff --git a/lib/models/cell_searchs/generic_model.py b/lib/models/cell_searchs/generic_model.py index 25f72f9..ad0cd30 100644 --- a/lib/models/cell_searchs/generic_model.py +++ b/lib/models/cell_searchs/generic_model.py @@ -8,296 +8,355 @@ from typing import Text from torch.distributions.categorical import Categorical from ..cell_operations import ResNetBasicblock, drop_path -from .search_cells import NAS201SearchCell as SearchCell -from .genotypes import Structure +from .search_cells import NAS201SearchCell as SearchCell +from .genotypes import Structure class Controller(nn.Module): - # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py - def __init__(self, edge2index, op_names, max_nodes, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0): - super(Controller, self).__init__() - # assign the attributes - self.max_nodes = max_nodes - self.num_edge = len(edge2index) - self.edge2index = edge2index - self.num_ops = len(op_names) - self.op_names = op_names - self.lstm_size = lstm_size - self.lstm_N = lstm_num_layers - self.tanh_constant = tanh_constant - self.temperature = temperature - # create parameters - self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size))) - self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N) - self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) - self.w_pred = nn.Linear(self.lstm_size, self.num_ops) + # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py + def __init__( + self, + edge2index, + op_names, + max_nodes, + lstm_size=32, + lstm_num_layers=2, + tanh_constant=2.5, + temperature=5.0, + ): + super(Controller, self).__init__() + # assign the attributes + self.max_nodes = max_nodes + self.num_edge = len(edge2index) + self.edge2index = edge2index + self.num_ops = len(op_names) + self.op_names = op_names + self.lstm_size = lstm_size + self.lstm_N = lstm_num_layers + self.tanh_constant = tanh_constant + self.temperature = temperature + # create parameters + self.register_parameter( + "input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size)) + ) + self.w_lstm = nn.LSTM( + input_size=self.lstm_size, + hidden_size=self.lstm_size, + num_layers=self.lstm_N, + ) + self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) + self.w_pred = nn.Linear(self.lstm_size, self.num_ops) - nn.init.uniform_(self.input_vars , -0.1, 0.1) - nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) - nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) - nn.init.uniform_(self.w_embd.weight , -0.1, 0.1) - nn.init.uniform_(self.w_pred.weight , -0.1, 0.1) + nn.init.uniform_(self.input_vars, -0.1, 0.1) + nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) + nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) + nn.init.uniform_(self.w_embd.weight, -0.1, 0.1) + nn.init.uniform_(self.w_pred.weight, -0.1, 0.1) - def convert_structure(self, _arch): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - op_index = _arch[self.edge2index[node_str]] - op_name = self.op_names[op_index] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure(genotypes) + def convert_structure(self, _arch): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + op_index = _arch[self.edge2index[node_str]] + op_name = self.op_names[op_index] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return Structure(genotypes) - def forward(self): + def forward(self): - inputs, h0 = self.input_vars, None - log_probs, entropys, sampled_arch = [], [], [] - for iedge in range(self.num_edge): - outputs, h0 = self.w_lstm(inputs, h0) - - logits = self.w_pred(outputs) - logits = logits / self.temperature - logits = self.tanh_constant * torch.tanh(logits) - # distribution - op_distribution = Categorical(logits=logits) - op_index = op_distribution.sample() - sampled_arch.append( op_index.item() ) + inputs, h0 = self.input_vars, None + log_probs, entropys, sampled_arch = [], [], [] + for iedge in range(self.num_edge): + outputs, h0 = self.w_lstm(inputs, h0) - op_log_prob = op_distribution.log_prob(op_index) - log_probs.append( op_log_prob.view(-1) ) - op_entropy = op_distribution.entropy() - entropys.append( op_entropy.view(-1) ) - - # obtain the input embedding for the next step - inputs = self.w_embd(op_index) - return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), self.convert_structure(sampled_arch) + logits = self.w_pred(outputs) + logits = logits / self.temperature + logits = self.tanh_constant * torch.tanh(logits) + # distribution + op_distribution = Categorical(logits=logits) + op_index = op_distribution.sample() + sampled_arch.append(op_index.item()) + op_log_prob = op_distribution.log_prob(op_index) + log_probs.append(op_log_prob.view(-1)) + op_entropy = op_distribution.entropy() + entropys.append(op_entropy.view(-1)) + + # obtain the input embedding for the next step + inputs = self.w_embd(op_index) + return ( + torch.sum(torch.cat(log_probs)), + torch.sum(torch.cat(entropys)), + self.convert_structure(sampled_arch), + ) class GenericNAS201Model(nn.Module): + def __init__( + self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats + ): + super(GenericNAS201Model, self).__init__() + self._C = C + self._layerN = N + self._max_nodes = max_nodes + self._stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) + ) + layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + C_prev, num_edge, edge2index = C, None, None + self._cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell( + C_prev, + C_curr, + 1, + max_nodes, + search_space, + affine, + track_running_stats, + ) + if num_edge is None: + num_edge, edge2index = cell.num_edges, cell.edge2index + else: + assert ( + num_edge == cell.num_edges and edge2index == cell.edge2index + ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) + self._cells.append(cell) + C_prev = cell.out_dim + self._op_names = deepcopy(search_space) + self._Layer = len(self._cells) + self.edge2index = edge2index + self.lastact = nn.Sequential( + nn.BatchNorm2d( + C_prev, affine=affine, track_running_stats=track_running_stats + ), + nn.ReLU(inplace=True), + ) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self._num_edge = num_edge + # algorithm related + self.arch_parameters = nn.Parameter( + 1e-3 * torch.randn(num_edge, len(search_space)) + ) + self._mode = None + self.dynamic_cell = None + self._tau = None + self._algo = None + self._drop_path = None + self.verbose = False - def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): - super(GenericNAS201Model, self).__init__() - self._C = C - self._layerN = N - self._max_nodes = max_nodes - self._stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - C_prev, num_edge, edge2index = C, None, None - self._cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self._cells.append(cell) - C_prev = cell.out_dim - self._op_names = deepcopy(search_space) - self._Layer = len(self._cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev, affine=affine, track_running_stats=track_running_stats), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self._num_edge = num_edge - # algorithm related - self.arch_parameters = nn.Parameter(1e-3*torch.randn(num_edge, len(search_space))) - self._mode = None - self.dynamic_cell = None - self._tau = None - self._algo = None - self._drop_path = None - self.verbose = False - - def set_algo(self, algo: Text): - # used for searching - assert self._algo is None, 'This functioin can only be called once.' - self._algo = algo - if algo == 'enas': - self.controller = Controller(self.edge2index, self._op_names, self._max_nodes) - else: - self.arch_parameters = nn.Parameter( 1e-3*torch.randn(self._num_edge, len(self._op_names)) ) - if algo == 'gdas': - self._tau = 10 - - def set_cal_mode(self, mode, dynamic_cell=None): - assert mode in ['gdas', 'enas', 'urs', 'joint', 'select', 'dynamic'] - self._mode = mode - if mode == 'dynamic': self.dynamic_cell = deepcopy(dynamic_cell) - else : self.dynamic_cell = None - - def set_drop_path(self, progress, drop_path_rate): - if drop_path_rate is None: - self._drop_path = None - elif progress is None: - self._drop_path = drop_path_rate - else: - self._drop_path = progress * drop_path_rate - - @property - def mode(self): - return self._mode - - @property - def drop_path(self): - return self._drop_path - - @property - def weights(self): - xlist = list(self._stem.parameters()) - xlist+= list(self._cells.parameters()) - xlist+= list(self.lastact.parameters()) - xlist+= list(self.global_pooling.parameters()) - xlist+= list(self.classifier.parameters()) - return xlist - - def set_tau(self, tau): - self._tau = tau - - @property - def tau(self): - return self._tau - - @property - def alphas(self): - if self._algo == 'enas': - return list(self.controller.parameters()) - else: - return [self.arch_parameters] - - @property - def message(self): - string = self.extra_repr() - for i, cell in enumerate(self._cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self._cells), cell.extra_repr()) - return string - - def show_alphas(self): - with torch.no_grad(): - if self._algo == 'enas': - return 'w_pred :\n{:}'.format(self.controller.w_pred.weight) - else: - return 'arch-parameters :\n{:}'.format(nn.functional.softmax(self.arch_parameters, dim=-1).cpu()) - - - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})'.format(name=self.__class__.__name__, **self.__dict__)) - - @property - def genotype(self): - genotypes = [] - for i in range(1, self._max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - with torch.no_grad(): - weights = self.arch_parameters[ self.edge2index[node_str] ] - op_name = self._op_names[ weights.argmax().item() ] - xlist.append((op_name, j)) - genotypes.append(tuple(xlist)) - return Structure(genotypes) - - def dync_genotype(self, use_random=False): - genotypes = [] - with torch.no_grad(): - alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) - for i in range(1, self._max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - if use_random: - op_name = random.choice(self._op_names) + def set_algo(self, algo: Text): + # used for searching + assert self._algo is None, "This functioin can only be called once." + self._algo = algo + if algo == "enas": + self.controller = Controller( + self.edge2index, self._op_names, self._max_nodes + ) else: - weights = alphas_cpu[ self.edge2index[node_str] ] - op_index = torch.multinomial(weights, 1).item() - op_name = self._op_names[ op_index ] - xlist.append((op_name, j)) - genotypes.append(tuple(xlist)) - return Structure(genotypes) + self.arch_parameters = nn.Parameter( + 1e-3 * torch.randn(self._num_edge, len(self._op_names)) + ) + if algo == "gdas": + self._tau = 10 - def get_log_prob(self, arch): - with torch.no_grad(): - logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) - select_logits = [] - for i, node_info in enumerate(arch.nodes): - for op, xin in node_info: - node_str = '{:}<-{:}'.format(i+1, xin) - op_index = self._op_names.index(op) - select_logits.append( logits[self.edge2index[node_str], op_index] ) - return sum(select_logits).item() + def set_cal_mode(self, mode, dynamic_cell=None): + assert mode in ["gdas", "enas", "urs", "joint", "select", "dynamic"] + self._mode = mode + if mode == "dynamic": + self.dynamic_cell = deepcopy(dynamic_cell) + else: + self.dynamic_cell = None - def return_topK(self, K, use_random=False): - archs = Structure.gen_all(self._op_names, self._max_nodes, False) - pairs = [(self.get_log_prob(arch), arch) for arch in archs] - if K < 0 or K >= len(archs): K = len(archs) - if use_random: - return random.sample(archs, K) - else: - sorted_pairs = sorted(pairs, key=lambda x: -x[0]) - return_pairs = [sorted_pairs[_][1] for _ in range(K)] - return return_pairs + def set_drop_path(self, progress, drop_path_rate): + if drop_path_rate is None: + self._drop_path = None + elif progress is None: + self._drop_path = drop_path_rate + else: + self._drop_path = progress * drop_path_rate - def normalize_archp(self): - if self.mode == 'gdas': - while True: - gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() - logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau - probs = nn.functional.softmax(logits, dim=1) - index = probs.max(-1, keepdim=True)[1] - one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) - hardwts = one_h - probs.detach() + probs - if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): - continue - else: break - with torch.no_grad(): - hardwts_cpu = hardwts.detach().cpu() - return hardwts, hardwts_cpu, index, 'GUMBEL' - else: - alphas = nn.functional.softmax(self.arch_parameters, dim=-1) - index = alphas.max(-1, keepdim=True)[1] - with torch.no_grad(): - alphas_cpu = alphas.detach().cpu() - return alphas, alphas_cpu, index, 'SOFTMAX' + @property + def mode(self): + return self._mode - def forward(self, inputs): - alphas, alphas_cpu, index, verbose_str = self.normalize_archp() - feature = self._stem(inputs) - for i, cell in enumerate(self._cells): - if isinstance(cell, SearchCell): - if self.mode == 'urs': - feature = cell.forward_urs(feature) - if self.verbose: - verbose_str += '-forward_urs' - elif self.mode == 'select': - feature = cell.forward_select(feature, alphas_cpu) - if self.verbose: - verbose_str += '-forward_select' - elif self.mode == 'joint': - feature = cell.forward_joint(feature, alphas) - if self.verbose: - verbose_str += '-forward_joint' - elif self.mode == 'dynamic': - feature = cell.forward_dynamic(feature, self.dynamic_cell) - if self.verbose: - verbose_str += '-forward_dynamic' - elif self.mode == 'gdas': - feature = cell.forward_gdas(feature, alphas, index) - if self.verbose: - verbose_str += '-forward_gdas' - else: raise ValueError('invalid mode={:}'.format(self.mode)) - else: feature = cell(feature) - if self.drop_path is not None: - feature = drop_path(feature, self.drop_path) - if self.verbose and random.random() < 0.001: - print(verbose_str) - out = self.lastact(feature) - out = self.global_pooling(out) - out = out.view(out.size(0), -1) - logits = self.classifier(out) - return out, logits + @property + def drop_path(self): + return self._drop_path + + @property + def weights(self): + xlist = list(self._stem.parameters()) + xlist += list(self._cells.parameters()) + xlist += list(self.lastact.parameters()) + xlist += list(self.global_pooling.parameters()) + xlist += list(self.classifier.parameters()) + return xlist + + def set_tau(self, tau): + self._tau = tau + + @property + def tau(self): + return self._tau + + @property + def alphas(self): + if self._algo == "enas": + return list(self.controller.parameters()) + else: + return [self.arch_parameters] + + @property + def message(self): + string = self.extra_repr() + for i, cell in enumerate(self._cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self._cells), cell.extra_repr() + ) + return string + + def show_alphas(self): + with torch.no_grad(): + if self._algo == "enas": + return "w_pred :\n{:}".format(self.controller.w_pred.weight) + else: + return "arch-parameters :\n{:}".format( + nn.functional.softmax(self.arch_parameters, dim=-1).cpu() + ) + + def extra_repr(self): + return "{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})".format( + name=self.__class__.__name__, **self.__dict__ + ) + + @property + def genotype(self): + genotypes = [] + for i in range(1, self._max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[self.edge2index[node_str]] + op_name = self._op_names[weights.argmax().item()] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return Structure(genotypes) + + def dync_genotype(self, use_random=False): + genotypes = [] + with torch.no_grad(): + alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) + for i in range(1, self._max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + if use_random: + op_name = random.choice(self._op_names) + else: + weights = alphas_cpu[self.edge2index[node_str]] + op_index = torch.multinomial(weights, 1).item() + op_name = self._op_names[op_index] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return Structure(genotypes) + + def get_log_prob(self, arch): + with torch.no_grad(): + logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) + select_logits = [] + for i, node_info in enumerate(arch.nodes): + for op, xin in node_info: + node_str = "{:}<-{:}".format(i + 1, xin) + op_index = self._op_names.index(op) + select_logits.append(logits[self.edge2index[node_str], op_index]) + return sum(select_logits).item() + + def return_topK(self, K, use_random=False): + archs = Structure.gen_all(self._op_names, self._max_nodes, False) + pairs = [(self.get_log_prob(arch), arch) for arch in archs] + if K < 0 or K >= len(archs): + K = len(archs) + if use_random: + return random.sample(archs, K) + else: + sorted_pairs = sorted(pairs, key=lambda x: -x[0]) + return_pairs = [sorted_pairs[_][1] for _ in range(K)] + return return_pairs + + def normalize_archp(self): + if self.mode == "gdas": + while True: + gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() + logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau + probs = nn.functional.softmax(logits, dim=1) + index = probs.max(-1, keepdim=True)[1] + one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) + hardwts = one_h - probs.detach() + probs + if ( + (torch.isinf(gumbels).any()) + or (torch.isinf(probs).any()) + or (torch.isnan(probs).any()) + ): + continue + else: + break + with torch.no_grad(): + hardwts_cpu = hardwts.detach().cpu() + return hardwts, hardwts_cpu, index, "GUMBEL" + else: + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + index = alphas.max(-1, keepdim=True)[1] + with torch.no_grad(): + alphas_cpu = alphas.detach().cpu() + return alphas, alphas_cpu, index, "SOFTMAX" + + def forward(self, inputs): + alphas, alphas_cpu, index, verbose_str = self.normalize_archp() + feature = self._stem(inputs) + for i, cell in enumerate(self._cells): + if isinstance(cell, SearchCell): + if self.mode == "urs": + feature = cell.forward_urs(feature) + if self.verbose: + verbose_str += "-forward_urs" + elif self.mode == "select": + feature = cell.forward_select(feature, alphas_cpu) + if self.verbose: + verbose_str += "-forward_select" + elif self.mode == "joint": + feature = cell.forward_joint(feature, alphas) + if self.verbose: + verbose_str += "-forward_joint" + elif self.mode == "dynamic": + feature = cell.forward_dynamic(feature, self.dynamic_cell) + if self.verbose: + verbose_str += "-forward_dynamic" + elif self.mode == "gdas": + feature = cell.forward_gdas(feature, alphas, index) + if self.verbose: + verbose_str += "-forward_gdas" + else: + raise ValueError("invalid mode={:}".format(self.mode)) + else: + feature = cell(feature) + if self.drop_path is not None: + feature = drop_path(feature, self.drop_path) + if self.verbose and random.random() < 0.001: + print(verbose_str) + out = self.lastact(feature) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + return out, logits diff --git a/lib/models/cell_searchs/genotypes.py b/lib/models/cell_searchs/genotypes.py index b2b4091..f0ec8f2 100644 --- a/lib/models/cell_searchs/genotypes.py +++ b/lib/models/cell_searchs/genotypes.py @@ -5,194 +5,270 @@ from copy import deepcopy def get_combination(space, num): - combs = [] - for i in range(num): - if i == 0: - for func in space: - combs.append( [(func, i)] ) - else: - new_combs = [] - for string in combs: - for func in space: - xstring = string + [(func, i)] - new_combs.append( xstring ) - combs = new_combs - return combs + combs = [] + for i in range(num): + if i == 0: + for func in space: + combs.append([(func, i)]) + else: + new_combs = [] + for string in combs: + for func in space: + xstring = string + [(func, i)] + new_combs.append(xstring) + combs = new_combs + return combs class Structure: + def __init__(self, genotype): + assert isinstance(genotype, list) or isinstance( + genotype, tuple + ), "invalid class of genotype : {:}".format(type(genotype)) + self.node_num = len(genotype) + 1 + self.nodes = [] + self.node_N = [] + for idx, node_info in enumerate(genotype): + assert isinstance(node_info, list) or isinstance( + node_info, tuple + ), "invalid class of node_info : {:}".format(type(node_info)) + assert len(node_info) >= 1, "invalid length : {:}".format(len(node_info)) + for node_in in node_info: + assert isinstance(node_in, list) or isinstance( + node_in, tuple + ), "invalid class of in-node : {:}".format(type(node_in)) + assert ( + len(node_in) == 2 and node_in[1] <= idx + ), "invalid in-node : {:}".format(node_in) + self.node_N.append(len(node_info)) + self.nodes.append(tuple(deepcopy(node_info))) - def __init__(self, genotype): - assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype)) - self.node_num = len(genotype) + 1 - self.nodes = [] - self.node_N = [] - for idx, node_info in enumerate(genotype): - assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info)) - assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info)) - for node_in in node_info: - assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in)) - assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in) - self.node_N.append( len(node_info) ) - self.nodes.append( tuple(deepcopy(node_info)) ) + def tolist(self, remove_str): + # convert this class to the list, if remove_str is 'none', then remove the 'none' operation. + # note that we re-order the input node in this function + # return the-genotype-list and success [if unsuccess, it is not a connectivity] + genotypes = [] + for node_info in self.nodes: + node_info = list(node_info) + node_info = sorted(node_info, key=lambda x: (x[1], x[0])) + node_info = tuple(filter(lambda x: x[0] != remove_str, node_info)) + if len(node_info) == 0: + return None, False + genotypes.append(node_info) + return genotypes, True - def tolist(self, remove_str): - # convert this class to the list, if remove_str is 'none', then remove the 'none' operation. - # note that we re-order the input node in this function - # return the-genotype-list and success [if unsuccess, it is not a connectivity] - genotypes = [] - for node_info in self.nodes: - node_info = list( node_info ) - node_info = sorted(node_info, key=lambda x: (x[1], x[0])) - node_info = tuple(filter(lambda x: x[0] != remove_str, node_info)) - if len(node_info) == 0: return None, False - genotypes.append( node_info ) - return genotypes, True + def node(self, index): + assert index > 0 and index <= len(self), "invalid index={:} < {:}".format( + index, len(self) + ) + return self.nodes[index] - def node(self, index): - assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self)) - return self.nodes[index] + def tostr(self): + strings = [] + for node_info in self.nodes: + string = "|".join([x[0] + "~{:}".format(x[1]) for x in node_info]) + string = "|{:}|".format(string) + strings.append(string) + return "+".join(strings) - def tostr(self): - strings = [] - for node_info in self.nodes: - string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info]) - string = '|{:}|'.format(string) - strings.append( string ) - return '+'.join(strings) + def check_valid(self): + nodes = {0: True} + for i, node_info in enumerate(self.nodes): + sums = [] + for op, xin in node_info: + if op == "none" or nodes[xin] is False: + x = False + else: + x = True + sums.append(x) + nodes[i + 1] = sum(sums) > 0 + return nodes[len(self.nodes)] - def check_valid(self): - nodes = {0: True} - for i, node_info in enumerate(self.nodes): - sums = [] - for op, xin in node_info: - if op == 'none' or nodes[xin] is False: x = False - else: x = True - sums.append( x ) - nodes[i+1] = sum(sums) > 0 - return nodes[len(self.nodes)] + def to_unique_str(self, consider_zero=False): + # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation + # two operations are special, i.e., none and skip_connect + nodes = {0: "0"} + for i_node, node_info in enumerate(self.nodes): + cur_node = [] + for op, xin in node_info: + if consider_zero is None: + x = "(" + nodes[xin] + ")" + "@{:}".format(op) + elif consider_zero: + if op == "none" or nodes[xin] == "#": + x = "#" # zero + elif op == "skip_connect": + x = nodes[xin] + else: + x = "(" + nodes[xin] + ")" + "@{:}".format(op) + else: + if op == "skip_connect": + x = nodes[xin] + else: + x = "(" + nodes[xin] + ")" + "@{:}".format(op) + cur_node.append(x) + nodes[i_node + 1] = "+".join(sorted(cur_node)) + return nodes[len(self.nodes)] - def to_unique_str(self, consider_zero=False): - # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation - # two operations are special, i.e., none and skip_connect - nodes = {0: '0'} - for i_node, node_info in enumerate(self.nodes): - cur_node = [] - for op, xin in node_info: - if consider_zero is None: - x = '('+nodes[xin]+')' + '@{:}'.format(op) - elif consider_zero: - if op == 'none' or nodes[xin] == '#': x = '#' # zero - elif op == 'skip_connect': x = nodes[xin] - else: x = '('+nodes[xin]+')' + '@{:}'.format(op) + def check_valid_op(self, op_names): + for node_info in self.nodes: + for inode_edge in node_info: + # assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0]) + if inode_edge[0] not in op_names: + return False + return True + + def __repr__(self): + return "{name}({node_num} nodes with {node_info})".format( + name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__ + ) + + def __len__(self): + return len(self.nodes) + 1 + + def __getitem__(self, index): + return self.nodes[index] + + @staticmethod + def str2structure(xstr): + if isinstance(xstr, Structure): + return xstr + assert isinstance(xstr, str), "must take string (not {:}) as input".format( + type(xstr) + ) + nodestrs = xstr.split("+") + genotypes = [] + for i, node_str in enumerate(nodestrs): + inputs = list(filter(lambda x: x != "", node_str.split("|"))) + for xinput in inputs: + assert len(xinput.split("~")) == 2, "invalid input length : {:}".format( + xinput + ) + inputs = (xi.split("~") for xi in inputs) + input_infos = tuple((op, int(IDX)) for (op, IDX) in inputs) + genotypes.append(input_infos) + return Structure(genotypes) + + @staticmethod + def str2fullstructure(xstr, default_name="none"): + assert isinstance(xstr, str), "must take string (not {:}) as input".format( + type(xstr) + ) + nodestrs = xstr.split("+") + genotypes = [] + for i, node_str in enumerate(nodestrs): + inputs = list(filter(lambda x: x != "", node_str.split("|"))) + for xinput in inputs: + assert len(xinput.split("~")) == 2, "invalid input length : {:}".format( + xinput + ) + inputs = (xi.split("~") for xi in inputs) + input_infos = list((op, int(IDX)) for (op, IDX) in inputs) + all_in_nodes = list(x[1] for x in input_infos) + for j in range(i): + if j not in all_in_nodes: + input_infos.append((default_name, j)) + node_info = sorted(input_infos, key=lambda x: (x[1], x[0])) + genotypes.append(tuple(node_info)) + return Structure(genotypes) + + @staticmethod + def gen_all(search_space, num, return_ori): + assert isinstance(search_space, list) or isinstance( + search_space, tuple + ), "invalid class of search-space : {:}".format(type(search_space)) + assert ( + num >= 2 + ), "There should be at least two nodes in a neural cell instead of {:}".format( + num + ) + all_archs = get_combination(search_space, 1) + for i, arch in enumerate(all_archs): + all_archs[i] = [tuple(arch)] + + for inode in range(2, num): + cur_nodes = get_combination(search_space, inode) + new_all_archs = [] + for previous_arch in all_archs: + for cur_node in cur_nodes: + new_all_archs.append(previous_arch + [tuple(cur_node)]) + all_archs = new_all_archs + if return_ori: + return all_archs else: - if op == 'skip_connect': x = nodes[xin] - else: x = '('+nodes[xin]+')' + '@{:}'.format(op) - cur_node.append(x) - nodes[i_node+1] = '+'.join( sorted(cur_node) ) - return nodes[ len(self.nodes) ] - - def check_valid_op(self, op_names): - for node_info in self.nodes: - for inode_edge in node_info: - #assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0]) - if inode_edge[0] not in op_names: return False - return True - - def __repr__(self): - return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__)) - - def __len__(self): - return len(self.nodes) + 1 - - def __getitem__(self, index): - return self.nodes[index] - - @staticmethod - def str2structure(xstr): - if isinstance(xstr, Structure): return xstr - assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) - nodestrs = xstr.split('+') - genotypes = [] - for i, node_str in enumerate(nodestrs): - inputs = list(filter(lambda x: x != '', node_str.split('|'))) - for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) - inputs = ( xi.split('~') for xi in inputs ) - input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs) - genotypes.append( input_infos ) - return Structure( genotypes ) - - @staticmethod - def str2fullstructure(xstr, default_name='none'): - assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) - nodestrs = xstr.split('+') - genotypes = [] - for i, node_str in enumerate(nodestrs): - inputs = list(filter(lambda x: x != '', node_str.split('|'))) - for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) - inputs = ( xi.split('~') for xi in inputs ) - input_infos = list( (op, int(IDX)) for (op, IDX) in inputs) - all_in_nodes= list(x[1] for x in input_infos) - for j in range(i): - if j not in all_in_nodes: input_infos.append((default_name, j)) - node_info = sorted(input_infos, key=lambda x: (x[1], x[0])) - genotypes.append( tuple(node_info) ) - return Structure( genotypes ) - - @staticmethod - def gen_all(search_space, num, return_ori): - assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space)) - assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num) - all_archs = get_combination(search_space, 1) - for i, arch in enumerate(all_archs): - all_archs[i] = [ tuple(arch) ] - - for inode in range(2, num): - cur_nodes = get_combination(search_space, inode) - new_all_archs = [] - for previous_arch in all_archs: - for cur_node in cur_nodes: - new_all_archs.append( previous_arch + [tuple(cur_node)] ) - all_archs = new_all_archs - if return_ori: - return all_archs - else: - return [Structure(x) for x in all_archs] - + return [Structure(x) for x in all_archs] ResNet_CODE = Structure( - [(('nor_conv_3x3', 0), ), # node-1 - (('nor_conv_3x3', 1), ), # node-2 - (('skip_connect', 0), ('skip_connect', 2))] # node-3 - ) + [ + (("nor_conv_3x3", 0),), # node-1 + (("nor_conv_3x3", 1),), # node-2 + (("skip_connect", 0), ("skip_connect", 2)), + ] # node-3 +) AllConv3x3_CODE = Structure( - [(('nor_conv_3x3', 0), ), # node-1 - (('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2 - (('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3 - ) + [ + (("nor_conv_3x3", 0),), # node-1 + (("nor_conv_3x3", 0), ("nor_conv_3x3", 1)), # node-2 + (("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)), + ] # node-3 +) AllFull_CODE = Structure( - [(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1 - (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2 - (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3 - ) + [ + ( + ("skip_connect", 0), + ("nor_conv_1x1", 0), + ("nor_conv_3x3", 0), + ("avg_pool_3x3", 0), + ), # node-1 + ( + ("skip_connect", 0), + ("nor_conv_1x1", 0), + ("nor_conv_3x3", 0), + ("avg_pool_3x3", 0), + ("skip_connect", 1), + ("nor_conv_1x1", 1), + ("nor_conv_3x3", 1), + ("avg_pool_3x3", 1), + ), # node-2 + ( + ("skip_connect", 0), + ("nor_conv_1x1", 0), + ("nor_conv_3x3", 0), + ("avg_pool_3x3", 0), + ("skip_connect", 1), + ("nor_conv_1x1", 1), + ("nor_conv_3x3", 1), + ("avg_pool_3x3", 1), + ("skip_connect", 2), + ("nor_conv_1x1", 2), + ("nor_conv_3x3", 2), + ("avg_pool_3x3", 2), + ), + ] # node-3 +) AllConv1x1_CODE = Structure( - [(('nor_conv_1x1', 0), ), # node-1 - (('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2 - (('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3 - ) + [ + (("nor_conv_1x1", 0),), # node-1 + (("nor_conv_1x1", 0), ("nor_conv_1x1", 1)), # node-2 + (("nor_conv_1x1", 0), ("nor_conv_1x1", 1), ("nor_conv_1x1", 2)), + ] # node-3 +) AllIdentity_CODE = Structure( - [(('skip_connect', 0), ), # node-1 - (('skip_connect', 0), ('skip_connect', 1)), # node-2 - (('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3 - ) + [ + (("skip_connect", 0),), # node-1 + (("skip_connect", 0), ("skip_connect", 1)), # node-2 + (("skip_connect", 0), ("skip_connect", 1), ("skip_connect", 2)), + ] # node-3 +) -architectures = {'resnet' : ResNet_CODE, - 'all_c3x3': AllConv3x3_CODE, - 'all_c1x1': AllConv1x1_CODE, - 'all_idnt': AllIdentity_CODE, - 'all_full': AllFull_CODE} +architectures = { + "resnet": ResNet_CODE, + "all_c3x3": AllConv3x3_CODE, + "all_c1x1": AllConv1x1_CODE, + "all_idnt": AllIdentity_CODE, + "all_full": AllFull_CODE, +} diff --git a/lib/models/cell_searchs/search_cells.py b/lib/models/cell_searchs/search_cells.py index 818a32c..9235823 100644 --- a/lib/models/cell_searchs/search_cells.py +++ b/lib/models/cell_searchs/search_cells.py @@ -11,191 +11,241 @@ from ..cell_operations import OPS # This module is used for NAS-Bench-201, represents a small search space with a complete DAG class NAS201SearchCell(nn.Module): + def __init__( + self, + C_in, + C_out, + stride, + max_nodes, + op_names, + affine=False, + track_running_stats=True, + ): + super(NAS201SearchCell, self).__init__() - def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): - super(NAS201SearchCell, self).__init__() + self.op_names = deepcopy(op_names) + self.edges = nn.ModuleDict() + self.max_nodes = max_nodes + self.in_dim = C_in + self.out_dim = C_out + for i in range(1, max_nodes): + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + if j == 0: + xlists = [ + OPS[op_name](C_in, C_out, stride, affine, track_running_stats) + for op_name in op_names + ] + else: + xlists = [ + OPS[op_name](C_in, C_out, 1, affine, track_running_stats) + for op_name in op_names + ] + self.edges[node_str] = nn.ModuleList(xlists) + self.edge_keys = sorted(list(self.edges.keys())) + self.edge2index = {key: i for i, key in enumerate(self.edge_keys)} + self.num_edges = len(self.edges) - self.op_names = deepcopy(op_names) - self.edges = nn.ModuleDict() - self.max_nodes = max_nodes - self.in_dim = C_in - self.out_dim = C_out - for i in range(1, max_nodes): - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - if j == 0: - xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names] - else: - xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names] - self.edges[ node_str ] = nn.ModuleList( xlists ) - self.edge_keys = sorted(list(self.edges.keys())) - self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} - self.num_edges = len(self.edges) + def extra_repr(self): + string = "info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}".format( + **self.__dict__ + ) + return string - def extra_repr(self): - string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) - return string + def forward(self, inputs, weightss): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + weights = weightss[self.edge2index[node_str]] + inter_nodes.append( + sum( + layer(nodes[j]) * w + for layer, w in zip(self.edges[node_str], weights) + ) + ) + nodes.append(sum(inter_nodes)) + return nodes[-1] - def forward(self, inputs, weightss): - nodes = [inputs] - for i in range(1, self.max_nodes): - inter_nodes = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - weights = weightss[ self.edge2index[node_str] ] - inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) - nodes.append( sum(inter_nodes) ) - return nodes[-1] + # GDAS + def forward_gdas(self, inputs, hardwts, index): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + weights = hardwts[self.edge2index[node_str]] + argmaxs = index[self.edge2index[node_str]].item() + weigsum = sum( + weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] + for _ie, edge in enumerate(self.edges[node_str]) + ) + inter_nodes.append(weigsum) + nodes.append(sum(inter_nodes)) + return nodes[-1] - # GDAS - def forward_gdas(self, inputs, hardwts, index): - nodes = [inputs] - for i in range(1, self.max_nodes): - inter_nodes = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - weights = hardwts[ self.edge2index[node_str] ] - argmaxs = index[ self.edge2index[node_str] ].item() - weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) - inter_nodes.append( weigsum ) - nodes.append( sum(inter_nodes) ) - return nodes[-1] + # joint + def forward_joint(self, inputs, weightss): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + weights = weightss[self.edge2index[node_str]] + # aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() + aggregation = sum( + layer(nodes[j]) * w + for layer, w in zip(self.edges[node_str], weights) + ) + inter_nodes.append(aggregation) + nodes.append(sum(inter_nodes)) + return nodes[-1] - # joint - def forward_joint(self, inputs, weightss): - nodes = [inputs] - for i in range(1, self.max_nodes): - inter_nodes = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - weights = weightss[ self.edge2index[node_str] ] - #aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() - aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) - inter_nodes.append( aggregation ) - nodes.append( sum(inter_nodes) ) - return nodes[-1] + # uniform random sampling per iteration, SETN + def forward_urs(self, inputs): + nodes = [inputs] + for i in range(1, self.max_nodes): + while True: # to avoid select zero for all ops + sops, has_non_zero = [], False + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + candidates = self.edges[node_str] + select_op = random.choice(candidates) + sops.append(select_op) + if not hasattr(select_op, "is_zero") or select_op.is_zero is False: + has_non_zero = True + if has_non_zero: + break + inter_nodes = [] + for j, select_op in enumerate(sops): + inter_nodes.append(select_op(nodes[j])) + nodes.append(sum(inter_nodes)) + return nodes[-1] - # uniform random sampling per iteration, SETN - def forward_urs(self, inputs): - nodes = [inputs] - for i in range(1, self.max_nodes): - while True: # to avoid select zero for all ops - sops, has_non_zero = [], False - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - candidates = self.edges[node_str] - select_op = random.choice(candidates) - sops.append( select_op ) - if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True - if has_non_zero: break - inter_nodes = [] - for j, select_op in enumerate(sops): - inter_nodes.append( select_op(nodes[j]) ) - nodes.append( sum(inter_nodes) ) - return nodes[-1] - - # select the argmax - def forward_select(self, inputs, weightss): - nodes = [inputs] - for i in range(1, self.max_nodes): - inter_nodes = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - weights = weightss[ self.edge2index[node_str] ] - inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) ) - #inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) - nodes.append( sum(inter_nodes) ) - return nodes[-1] - - # forward with a specific structure - def forward_dynamic(self, inputs, structure): - nodes = [inputs] - for i in range(1, self.max_nodes): - cur_op_node = structure.nodes[i-1] - inter_nodes = [] - for op_name, j in cur_op_node: - node_str = '{:}<-{:}'.format(i, j) - op_index = self.op_names.index( op_name ) - inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) ) - nodes.append( sum(inter_nodes) ) - return nodes[-1] + # select the argmax + def forward_select(self, inputs, weightss): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + weights = weightss[self.edge2index[node_str]] + inter_nodes.append( + self.edges[node_str][weights.argmax().item()](nodes[j]) + ) + # inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) + nodes.append(sum(inter_nodes)) + return nodes[-1] + # forward with a specific structure + def forward_dynamic(self, inputs, structure): + nodes = [inputs] + for i in range(1, self.max_nodes): + cur_op_node = structure.nodes[i - 1] + inter_nodes = [] + for op_name, j in cur_op_node: + node_str = "{:}<-{:}".format(i, j) + op_index = self.op_names.index(op_name) + inter_nodes.append(self.edges[node_str][op_index](nodes[j])) + nodes.append(sum(inter_nodes)) + return nodes[-1] class MixedOp(nn.Module): + def __init__(self, space, C, stride, affine, track_running_stats): + super(MixedOp, self).__init__() + self._ops = nn.ModuleList() + for primitive in space: + op = OPS[primitive](C, C, stride, affine, track_running_stats) + self._ops.append(op) - def __init__(self, space, C, stride, affine, track_running_stats): - super(MixedOp, self).__init__() - self._ops = nn.ModuleList() - for primitive in space: - op = OPS[primitive](C, C, stride, affine, track_running_stats) - self._ops.append(op) + def forward_gdas(self, x, weights, index): + return self._ops[index](x) * weights[index] - def forward_gdas(self, x, weights, index): - return self._ops[index](x) * weights[index] - - def forward_darts(self, x, weights): - return sum(w * op(x) for w, op in zip(weights, self._ops)) + def forward_darts(self, x, weights): + return sum(w * op(x) for w, op in zip(weights, self._ops)) # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 class NASNetSearchCell(nn.Module): + def __init__( + self, + space, + steps, + multiplier, + C_prev_prev, + C_prev, + C, + reduction, + reduction_prev, + affine, + track_running_stats, + ): + super(NASNetSearchCell, self).__init__() + self.reduction = reduction + self.op_names = deepcopy(space) + if reduction_prev: + self.preprocess0 = OPS["skip_connect"]( + C_prev_prev, C, 2, affine, track_running_stats + ) + else: + self.preprocess0 = OPS["nor_conv_1x1"]( + C_prev_prev, C, 1, affine, track_running_stats + ) + self.preprocess1 = OPS["nor_conv_1x1"]( + C_prev, C, 1, affine, track_running_stats + ) + self._steps = steps + self._multiplier = multiplier - def __init__(self, space, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): - super(NASNetSearchCell, self).__init__() - self.reduction = reduction - self.op_names = deepcopy(space) - if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) - else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) - self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) - self._steps = steps - self._multiplier = multiplier + self._ops = nn.ModuleList() + self.edges = nn.ModuleDict() + for i in range(self._steps): + for j in range(2 + i): + node_str = "{:}<-{:}".format( + i, j + ) # indicate the edge from node-(j) to node-(i+2) + stride = 2 if reduction and j < 2 else 1 + op = MixedOp(space, C, stride, affine, track_running_stats) + self.edges[node_str] = op + self.edge_keys = sorted(list(self.edges.keys())) + self.edge2index = {key: i for i, key in enumerate(self.edge_keys)} + self.num_edges = len(self.edges) - self._ops = nn.ModuleList() - self.edges = nn.ModuleDict() - for i in range(self._steps): - for j in range(2+i): - node_str = '{:}<-{:}'.format(i, j) # indicate the edge from node-(j) to node-(i+2) - stride = 2 if reduction and j < 2 else 1 - op = MixedOp(space, C, stride, affine, track_running_stats) - self.edges[ node_str ] = op - self.edge_keys = sorted(list(self.edges.keys())) - self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} - self.num_edges = len(self.edges) + @property + def multiplier(self): + return self._multiplier - @property - def multiplier(self): - return self._multiplier + def forward_gdas(self, s0, s1, weightss, indexs): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) - def forward_gdas(self, s0, s1, weightss, indexs): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) + states = [s0, s1] + for i in range(self._steps): + clist = [] + for j, h in enumerate(states): + node_str = "{:}<-{:}".format(i, j) + op = self.edges[node_str] + weights = weightss[self.edge2index[node_str]] + index = indexs[self.edge2index[node_str]].item() + clist.append(op.forward_gdas(h, weights, index)) + states.append(sum(clist)) - states = [s0, s1] - for i in range(self._steps): - clist = [] - for j, h in enumerate(states): - node_str = '{:}<-{:}'.format(i, j) - op = self.edges[ node_str ] - weights = weightss[ self.edge2index[node_str] ] - index = indexs[ self.edge2index[node_str] ].item() - clist.append( op.forward_gdas(h, weights, index) ) - states.append( sum(clist) ) + return torch.cat(states[-self._multiplier :], dim=1) - return torch.cat(states[-self._multiplier:], dim=1) + def forward_darts(self, s0, s1, weightss): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) - def forward_darts(self, s0, s1, weightss): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) + states = [s0, s1] + for i in range(self._steps): + clist = [] + for j, h in enumerate(states): + node_str = "{:}<-{:}".format(i, j) + op = self.edges[node_str] + weights = weightss[self.edge2index[node_str]] + clist.append(op.forward_darts(h, weights)) + states.append(sum(clist)) - states = [s0, s1] - for i in range(self._steps): - clist = [] - for j, h in enumerate(states): - node_str = '{:}<-{:}'.format(i, j) - op = self.edges[ node_str ] - weights = weightss[ self.edge2index[node_str] ] - clist.append( op.forward_darts(h, weights) ) - states.append( sum(clist) ) - - return torch.cat(states[-self._multiplier:], dim=1) + return torch.cat(states[-self._multiplier :], dim=1) diff --git a/lib/models/cell_searchs/search_model_darts.py b/lib/models/cell_searchs/search_model_darts.py index e7e61a7..31041b6 100644 --- a/lib/models/cell_searchs/search_model_darts.py +++ b/lib/models/cell_searchs/search_model_darts.py @@ -7,91 +7,116 @@ import torch import torch.nn as nn from copy import deepcopy from ..cell_operations import ResNetBasicblock -from .search_cells import NAS201SearchCell as SearchCell -from .genotypes import Structure +from .search_cells import NAS201SearchCell as SearchCell +from .genotypes import Structure class TinyNetworkDarts(nn.Module): + def __init__( + self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats + ): + super(TinyNetworkDarts, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) + ) - def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): - super(TinyNetworkDarts, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell( + C_prev, + C_curr, + 1, + max_nodes, + search_space, + affine, + track_running_stats, + ) + if num_edge is None: + num_edge, edge2index = cell.num_edges, cell.edge2index + else: + assert ( + num_edge == cell.num_edges and edge2index == cell.edge2index + ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) + self.cells.append(cell) + C_prev = cell.out_dim + self.op_names = deepcopy(search_space) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_parameters = nn.Parameter( + 1e-3 * torch.randn(num_edge, len(search_space)) + ) - def get_weights(self): - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist + def get_weights(self): + xlist = list(self.stem.parameters()) + list(self.cells.parameters()) + xlist += list(self.lastact.parameters()) + list( + self.global_pooling.parameters() + ) + xlist += list(self.classifier.parameters()) + return xlist - def get_alphas(self): - return [self.arch_parameters] + def get_alphas(self): + return [self.arch_parameters] - def show_alphas(self): - with torch.no_grad(): - return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() ) - - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def genotype(self): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) + def show_alphas(self): with torch.no_grad(): - weights = self.arch_parameters[ self.edge2index[node_str] ] - op_name = self.op_names[ weights.argmax().item() ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure( genotypes ) + return "arch-parameters :\n{:}".format( + nn.functional.softmax(self.arch_parameters, dim=-1).cpu() + ) - def forward(self, inputs): - alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - feature = cell(feature, alphas) - else: - feature = cell(feature) + def extra_repr(self): + return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[self.edge2index[node_str]] + op_name = self.op_names[weights.argmax().item()] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return Structure(genotypes) - return out, logits + def forward(self, inputs): + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell(feature, alphas) + else: + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/cell_searchs/search_model_darts_nasnet.py b/lib/models/cell_searchs/search_model_darts_nasnet.py index 6eb3278..7cfdb47 100644 --- a/lib/models/cell_searchs/search_model_darts_nasnet.py +++ b/lib/models/cell_searchs/search_model_darts_nasnet.py @@ -10,103 +10,169 @@ from .search_cells import NASNetSearchCell as SearchCell # The macro structure is based on NASNet class NASNetworkDARTS(nn.Module): + def __init__( + self, + C: int, + N: int, + steps: int, + multiplier: int, + stem_multiplier: int, + num_classes: int, + search_space: List[Text], + affine: bool, + track_running_stats: bool, + ): + super(NASNetworkDARTS, self).__init__() + self._C = C + self._layerN = N + self._steps = steps + self._multiplier = multiplier + self.stem = nn.Sequential( + nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C * stem_multiplier), + ) - def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, - num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): - super(NASNetworkDARTS, self).__init__() - self._C = C - self._layerN = N - self._steps = steps - self._multiplier = multiplier - self.stem = nn.Sequential( - nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C*stem_multiplier)) - - # config for each layer - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) - layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) + # config for each layer + layer_channels = ( + [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) + ) + layer_reductions = ( + [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) + ) - num_edge, edge2index = None, None - C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False + num_edge, edge2index = None, None + C_prev_prev, C_prev, C_curr, reduction_prev = ( + C * stem_multiplier, + C * stem_multiplier, + C, + False, + ) - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + cell = SearchCell( + search_space, + steps, + multiplier, + C_prev_prev, + C_prev, + C_curr, + reduction, + reduction_prev, + affine, + track_running_stats, + ) + if num_edge is None: + num_edge, edge2index = cell.num_edges, cell.edge2index + else: + assert ( + num_edge == cell.num_edges and edge2index == cell.edge2index + ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) + self.cells.append(cell) + C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction + self.op_names = deepcopy(search_space) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_normal_parameters = nn.Parameter( + 1e-3 * torch.randn(num_edge, len(search_space)) + ) + self.arch_reduce_parameters = nn.Parameter( + 1e-3 * torch.randn(num_edge, len(search_space)) + ) - def get_weights(self) -> List[torch.nn.Parameter]: - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist + def get_weights(self) -> List[torch.nn.Parameter]: + xlist = list(self.stem.parameters()) + list(self.cells.parameters()) + xlist += list(self.lastact.parameters()) + list( + self.global_pooling.parameters() + ) + xlist += list(self.classifier.parameters()) + return xlist - def get_alphas(self) -> List[torch.nn.Parameter]: - return [self.arch_normal_parameters, self.arch_reduce_parameters] + def get_alphas(self) -> List[torch.nn.Parameter]: + return [self.arch_normal_parameters, self.arch_reduce_parameters] - def show_alphas(self) -> Text: - with torch.no_grad(): - A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) - B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) - return '{:}\n{:}'.format(A, B) + def show_alphas(self) -> Text: + with torch.no_grad(): + A = "arch-normal-parameters :\n{:}".format( + nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() + ) + B = "arch-reduce-parameters :\n{:}".format( + nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() + ) + return "{:}\n{:}".format(A, B) - def get_message(self) -> Text: - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string + def get_message(self) -> Text: + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string - def extra_repr(self) -> Text: - return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + def extra_repr(self) -> Text: + return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) - def genotype(self) -> Dict[Text, List]: - def _parse(weights): - gene = [] - for i in range(self._steps): - edges = [] - for j in range(2+i): - node_str = '{:}<-{:}'.format(i, j) - ws = weights[ self.edge2index[node_str] ] - for k, op_name in enumerate(self.op_names): - if op_name == 'none': continue - edges.append( (op_name, j, ws[k]) ) - # (TODO) xuanyidong: - # Here the selected two edges might come from the same input node. - # And this case could be a problem that two edges will collapse into a single one - # due to our assumption -- at most one edge from an input node during evaluation. - edges = sorted(edges, key=lambda x: -x[-1]) - selected_edges = edges[:2] - gene.append( tuple(selected_edges) ) - return gene - with torch.no_grad(): - gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()) - gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()) - return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)), - 'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))} + def genotype(self) -> Dict[Text, List]: + def _parse(weights): + gene = [] + for i in range(self._steps): + edges = [] + for j in range(2 + i): + node_str = "{:}<-{:}".format(i, j) + ws = weights[self.edge2index[node_str]] + for k, op_name in enumerate(self.op_names): + if op_name == "none": + continue + edges.append((op_name, j, ws[k])) + # (TODO) xuanyidong: + # Here the selected two edges might come from the same input node. + # And this case could be a problem that two edges will collapse into a single one + # due to our assumption -- at most one edge from an input node during evaluation. + edges = sorted(edges, key=lambda x: -x[-1]) + selected_edges = edges[:2] + gene.append(tuple(selected_edges)) + return gene - def forward(self, inputs): + with torch.no_grad(): + gene_normal = _parse( + torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy() + ) + gene_reduce = _parse( + torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy() + ) + return { + "normal": gene_normal, + "normal_concat": list( + range(2 + self._steps - self._multiplier, self._steps + 2) + ), + "reduce": gene_reduce, + "reduce_concat": list( + range(2 + self._steps - self._multiplier, self._steps + 2) + ), + } - normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1) - reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1) + def forward(self, inputs): - s0 = s1 = self.stem(inputs) - for i, cell in enumerate(self.cells): - if cell.reduction: ww = reduce_w - else : ww = normal_w - s0, s1 = s1, cell.forward_darts(s0, s1, ww) - out = self.lastact(s1) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) + normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1) + reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1) - return out, logits + s0 = s1 = self.stem(inputs) + for i, cell in enumerate(self.cells): + if cell.reduction: + ww = reduce_w + else: + ww = normal_w + s0, s1 = s1, cell.forward_darts(s0, s1, ww) + out = self.lastact(s1) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/cell_searchs/search_model_enas.py b/lib/models/cell_searchs/search_model_enas.py index 58aca9c..7ba91d4 100644 --- a/lib/models/cell_searchs/search_model_enas.py +++ b/lib/models/cell_searchs/search_model_enas.py @@ -7,88 +7,108 @@ import torch import torch.nn as nn from copy import deepcopy from ..cell_operations import ResNetBasicblock -from .search_cells import NAS201SearchCell as SearchCell -from .genotypes import Structure +from .search_cells import NAS201SearchCell as SearchCell +from .genotypes import Structure from .search_model_enas_utils import Controller class TinyNetworkENAS(nn.Module): + def __init__( + self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats + ): + super(TinyNetworkENAS, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) + ) - def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): - super(TinyNetworkENAS, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - # to maintain the sampled architecture - self.sampled_arch = None + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell( + C_prev, + C_curr, + 1, + max_nodes, + search_space, + affine, + track_running_stats, + ) + if num_edge is None: + num_edge, edge2index = cell.num_edges, cell.edge2index + else: + assert ( + num_edge == cell.num_edges and edge2index == cell.edge2index + ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) + self.cells.append(cell) + C_prev = cell.out_dim + self.op_names = deepcopy(search_space) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + # to maintain the sampled architecture + self.sampled_arch = None - def update_arch(self, _arch): - if _arch is None: - self.sampled_arch = None - elif isinstance(_arch, Structure): - self.sampled_arch = _arch - elif isinstance(_arch, (list, tuple)): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - op_index = _arch[ self.edge2index[node_str] ] - op_name = self.op_names[ op_index ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - self.sampled_arch = Structure(genotypes) - else: - raise ValueError('invalid type of input architecture : {:}'.format(_arch)) - return self.sampled_arch - - def create_controller(self): - return Controller(len(self.edge2index), len(self.op_names)) + def update_arch(self, _arch): + if _arch is None: + self.sampled_arch = None + elif isinstance(_arch, Structure): + self.sampled_arch = _arch + elif isinstance(_arch, (list, tuple)): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + op_index = _arch[self.edge2index[node_str]] + op_name = self.op_names[op_index] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + self.sampled_arch = Structure(genotypes) + else: + raise ValueError("invalid type of input architecture : {:}".format(_arch)) + return self.sampled_arch - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string + def create_controller(self): + return Controller(len(self.edge2index), len(self.op_names)) - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string - def forward(self, inputs): + def extra_repr(self): + return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - feature = cell.forward_dynamic(feature, self.sampled_arch) - else: feature = cell(feature) + def forward(self, inputs): - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell.forward_dynamic(feature, self.sampled_arch) + else: + feature = cell(feature) - return out, logits + out = self.lastact(feature) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/cell_searchs/search_model_enas_utils.py b/lib/models/cell_searchs/search_model_enas_utils.py index e03f57b..71d5d0f 100644 --- a/lib/models/cell_searchs/search_model_enas_utils.py +++ b/lib/models/cell_searchs/search_model_enas_utils.py @@ -7,49 +7,68 @@ import torch import torch.nn as nn from torch.distributions.categorical import Categorical + class Controller(nn.Module): - # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py - def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0): - super(Controller, self).__init__() - # assign the attributes - self.num_edge = num_edge - self.num_ops = num_ops - self.lstm_size = lstm_size - self.lstm_N = lstm_num_layers - self.tanh_constant = tanh_constant - self.temperature = temperature - # create parameters - self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size))) - self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N) - self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) - self.w_pred = nn.Linear(self.lstm_size, self.num_ops) + # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py + def __init__( + self, + num_edge, + num_ops, + lstm_size=32, + lstm_num_layers=2, + tanh_constant=2.5, + temperature=5.0, + ): + super(Controller, self).__init__() + # assign the attributes + self.num_edge = num_edge + self.num_ops = num_ops + self.lstm_size = lstm_size + self.lstm_N = lstm_num_layers + self.tanh_constant = tanh_constant + self.temperature = temperature + # create parameters + self.register_parameter( + "input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size)) + ) + self.w_lstm = nn.LSTM( + input_size=self.lstm_size, + hidden_size=self.lstm_size, + num_layers=self.lstm_N, + ) + self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) + self.w_pred = nn.Linear(self.lstm_size, self.num_ops) - nn.init.uniform_(self.input_vars , -0.1, 0.1) - nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) - nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) - nn.init.uniform_(self.w_embd.weight , -0.1, 0.1) - nn.init.uniform_(self.w_pred.weight , -0.1, 0.1) + nn.init.uniform_(self.input_vars, -0.1, 0.1) + nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) + nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) + nn.init.uniform_(self.w_embd.weight, -0.1, 0.1) + nn.init.uniform_(self.w_pred.weight, -0.1, 0.1) - def forward(self): + def forward(self): - inputs, h0 = self.input_vars, None - log_probs, entropys, sampled_arch = [], [], [] - for iedge in range(self.num_edge): - outputs, h0 = self.w_lstm(inputs, h0) - - logits = self.w_pred(outputs) - logits = logits / self.temperature - logits = self.tanh_constant * torch.tanh(logits) - # distribution - op_distribution = Categorical(logits=logits) - op_index = op_distribution.sample() - sampled_arch.append( op_index.item() ) + inputs, h0 = self.input_vars, None + log_probs, entropys, sampled_arch = [], [], [] + for iedge in range(self.num_edge): + outputs, h0 = self.w_lstm(inputs, h0) - op_log_prob = op_distribution.log_prob(op_index) - log_probs.append( op_log_prob.view(-1) ) - op_entropy = op_distribution.entropy() - entropys.append( op_entropy.view(-1) ) - - # obtain the input embedding for the next step - inputs = self.w_embd(op_index) - return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch + logits = self.w_pred(outputs) + logits = logits / self.temperature + logits = self.tanh_constant * torch.tanh(logits) + # distribution + op_distribution = Categorical(logits=logits) + op_index = op_distribution.sample() + sampled_arch.append(op_index.item()) + + op_log_prob = op_distribution.log_prob(op_index) + log_probs.append(op_log_prob.view(-1)) + op_entropy = op_distribution.entropy() + entropys.append(op_entropy.view(-1)) + + # obtain the input embedding for the next step + inputs = self.w_embd(op_index) + return ( + torch.sum(torch.cat(log_probs)), + torch.sum(torch.cat(entropys)), + sampled_arch, + ) diff --git a/lib/models/cell_searchs/search_model_gdas.py b/lib/models/cell_searchs/search_model_gdas.py index 0400922..82f7b9a 100644 --- a/lib/models/cell_searchs/search_model_gdas.py +++ b/lib/models/cell_searchs/search_model_gdas.py @@ -5,107 +5,138 @@ import torch import torch.nn as nn from copy import deepcopy from ..cell_operations import ResNetBasicblock -from .search_cells import NAS201SearchCell as SearchCell -from .genotypes import Structure +from .search_cells import NAS201SearchCell as SearchCell +from .genotypes import Structure class TinyNetworkGDAS(nn.Module): - #def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True): - def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): - super(TinyNetworkGDAS, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + # def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True): + def __init__( + self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats + ): + super(TinyNetworkGDAS, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) + ) - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - self.tau = 10 + layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - def get_weights(self): - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell( + C_prev, + C_curr, + 1, + max_nodes, + search_space, + affine, + track_running_stats, + ) + if num_edge is None: + num_edge, edge2index = cell.num_edges, cell.edge2index + else: + assert ( + num_edge == cell.num_edges and edge2index == cell.edge2index + ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) + self.cells.append(cell) + C_prev = cell.out_dim + self.op_names = deepcopy(search_space) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_parameters = nn.Parameter( + 1e-3 * torch.randn(num_edge, len(search_space)) + ) + self.tau = 10 - def set_tau(self, tau): - self.tau = tau + def get_weights(self): + xlist = list(self.stem.parameters()) + list(self.cells.parameters()) + xlist += list(self.lastact.parameters()) + list( + self.global_pooling.parameters() + ) + xlist += list(self.classifier.parameters()) + return xlist - def get_tau(self): - return self.tau + def set_tau(self, tau): + self.tau = tau - def get_alphas(self): - return [self.arch_parameters] + def get_tau(self): + return self.tau - def show_alphas(self): - with torch.no_grad(): - return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() ) + def get_alphas(self): + return [self.arch_parameters] - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def genotype(self): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) + def show_alphas(self): with torch.no_grad(): - weights = self.arch_parameters[ self.edge2index[node_str] ] - op_name = self.op_names[ weights.argmax().item() ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure( genotypes ) + return "arch-parameters :\n{:}".format( + nn.functional.softmax(self.arch_parameters, dim=-1).cpu() + ) - def forward(self, inputs): - while True: - gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() - logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau - probs = nn.functional.softmax(logits, dim=1) - index = probs.max(-1, keepdim=True)[1] - one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) - hardwts = one_h - probs.detach() + probs - if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): - continue - else: break + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - feature = cell.forward_gdas(feature, hardwts, index) - else: - feature = cell(feature) - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) + def extra_repr(self): + return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) - return out, logits + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[self.edge2index[node_str]] + op_name = self.op_names[weights.argmax().item()] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return Structure(genotypes) + + def forward(self, inputs): + while True: + gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() + logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau + probs = nn.functional.softmax(logits, dim=1) + index = probs.max(-1, keepdim=True)[1] + one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) + hardwts = one_h - probs.detach() + probs + if ( + (torch.isinf(gumbels).any()) + or (torch.isinf(probs).any()) + or (torch.isnan(probs).any()) + ): + continue + else: + break + + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell.forward_gdas(feature, hardwts, index) + else: + feature = cell(feature) + out = self.lastact(feature) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/cell_searchs/search_model_gdas_frc_nasnet.py b/lib/models/cell_searchs/search_model_gdas_frc_nasnet.py index 06aaf93..ee04d1e 100644 --- a/lib/models/cell_searchs/search_model_gdas_frc_nasnet.py +++ b/lib/models/cell_searchs/search_model_gdas_frc_nasnet.py @@ -10,116 +10,190 @@ from models.cell_operations import RAW_OP_CLASSES # The macro structure is based on NASNet class NASNetworkGDAS_FRC(nn.Module): + def __init__( + self, + C, + N, + steps, + multiplier, + stem_multiplier, + num_classes, + search_space, + affine, + track_running_stats, + ): + super(NASNetworkGDAS_FRC, self).__init__() + self._C = C + self._layerN = N + self._steps = steps + self._multiplier = multiplier + self.stem = nn.Sequential( + nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C * stem_multiplier), + ) - def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): - super(NASNetworkGDAS_FRC, self).__init__() - self._C = C - self._layerN = N - self._steps = steps - self._multiplier = multiplier - self.stem = nn.Sequential( - nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C*stem_multiplier)) - - # config for each layer - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) - layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) + # config for each layer + layer_channels = ( + [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) + ) + layer_reductions = ( + [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) + ) - num_edge, edge2index = None, None - C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False + num_edge, edge2index = None, None + C_prev_prev, C_prev, C_curr, reduction_prev = ( + C * stem_multiplier, + C * stem_multiplier, + C, + False, + ) - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = RAW_OP_CLASSES['gdas_reduction'](C_prev_prev, C_prev, C_curr, reduction_prev, affine, track_running_stats) - else: - cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert reduction or num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev_prev, C_prev, reduction_prev = C_prev, cell.multiplier * C_curr, reduction - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - self.tau = 10 + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + if reduction: + cell = RAW_OP_CLASSES["gdas_reduction"]( + C_prev_prev, + C_prev, + C_curr, + reduction_prev, + affine, + track_running_stats, + ) + else: + cell = SearchCell( + search_space, + steps, + multiplier, + C_prev_prev, + C_prev, + C_curr, + reduction, + reduction_prev, + affine, + track_running_stats, + ) + if num_edge is None: + num_edge, edge2index = cell.num_edges, cell.edge2index + else: + assert ( + reduction + or num_edge == cell.num_edges + and edge2index == cell.edge2index + ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) + self.cells.append(cell) + C_prev_prev, C_prev, reduction_prev = ( + C_prev, + cell.multiplier * C_curr, + reduction, + ) + self.op_names = deepcopy(search_space) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_parameters = nn.Parameter( + 1e-3 * torch.randn(num_edge, len(search_space)) + ) + self.tau = 10 - def get_weights(self): - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist + def get_weights(self): + xlist = list(self.stem.parameters()) + list(self.cells.parameters()) + xlist += list(self.lastact.parameters()) + list( + self.global_pooling.parameters() + ) + xlist += list(self.classifier.parameters()) + return xlist - def set_tau(self, tau): - self.tau = tau + def set_tau(self, tau): + self.tau = tau - def get_tau(self): - return self.tau + def get_tau(self): + return self.tau - def get_alphas(self): - return [self.arch_parameters] + def get_alphas(self): + return [self.arch_parameters] - def show_alphas(self): - with torch.no_grad(): - A = 'arch-normal-parameters :\n{:}'.format(nn.functional.softmax(self.arch_parameters, dim=-1).cpu()) - return '{:}'.format(A) + def show_alphas(self): + with torch.no_grad(): + A = "arch-normal-parameters :\n{:}".format( + nn.functional.softmax(self.arch_parameters, dim=-1).cpu() + ) + return "{:}".format(A) - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string - def extra_repr(self): - return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + def extra_repr(self): + return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) - def genotype(self): - def _parse(weights): - gene = [] - for i in range(self._steps): - edges = [] - for j in range(2+i): - node_str = '{:}<-{:}'.format(i, j) - ws = weights[ self.edge2index[node_str] ] - for k, op_name in enumerate(self.op_names): - if op_name == 'none': continue - edges.append( (op_name, j, ws[k]) ) - edges = sorted(edges, key=lambda x: -x[-1]) - selected_edges = edges[:2] - gene.append( tuple(selected_edges) ) - return gene - with torch.no_grad(): - gene_normal = _parse(torch.softmax(self.arch_parameters, dim=-1).cpu().numpy()) - return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2))} + def genotype(self): + def _parse(weights): + gene = [] + for i in range(self._steps): + edges = [] + for j in range(2 + i): + node_str = "{:}<-{:}".format(i, j) + ws = weights[self.edge2index[node_str]] + for k, op_name in enumerate(self.op_names): + if op_name == "none": + continue + edges.append((op_name, j, ws[k])) + edges = sorted(edges, key=lambda x: -x[-1]) + selected_edges = edges[:2] + gene.append(tuple(selected_edges)) + return gene - def forward(self, inputs): - def get_gumbel_prob(xins): - while True: - gumbels = -torch.empty_like(xins).exponential_().log() - logits = (xins.log_softmax(dim=1) + gumbels) / self.tau - probs = nn.functional.softmax(logits, dim=1) - index = probs.max(-1, keepdim=True)[1] - one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) - hardwts = one_h - probs.detach() + probs - if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): - continue - else: break - return hardwts, index + with torch.no_grad(): + gene_normal = _parse( + torch.softmax(self.arch_parameters, dim=-1).cpu().numpy() + ) + return { + "normal": gene_normal, + "normal_concat": list( + range(2 + self._steps - self._multiplier, self._steps + 2) + ), + } - hardwts, index = get_gumbel_prob(self.arch_parameters) + def forward(self, inputs): + def get_gumbel_prob(xins): + while True: + gumbels = -torch.empty_like(xins).exponential_().log() + logits = (xins.log_softmax(dim=1) + gumbels) / self.tau + probs = nn.functional.softmax(logits, dim=1) + index = probs.max(-1, keepdim=True)[1] + one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) + hardwts = one_h - probs.detach() + probs + if ( + (torch.isinf(gumbels).any()) + or (torch.isinf(probs).any()) + or (torch.isnan(probs).any()) + ): + continue + else: + break + return hardwts, index - s0 = s1 = self.stem(inputs) - for i, cell in enumerate(self.cells): - if cell.reduction: - s0, s1 = s1, cell(s0, s1) - else: - s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) - out = self.lastact(s1) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) + hardwts, index = get_gumbel_prob(self.arch_parameters) - return out, logits + s0 = s1 = self.stem(inputs) + for i, cell in enumerate(self.cells): + if cell.reduction: + s0, s1 = s1, cell(s0, s1) + else: + s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) + out = self.lastact(s1) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/cell_searchs/search_model_gdas_nasnet.py b/lib/models/cell_searchs/search_model_gdas_nasnet.py index 2115ce4..5aff5d3 100644 --- a/lib/models/cell_searchs/search_model_gdas_nasnet.py +++ b/lib/models/cell_searchs/search_model_gdas_nasnet.py @@ -9,117 +9,189 @@ from .search_cells import NASNetSearchCell as SearchCell # The macro structure is based on NASNet class NASNetworkGDAS(nn.Module): + def __init__( + self, + C, + N, + steps, + multiplier, + stem_multiplier, + num_classes, + search_space, + affine, + track_running_stats, + ): + super(NASNetworkGDAS, self).__init__() + self._C = C + self._layerN = N + self._steps = steps + self._multiplier = multiplier + self.stem = nn.Sequential( + nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C * stem_multiplier), + ) - def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): - super(NASNetworkGDAS, self).__init__() - self._C = C - self._layerN = N - self._steps = steps - self._multiplier = multiplier - self.stem = nn.Sequential( - nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C*stem_multiplier)) - - # config for each layer - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) - layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) + # config for each layer + layer_channels = ( + [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) + ) + layer_reductions = ( + [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) + ) - num_edge, edge2index = None, None - C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False + num_edge, edge2index = None, None + C_prev_prev, C_prev, C_curr, reduction_prev = ( + C * stem_multiplier, + C * stem_multiplier, + C, + False, + ) - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - self.tau = 10 + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + cell = SearchCell( + search_space, + steps, + multiplier, + C_prev_prev, + C_prev, + C_curr, + reduction, + reduction_prev, + affine, + track_running_stats, + ) + if num_edge is None: + num_edge, edge2index = cell.num_edges, cell.edge2index + else: + assert ( + num_edge == cell.num_edges and edge2index == cell.edge2index + ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) + self.cells.append(cell) + C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction + self.op_names = deepcopy(search_space) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_normal_parameters = nn.Parameter( + 1e-3 * torch.randn(num_edge, len(search_space)) + ) + self.arch_reduce_parameters = nn.Parameter( + 1e-3 * torch.randn(num_edge, len(search_space)) + ) + self.tau = 10 - def get_weights(self): - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist + def get_weights(self): + xlist = list(self.stem.parameters()) + list(self.cells.parameters()) + xlist += list(self.lastact.parameters()) + list( + self.global_pooling.parameters() + ) + xlist += list(self.classifier.parameters()) + return xlist - def set_tau(self, tau): - self.tau = tau + def set_tau(self, tau): + self.tau = tau - def get_tau(self): - return self.tau + def get_tau(self): + return self.tau - def get_alphas(self): - return [self.arch_normal_parameters, self.arch_reduce_parameters] + def get_alphas(self): + return [self.arch_normal_parameters, self.arch_reduce_parameters] - def show_alphas(self): - with torch.no_grad(): - A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) - B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) - return '{:}\n{:}'.format(A, B) + def show_alphas(self): + with torch.no_grad(): + A = "arch-normal-parameters :\n{:}".format( + nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() + ) + B = "arch-reduce-parameters :\n{:}".format( + nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() + ) + return "{:}\n{:}".format(A, B) - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string - def extra_repr(self): - return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + def extra_repr(self): + return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) - def genotype(self): - def _parse(weights): - gene = [] - for i in range(self._steps): - edges = [] - for j in range(2+i): - node_str = '{:}<-{:}'.format(i, j) - ws = weights[ self.edge2index[node_str] ] - for k, op_name in enumerate(self.op_names): - if op_name == 'none': continue - edges.append( (op_name, j, ws[k]) ) - edges = sorted(edges, key=lambda x: -x[-1]) - selected_edges = edges[:2] - gene.append( tuple(selected_edges) ) - return gene - with torch.no_grad(): - gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()) - gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()) - return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)), - 'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))} + def genotype(self): + def _parse(weights): + gene = [] + for i in range(self._steps): + edges = [] + for j in range(2 + i): + node_str = "{:}<-{:}".format(i, j) + ws = weights[self.edge2index[node_str]] + for k, op_name in enumerate(self.op_names): + if op_name == "none": + continue + edges.append((op_name, j, ws[k])) + edges = sorted(edges, key=lambda x: -x[-1]) + selected_edges = edges[:2] + gene.append(tuple(selected_edges)) + return gene - def forward(self, inputs): - def get_gumbel_prob(xins): - while True: - gumbels = -torch.empty_like(xins).exponential_().log() - logits = (xins.log_softmax(dim=1) + gumbels) / self.tau - probs = nn.functional.softmax(logits, dim=1) - index = probs.max(-1, keepdim=True)[1] - one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) - hardwts = one_h - probs.detach() + probs - if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): - continue - else: break - return hardwts, index + with torch.no_grad(): + gene_normal = _parse( + torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy() + ) + gene_reduce = _parse( + torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy() + ) + return { + "normal": gene_normal, + "normal_concat": list( + range(2 + self._steps - self._multiplier, self._steps + 2) + ), + "reduce": gene_reduce, + "reduce_concat": list( + range(2 + self._steps - self._multiplier, self._steps + 2) + ), + } - normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters) - reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters) + def forward(self, inputs): + def get_gumbel_prob(xins): + while True: + gumbels = -torch.empty_like(xins).exponential_().log() + logits = (xins.log_softmax(dim=1) + gumbels) / self.tau + probs = nn.functional.softmax(logits, dim=1) + index = probs.max(-1, keepdim=True)[1] + one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) + hardwts = one_h - probs.detach() + probs + if ( + (torch.isinf(gumbels).any()) + or (torch.isinf(probs).any()) + or (torch.isnan(probs).any()) + ): + continue + else: + break + return hardwts, index - s0 = s1 = self.stem(inputs) - for i, cell in enumerate(self.cells): - if cell.reduction: hardwts, index = reduce_hardwts, reduce_index - else : hardwts, index = normal_hardwts, normal_index - s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) - out = self.lastact(s1) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) + normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters) + reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters) - return out, logits + s0 = s1 = self.stem(inputs) + for i, cell in enumerate(self.cells): + if cell.reduction: + hardwts, index = reduce_hardwts, reduce_index + else: + hardwts, index = normal_hardwts, normal_index + s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) + out = self.lastact(s1) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/cell_searchs/search_model_random.py b/lib/models/cell_searchs/search_model_random.py index 3345577..611dc75 100644 --- a/lib/models/cell_searchs/search_model_random.py +++ b/lib/models/cell_searchs/search_model_random.py @@ -1,81 +1,102 @@ ################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ############################################################################## -# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 # +# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 # ############################################################################## import torch, random import torch.nn as nn from copy import deepcopy from ..cell_operations import ResNetBasicblock -from .search_cells import NAS201SearchCell as SearchCell -from .genotypes import Structure +from .search_cells import NAS201SearchCell as SearchCell +from .genotypes import Structure class TinyNetworkRANDOM(nn.Module): + def __init__( + self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats + ): + super(TinyNetworkRANDOM, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) + ) - def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): - super(TinyNetworkRANDOM, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_cache = None - - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell( + C_prev, + C_curr, + 1, + max_nodes, + search_space, + affine, + track_running_stats, + ) + if num_edge is None: + num_edge, edge2index = cell.num_edges, cell.edge2index + else: + assert ( + num_edge == cell.num_edges and edge2index == cell.edge2index + ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) + self.cells.append(cell) + C_prev = cell.out_dim + self.op_names = deepcopy(search_space) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_cache = None - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string - def random_genotype(self, set_cache): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - op_name = random.choice( self.op_names ) - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - arch = Structure( genotypes ) - if set_cache: self.arch_cache = arch - return arch + def extra_repr(self): + return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) - def forward(self, inputs): + def random_genotype(self, set_cache): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + op_name = random.choice(self.op_names) + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + arch = Structure(genotypes) + if set_cache: + self.arch_cache = arch + return arch - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - feature = cell.forward_dynamic(feature, self.arch_cache) - else: feature = cell(feature) + def forward(self, inputs): - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) - return out, logits + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell.forward_dynamic(feature, self.arch_cache) + else: + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + return out, logits diff --git a/lib/models/cell_searchs/search_model_setn.py b/lib/models/cell_searchs/search_model_setn.py index 83d7659..ce38be9 100644 --- a/lib/models/cell_searchs/search_model_setn.py +++ b/lib/models/cell_searchs/search_model_setn.py @@ -7,146 +7,172 @@ import torch, random import torch.nn as nn from copy import deepcopy from ..cell_operations import ResNetBasicblock -from .search_cells import NAS201SearchCell as SearchCell -from .genotypes import Structure +from .search_cells import NAS201SearchCell as SearchCell +from .genotypes import Structure class TinyNetworkSETN(nn.Module): + def __init__( + self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats + ): + super(TinyNetworkSETN, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) + ) - def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): - super(TinyNetworkSETN, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - self.mode = 'urs' - self.dynamic_cell = None - - def set_cal_mode(self, mode, dynamic_cell=None): - assert mode in ['urs', 'joint', 'select', 'dynamic'] - self.mode = mode - if mode == 'dynamic': self.dynamic_cell = deepcopy( dynamic_cell ) - else : self.dynamic_cell = None + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell( + C_prev, + C_curr, + 1, + max_nodes, + search_space, + affine, + track_running_stats, + ) + if num_edge is None: + num_edge, edge2index = cell.num_edges, cell.edge2index + else: + assert ( + num_edge == cell.num_edges and edge2index == cell.edge2index + ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) + self.cells.append(cell) + C_prev = cell.out_dim + self.op_names = deepcopy(search_space) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_parameters = nn.Parameter( + 1e-3 * torch.randn(num_edge, len(search_space)) + ) + self.mode = "urs" + self.dynamic_cell = None - def get_cal_mode(self): - return self.mode - - def get_weights(self): - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist - - def get_alphas(self): - return [self.arch_parameters] - - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def genotype(self): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - with torch.no_grad(): - weights = self.arch_parameters[ self.edge2index[node_str] ] - op_name = self.op_names[ weights.argmax().item() ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure( genotypes ) - - def dync_genotype(self, use_random=False): - genotypes = [] - with torch.no_grad(): - alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - if use_random: - op_name = random.choice(self.op_names) + def set_cal_mode(self, mode, dynamic_cell=None): + assert mode in ["urs", "joint", "select", "dynamic"] + self.mode = mode + if mode == "dynamic": + self.dynamic_cell = deepcopy(dynamic_cell) else: - weights = alphas_cpu[ self.edge2index[node_str] ] - op_index = torch.multinomial(weights, 1).item() - op_name = self.op_names[ op_index ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure( genotypes ) + self.dynamic_cell = None - def get_log_prob(self, arch): - with torch.no_grad(): - logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) - select_logits = [] - for i, node_info in enumerate(arch.nodes): - for op, xin in node_info: - node_str = '{:}<-{:}'.format(i+1, xin) - op_index = self.op_names.index(op) - select_logits.append( logits[self.edge2index[node_str], op_index] ) - return sum(select_logits).item() + def get_cal_mode(self): + return self.mode + def get_weights(self): + xlist = list(self.stem.parameters()) + list(self.cells.parameters()) + xlist += list(self.lastact.parameters()) + list( + self.global_pooling.parameters() + ) + xlist += list(self.classifier.parameters()) + return xlist - def return_topK(self, K): - archs = Structure.gen_all(self.op_names, self.max_nodes, False) - pairs = [(self.get_log_prob(arch), arch) for arch in archs] - if K < 0 or K >= len(archs): K = len(archs) - sorted_pairs = sorted(pairs, key=lambda x: -x[0]) - return_pairs = [sorted_pairs[_][1] for _ in range(K)] - return return_pairs + def get_alphas(self): + return [self.arch_parameters] + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string - def forward(self, inputs): - alphas = nn.functional.softmax(self.arch_parameters, dim=-1) - with torch.no_grad(): - alphas_cpu = alphas.detach().cpu() + def extra_repr(self): + return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - if self.mode == 'urs': - feature = cell.forward_urs(feature) - elif self.mode == 'select': - feature = cell.forward_select(feature, alphas_cpu) - elif self.mode == 'joint': - feature = cell.forward_joint(feature, alphas) - elif self.mode == 'dynamic': - feature = cell.forward_dynamic(feature, self.dynamic_cell) - else: raise ValueError('invalid mode={:}'.format(self.mode)) - else: feature = cell(feature) + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[self.edge2index[node_str]] + op_name = self.op_names[weights.argmax().item()] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return Structure(genotypes) - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) + def dync_genotype(self, use_random=False): + genotypes = [] + with torch.no_grad(): + alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + if use_random: + op_name = random.choice(self.op_names) + else: + weights = alphas_cpu[self.edge2index[node_str]] + op_index = torch.multinomial(weights, 1).item() + op_name = self.op_names[op_index] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return Structure(genotypes) - return out, logits + def get_log_prob(self, arch): + with torch.no_grad(): + logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) + select_logits = [] + for i, node_info in enumerate(arch.nodes): + for op, xin in node_info: + node_str = "{:}<-{:}".format(i + 1, xin) + op_index = self.op_names.index(op) + select_logits.append(logits[self.edge2index[node_str], op_index]) + return sum(select_logits).item() + + def return_topK(self, K): + archs = Structure.gen_all(self.op_names, self.max_nodes, False) + pairs = [(self.get_log_prob(arch), arch) for arch in archs] + if K < 0 or K >= len(archs): + K = len(archs) + sorted_pairs = sorted(pairs, key=lambda x: -x[0]) + return_pairs = [sorted_pairs[_][1] for _ in range(K)] + return return_pairs + + def forward(self, inputs): + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + with torch.no_grad(): + alphas_cpu = alphas.detach().cpu() + + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + if self.mode == "urs": + feature = cell.forward_urs(feature) + elif self.mode == "select": + feature = cell.forward_select(feature, alphas_cpu) + elif self.mode == "joint": + feature = cell.forward_joint(feature, alphas) + elif self.mode == "dynamic": + feature = cell.forward_dynamic(feature, self.dynamic_cell) + else: + raise ValueError("invalid mode={:}".format(self.mode)) + else: + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/cell_searchs/search_model_setn_nasnet.py b/lib/models/cell_searchs/search_model_setn_nasnet.py index 9082032..c406fc3 100644 --- a/lib/models/cell_searchs/search_model_setn_nasnet.py +++ b/lib/models/cell_searchs/search_model_setn_nasnet.py @@ -7,133 +7,199 @@ import torch import torch.nn as nn from copy import deepcopy from typing import List, Text, Dict -from .search_cells import NASNetSearchCell as SearchCell +from .search_cells import NASNetSearchCell as SearchCell # The macro structure is based on NASNet class NASNetworkSETN(nn.Module): + def __init__( + self, + C: int, + N: int, + steps: int, + multiplier: int, + stem_multiplier: int, + num_classes: int, + search_space: List[Text], + affine: bool, + track_running_stats: bool, + ): + super(NASNetworkSETN, self).__init__() + self._C = C + self._layerN = N + self._steps = steps + self._multiplier = multiplier + self.stem = nn.Sequential( + nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C * stem_multiplier), + ) - def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, - num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): - super(NASNetworkSETN, self).__init__() - self._C = C - self._layerN = N - self._steps = steps - self._multiplier = multiplier - self.stem = nn.Sequential( - nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C*stem_multiplier)) - - # config for each layer - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) - layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) + # config for each layer + layer_channels = ( + [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) + ) + layer_reductions = ( + [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) + ) - num_edge, edge2index = None, None - C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False + num_edge, edge2index = None, None + C_prev_prev, C_prev, C_curr, reduction_prev = ( + C * stem_multiplier, + C * stem_multiplier, + C, + False, + ) - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - self.mode = 'urs' - self.dynamic_cell = None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate( + zip(layer_channels, layer_reductions) + ): + cell = SearchCell( + search_space, + steps, + multiplier, + C_prev_prev, + C_prev, + C_curr, + reduction, + reduction_prev, + affine, + track_running_stats, + ) + if num_edge is None: + num_edge, edge2index = cell.num_edges, cell.edge2index + else: + assert ( + num_edge == cell.num_edges and edge2index == cell.edge2index + ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) + self.cells.append(cell) + C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction + self.op_names = deepcopy(search_space) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_normal_parameters = nn.Parameter( + 1e-3 * torch.randn(num_edge, len(search_space)) + ) + self.arch_reduce_parameters = nn.Parameter( + 1e-3 * torch.randn(num_edge, len(search_space)) + ) + self.mode = "urs" + self.dynamic_cell = None - def set_cal_mode(self, mode, dynamic_cell=None): - assert mode in ['urs', 'joint', 'select', 'dynamic'] - self.mode = mode - if mode == 'dynamic': - self.dynamic_cell = deepcopy(dynamic_cell) - else: - self.dynamic_cell = None - - def get_weights(self): - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist - - def get_alphas(self): - return [self.arch_normal_parameters, self.arch_reduce_parameters] - - def show_alphas(self): - with torch.no_grad(): - A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) - B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) - return '{:}\n{:}'.format(A, B) - - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def dync_genotype(self, use_random=False): - genotypes = [] - with torch.no_grad(): - alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - if use_random: - op_name = random.choice(self.op_names) + def set_cal_mode(self, mode, dynamic_cell=None): + assert mode in ["urs", "joint", "select", "dynamic"] + self.mode = mode + if mode == "dynamic": + self.dynamic_cell = deepcopy(dynamic_cell) else: - weights = alphas_cpu[ self.edge2index[node_str] ] - op_index = torch.multinomial(weights, 1).item() - op_name = self.op_names[ op_index ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure( genotypes ) + self.dynamic_cell = None - def genotype(self): - def _parse(weights): - gene = [] - for i in range(self._steps): - edges = [] - for j in range(2+i): - node_str = '{:}<-{:}'.format(i, j) - ws = weights[ self.edge2index[node_str] ] - for k, op_name in enumerate(self.op_names): - if op_name == 'none': continue - edges.append( (op_name, j, ws[k]) ) - edges = sorted(edges, key=lambda x: -x[-1]) - selected_edges = edges[:2] - gene.append( tuple(selected_edges) ) - return gene - with torch.no_grad(): - gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()) - gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()) - return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)), - 'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))} + def get_weights(self): + xlist = list(self.stem.parameters()) + list(self.cells.parameters()) + xlist += list(self.lastact.parameters()) + list( + self.global_pooling.parameters() + ) + xlist += list(self.classifier.parameters()) + return xlist - def forward(self, inputs): - normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1) - reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1) + def get_alphas(self): + return [self.arch_normal_parameters, self.arch_reduce_parameters] - s0 = s1 = self.stem(inputs) - for i, cell in enumerate(self.cells): - # [TODO] - raise NotImplementedError - if cell.reduction: hardwts, index = reduce_hardwts, reduce_index - else : hardwts, index = normal_hardwts, normal_index - s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) - out = self.lastact(s1) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) + def show_alphas(self): + with torch.no_grad(): + A = "arch-normal-parameters :\n{:}".format( + nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() + ) + B = "arch-reduce-parameters :\n{:}".format( + nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() + ) + return "{:}\n{:}".format(A, B) - return out, logits + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string + + def extra_repr(self): + return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) + + def dync_genotype(self, use_random=False): + genotypes = [] + with torch.no_grad(): + alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + if use_random: + op_name = random.choice(self.op_names) + else: + weights = alphas_cpu[self.edge2index[node_str]] + op_index = torch.multinomial(weights, 1).item() + op_name = self.op_names[op_index] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return Structure(genotypes) + + def genotype(self): + def _parse(weights): + gene = [] + for i in range(self._steps): + edges = [] + for j in range(2 + i): + node_str = "{:}<-{:}".format(i, j) + ws = weights[self.edge2index[node_str]] + for k, op_name in enumerate(self.op_names): + if op_name == "none": + continue + edges.append((op_name, j, ws[k])) + edges = sorted(edges, key=lambda x: -x[-1]) + selected_edges = edges[:2] + gene.append(tuple(selected_edges)) + return gene + + with torch.no_grad(): + gene_normal = _parse( + torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy() + ) + gene_reduce = _parse( + torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy() + ) + return { + "normal": gene_normal, + "normal_concat": list( + range(2 + self._steps - self._multiplier, self._steps + 2) + ), + "reduce": gene_reduce, + "reduce_concat": list( + range(2 + self._steps - self._multiplier, self._steps + 2) + ), + } + + def forward(self, inputs): + normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1) + reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1) + + s0 = s1 = self.stem(inputs) + for i, cell in enumerate(self.cells): + # [TODO] + raise NotImplementedError + if cell.reduction: + hardwts, index = reduce_hardwts, reduce_index + else: + hardwts, index = normal_hardwts, normal_index + s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) + out = self.lastact(s1) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/clone_weights.py b/lib/models/clone_weights.py index 2e5b1c1..9e904ac 100644 --- a/lib/models/clone_weights.py +++ b/lib/models/clone_weights.py @@ -3,60 +3,72 @@ import torch.nn as nn def copy_conv(module, init): - assert isinstance(module, nn.Conv2d), 'invalid module : {:}'.format(module) - assert isinstance(init , nn.Conv2d), 'invalid module : {:}'.format(init) - new_i, new_o = module.in_channels, module.out_channels - module.weight.copy_( init.weight.detach()[:new_o, :new_i] ) - if module.bias is not None: - module.bias.copy_( init.bias.detach()[:new_o] ) + assert isinstance(module, nn.Conv2d), "invalid module : {:}".format(module) + assert isinstance(init, nn.Conv2d), "invalid module : {:}".format(init) + new_i, new_o = module.in_channels, module.out_channels + module.weight.copy_(init.weight.detach()[:new_o, :new_i]) + if module.bias is not None: + module.bias.copy_(init.bias.detach()[:new_o]) -def copy_bn (module, init): - assert isinstance(module, nn.BatchNorm2d), 'invalid module : {:}'.format(module) - assert isinstance(init , nn.BatchNorm2d), 'invalid module : {:}'.format(init) - num_features = module.num_features - if module.weight is not None: - module.weight.copy_( init.weight.detach()[:num_features] ) - if module.bias is not None: - module.bias.copy_( init.bias.detach()[:num_features] ) - if module.running_mean is not None: - module.running_mean.copy_( init.running_mean.detach()[:num_features] ) - if module.running_var is not None: - module.running_var.copy_( init.running_var.detach()[:num_features] ) -def copy_fc (module, init): - assert isinstance(module, nn.Linear), 'invalid module : {:}'.format(module) - assert isinstance(init , nn.Linear), 'invalid module : {:}'.format(init) - new_i, new_o = module.in_features, module.out_features - module.weight.copy_( init.weight.detach()[:new_o, :new_i] ) - if module.bias is not None: - module.bias.copy_( init.bias.detach()[:new_o] ) +def copy_bn(module, init): + assert isinstance(module, nn.BatchNorm2d), "invalid module : {:}".format(module) + assert isinstance(init, nn.BatchNorm2d), "invalid module : {:}".format(init) + num_features = module.num_features + if module.weight is not None: + module.weight.copy_(init.weight.detach()[:num_features]) + if module.bias is not None: + module.bias.copy_(init.bias.detach()[:num_features]) + if module.running_mean is not None: + module.running_mean.copy_(init.running_mean.detach()[:num_features]) + if module.running_var is not None: + module.running_var.copy_(init.running_var.detach()[:num_features]) + + +def copy_fc(module, init): + assert isinstance(module, nn.Linear), "invalid module : {:}".format(module) + assert isinstance(init, nn.Linear), "invalid module : {:}".format(init) + new_i, new_o = module.in_features, module.out_features + module.weight.copy_(init.weight.detach()[:new_o, :new_i]) + if module.bias is not None: + module.bias.copy_(init.bias.detach()[:new_o]) + def copy_base(module, init): - assert type(module).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format(module) - assert type( init).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format( init) - if module.conv is not None: - copy_conv(module.conv, init.conv) - if module.bn is not None: - copy_bn (module.bn, init.bn) + assert type(module).__name__ in [ + "ConvBNReLU", + "Downsample", + ], "invalid module : {:}".format(module) + assert type(init).__name__ in [ + "ConvBNReLU", + "Downsample", + ], "invalid module : {:}".format(init) + if module.conv is not None: + copy_conv(module.conv, init.conv) + if module.bn is not None: + copy_bn(module.bn, init.bn) + def copy_basic(module, init): - copy_base(module.conv_a, init.conv_a) - copy_base(module.conv_b, init.conv_b) - if module.downsample is not None: - if init.downsample is not None: - copy_base(module.downsample, init.downsample) - #else: - # import pdb; pdb.set_trace() + copy_base(module.conv_a, init.conv_a) + copy_base(module.conv_b, init.conv_b) + if module.downsample is not None: + if init.downsample is not None: + copy_base(module.downsample, init.downsample) + # else: + # import pdb; pdb.set_trace() def init_from_model(network, init_model): - with torch.no_grad(): - copy_fc(network.classifier, init_model.classifier) - for base, target in zip(init_model.layers, network.layers): - assert type(base).__name__ == type(target).__name__, 'invalid type : {:} vs {:}'.format(base, target) - if type(base).__name__ == 'ConvBNReLU': - copy_base(target, base) - elif type(base).__name__ == 'ResNetBasicblock': - copy_basic(target, base) - else: - raise ValueError('unknown type name : {:}'.format( type(base).__name__ )) + with torch.no_grad(): + copy_fc(network.classifier, init_model.classifier) + for base, target in zip(init_model.layers, network.layers): + assert ( + type(base).__name__ == type(target).__name__ + ), "invalid type : {:} vs {:}".format(base, target) + if type(base).__name__ == "ConvBNReLU": + copy_base(target, base) + elif type(base).__name__ == "ResNetBasicblock": + copy_basic(target, base) + else: + raise ValueError("unknown type name : {:}".format(type(base).__name__)) diff --git a/lib/models/initialization.py b/lib/models/initialization.py index e35ca6f..e82d723 100644 --- a/lib/models/initialization.py +++ b/lib/models/initialization.py @@ -3,16 +3,14 @@ import torch.nn as nn def initialize_resnet(m): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - nn.init.constant_(m.bias, 0) - - + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) diff --git a/lib/models/shape_infers/InferCifarResNet.py b/lib/models/shape_infers/InferCifarResNet.py index a6524d6..0575fd0 100644 --- a/lib/models/shape_infers/InferCifarResNet.py +++ b/lib/models/shape_infers/InferCifarResNet.py @@ -7,161 +7,280 @@ from ..initialization import initialize_resnet class ConvBNReLU(nn.Module): - - def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): - super(ConvBNReLU, self).__init__() - if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - else : self.avg = None - self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) - if has_bn : self.bn = nn.BatchNorm2d(nOut) - else : self.bn = None - if has_relu: self.relu = nn.ReLU(inplace=True) - else : self.relu = None + def __init__( + self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu + ): + super(ConvBNReLU, self).__init__() + if has_avg: + self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.avg = None + self.conv = nn.Conv2d( + nIn, + nOut, + kernel_size=kernel, + stride=stride, + padding=padding, + dilation=1, + groups=1, + bias=bias, + ) + if has_bn: + self.bn = nn.BatchNorm2d(nOut) + else: + self.bn = None + if has_relu: + self.relu = nn.ReLU(inplace=True) + else: + self.relu = None - def forward(self, inputs): - if self.avg : out = self.avg( inputs ) - else : out = inputs - conv = self.conv( out ) - if self.bn : out = self.bn( conv ) - else : out = conv - if self.relu: out = self.relu( out ) - else : out = out + def forward(self, inputs): + if self.avg: + out = self.avg(inputs) + else: + out = inputs + conv = self.conv(out) + if self.bn: + out = self.bn(conv) + else: + out = conv + if self.relu: + out = self.relu(out) + else: + out = out - return out + return out class ResNetBasicblock(nn.Module): - num_conv = 2 - expansion = 1 - def __init__(self, iCs, stride): - super(ResNetBasicblock, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) - assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs) - - self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False) - residual_in = iCs[0] - if stride == 2: - self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) - residual_in = iCs[2] - elif iCs[0] != iCs[2]: - self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) - else: - self.downsample = None - #self.out_dim = max(residual_in, iCs[2]) - self.out_dim = iCs[2] + num_conv = 2 + expansion = 1 - def forward(self, inputs): - basicblock = self.conv_a(inputs) - basicblock = self.conv_b(basicblock) + def __init__(self, iCs, stride): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + assert isinstance(iCs, tuple) or isinstance( + iCs, list + ), "invalid type of iCs : {:}".format(iCs) + assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) - if self.downsample is not None: - residual = self.downsample(inputs) - else: - residual = inputs - out = residual + basicblock - return F.relu(out, inplace=True) + self.conv_a = ConvBNReLU( + iCs[0], + iCs[1], + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_b = ConvBNReLU( + iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False + ) + residual_in = iCs[0] + if stride == 2: + self.downsample = ConvBNReLU( + iCs[0], + iCs[2], + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + residual_in = iCs[2] + elif iCs[0] != iCs[2]: + self.downsample = ConvBNReLU( + iCs[0], + iCs[2], + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + # self.out_dim = max(residual_in, iCs[2]) + self.out_dim = iCs[2] + def forward(self, inputs): + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = residual + basicblock + return F.relu(out, inplace=True) class ResNetBottleneck(nn.Module): - expansion = 4 - num_conv = 3 - def __init__(self, iCs, stride): - super(ResNetBottleneck, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) - assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs) - self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False) - residual_in = iCs[0] - if stride == 2: - self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False) - residual_in = iCs[3] - elif iCs[0] != iCs[3]: - self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False) - residual_in = iCs[3] - else: - self.downsample = None - #self.out_dim = max(residual_in, iCs[3]) - self.out_dim = iCs[3] + expansion = 4 + num_conv = 3 - def forward(self, inputs): + def __init__(self, iCs, stride): + super(ResNetBottleneck, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + assert isinstance(iCs, tuple) or isinstance( + iCs, list + ), "invalid type of iCs : {:}".format(iCs) + assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) + self.conv_1x1 = ConvBNReLU( + iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True + ) + self.conv_3x3 = ConvBNReLU( + iCs[1], + iCs[2], + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_1x4 = ConvBNReLU( + iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False + ) + residual_in = iCs[0] + if stride == 2: + self.downsample = ConvBNReLU( + iCs[0], + iCs[3], + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + residual_in = iCs[3] + elif iCs[0] != iCs[3]: + self.downsample = ConvBNReLU( + iCs[0], + iCs[3], + 1, + 1, + 0, + False, + has_avg=False, + has_bn=False, + has_relu=False, + ) + residual_in = iCs[3] + else: + self.downsample = None + # self.out_dim = max(residual_in, iCs[3]) + self.out_dim = iCs[3] - bottleneck = self.conv_1x1(inputs) - bottleneck = self.conv_3x3(bottleneck) - bottleneck = self.conv_1x4(bottleneck) + def forward(self, inputs): - if self.downsample is not None: - residual = self.downsample(inputs) - else: - residual = inputs - out = residual + bottleneck - return F.relu(out, inplace=True) + bottleneck = self.conv_1x1(inputs) + bottleneck = self.conv_3x3(bottleneck) + bottleneck = self.conv_1x4(bottleneck) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = residual + bottleneck + return F.relu(out, inplace=True) class InferCifarResNet(nn.Module): + def __init__( + self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual + ): + super(InferCifarResNet, self).__init__() - def __init__(self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual): - super(InferCifarResNet, self).__init__() + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + if block_name == "ResNetBasicblock": + block = ResNetBasicblock + assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" + layer_blocks = (depth - 2) // 6 + elif block_name == "ResNetBottleneck": + block = ResNetBottleneck + assert (depth - 2) % 9 == 0, "depth should be one of 164" + layer_blocks = (depth - 2) // 9 + else: + raise ValueError("invalid block : {:}".format(block_name)) + assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks) - #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model - if block_name == 'ResNetBasicblock': - block = ResNetBasicblock - assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' - layer_blocks = (depth - 2) // 6 - elif block_name == 'ResNetBottleneck': - block = ResNetBottleneck - assert (depth - 2) % 9 == 0, 'depth should be one of 164' - layer_blocks = (depth - 2) // 9 - else: - raise ValueError('invalid block : {:}'.format(block_name)) - assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks) + self.message = ( + "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( + depth, layer_blocks + ) + ) + self.num_classes = num_classes + self.xchannels = xchannels + self.layers = nn.ModuleList( + [ + ConvBNReLU( + xchannels[0], + xchannels[1], + 3, + 1, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + ] + ) + last_channel_idx = 1 + for stage in range(3): + for iL in range(layer_blocks): + num_conv = block.num_conv + iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] + stride = 2 if stage > 0 and iL == 0 else 1 + module = block(iCs, stride) + last_channel_idx += num_conv + self.xchannels[last_channel_idx] = module.out_dim + self.layers.append(module) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( + stage, + iL, + layer_blocks, + len(self.layers) - 1, + iCs, + module.out_dim, + stride, + ) + if iL + 1 == xblocks[stage]: # reach the maximum depth + out_channel = module.out_dim + for iiL in range(iL + 1, layer_blocks): + last_channel_idx += num_conv + self.xchannels[last_channel_idx] = module.out_dim + break - self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) - self.num_classes = num_classes - self.xchannels = xchannels - self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) - last_channel_idx = 1 - for stage in range(3): - for iL in range(layer_blocks): - num_conv = block.num_conv - iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1] - stride = 2 if stage > 0 and iL == 0 else 1 - module = block(iCs, stride) - last_channel_idx += num_conv - self.xchannels[last_channel_idx] = module.out_dim - self.layers.append ( module ) - self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride) - if iL + 1 == xblocks[stage]: # reach the maximum depth - out_channel = module.out_dim - for iiL in range(iL+1, layer_blocks): - last_channel_idx += num_conv - self.xchannels[last_channel_idx] = module.out_dim - break - - self.avgpool = nn.AvgPool2d(8) - self.classifier = nn.Linear(self.xchannels[-1], num_classes) - - self.apply(initialize_resnet) - if zero_init_residual: - for m in self.modules(): - if isinstance(m, ResNetBasicblock): - nn.init.constant_(m.conv_b.bn.weight, 0) - elif isinstance(m, ResNetBottleneck): - nn.init.constant_(m.conv_1x4.bn.weight, 0) + self.avgpool = nn.AvgPool2d(8) + self.classifier = nn.Linear(self.xchannels[-1], num_classes) - def get_message(self): - return self.message + self.apply(initialize_resnet) + if zero_init_residual: + for m in self.modules(): + if isinstance(m, ResNetBasicblock): + nn.init.constant_(m.conv_b.bn.weight, 0) + elif isinstance(m, ResNetBottleneck): + nn.init.constant_(m.conv_1x4.bn.weight, 0) - def forward(self, inputs): - x = inputs - for i, layer in enumerate(self.layers): - x = layer( x ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = self.classifier(features) - return features, logits + def get_message(self): + return self.message + + def forward(self, inputs): + x = inputs + for i, layer in enumerate(self.layers): + x = layer(x) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.classifier(features) + return features, logits diff --git a/lib/models/shape_infers/InferCifarResNet_depth.py b/lib/models/shape_infers/InferCifarResNet_depth.py index d773fc5..c6f9bb3 100644 --- a/lib/models/shape_infers/InferCifarResNet_depth.py +++ b/lib/models/shape_infers/InferCifarResNet_depth.py @@ -7,144 +7,257 @@ from ..initialization import initialize_resnet class ConvBNReLU(nn.Module): - - def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): - super(ConvBNReLU, self).__init__() - if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - else : self.avg = None - self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) - if has_bn : self.bn = nn.BatchNorm2d(nOut) - else : self.bn = None - if has_relu: self.relu = nn.ReLU(inplace=True) - else : self.relu = None + def __init__( + self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu + ): + super(ConvBNReLU, self).__init__() + if has_avg: + self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.avg = None + self.conv = nn.Conv2d( + nIn, + nOut, + kernel_size=kernel, + stride=stride, + padding=padding, + dilation=1, + groups=1, + bias=bias, + ) + if has_bn: + self.bn = nn.BatchNorm2d(nOut) + else: + self.bn = None + if has_relu: + self.relu = nn.ReLU(inplace=True) + else: + self.relu = None - def forward(self, inputs): - if self.avg : out = self.avg( inputs ) - else : out = inputs - conv = self.conv( out ) - if self.bn : out = self.bn( conv ) - else : out = conv - if self.relu: out = self.relu( out ) - else : out = out + def forward(self, inputs): + if self.avg: + out = self.avg(inputs) + else: + out = inputs + conv = self.conv(out) + if self.bn: + out = self.bn(conv) + else: + out = conv + if self.relu: + out = self.relu(out) + else: + out = out - return out + return out class ResNetBasicblock(nn.Module): - num_conv = 2 - expansion = 1 - def __init__(self, inplanes, planes, stride): - super(ResNetBasicblock, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - - self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False) - if stride == 2: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) - elif inplanes != planes: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) - else: - self.downsample = None - self.out_dim = planes + num_conv = 2 + expansion = 1 - def forward(self, inputs): - basicblock = self.conv_a(inputs) - basicblock = self.conv_b(basicblock) + def __init__(self, inplanes, planes, stride): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) - if self.downsample is not None: - residual = self.downsample(inputs) - else: - residual = inputs - out = residual + basicblock - return F.relu(out, inplace=True) + self.conv_a = ConvBNReLU( + inplanes, + planes, + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_b = ConvBNReLU( + planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False + ) + if stride == 2: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + elif inplanes != planes: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + self.out_dim = planes + def forward(self, inputs): + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = residual + basicblock + return F.relu(out, inplace=True) class ResNetBottleneck(nn.Module): - expansion = 4 - num_conv = 3 - def __init__(self, inplanes, planes, stride): - super(ResNetBottleneck, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False) - if stride == 2: - self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False) - elif inplanes != planes*self.expansion: - self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False) - else: - self.downsample = None - self.out_dim = planes*self.expansion + expansion = 4 + num_conv = 3 - def forward(self, inputs): + def __init__(self, inplanes, planes, stride): + super(ResNetBottleneck, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_1x1 = ConvBNReLU( + inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True + ) + self.conv_3x3 = ConvBNReLU( + planes, + planes, + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_1x4 = ConvBNReLU( + planes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + if stride == 2: + self.downsample = ConvBNReLU( + inplanes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + elif inplanes != planes * self.expansion: + self.downsample = ConvBNReLU( + inplanes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=False, + has_relu=False, + ) + else: + self.downsample = None + self.out_dim = planes * self.expansion - bottleneck = self.conv_1x1(inputs) - bottleneck = self.conv_3x3(bottleneck) - bottleneck = self.conv_1x4(bottleneck) + def forward(self, inputs): - if self.downsample is not None: - residual = self.downsample(inputs) - else: - residual = inputs - out = residual + bottleneck - return F.relu(out, inplace=True) + bottleneck = self.conv_1x1(inputs) + bottleneck = self.conv_3x3(bottleneck) + bottleneck = self.conv_1x4(bottleneck) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = residual + bottleneck + return F.relu(out, inplace=True) class InferDepthCifarResNet(nn.Module): + def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual): + super(InferDepthCifarResNet, self).__init__() - def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual): - super(InferDepthCifarResNet, self).__init__() + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + if block_name == "ResNetBasicblock": + block = ResNetBasicblock + assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" + layer_blocks = (depth - 2) // 6 + elif block_name == "ResNetBottleneck": + block = ResNetBottleneck + assert (depth - 2) % 9 == 0, "depth should be one of 164" + layer_blocks = (depth - 2) // 9 + else: + raise ValueError("invalid block : {:}".format(block_name)) + assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks) - #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model - if block_name == 'ResNetBasicblock': - block = ResNetBasicblock - assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' - layer_blocks = (depth - 2) // 6 - elif block_name == 'ResNetBottleneck': - block = ResNetBottleneck - assert (depth - 2) % 9 == 0, 'depth should be one of 164' - layer_blocks = (depth - 2) // 9 - else: - raise ValueError('invalid block : {:}'.format(block_name)) - assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks) + self.message = ( + "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( + depth, layer_blocks + ) + ) + self.num_classes = num_classes + self.layers = nn.ModuleList( + [ + ConvBNReLU( + 3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True + ) + ] + ) + self.channels = [16] + for stage in range(3): + for iL in range(layer_blocks): + iC = self.channels[-1] + planes = 16 * (2 ** stage) + stride = 2 if stage > 0 and iL == 0 else 1 + module = block(iC, planes, stride) + self.channels.append(module.out_dim) + self.layers.append(module) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format( + stage, + iL, + layer_blocks, + len(self.layers) - 1, + planes, + module.out_dim, + stride, + ) + if iL + 1 == xblocks[stage]: # reach the maximum depth + break - self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) - self.num_classes = num_classes - self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) - self.channels = [16] - for stage in range(3): - for iL in range(layer_blocks): - iC = self.channels[-1] - planes = 16 * (2**stage) - stride = 2 if stage > 0 and iL == 0 else 1 - module = block(iC, planes, stride) - self.channels.append( module.out_dim ) - self.layers.append ( module ) - self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, planes, module.out_dim, stride) - if iL + 1 == xblocks[stage]: # reach the maximum depth - break - - self.avgpool = nn.AvgPool2d(8) - self.classifier = nn.Linear(self.channels[-1], num_classes) - - self.apply(initialize_resnet) - if zero_init_residual: - for m in self.modules(): - if isinstance(m, ResNetBasicblock): - nn.init.constant_(m.conv_b.bn.weight, 0) - elif isinstance(m, ResNetBottleneck): - nn.init.constant_(m.conv_1x4.bn.weight, 0) + self.avgpool = nn.AvgPool2d(8) + self.classifier = nn.Linear(self.channels[-1], num_classes) - def get_message(self): - return self.message + self.apply(initialize_resnet) + if zero_init_residual: + for m in self.modules(): + if isinstance(m, ResNetBasicblock): + nn.init.constant_(m.conv_b.bn.weight, 0) + elif isinstance(m, ResNetBottleneck): + nn.init.constant_(m.conv_1x4.bn.weight, 0) - def forward(self, inputs): - x = inputs - for i, layer in enumerate(self.layers): - x = layer( x ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = self.classifier(features) - return features, logits + def get_message(self): + return self.message + + def forward(self, inputs): + x = inputs + for i, layer in enumerate(self.layers): + x = layer(x) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.classifier(features) + return features, logits diff --git a/lib/models/shape_infers/InferCifarResNet_width.py b/lib/models/shape_infers/InferCifarResNet_width.py index 7183875..9400f71 100644 --- a/lib/models/shape_infers/InferCifarResNet_width.py +++ b/lib/models/shape_infers/InferCifarResNet_width.py @@ -7,154 +7,271 @@ from ..initialization import initialize_resnet class ConvBNReLU(nn.Module): - - def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): - super(ConvBNReLU, self).__init__() - if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - else : self.avg = None - self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) - if has_bn : self.bn = nn.BatchNorm2d(nOut) - else : self.bn = None - if has_relu: self.relu = nn.ReLU(inplace=True) - else : self.relu = None + def __init__( + self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu + ): + super(ConvBNReLU, self).__init__() + if has_avg: + self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.avg = None + self.conv = nn.Conv2d( + nIn, + nOut, + kernel_size=kernel, + stride=stride, + padding=padding, + dilation=1, + groups=1, + bias=bias, + ) + if has_bn: + self.bn = nn.BatchNorm2d(nOut) + else: + self.bn = None + if has_relu: + self.relu = nn.ReLU(inplace=True) + else: + self.relu = None - def forward(self, inputs): - if self.avg : out = self.avg( inputs ) - else : out = inputs - conv = self.conv( out ) - if self.bn : out = self.bn( conv ) - else : out = conv - if self.relu: out = self.relu( out ) - else : out = out + def forward(self, inputs): + if self.avg: + out = self.avg(inputs) + else: + out = inputs + conv = self.conv(out) + if self.bn: + out = self.bn(conv) + else: + out = conv + if self.relu: + out = self.relu(out) + else: + out = out - return out + return out class ResNetBasicblock(nn.Module): - num_conv = 2 - expansion = 1 - def __init__(self, iCs, stride): - super(ResNetBasicblock, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) - assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs) - - self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False) - residual_in = iCs[0] - if stride == 2: - self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) - residual_in = iCs[2] - elif iCs[0] != iCs[2]: - self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) - else: - self.downsample = None - #self.out_dim = max(residual_in, iCs[2]) - self.out_dim = iCs[2] + num_conv = 2 + expansion = 1 - def forward(self, inputs): - basicblock = self.conv_a(inputs) - basicblock = self.conv_b(basicblock) + def __init__(self, iCs, stride): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + assert isinstance(iCs, tuple) or isinstance( + iCs, list + ), "invalid type of iCs : {:}".format(iCs) + assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) - if self.downsample is not None: - residual = self.downsample(inputs) - else: - residual = inputs - out = residual + basicblock - return F.relu(out, inplace=True) + self.conv_a = ConvBNReLU( + iCs[0], + iCs[1], + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_b = ConvBNReLU( + iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False + ) + residual_in = iCs[0] + if stride == 2: + self.downsample = ConvBNReLU( + iCs[0], + iCs[2], + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + residual_in = iCs[2] + elif iCs[0] != iCs[2]: + self.downsample = ConvBNReLU( + iCs[0], + iCs[2], + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + # self.out_dim = max(residual_in, iCs[2]) + self.out_dim = iCs[2] + def forward(self, inputs): + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = residual + basicblock + return F.relu(out, inplace=True) class ResNetBottleneck(nn.Module): - expansion = 4 - num_conv = 3 - def __init__(self, iCs, stride): - super(ResNetBottleneck, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) - assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs) - self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False) - residual_in = iCs[0] - if stride == 2: - self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False) - residual_in = iCs[3] - elif iCs[0] != iCs[3]: - self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False) - residual_in = iCs[3] - else: - self.downsample = None - #self.out_dim = max(residual_in, iCs[3]) - self.out_dim = iCs[3] + expansion = 4 + num_conv = 3 - def forward(self, inputs): + def __init__(self, iCs, stride): + super(ResNetBottleneck, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + assert isinstance(iCs, tuple) or isinstance( + iCs, list + ), "invalid type of iCs : {:}".format(iCs) + assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) + self.conv_1x1 = ConvBNReLU( + iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True + ) + self.conv_3x3 = ConvBNReLU( + iCs[1], + iCs[2], + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_1x4 = ConvBNReLU( + iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False + ) + residual_in = iCs[0] + if stride == 2: + self.downsample = ConvBNReLU( + iCs[0], + iCs[3], + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + residual_in = iCs[3] + elif iCs[0] != iCs[3]: + self.downsample = ConvBNReLU( + iCs[0], + iCs[3], + 1, + 1, + 0, + False, + has_avg=False, + has_bn=False, + has_relu=False, + ) + residual_in = iCs[3] + else: + self.downsample = None + # self.out_dim = max(residual_in, iCs[3]) + self.out_dim = iCs[3] - bottleneck = self.conv_1x1(inputs) - bottleneck = self.conv_3x3(bottleneck) - bottleneck = self.conv_1x4(bottleneck) + def forward(self, inputs): - if self.downsample is not None: - residual = self.downsample(inputs) - else: - residual = inputs - out = residual + bottleneck - return F.relu(out, inplace=True) + bottleneck = self.conv_1x1(inputs) + bottleneck = self.conv_3x3(bottleneck) + bottleneck = self.conv_1x4(bottleneck) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = residual + bottleneck + return F.relu(out, inplace=True) class InferWidthCifarResNet(nn.Module): + def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual): + super(InferWidthCifarResNet, self).__init__() - def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual): - super(InferWidthCifarResNet, self).__init__() + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + if block_name == "ResNetBasicblock": + block = ResNetBasicblock + assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" + layer_blocks = (depth - 2) // 6 + elif block_name == "ResNetBottleneck": + block = ResNetBottleneck + assert (depth - 2) % 9 == 0, "depth should be one of 164" + layer_blocks = (depth - 2) // 9 + else: + raise ValueError("invalid block : {:}".format(block_name)) - #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model - if block_name == 'ResNetBasicblock': - block = ResNetBasicblock - assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' - layer_blocks = (depth - 2) // 6 - elif block_name == 'ResNetBottleneck': - block = ResNetBottleneck - assert (depth - 2) % 9 == 0, 'depth should be one of 164' - layer_blocks = (depth - 2) // 9 - else: - raise ValueError('invalid block : {:}'.format(block_name)) + self.message = ( + "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( + depth, layer_blocks + ) + ) + self.num_classes = num_classes + self.xchannels = xchannels + self.layers = nn.ModuleList( + [ + ConvBNReLU( + xchannels[0], + xchannels[1], + 3, + 1, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + ] + ) + last_channel_idx = 1 + for stage in range(3): + for iL in range(layer_blocks): + num_conv = block.num_conv + iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] + stride = 2 if stage > 0 and iL == 0 else 1 + module = block(iCs, stride) + last_channel_idx += num_conv + self.xchannels[last_channel_idx] = module.out_dim + self.layers.append(module) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( + stage, + iL, + layer_blocks, + len(self.layers) - 1, + iCs, + module.out_dim, + stride, + ) - self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) - self.num_classes = num_classes - self.xchannels = xchannels - self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) - last_channel_idx = 1 - for stage in range(3): - for iL in range(layer_blocks): - num_conv = block.num_conv - iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1] - stride = 2 if stage > 0 and iL == 0 else 1 - module = block(iCs, stride) - last_channel_idx += num_conv - self.xchannels[last_channel_idx] = module.out_dim - self.layers.append ( module ) - self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride) - - self.avgpool = nn.AvgPool2d(8) - self.classifier = nn.Linear(self.xchannels[-1], num_classes) - - self.apply(initialize_resnet) - if zero_init_residual: - for m in self.modules(): - if isinstance(m, ResNetBasicblock): - nn.init.constant_(m.conv_b.bn.weight, 0) - elif isinstance(m, ResNetBottleneck): - nn.init.constant_(m.conv_1x4.bn.weight, 0) + self.avgpool = nn.AvgPool2d(8) + self.classifier = nn.Linear(self.xchannels[-1], num_classes) - def get_message(self): - return self.message + self.apply(initialize_resnet) + if zero_init_residual: + for m in self.modules(): + if isinstance(m, ResNetBasicblock): + nn.init.constant_(m.conv_b.bn.weight, 0) + elif isinstance(m, ResNetBottleneck): + nn.init.constant_(m.conv_1x4.bn.weight, 0) - def forward(self, inputs): - x = inputs - for i, layer in enumerate(self.layers): - x = layer( x ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = self.classifier(features) - return features, logits + def get_message(self): + return self.message + + def forward(self, inputs): + x = inputs + for i, layer in enumerate(self.layers): + x = layer(x) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.classifier(features) + return features, logits diff --git a/lib/models/shape_infers/InferImagenetResNet.py b/lib/models/shape_infers/InferImagenetResNet.py index 8f06db7..0415e58 100644 --- a/lib/models/shape_infers/InferImagenetResNet.py +++ b/lib/models/shape_infers/InferImagenetResNet.py @@ -7,164 +7,318 @@ from ..initialization import initialize_resnet class ConvBNReLU(nn.Module): - - num_conv = 1 - def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): - super(ConvBNReLU, self).__init__() - if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - else : self.avg = None - self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) - if has_bn : self.bn = nn.BatchNorm2d(nOut) - else : self.bn = None - if has_relu: self.relu = nn.ReLU(inplace=True) - else : self.relu = None - def forward(self, inputs): - if self.avg : out = self.avg( inputs ) - else : out = inputs - conv = self.conv( out ) - if self.bn : out = self.bn( conv ) - else : out = conv - if self.relu: out = self.relu( out ) - else : out = out + num_conv = 1 - return out + def __init__( + self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu + ): + super(ConvBNReLU, self).__init__() + if has_avg: + self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.avg = None + self.conv = nn.Conv2d( + nIn, + nOut, + kernel_size=kernel, + stride=stride, + padding=padding, + dilation=1, + groups=1, + bias=bias, + ) + if has_bn: + self.bn = nn.BatchNorm2d(nOut) + else: + self.bn = None + if has_relu: + self.relu = nn.ReLU(inplace=True) + else: + self.relu = None + + def forward(self, inputs): + if self.avg: + out = self.avg(inputs) + else: + out = inputs + conv = self.conv(out) + if self.bn: + out = self.bn(conv) + else: + out = conv + if self.relu: + out = self.relu(out) + else: + out = out + + return out class ResNetBasicblock(nn.Module): - num_conv = 2 - expansion = 1 - def __init__(self, iCs, stride): - super(ResNetBasicblock, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) - assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs) - - self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False) - residual_in = iCs[0] - if stride == 2: - self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False) - residual_in = iCs[2] - elif iCs[0] != iCs[2]: - self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) - else: - self.downsample = None - #self.out_dim = max(residual_in, iCs[2]) - self.out_dim = iCs[2] + num_conv = 2 + expansion = 1 - def forward(self, inputs): - basicblock = self.conv_a(inputs) - basicblock = self.conv_b(basicblock) + def __init__(self, iCs, stride): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + assert isinstance(iCs, tuple) or isinstance( + iCs, list + ), "invalid type of iCs : {:}".format(iCs) + assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) - if self.downsample is not None: - residual = self.downsample(inputs) - else: - residual = inputs - out = residual + basicblock - return F.relu(out, inplace=True) + self.conv_a = ConvBNReLU( + iCs[0], + iCs[1], + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_b = ConvBNReLU( + iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False + ) + residual_in = iCs[0] + if stride == 2: + self.downsample = ConvBNReLU( + iCs[0], + iCs[2], + 1, + 1, + 0, + False, + has_avg=True, + has_bn=True, + has_relu=False, + ) + residual_in = iCs[2] + elif iCs[0] != iCs[2]: + self.downsample = ConvBNReLU( + iCs[0], + iCs[2], + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + # self.out_dim = max(residual_in, iCs[2]) + self.out_dim = iCs[2] + def forward(self, inputs): + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = residual + basicblock + return F.relu(out, inplace=True) class ResNetBottleneck(nn.Module): - expansion = 4 - num_conv = 3 - def __init__(self, iCs, stride): - super(ResNetBottleneck, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs ) - assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs) - self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False) - residual_in = iCs[0] - if stride == 2: - self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=True, has_relu=False) - residual_in = iCs[3] - elif iCs[0] != iCs[3]: - self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False) - residual_in = iCs[3] - else: - self.downsample = None - #self.out_dim = max(residual_in, iCs[3]) - self.out_dim = iCs[3] + expansion = 4 + num_conv = 3 - def forward(self, inputs): + def __init__(self, iCs, stride): + super(ResNetBottleneck, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + assert isinstance(iCs, tuple) or isinstance( + iCs, list + ), "invalid type of iCs : {:}".format(iCs) + assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) + self.conv_1x1 = ConvBNReLU( + iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True + ) + self.conv_3x3 = ConvBNReLU( + iCs[1], + iCs[2], + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_1x4 = ConvBNReLU( + iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False + ) + residual_in = iCs[0] + if stride == 2: + self.downsample = ConvBNReLU( + iCs[0], + iCs[3], + 1, + 1, + 0, + False, + has_avg=True, + has_bn=True, + has_relu=False, + ) + residual_in = iCs[3] + elif iCs[0] != iCs[3]: + self.downsample = ConvBNReLU( + iCs[0], + iCs[3], + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + residual_in = iCs[3] + else: + self.downsample = None + # self.out_dim = max(residual_in, iCs[3]) + self.out_dim = iCs[3] - bottleneck = self.conv_1x1(inputs) - bottleneck = self.conv_3x3(bottleneck) - bottleneck = self.conv_1x4(bottleneck) + def forward(self, inputs): - if self.downsample is not None: - residual = self.downsample(inputs) - else: - residual = inputs - out = residual + bottleneck - return F.relu(out, inplace=True) + bottleneck = self.conv_1x1(inputs) + bottleneck = self.conv_3x3(bottleneck) + bottleneck = self.conv_1x4(bottleneck) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = residual + bottleneck + return F.relu(out, inplace=True) class InferImagenetResNet(nn.Module): + def __init__( + self, + block_name, + layers, + xblocks, + xchannels, + deep_stem, + num_classes, + zero_init_residual, + ): + super(InferImagenetResNet, self).__init__() - def __init__(self, block_name, layers, xblocks, xchannels, deep_stem, num_classes, zero_init_residual): - super(InferImagenetResNet, self).__init__() + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + if block_name == "BasicBlock": + block = ResNetBasicblock + elif block_name == "Bottleneck": + block = ResNetBottleneck + else: + raise ValueError("invalid block : {:}".format(block_name)) + assert len(xblocks) == len( + layers + ), "invalid layers : {:} vs xblocks : {:}".format(layers, xblocks) - #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model - if block_name == 'BasicBlock': - block = ResNetBasicblock - elif block_name == 'Bottleneck': - block = ResNetBottleneck - else: - raise ValueError('invalid block : {:}'.format(block_name)) - assert len(xblocks) == len(layers), 'invalid layers : {:} vs xblocks : {:}'.format(layers, xblocks) + self.message = "InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}".format( + sum(layers) * block.num_conv, sum(xblocks) * block.num_conv, xblocks + ) + self.num_classes = num_classes + self.xchannels = xchannels + if not deep_stem: + self.layers = nn.ModuleList( + [ + ConvBNReLU( + xchannels[0], + xchannels[1], + 7, + 2, + 3, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + ] + ) + last_channel_idx = 1 + else: + self.layers = nn.ModuleList( + [ + ConvBNReLU( + xchannels[0], + xchannels[1], + 3, + 2, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ), + ConvBNReLU( + xchannels[1], + xchannels[2], + 3, + 1, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ), + ] + ) + last_channel_idx = 2 + self.layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + for stage, layer_blocks in enumerate(layers): + for iL in range(layer_blocks): + num_conv = block.num_conv + iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] + stride = 2 if stage > 0 and iL == 0 else 1 + module = block(iCs, stride) + last_channel_idx += num_conv + self.xchannels[last_channel_idx] = module.out_dim + self.layers.append(module) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( + stage, + iL, + layer_blocks, + len(self.layers) - 1, + iCs, + module.out_dim, + stride, + ) + if iL + 1 == xblocks[stage]: # reach the maximum depth + out_channel = module.out_dim + for iiL in range(iL + 1, layer_blocks): + last_channel_idx += num_conv + self.xchannels[last_channel_idx] = module.out_dim + break + assert last_channel_idx + 1 == len(self.xchannels), "{:} vs {:}".format( + last_channel_idx, len(self.xchannels) + ) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.classifier = nn.Linear(self.xchannels[-1], num_classes) - self.message = 'InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}'.format(sum(layers)*block.num_conv, sum(xblocks)*block.num_conv, xblocks) - self.num_classes = num_classes - self.xchannels = xchannels - if not deep_stem: - self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True) ] ) - last_channel_idx = 1 - else: - self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True) - ,ConvBNReLU(xchannels[1], xchannels[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) - last_channel_idx = 2 - self.layers.append( nn.MaxPool2d(kernel_size=3, stride=2, padding=1) ) - for stage, layer_blocks in enumerate(layers): - for iL in range(layer_blocks): - num_conv = block.num_conv - iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1] - stride = 2 if stage > 0 and iL == 0 else 1 - module = block(iCs, stride) - last_channel_idx += num_conv - self.xchannels[last_channel_idx] = module.out_dim - self.layers.append ( module ) - self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride) - if iL + 1 == xblocks[stage]: # reach the maximum depth - out_channel = module.out_dim - for iiL in range(iL+1, layer_blocks): - last_channel_idx += num_conv - self.xchannels[last_channel_idx] = module.out_dim - break - assert last_channel_idx + 1 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels)) - self.avgpool = nn.AdaptiveAvgPool2d((1,1)) - self.classifier = nn.Linear(self.xchannels[-1], num_classes) - - self.apply(initialize_resnet) - if zero_init_residual: - for m in self.modules(): - if isinstance(m, ResNetBasicblock): - nn.init.constant_(m.conv_b.bn.weight, 0) - elif isinstance(m, ResNetBottleneck): - nn.init.constant_(m.conv_1x4.bn.weight, 0) + self.apply(initialize_resnet) + if zero_init_residual: + for m in self.modules(): + if isinstance(m, ResNetBasicblock): + nn.init.constant_(m.conv_b.bn.weight, 0) + elif isinstance(m, ResNetBottleneck): + nn.init.constant_(m.conv_1x4.bn.weight, 0) - def get_message(self): - return self.message + def get_message(self): + return self.message - def forward(self, inputs): - x = inputs - for i, layer in enumerate(self.layers): - x = layer( x ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = self.classifier(features) - return features, logits + def forward(self, inputs): + x = inputs + for i, layer in enumerate(self.layers): + x = layer(x) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.classifier(features) + return features, logits diff --git a/lib/models/shape_infers/InferMobileNetV2.py b/lib/models/shape_infers/InferMobileNetV2.py index d072b99..4057547 100644 --- a/lib/models/shape_infers/InferMobileNetV2.py +++ b/lib/models/shape_infers/InferMobileNetV2.py @@ -4,119 +4,171 @@ # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 from torch import nn from ..initialization import initialize_resnet -from ..SharedUtils import parse_channel_info +from ..SharedUtils import parse_channel_info class ConvBNReLU(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, groups, has_bn=True, has_relu=True): - super(ConvBNReLU, self).__init__() - padding = (kernel_size - 1) // 2 - self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False) - if has_bn: self.bn = nn.BatchNorm2d(out_planes) - else : self.bn = None - if has_relu: self.relu = nn.ReLU6(inplace=True) - else : self.relu = None - - def forward(self, x): - out = self.conv( x ) - if self.bn: out = self.bn ( out ) - if self.relu: out = self.relu( out ) - return out + def __init__( + self, + in_planes, + out_planes, + kernel_size, + stride, + groups, + has_bn=True, + has_relu=True, + ): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + self.conv = nn.Conv2d( + in_planes, + out_planes, + kernel_size, + stride, + padding, + groups=groups, + bias=False, + ) + if has_bn: + self.bn = nn.BatchNorm2d(out_planes) + else: + self.bn = None + if has_relu: + self.relu = nn.ReLU6(inplace=True) + else: + self.relu = None + + def forward(self, x): + out = self.conv(x) + if self.bn: + out = self.bn(out) + if self.relu: + out = self.relu(out) + return out class InvertedResidual(nn.Module): - def __init__(self, channels, stride, expand_ratio, additive): - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2], 'invalid stride : {:}'.format(stride) - assert len(channels) in [2, 3], 'invalid channels : {:}'.format(channels) + def __init__(self, channels, stride, expand_ratio, additive): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2], "invalid stride : {:}".format(stride) + assert len(channels) in [2, 3], "invalid channels : {:}".format(channels) - if len(channels) == 2: - layers = [] - else: - layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)] - layers.extend([ - # dw - ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]), - # pw-linear - ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False), - ]) - self.conv = nn.Sequential(*layers) - self.additive = additive - if self.additive and channels[0] != channels[-1]: - self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False) - else: - self.shortcut = None - self.out_dim = channels[-1] + if len(channels) == 2: + layers = [] + else: + layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)] + layers.extend( + [ + # dw + ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]), + # pw-linear + ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False), + ] + ) + self.conv = nn.Sequential(*layers) + self.additive = additive + if self.additive and channels[0] != channels[-1]: + self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False) + else: + self.shortcut = None + self.out_dim = channels[-1] - def forward(self, x): - out = self.conv(x) - # if self.additive: return additive_func(out, x) - if self.shortcut: return out + self.shortcut(x) - else : return out + def forward(self, x): + out = self.conv(x) + # if self.additive: return additive_func(out, x) + if self.shortcut: + return out + self.shortcut(x) + else: + return out class InferMobileNetV2(nn.Module): - def __init__(self, num_classes, xchannels, xblocks, dropout): - super(InferMobileNetV2, self).__init__() - block = InvertedResidual - inverted_residual_setting = [ - # t, c, n, s - [1, 16 , 1, 1], - [6, 24 , 2, 2], - [6, 32 , 3, 2], - [6, 64 , 4, 2], - [6, 96 , 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1], - ] - assert len(inverted_residual_setting) == len(xblocks), 'invalid number of layers : {:} vs {:}'.format(len(inverted_residual_setting), len(xblocks)) - for block_num, ir_setting in zip(xblocks, inverted_residual_setting): - assert block_num <= ir_setting[2], '{:} vs {:}'.format(block_num, ir_setting) - xchannels = parse_channel_info(xchannels) - #for i, chs in enumerate(xchannels): - # if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs) - self.xchannels = xchannels - self.message = 'InferMobileNetV2 : xblocks={:}'.format(xblocks) - # building first layer - features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)] - last_channel_idx = 1 + def __init__(self, num_classes, xchannels, xblocks, dropout): + super(InferMobileNetV2, self).__init__() + block = InvertedResidual + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + assert len(inverted_residual_setting) == len( + xblocks + ), "invalid number of layers : {:} vs {:}".format( + len(inverted_residual_setting), len(xblocks) + ) + for block_num, ir_setting in zip(xblocks, inverted_residual_setting): + assert block_num <= ir_setting[2], "{:} vs {:}".format( + block_num, ir_setting + ) + xchannels = parse_channel_info(xchannels) + # for i, chs in enumerate(xchannels): + # if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs) + self.xchannels = xchannels + self.message = "InferMobileNetV2 : xblocks={:}".format(xblocks) + # building first layer + features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)] + last_channel_idx = 1 - # building inverted residual blocks - for stage, (t, c, n, s) in enumerate(inverted_residual_setting): - for i in range(n): - stride = s if i == 0 else 1 - additv = True if i > 0 else False - module = block(self.xchannels[last_channel_idx], stride, t, additv) - features.append(module) - self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(stage, i, n, len(features), self.xchannels[last_channel_idx], stride, t, c) - last_channel_idx += 1 - if i + 1 == xblocks[stage]: - out_channel = module.out_dim - for iiL in range(i+1, n): - last_channel_idx += 1 - self.xchannels[last_channel_idx][0] = module.out_dim - break - # building last several layers - features.append(ConvBNReLU(self.xchannels[last_channel_idx][0], self.xchannels[last_channel_idx][1], 1, 1, 1)) - assert last_channel_idx + 2 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels)) - # make it nn.Sequential - self.features = nn.Sequential(*features) + # building inverted residual blocks + for stage, (t, c, n, s) in enumerate(inverted_residual_setting): + for i in range(n): + stride = s if i == 0 else 1 + additv = True if i > 0 else False + module = block(self.xchannels[last_channel_idx], stride, t, additv) + features.append(module) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format( + stage, + i, + n, + len(features), + self.xchannels[last_channel_idx], + stride, + t, + c, + ) + last_channel_idx += 1 + if i + 1 == xblocks[stage]: + out_channel = module.out_dim + for iiL in range(i + 1, n): + last_channel_idx += 1 + self.xchannels[last_channel_idx][0] = module.out_dim + break + # building last several layers + features.append( + ConvBNReLU( + self.xchannels[last_channel_idx][0], + self.xchannels[last_channel_idx][1], + 1, + 1, + 1, + ) + ) + assert last_channel_idx + 2 == len(self.xchannels), "{:} vs {:}".format( + last_channel_idx, len(self.xchannels) + ) + # make it nn.Sequential + self.features = nn.Sequential(*features) - # building classifier - self.classifier = nn.Sequential( - nn.Dropout(dropout), - nn.Linear(self.xchannels[last_channel_idx][1], num_classes), - ) + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(self.xchannels[last_channel_idx][1], num_classes), + ) - # weight initialization - self.apply( initialize_resnet ) + # weight initialization + self.apply(initialize_resnet) - def get_message(self): - return self.message + def get_message(self): + return self.message - def forward(self, inputs): - features = self.features(inputs) - vectors = features.mean([2, 3]) - predicts = self.classifier(vectors) - return features, predicts + def forward(self, inputs): + features = self.features(inputs) + vectors = features.mean([2, 3]) + predicts = self.classifier(vectors) + return features, predicts diff --git a/lib/models/shape_infers/InferTinyCellNet.py b/lib/models/shape_infers/InferTinyCellNet.py index d92c222..3320b9e 100644 --- a/lib/models/shape_infers/InferTinyCellNet.py +++ b/lib/models/shape_infers/InferTinyCellNet.py @@ -8,51 +8,57 @@ from models.cell_infers.cells import InferCell class DynamicShapeTinyNet(nn.Module): + def __init__(self, channels: List[int], genotype: Any, num_classes: int): + super(DynamicShapeTinyNet, self).__init__() + self._channels = channels + if len(channels) % 3 != 2: + raise ValueError("invalid number of layers : {:}".format(len(channels))) + self._num_stage = N = len(channels) // 3 - def __init__(self, channels: List[int], genotype: Any, num_classes: int): - super(DynamicShapeTinyNet, self).__init__() - self._channels = channels - if len(channels) % 3 != 2: - raise ValueError('invalid number of layers : {:}'.format(len(channels))) - self._num_stage = N = len(channels) // 3 + self.stem = nn.Sequential( + nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(channels[0]), + ) - self.stem = nn.Sequential( - nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(channels[0])) + # layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - # layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + c_prev = channels[0] + self.cells = nn.ModuleList() + for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)): + if reduction: + cell = ResNetBasicblock(c_prev, c_curr, 2, True) + else: + cell = InferCell(genotype, c_prev, c_curr, 1) + self.cells.append(cell) + c_prev = cell.out_dim + self._num_layer = len(self.cells) - c_prev = channels[0] - self.cells = nn.ModuleList() - for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)): - if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True) - else : cell = InferCell(genotype, c_prev, c_curr, 1) - self.cells.append( cell ) - c_prev = cell.out_dim - self._num_layer = len(self.cells) + self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(c_prev, num_classes) - self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(c_prev, num_classes) + def get_message(self) -> Text: + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self.cells), cell.extra_repr() + ) + return string - def get_message(self) -> Text: - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string + def extra_repr(self): + return "{name}(C={_channels}, N={_num_stage}, L={_num_layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) - def extra_repr(self): - return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__)) + def forward(self, inputs): + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + feature = cell(feature) - def forward(self, inputs): - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - feature = cell(feature) + out = self.lastact(feature) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) - - return out, logits + return out, logits diff --git a/lib/models/shape_infers/__init__.py b/lib/models/shape_infers/__init__.py index 0f6cf36..9c305ff 100644 --- a/lib/models/shape_infers/__init__.py +++ b/lib/models/shape_infers/__init__.py @@ -6,4 +6,4 @@ from .InferImagenetResNet import InferImagenetResNet from .InferCifarResNet_depth import InferDepthCifarResNet from .InferCifarResNet import InferCifarResNet from .InferMobileNetV2 import InferMobileNetV2 -from .InferTinyCellNet import DynamicShapeTinyNet \ No newline at end of file +from .InferTinyCellNet import DynamicShapeTinyNet diff --git a/lib/models/shape_infers/shared_utils.py b/lib/models/shape_infers/shared_utils.py index c29620c..86ab949 100644 --- a/lib/models/shape_infers/shared_utils.py +++ b/lib/models/shape_infers/shared_utils.py @@ -1,5 +1,5 @@ def parse_channel_info(xstring): - blocks = xstring.split(' ') - blocks = [x.split('-') for x in blocks] - blocks = [[int(_) for _ in x] for x in blocks] - return blocks + blocks = xstring.split(" ") + blocks = [x.split("-") for x in blocks] + blocks = [[int(_) for _ in x] for x in blocks] + return blocks diff --git a/lib/models/shape_searchs/SearchCifarResNet.py b/lib/models/shape_searchs/SearchCifarResNet.py index 828c052..653051b 100644 --- a/lib/models/shape_searchs/SearchCifarResNet.py +++ b/lib/models/shape_searchs/SearchCifarResNet.py @@ -6,497 +6,755 @@ from collections import OrderedDict from bisect import bisect_right import torch.nn as nn from ..initialization import initialize_resnet -from ..SharedUtils import additive_func -from .SoftSelect import select2withP, ChannelWiseInter -from .SoftSelect import linear_forward -from .SoftSelect import get_width_choices +from ..SharedUtils import additive_func +from .SoftSelect import select2withP, ChannelWiseInter +from .SoftSelect import linear_forward +from .SoftSelect import get_width_choices def get_depth_choices(nDepth, return_num): - if nDepth == 2: - choices = (1, 2) - elif nDepth == 3: - choices = (1, 2, 3) - elif nDepth > 3: - choices = list(range(1, nDepth+1, 2)) - if choices[-1] < nDepth: choices.append(nDepth) - else: - raise ValueError('invalid nDepth : {:}'.format(nDepth)) - if return_num: return len(choices) - else : return choices - + if nDepth == 2: + choices = (1, 2) + elif nDepth == 3: + choices = (1, 2, 3) + elif nDepth > 3: + choices = list(range(1, nDepth + 1, 2)) + if choices[-1] < nDepth: + choices.append(nDepth) + else: + raise ValueError("invalid nDepth : {:}".format(nDepth)) + if return_num: + return len(choices) + else: + return choices + def conv_forward(inputs, conv, choices): - iC = conv.in_channels - fill_size = list(inputs.size()) - fill_size[1] = iC - fill_size[1] - filled = torch.zeros(fill_size, device=inputs.device) - xinputs = torch.cat((inputs, filled), dim=1) - outputs = conv(xinputs) - selecteds = [outputs[:,:oC] for oC in choices] - return selecteds + iC = conv.in_channels + fill_size = list(inputs.size()) + fill_size[1] = iC - fill_size[1] + filled = torch.zeros(fill_size, device=inputs.device) + xinputs = torch.cat((inputs, filled), dim=1) + outputs = conv(xinputs) + selecteds = [outputs[:, :oC] for oC in choices] + return selecteds class ConvBNReLU(nn.Module): - num_conv = 1 - def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): - super(ConvBNReLU, self).__init__() - self.InShape = None - self.OutShape = None - self.choices = get_width_choices(nOut) - self.register_buffer('choices_tensor', torch.Tensor( self.choices )) + num_conv = 1 - if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - else : self.avg = None - self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) - #if has_bn : self.bn = nn.BatchNorm2d(nOut) - #else : self.bn = None - self.has_bn = has_bn - self.BNs = nn.ModuleList() - for i, _out in enumerate(self.choices): - self.BNs.append(nn.BatchNorm2d(_out)) - if has_relu: self.relu = nn.ReLU(inplace=True) - else : self.relu = None - self.in_dim = nIn - self.out_dim = nOut - self.search_mode = 'basic' + def __init__( + self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu + ): + super(ConvBNReLU, self).__init__() + self.InShape = None + self.OutShape = None + self.choices = get_width_choices(nOut) + self.register_buffer("choices_tensor", torch.Tensor(self.choices)) - def get_flops(self, channels, check_range=True, divide=1): - iC, oC = channels - if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels) - assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape) - assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape) - #conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups - conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups) - all_positions = self.OutShape[0] * self.OutShape[1] - flops = (conv_per_position_flops * all_positions / divide) * iC * oC - if self.conv.bias is not None: flops += all_positions / divide - return flops + if has_avg: + self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.avg = None + self.conv = nn.Conv2d( + nIn, + nOut, + kernel_size=kernel, + stride=stride, + padding=padding, + dilation=1, + groups=1, + bias=bias, + ) + # if has_bn : self.bn = nn.BatchNorm2d(nOut) + # else : self.bn = None + self.has_bn = has_bn + self.BNs = nn.ModuleList() + for i, _out in enumerate(self.choices): + self.BNs.append(nn.BatchNorm2d(_out)) + if has_relu: + self.relu = nn.ReLU(inplace=True) + else: + self.relu = None + self.in_dim = nIn + self.out_dim = nOut + self.search_mode = "basic" - def get_range(self): - return [self.choices] + def get_flops(self, channels, check_range=True, divide=1): + iC, oC = channels + if check_range: + assert ( + iC <= self.conv.in_channels and oC <= self.conv.out_channels + ), "{:} vs {:} | {:} vs {:}".format( + iC, self.conv.in_channels, oC, self.conv.out_channels + ) + assert ( + isinstance(self.InShape, tuple) and len(self.InShape) == 2 + ), "invalid in-shape : {:}".format(self.InShape) + assert ( + isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 + ), "invalid out-shape : {:}".format(self.OutShape) + # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups + conv_per_position_flops = ( + self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups + ) + all_positions = self.OutShape[0] * self.OutShape[1] + flops = (conv_per_position_flops * all_positions / divide) * iC * oC + if self.conv.bias is not None: + flops += all_positions / divide + return flops - def forward(self, inputs): - if self.search_mode == 'basic': - return self.basic_forward(inputs) - elif self.search_mode == 'search': - return self.search_forward(inputs) - else: - raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def get_range(self): + return [self.choices] - def search_forward(self, tuple_inputs): - assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) - inputs, expected_inC, probability, index, prob = tuple_inputs - index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) - probability = torch.squeeze(probability) - assert len(index) == 2, 'invalid length : {:}'.format(index) - # compute expected flop - #coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) - expected_outC = (self.choices_tensor * probability).sum() - expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) - if self.avg : out = self.avg( inputs ) - else : out = inputs - # convolutional layer - out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) - out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] - # merge - out_channel = max([x.size(1) for x in out_bns]) - outA = ChannelWiseInter(out_bns[0], out_channel) - outB = ChannelWiseInter(out_bns[1], out_channel) - out = outA * prob[0] + outB * prob[1] - #out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) - if self.relu: out = self.relu( out ) - else : out = out - return out, expected_outC, expected_flop + def search_forward(self, tuple_inputs): + assert ( + isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 + ), "invalid type input : {:}".format(type(tuple_inputs)) + inputs, expected_inC, probability, index, prob = tuple_inputs + index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) + probability = torch.squeeze(probability) + assert len(index) == 2, "invalid length : {:}".format(index) + # compute expected flop + # coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) + expected_outC = (self.choices_tensor * probability).sum() + expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) + if self.avg: + out = self.avg(inputs) + else: + out = inputs + # convolutional layer + out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) + out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] + # merge + out_channel = max([x.size(1) for x in out_bns]) + outA = ChannelWiseInter(out_bns[0], out_channel) + outB = ChannelWiseInter(out_bns[1], out_channel) + out = outA * prob[0] + outB * prob[1] + # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) - def basic_forward(self, inputs): - if self.avg : out = self.avg( inputs ) - else : out = inputs - conv = self.conv( out ) - if self.has_bn:out= self.BNs[-1]( conv ) - else : out = conv - if self.relu: out = self.relu( out ) - else : out = out - if self.InShape is None: - self.InShape = (inputs.size(-2), inputs.size(-1)) - self.OutShape = (out.size(-2) , out.size(-1)) - return out + if self.relu: + out = self.relu(out) + else: + out = out + return out, expected_outC, expected_flop + + def basic_forward(self, inputs): + if self.avg: + out = self.avg(inputs) + else: + out = inputs + conv = self.conv(out) + if self.has_bn: + out = self.BNs[-1](conv) + else: + out = conv + if self.relu: + out = self.relu(out) + else: + out = out + if self.InShape is None: + self.InShape = (inputs.size(-2), inputs.size(-1)) + self.OutShape = (out.size(-2), out.size(-1)) + return out class ResNetBasicblock(nn.Module): - expansion = 1 - num_conv = 2 - def __init__(self, inplanes, planes, stride): - super(ResNetBasicblock, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False) - if stride == 2: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) - elif inplanes != planes: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) - else: - self.downsample = None - self.out_dim = planes - self.search_mode = 'basic' + expansion = 1 + num_conv = 2 - def get_range(self): - return self.conv_a.get_range() + self.conv_b.get_range() + def __init__(self, inplanes, planes, stride): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_a = ConvBNReLU( + inplanes, + planes, + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_b = ConvBNReLU( + planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False + ) + if stride == 2: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + elif inplanes != planes: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + self.out_dim = planes + self.search_mode = "basic" - def get_flops(self, channels): - assert len(channels) == 3, 'invalid channels : {:}'.format(channels) - flop_A = self.conv_a.get_flops([channels[0], channels[1]]) - flop_B = self.conv_b.get_flops([channels[1], channels[2]]) - if hasattr(self.downsample, 'get_flops'): - flop_C = self.downsample.get_flops([channels[0], channels[-1]]) - else: - flop_C = 0 - if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train - flop_C = channels[0] * channels[-1] * self.conv_b.OutShape[0] * self.conv_b.OutShape[1] - return flop_A + flop_B + flop_C + def get_range(self): + return self.conv_a.get_range() + self.conv_b.get_range() - def forward(self, inputs): - if self.search_mode == 'basic' : return self.basic_forward(inputs) - elif self.search_mode == 'search': return self.search_forward(inputs) - else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def get_flops(self, channels): + assert len(channels) == 3, "invalid channels : {:}".format(channels) + flop_A = self.conv_a.get_flops([channels[0], channels[1]]) + flop_B = self.conv_b.get_flops([channels[1], channels[2]]) + if hasattr(self.downsample, "get_flops"): + flop_C = self.downsample.get_flops([channels[0], channels[-1]]) + else: + flop_C = 0 + if ( + channels[0] != channels[-1] and self.downsample is None + ): # this short-cut will be added during the infer-train + flop_C = ( + channels[0] + * channels[-1] + * self.conv_b.OutShape[0] + * self.conv_b.OutShape[1] + ) + return flop_A + flop_B + flop_C - def search_forward(self, tuple_inputs): - assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) - inputs, expected_inC, probability, indexes, probs = tuple_inputs - assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 - out_a, expected_inC_a, expected_flop_a = self.conv_a( (inputs, expected_inC , probability[0], indexes[0], probs[0]) ) - out_b, expected_inC_b, expected_flop_b = self.conv_b( (out_a , expected_inC_a, probability[1], indexes[1], probs[1]) ) - if self.downsample is not None: - residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[1], indexes[1], probs[1]) ) - else: - residual, expected_flop_c = inputs, 0 - out = additive_func(residual, out_b) - return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c]) + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) - def basic_forward(self, inputs): - basicblock = self.conv_a(inputs) - basicblock = self.conv_b(basicblock) - if self.downsample is not None: residual = self.downsample(inputs) - else : residual = inputs - out = additive_func(residual, basicblock) - return nn.functional.relu(out, inplace=True) + def search_forward(self, tuple_inputs): + assert ( + isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 + ), "invalid type input : {:}".format(type(tuple_inputs)) + inputs, expected_inC, probability, indexes, probs = tuple_inputs + assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 + out_a, expected_inC_a, expected_flop_a = self.conv_a( + (inputs, expected_inC, probability[0], indexes[0], probs[0]) + ) + out_b, expected_inC_b, expected_flop_b = self.conv_b( + (out_a, expected_inC_a, probability[1], indexes[1], probs[1]) + ) + if self.downsample is not None: + residual, _, expected_flop_c = self.downsample( + (inputs, expected_inC, probability[1], indexes[1], probs[1]) + ) + else: + residual, expected_flop_c = inputs, 0 + out = additive_func(residual, out_b) + return ( + nn.functional.relu(out, inplace=True), + expected_inC_b, + sum([expected_flop_a, expected_flop_b, expected_flop_c]), + ) + def basic_forward(self, inputs): + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = additive_func(residual, basicblock) + return nn.functional.relu(out, inplace=True) class ResNetBottleneck(nn.Module): - expansion = 4 - num_conv = 3 - def __init__(self, inplanes, planes, stride): - super(ResNetBottleneck, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False) - if stride == 2: - self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) - elif inplanes != planes*self.expansion: - self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) - else: - self.downsample = None - self.out_dim = planes * self.expansion - self.search_mode = 'basic' + expansion = 4 + num_conv = 3 - def get_range(self): - return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range() + def __init__(self, inplanes, planes, stride): + super(ResNetBottleneck, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_1x1 = ConvBNReLU( + inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True + ) + self.conv_3x3 = ConvBNReLU( + planes, + planes, + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_1x4 = ConvBNReLU( + planes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + if stride == 2: + self.downsample = ConvBNReLU( + inplanes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + elif inplanes != planes * self.expansion: + self.downsample = ConvBNReLU( + inplanes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + self.out_dim = planes * self.expansion + self.search_mode = "basic" - def get_flops(self, channels): - assert len(channels) == 4, 'invalid channels : {:}'.format(channels) - flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) - flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) - flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) - if hasattr(self.downsample, 'get_flops'): - flop_D = self.downsample.get_flops([channels[0], channels[-1]]) - else: - flop_D = 0 - if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train - flop_D = channels[0] * channels[-1] * self.conv_1x4.OutShape[0] * self.conv_1x4.OutShape[1] - return flop_A + flop_B + flop_C + flop_D + def get_range(self): + return ( + self.conv_1x1.get_range() + + self.conv_3x3.get_range() + + self.conv_1x4.get_range() + ) - def forward(self, inputs): - if self.search_mode == 'basic' : return self.basic_forward(inputs) - elif self.search_mode == 'search': return self.search_forward(inputs) - else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def get_flops(self, channels): + assert len(channels) == 4, "invalid channels : {:}".format(channels) + flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) + flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) + flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) + if hasattr(self.downsample, "get_flops"): + flop_D = self.downsample.get_flops([channels[0], channels[-1]]) + else: + flop_D = 0 + if ( + channels[0] != channels[-1] and self.downsample is None + ): # this short-cut will be added during the infer-train + flop_D = ( + channels[0] + * channels[-1] + * self.conv_1x4.OutShape[0] + * self.conv_1x4.OutShape[1] + ) + return flop_A + flop_B + flop_C + flop_D - def basic_forward(self, inputs): - bottleneck = self.conv_1x1(inputs) - bottleneck = self.conv_3x3(bottleneck) - bottleneck = self.conv_1x4(bottleneck) - if self.downsample is not None: residual = self.downsample(inputs) - else : residual = inputs - out = additive_func(residual, bottleneck) - return nn.functional.relu(out, inplace=True) + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) - def search_forward(self, tuple_inputs): - assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) - inputs, expected_inC, probability, indexes, probs = tuple_inputs - assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 - out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( (inputs, expected_inC , probability[0], indexes[0], probs[0]) ) - out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( (out_1x1,expected_inC_1x1, probability[1], indexes[1], probs[1]) ) - out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( (out_3x3,expected_inC_3x3, probability[2], indexes[2], probs[2]) ) - if self.downsample is not None: - residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[2], indexes[2], probs[2]) ) - else: - residual, expected_flop_c = inputs, 0 - out = additive_func(residual, out_1x4) - return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c]) + def basic_forward(self, inputs): + bottleneck = self.conv_1x1(inputs) + bottleneck = self.conv_3x3(bottleneck) + bottleneck = self.conv_1x4(bottleneck) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = additive_func(residual, bottleneck) + return nn.functional.relu(out, inplace=True) + + def search_forward(self, tuple_inputs): + assert ( + isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 + ), "invalid type input : {:}".format(type(tuple_inputs)) + inputs, expected_inC, probability, indexes, probs = tuple_inputs + assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 + out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( + (inputs, expected_inC, probability[0], indexes[0], probs[0]) + ) + out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( + (out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1]) + ) + out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( + (out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2]) + ) + if self.downsample is not None: + residual, _, expected_flop_c = self.downsample( + (inputs, expected_inC, probability[2], indexes[2], probs[2]) + ) + else: + residual, expected_flop_c = inputs, 0 + out = additive_func(residual, out_1x4) + return ( + nn.functional.relu(out, inplace=True), + expected_inC_1x4, + sum( + [ + expected_flop_1x1, + expected_flop_3x3, + expected_flop_1x4, + expected_flop_c, + ] + ), + ) class SearchShapeCifarResNet(nn.Module): + def __init__(self, block_name, depth, num_classes): + super(SearchShapeCifarResNet, self).__init__() - def __init__(self, block_name, depth, num_classes): - super(SearchShapeCifarResNet, self).__init__() - - #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model - if block_name == 'ResNetBasicblock': - block = ResNetBasicblock - assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' - layer_blocks = (depth - 2) // 6 - elif block_name == 'ResNetBottleneck': - block = ResNetBottleneck - assert (depth - 2) % 9 == 0, 'depth should be one of 164' - layer_blocks = (depth - 2) // 9 - else: - raise ValueError('invalid block : {:}'.format(block_name)) - - self.message = 'SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) - self.num_classes = num_classes - self.channels = [16] - self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) - self.InShape = None - self.depth_info = OrderedDict() - self.depth_at_i = OrderedDict() - for stage in range(3): - cur_block_choices = get_depth_choices(layer_blocks, False) - assert cur_block_choices[-1] == layer_blocks, 'stage={:}, {:} vs {:}'.format(stage, cur_block_choices, layer_blocks) - self.message += "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(stage, cur_block_choices, layer_blocks) - block_choices, xstart = [], len(self.layers) - for iL in range(layer_blocks): - iC = self.channels[-1] - planes = 16 * (2**stage) - stride = 2 if stage > 0 and iL == 0 else 1 - module = block(iC, planes, stride) - self.channels.append( module.out_dim ) - self.layers.append ( module ) - self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride) - # added for depth - layer_index = len(self.layers) - 1 - if iL + 1 in cur_block_choices: block_choices.append( layer_index ) - if iL + 1 == layer_blocks: - self.depth_info[layer_index] = {'choices': block_choices, - 'stage' : stage, - 'xstart' : xstart} - self.depth_info_list = [] - for xend, info in self.depth_info.items(): - self.depth_info_list.append( (xend, info) ) - xstart, xstage = info['xstart'], info['stage'] - for ilayer in range(xstart, xend+1): - idx = bisect_right(info['choices'], ilayer-1) - self.depth_at_i[ilayer] = (xstage, idx) - - self.avgpool = nn.AvgPool2d(8) - self.classifier = nn.Linear(module.out_dim, num_classes) - self.InShape = None - self.tau = -1 - self.search_mode = 'basic' - #assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) - - # parameters for width - self.Ranges = [] - self.layer2indexRange = [] - for i, layer in enumerate(self.layers): - start_index = len(self.Ranges) - self.Ranges += layer.get_range() - self.layer2indexRange.append( (start_index, len(self.Ranges)) ) - assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth) - - self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None)))) - self.register_parameter('depth_attentions', nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True)))) - nn.init.normal_(self.width_attentions, 0, 0.01) - nn.init.normal_(self.depth_attentions, 0, 0.01) - self.apply(initialize_resnet) - - def arch_parameters(self, LR=None): - if LR is None: - return [self.width_attentions, self.depth_attentions] - else: - return [ - {"params": self.width_attentions, "lr": LR}, - {"params": self.depth_attentions, "lr": LR}, - ] - - def base_parameters(self): - return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters()) - - def get_flop(self, mode, config_dict, extra_info): - if config_dict is not None: config_dict = config_dict.copy() - # select channels - channels = [3] - for i, weight in enumerate(self.width_attentions): - if mode == 'genotype': - with torch.no_grad(): - probe = nn.functional.softmax(weight, dim=0) - C = self.Ranges[i][ torch.argmax(probe).item() ] - elif mode == 'max': - C = self.Ranges[i][-1] - elif mode == 'fix': - C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] ) - elif mode == 'random': - assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info) - with torch.no_grad(): - prob = nn.functional.softmax(weight, dim=0) - approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] ) - for j in range(prob.size(0)): - prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2) - C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ] - else: - raise ValueError('invalid mode : {:}'.format(mode)) - channels.append( C ) - # select depth - if mode == 'genotype': - with torch.no_grad(): - depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) - choices = torch.argmax(depth_probs, dim=1).cpu().tolist() - elif mode == 'max' or mode == 'fix': - choices = [depth_probs.size(1)-1 for _ in range(depth_probs.size(0))] - elif mode == 'random': - with torch.no_grad(): - depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) - choices = torch.multinomial(depth_probs, 1, False).cpu().tolist() - else: - raise ValueError('invalid mode : {:}'.format(mode)) - selected_layers = [] - for choice, xvalue in zip(choices, self.depth_info_list): - xtemp = xvalue[1]['choices'][choice] - xvalue[1]['xstart'] + 1 - selected_layers.append(xtemp) - flop = 0 - for i, layer in enumerate(self.layers): - s, e = self.layer2indexRange[i] - xchl = tuple( channels[s:e+1] ) - if i in self.depth_at_i: - xstagei, xatti = self.depth_at_i[i] - if xatti <= choices[xstagei]: # leave this depth - flop+= layer.get_flops(xchl) + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + if block_name == "ResNetBasicblock": + block = ResNetBasicblock + assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" + layer_blocks = (depth - 2) // 6 + elif block_name == "ResNetBottleneck": + block = ResNetBottleneck + assert (depth - 2) % 9 == 0, "depth should be one of 164" + layer_blocks = (depth - 2) // 9 else: - flop+= 0 # do not use this layer - else: - flop+= layer.get_flops(xchl) - # the last fc layer - flop += channels[-1] * self.classifier.out_features - if config_dict is None: - return flop / 1e6 - else: - config_dict['xchannels'] = channels - config_dict['xblocks'] = selected_layers - config_dict['super_type'] = 'infer-shape' - config_dict['estimated_FLOP'] = flop / 1e6 - return flop / 1e6, config_dict + raise ValueError("invalid block : {:}".format(block_name)) - def get_arch_info(self): - string = "for depth and width, there are {:} + {:} attention probabilities.".format(len(self.depth_attentions), len(self.width_attentions)) - string+= '\n{:}'.format(self.depth_info) - discrepancy = [] - with torch.no_grad(): - for i, att in enumerate(self.depth_attentions): - prob = nn.functional.softmax(att, dim=0) - prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist() - prob = ['{:.3f}'.format(x) for x in prob] - xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.depth_attentions), ' '.join(prob)) - logt = ['{:.4f}'.format(x) for x in att.cpu().tolist()] - xstring += ' || {:17s}'.format(' '.join(logt)) - prob = sorted( [float(x) for x in prob] ) - disc = prob[-1] - prob[-2] - xstring += ' || discrepancy={:.2f} || select={:}/{:}'.format(disc, selc, len(prob)) - discrepancy.append( disc ) - string += '\n{:}'.format(xstring) - string += '\n-----------------------------------------------' - for i, att in enumerate(self.width_attentions): - prob = nn.functional.softmax(att, dim=0) - prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist() - prob = ['{:.3f}'.format(x) for x in prob] - xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob)) - logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()] - xstring += ' || {:52s}'.format(' '.join(logt)) - prob = sorted( [float(x) for x in prob] ) - disc = prob[-1] - prob[-2] - xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob)) - discrepancy.append( disc ) - string += '\n{:}'.format(xstring) - return string, discrepancy + self.message = ( + "SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format( + depth, layer_blocks + ) + ) + self.num_classes = num_classes + self.channels = [16] + self.layers = nn.ModuleList( + [ + ConvBNReLU( + 3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True + ) + ] + ) + self.InShape = None + self.depth_info = OrderedDict() + self.depth_at_i = OrderedDict() + for stage in range(3): + cur_block_choices = get_depth_choices(layer_blocks, False) + assert ( + cur_block_choices[-1] == layer_blocks + ), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks) + self.message += ( + "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format( + stage, cur_block_choices, layer_blocks + ) + ) + block_choices, xstart = [], len(self.layers) + for iL in range(layer_blocks): + iC = self.channels[-1] + planes = 16 * (2 ** stage) + stride = 2 if stage > 0 and iL == 0 else 1 + module = block(iC, planes, stride) + self.channels.append(module.out_dim) + self.layers.append(module) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( + stage, + iL, + layer_blocks, + len(self.layers) - 1, + iC, + module.out_dim, + stride, + ) + # added for depth + layer_index = len(self.layers) - 1 + if iL + 1 in cur_block_choices: + block_choices.append(layer_index) + if iL + 1 == layer_blocks: + self.depth_info[layer_index] = { + "choices": block_choices, + "stage": stage, + "xstart": xstart, + } + self.depth_info_list = [] + for xend, info in self.depth_info.items(): + self.depth_info_list.append((xend, info)) + xstart, xstage = info["xstart"], info["stage"] + for ilayer in range(xstart, xend + 1): + idx = bisect_right(info["choices"], ilayer - 1) + self.depth_at_i[ilayer] = (xstage, idx) - def set_tau(self, tau_max, tau_min, epoch_ratio): - assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio) - tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 - self.tau = tau + self.avgpool = nn.AvgPool2d(8) + self.classifier = nn.Linear(module.out_dim, num_classes) + self.InShape = None + self.tau = -1 + self.search_mode = "basic" + # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) - def get_message(self): - return self.message + # parameters for width + self.Ranges = [] + self.layer2indexRange = [] + for i, layer in enumerate(self.layers): + start_index = len(self.Ranges) + self.Ranges += layer.get_range() + self.layer2indexRange.append((start_index, len(self.Ranges))) + assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format( + len(self.Ranges) + 1, depth + ) - def forward(self, inputs): - if self.search_mode == 'basic': - return self.basic_forward(inputs) - elif self.search_mode == 'search': - return self.search_forward(inputs) - else: - raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + self.register_parameter( + "width_attentions", + nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))), + ) + self.register_parameter( + "depth_attentions", + nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))), + ) + nn.init.normal_(self.width_attentions, 0, 0.01) + nn.init.normal_(self.depth_attentions, 0, 0.01) + self.apply(initialize_resnet) - def search_forward(self, inputs): - flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1) - flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) - flop_depth_probs = torch.flip( torch.cumsum( torch.flip(flop_depth_probs, [1]), 1 ), [1] ) - selected_widths, selected_width_probs = select2withP(self.width_attentions, self.tau) - selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) - with torch.no_grad(): - selected_widths = selected_widths.cpu() + def arch_parameters(self, LR=None): + if LR is None: + return [self.width_attentions, self.depth_attentions] + else: + return [ + {"params": self.width_attentions, "lr": LR}, + {"params": self.depth_attentions, "lr": LR}, + ] - x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] - feature_maps = [] - for i, layer in enumerate(self.layers): - selected_w_index = selected_widths [last_channel_idx: last_channel_idx+layer.num_conv] - selected_w_probs = selected_width_probs[last_channel_idx: last_channel_idx+layer.num_conv] - layer_prob = flop_width_probs [last_channel_idx: last_channel_idx+layer.num_conv] - x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) ) - feature_maps.append( x ) - last_channel_idx += layer.num_conv - if i in self.depth_info: # aggregate the information - choices = self.depth_info[i]['choices'] - xstagei = self.depth_info[i]['stage'] - #print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist())) - #for A, W in zip(choices, selected_depth_probs[xstagei]): - # print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W)) - possible_tensors = [] - max_C = max( feature_maps[A].size(1) for A in choices ) - for tempi, A in enumerate(choices): - xtensor = ChannelWiseInter(feature_maps[A], max_C) - #drop_ratio = 1-(tempi+1.0)/len(choices) - #xtensor = drop_path(xtensor, drop_ratio) - possible_tensors.append( xtensor ) - weighted_sum = sum( xtensor * W for xtensor, W in zip(possible_tensors, selected_depth_probs[xstagei]) ) - x = weighted_sum - - if i in self.depth_at_i: - xstagei, xatti = self.depth_at_i[i] - x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop - else: - x_expected_flop = expected_flop - flops.append( x_expected_flop ) - flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = linear_forward(features, self.classifier) - return logits, torch.stack( [sum(flops)] ) + def base_parameters(self): + return ( + list(self.layers.parameters()) + + list(self.avgpool.parameters()) + + list(self.classifier.parameters()) + ) - def basic_forward(self, inputs): - if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1)) - x = inputs - for i, layer in enumerate(self.layers): - x = layer( x ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = self.classifier(features) - return features, logits + def get_flop(self, mode, config_dict, extra_info): + if config_dict is not None: + config_dict = config_dict.copy() + # select channels + channels = [3] + for i, weight in enumerate(self.width_attentions): + if mode == "genotype": + with torch.no_grad(): + probe = nn.functional.softmax(weight, dim=0) + C = self.Ranges[i][torch.argmax(probe).item()] + elif mode == "max": + C = self.Ranges[i][-1] + elif mode == "fix": + C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) + elif mode == "random": + assert isinstance(extra_info, float), "invalid extra_info : {:}".format( + extra_info + ) + with torch.no_grad(): + prob = nn.functional.softmax(weight, dim=0) + approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) + for j in range(prob.size(0)): + prob[j] = 1 / ( + abs(j - (approximate_C - self.Ranges[i][j])) + 0.2 + ) + C = self.Ranges[i][torch.multinomial(prob, 1, False).item()] + else: + raise ValueError("invalid mode : {:}".format(mode)) + channels.append(C) + # select depth + if mode == "genotype": + with torch.no_grad(): + depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) + choices = torch.argmax(depth_probs, dim=1).cpu().tolist() + elif mode == "max" or mode == "fix": + choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))] + elif mode == "random": + with torch.no_grad(): + depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) + choices = torch.multinomial(depth_probs, 1, False).cpu().tolist() + else: + raise ValueError("invalid mode : {:}".format(mode)) + selected_layers = [] + for choice, xvalue in zip(choices, self.depth_info_list): + xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1 + selected_layers.append(xtemp) + flop = 0 + for i, layer in enumerate(self.layers): + s, e = self.layer2indexRange[i] + xchl = tuple(channels[s : e + 1]) + if i in self.depth_at_i: + xstagei, xatti = self.depth_at_i[i] + if xatti <= choices[xstagei]: # leave this depth + flop += layer.get_flops(xchl) + else: + flop += 0 # do not use this layer + else: + flop += layer.get_flops(xchl) + # the last fc layer + flop += channels[-1] * self.classifier.out_features + if config_dict is None: + return flop / 1e6 + else: + config_dict["xchannels"] = channels + config_dict["xblocks"] = selected_layers + config_dict["super_type"] = "infer-shape" + config_dict["estimated_FLOP"] = flop / 1e6 + return flop / 1e6, config_dict + + def get_arch_info(self): + string = ( + "for depth and width, there are {:} + {:} attention probabilities.".format( + len(self.depth_attentions), len(self.width_attentions) + ) + ) + string += "\n{:}".format(self.depth_info) + discrepancy = [] + with torch.no_grad(): + for i, att in enumerate(self.depth_attentions): + prob = nn.functional.softmax(att, dim=0) + prob = prob.cpu() + selc = prob.argmax().item() + prob = prob.tolist() + prob = ["{:.3f}".format(x) for x in prob] + xstring = "{:03d}/{:03d}-th : {:}".format( + i, len(self.depth_attentions), " ".join(prob) + ) + logt = ["{:.4f}".format(x) for x in att.cpu().tolist()] + xstring += " || {:17s}".format(" ".join(logt)) + prob = sorted([float(x) for x in prob]) + disc = prob[-1] - prob[-2] + xstring += " || discrepancy={:.2f} || select={:}/{:}".format( + disc, selc, len(prob) + ) + discrepancy.append(disc) + string += "\n{:}".format(xstring) + string += "\n-----------------------------------------------" + for i, att in enumerate(self.width_attentions): + prob = nn.functional.softmax(att, dim=0) + prob = prob.cpu() + selc = prob.argmax().item() + prob = prob.tolist() + prob = ["{:.3f}".format(x) for x in prob] + xstring = "{:03d}/{:03d}-th : {:}".format( + i, len(self.width_attentions), " ".join(prob) + ) + logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] + xstring += " || {:52s}".format(" ".join(logt)) + prob = sorted([float(x) for x in prob]) + disc = prob[-1] - prob[-2] + xstring += " || dis={:.2f} || select={:}/{:}".format( + disc, selc, len(prob) + ) + discrepancy.append(disc) + string += "\n{:}".format(xstring) + return string, discrepancy + + def set_tau(self, tau_max, tau_min, epoch_ratio): + assert ( + epoch_ratio >= 0 and epoch_ratio <= 1 + ), "invalid epoch-ratio : {:}".format(epoch_ratio) + tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 + self.tau = tau + + def get_message(self): + return self.message + + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) + + def search_forward(self, inputs): + flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1) + flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) + flop_depth_probs = torch.flip( + torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1] + ) + selected_widths, selected_width_probs = select2withP( + self.width_attentions, self.tau + ) + selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) + with torch.no_grad(): + selected_widths = selected_widths.cpu() + + x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] + feature_maps = [] + for i, layer in enumerate(self.layers): + selected_w_index = selected_widths[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + selected_w_probs = selected_width_probs[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + layer_prob = flop_width_probs[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + x, expected_inC, expected_flop = layer( + (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) + ) + feature_maps.append(x) + last_channel_idx += layer.num_conv + if i in self.depth_info: # aggregate the information + choices = self.depth_info[i]["choices"] + xstagei = self.depth_info[i]["stage"] + # print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist())) + # for A, W in zip(choices, selected_depth_probs[xstagei]): + # print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W)) + possible_tensors = [] + max_C = max(feature_maps[A].size(1) for A in choices) + for tempi, A in enumerate(choices): + xtensor = ChannelWiseInter(feature_maps[A], max_C) + # drop_ratio = 1-(tempi+1.0)/len(choices) + # xtensor = drop_path(xtensor, drop_ratio) + possible_tensors.append(xtensor) + weighted_sum = sum( + xtensor * W + for xtensor, W in zip( + possible_tensors, selected_depth_probs[xstagei] + ) + ) + x = weighted_sum + + if i in self.depth_at_i: + xstagei, xatti = self.depth_at_i[i] + x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop + else: + x_expected_flop = expected_flop + flops.append(x_expected_flop) + flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = linear_forward(features, self.classifier) + return logits, torch.stack([sum(flops)]) + + def basic_forward(self, inputs): + if self.InShape is None: + self.InShape = (inputs.size(-2), inputs.size(-1)) + x = inputs + for i, layer in enumerate(self.layers): + x = layer(x) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.classifier(features) + return features, logits diff --git a/lib/models/shape_searchs/SearchCifarResNet_depth.py b/lib/models/shape_searchs/SearchCifarResNet_depth.py index 9395e7d..24c5d83 100644 --- a/lib/models/shape_searchs/SearchCifarResNet_depth.py +++ b/lib/models/shape_searchs/SearchCifarResNet_depth.py @@ -6,335 +6,510 @@ from collections import OrderedDict from bisect import bisect_right import torch.nn as nn from ..initialization import initialize_resnet -from ..SharedUtils import additive_func -from .SoftSelect import select2withP, ChannelWiseInter -from .SoftSelect import linear_forward -from .SoftSelect import get_width_choices +from ..SharedUtils import additive_func +from .SoftSelect import select2withP, ChannelWiseInter +from .SoftSelect import linear_forward +from .SoftSelect import get_width_choices def get_depth_choices(nDepth, return_num): - if nDepth == 2: - choices = (1, 2) - elif nDepth == 3: - choices = (1, 2, 3) - elif nDepth > 3: - choices = list(range(1, nDepth+1, 2)) - if choices[-1] < nDepth: choices.append(nDepth) - else: - raise ValueError('invalid nDepth : {:}'.format(nDepth)) - if return_num: return len(choices) - else : return choices - + if nDepth == 2: + choices = (1, 2) + elif nDepth == 3: + choices = (1, 2, 3) + elif nDepth > 3: + choices = list(range(1, nDepth + 1, 2)) + if choices[-1] < nDepth: + choices.append(nDepth) + else: + raise ValueError("invalid nDepth : {:}".format(nDepth)) + if return_num: + return len(choices) + else: + return choices class ConvBNReLU(nn.Module): - num_conv = 1 - def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): - super(ConvBNReLU, self).__init__() - self.InShape = None - self.OutShape = None - self.choices = get_width_choices(nOut) - self.register_buffer('choices_tensor', torch.Tensor( self.choices )) + num_conv = 1 - if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - else : self.avg = None - self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) - if has_bn : self.bn = nn.BatchNorm2d(nOut) - else : self.bn = None - if has_relu: self.relu = nn.ReLU(inplace=False) - else : self.relu = None - self.in_dim = nIn - self.out_dim = nOut + def __init__( + self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu + ): + super(ConvBNReLU, self).__init__() + self.InShape = None + self.OutShape = None + self.choices = get_width_choices(nOut) + self.register_buffer("choices_tensor", torch.Tensor(self.choices)) - def get_flops(self, divide=1): - iC, oC = self.in_dim, self.out_dim - assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels) - assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape) - assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape) - #conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups - conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups) - all_positions = self.OutShape[0] * self.OutShape[1] - flops = (conv_per_position_flops * all_positions / divide) * iC * oC - if self.conv.bias is not None: flops += all_positions / divide - return flops + if has_avg: + self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.avg = None + self.conv = nn.Conv2d( + nIn, + nOut, + kernel_size=kernel, + stride=stride, + padding=padding, + dilation=1, + groups=1, + bias=bias, + ) + if has_bn: + self.bn = nn.BatchNorm2d(nOut) + else: + self.bn = None + if has_relu: + self.relu = nn.ReLU(inplace=False) + else: + self.relu = None + self.in_dim = nIn + self.out_dim = nOut - def forward(self, inputs): - if self.avg : out = self.avg( inputs ) - else : out = inputs - conv = self.conv( out ) - if self.bn : out = self.bn( conv ) - else : out = conv - if self.relu: out = self.relu( out ) - else : out = out - if self.InShape is None: - self.InShape = (inputs.size(-2), inputs.size(-1)) - self.OutShape = (out.size(-2) , out.size(-1)) - return out + def get_flops(self, divide=1): + iC, oC = self.in_dim, self.out_dim + assert ( + iC <= self.conv.in_channels and oC <= self.conv.out_channels + ), "{:} vs {:} | {:} vs {:}".format( + iC, self.conv.in_channels, oC, self.conv.out_channels + ) + assert ( + isinstance(self.InShape, tuple) and len(self.InShape) == 2 + ), "invalid in-shape : {:}".format(self.InShape) + assert ( + isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 + ), "invalid out-shape : {:}".format(self.OutShape) + # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups + conv_per_position_flops = ( + self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups + ) + all_positions = self.OutShape[0] * self.OutShape[1] + flops = (conv_per_position_flops * all_positions / divide) * iC * oC + if self.conv.bias is not None: + flops += all_positions / divide + return flops + + def forward(self, inputs): + if self.avg: + out = self.avg(inputs) + else: + out = inputs + conv = self.conv(out) + if self.bn: + out = self.bn(conv) + else: + out = conv + if self.relu: + out = self.relu(out) + else: + out = out + if self.InShape is None: + self.InShape = (inputs.size(-2), inputs.size(-1)) + self.OutShape = (out.size(-2), out.size(-1)) + return out class ResNetBasicblock(nn.Module): - expansion = 1 - num_conv = 2 - def __init__(self, inplanes, planes, stride): - super(ResNetBasicblock, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False) - if stride == 2: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) - elif inplanes != planes: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) - else: - self.downsample = None - self.out_dim = planes - self.search_mode = 'basic' + expansion = 1 + num_conv = 2 - def get_flops(self, divide=1): - flop_A = self.conv_a.get_flops(divide) - flop_B = self.conv_b.get_flops(divide) - if hasattr(self.downsample, 'get_flops'): - flop_C = self.downsample.get_flops(divide) - else: - flop_C = 0 - return flop_A + flop_B + flop_C + def __init__(self, inplanes, planes, stride): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_a = ConvBNReLU( + inplanes, + planes, + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_b = ConvBNReLU( + planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False + ) + if stride == 2: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + elif inplanes != planes: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + self.out_dim = planes + self.search_mode = "basic" - def forward(self, inputs): - basicblock = self.conv_a(inputs) - basicblock = self.conv_b(basicblock) - if self.downsample is not None: residual = self.downsample(inputs) - else : residual = inputs - out = additive_func(residual, basicblock) - return nn.functional.relu(out, inplace=True) + def get_flops(self, divide=1): + flop_A = self.conv_a.get_flops(divide) + flop_B = self.conv_b.get_flops(divide) + if hasattr(self.downsample, "get_flops"): + flop_C = self.downsample.get_flops(divide) + else: + flop_C = 0 + return flop_A + flop_B + flop_C + def forward(self, inputs): + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = additive_func(residual, basicblock) + return nn.functional.relu(out, inplace=True) class ResNetBottleneck(nn.Module): - expansion = 4 - num_conv = 3 - def __init__(self, inplanes, planes, stride): - super(ResNetBottleneck, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False) - if stride == 2: - self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) - elif inplanes != planes*self.expansion: - self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) - else: - self.downsample = None - self.out_dim = planes * self.expansion - self.search_mode = 'basic' + expansion = 4 + num_conv = 3 - def get_range(self): - return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range() + def __init__(self, inplanes, planes, stride): + super(ResNetBottleneck, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_1x1 = ConvBNReLU( + inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True + ) + self.conv_3x3 = ConvBNReLU( + planes, + planes, + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_1x4 = ConvBNReLU( + planes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + if stride == 2: + self.downsample = ConvBNReLU( + inplanes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + elif inplanes != planes * self.expansion: + self.downsample = ConvBNReLU( + inplanes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + self.out_dim = planes * self.expansion + self.search_mode = "basic" - def get_flops(self, divide): - flop_A = self.conv_1x1.get_flops(divide) - flop_B = self.conv_3x3.get_flops(divide) - flop_C = self.conv_1x4.get_flops(divide) - if hasattr(self.downsample, 'get_flops'): - flop_D = self.downsample.get_flops(divide) - else: - flop_D = 0 - return flop_A + flop_B + flop_C + flop_D + def get_range(self): + return ( + self.conv_1x1.get_range() + + self.conv_3x3.get_range() + + self.conv_1x4.get_range() + ) - def forward(self, inputs): - bottleneck = self.conv_1x1(inputs) - bottleneck = self.conv_3x3(bottleneck) - bottleneck = self.conv_1x4(bottleneck) - if self.downsample is not None: residual = self.downsample(inputs) - else : residual = inputs - out = additive_func(residual, bottleneck) - return nn.functional.relu(out, inplace=True) + def get_flops(self, divide): + flop_A = self.conv_1x1.get_flops(divide) + flop_B = self.conv_3x3.get_flops(divide) + flop_C = self.conv_1x4.get_flops(divide) + if hasattr(self.downsample, "get_flops"): + flop_D = self.downsample.get_flops(divide) + else: + flop_D = 0 + return flop_A + flop_B + flop_C + flop_D + + def forward(self, inputs): + bottleneck = self.conv_1x1(inputs) + bottleneck = self.conv_3x3(bottleneck) + bottleneck = self.conv_1x4(bottleneck) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = additive_func(residual, bottleneck) + return nn.functional.relu(out, inplace=True) class SearchDepthCifarResNet(nn.Module): + def __init__(self, block_name, depth, num_classes): + super(SearchDepthCifarResNet, self).__init__() - def __init__(self, block_name, depth, num_classes): - super(SearchDepthCifarResNet, self).__init__() - - #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model - if block_name == 'ResNetBasicblock': - block = ResNetBasicblock - assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' - layer_blocks = (depth - 2) // 6 - elif block_name == 'ResNetBottleneck': - block = ResNetBottleneck - assert (depth - 2) % 9 == 0, 'depth should be one of 164' - layer_blocks = (depth - 2) // 9 - else: - raise ValueError('invalid block : {:}'.format(block_name)) - - self.message = 'SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) - self.num_classes = num_classes - self.channels = [16] - self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) - self.InShape = None - self.depth_info = OrderedDict() - self.depth_at_i = OrderedDict() - for stage in range(3): - cur_block_choices = get_depth_choices(layer_blocks, False) - assert cur_block_choices[-1] == layer_blocks, 'stage={:}, {:} vs {:}'.format(stage, cur_block_choices, layer_blocks) - self.message += "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(stage, cur_block_choices, layer_blocks) - block_choices, xstart = [], len(self.layers) - for iL in range(layer_blocks): - iC = self.channels[-1] - planes = 16 * (2**stage) - stride = 2 if stage > 0 and iL == 0 else 1 - module = block(iC, planes, stride) - self.channels.append( module.out_dim ) - self.layers.append ( module ) - self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride) - # added for depth - layer_index = len(self.layers) - 1 - if iL + 1 in cur_block_choices: block_choices.append( layer_index ) - if iL + 1 == layer_blocks: - self.depth_info[layer_index] = {'choices': block_choices, - 'stage' : stage, - 'xstart' : xstart} - self.depth_info_list = [] - for xend, info in self.depth_info.items(): - self.depth_info_list.append( (xend, info) ) - xstart, xstage = info['xstart'], info['stage'] - for ilayer in range(xstart, xend+1): - idx = bisect_right(info['choices'], ilayer-1) - self.depth_at_i[ilayer] = (xstage, idx) - - self.avgpool = nn.AvgPool2d(8) - self.classifier = nn.Linear(module.out_dim, num_classes) - self.InShape = None - self.tau = -1 - self.search_mode = 'basic' - #assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) - - - self.register_parameter('depth_attentions', nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True)))) - nn.init.normal_(self.depth_attentions, 0, 0.01) - self.apply(initialize_resnet) - - def arch_parameters(self): - return [self.depth_attentions] - - def base_parameters(self): - return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters()) - - def get_flop(self, mode, config_dict, extra_info): - if config_dict is not None: config_dict = config_dict.copy() - # select depth - if mode == 'genotype': - with torch.no_grad(): - depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) - choices = torch.argmax(depth_probs, dim=1).cpu().tolist() - elif mode == 'max': - choices = [depth_probs.size(1)-1 for _ in range(depth_probs.size(0))] - elif mode == 'random': - with torch.no_grad(): - depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) - choices = torch.multinomial(depth_probs, 1, False).cpu().tolist() - else: - raise ValueError('invalid mode : {:}'.format(mode)) - selected_layers = [] - for choice, xvalue in zip(choices, self.depth_info_list): - xtemp = xvalue[1]['choices'][choice] - xvalue[1]['xstart'] + 1 - selected_layers.append(xtemp) - flop = 0 - for i, layer in enumerate(self.layers): - if i in self.depth_at_i: - xstagei, xatti = self.depth_at_i[i] - if xatti <= choices[xstagei]: # leave this depth - flop+= layer.get_flops() + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + if block_name == "ResNetBasicblock": + block = ResNetBasicblock + assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" + layer_blocks = (depth - 2) // 6 + elif block_name == "ResNetBottleneck": + block = ResNetBottleneck + assert (depth - 2) % 9 == 0, "depth should be one of 164" + layer_blocks = (depth - 2) // 9 else: - flop+= 0 # do not use this layer - else: - flop+= layer.get_flops() - # the last fc layer - flop += self.classifier.in_features * self.classifier.out_features - if config_dict is None: - return flop / 1e6 - else: - config_dict['xblocks'] = selected_layers - config_dict['super_type'] = 'infer-depth' - config_dict['estimated_FLOP'] = flop / 1e6 - return flop / 1e6, config_dict + raise ValueError("invalid block : {:}".format(block_name)) - def get_arch_info(self): - string = "for depth, there are {:} attention probabilities.".format(len(self.depth_attentions)) - string+= '\n{:}'.format(self.depth_info) - discrepancy = [] - with torch.no_grad(): - for i, att in enumerate(self.depth_attentions): - prob = nn.functional.softmax(att, dim=0) - prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist() - prob = ['{:.3f}'.format(x) for x in prob] - xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.depth_attentions), ' '.join(prob)) - logt = ['{:.4f}'.format(x) for x in att.cpu().tolist()] - xstring += ' || {:17s}'.format(' '.join(logt)) - prob = sorted( [float(x) for x in prob] ) - disc = prob[-1] - prob[-2] - xstring += ' || discrepancy={:.2f} || select={:}/{:}'.format(disc, selc, len(prob)) - discrepancy.append( disc ) - string += '\n{:}'.format(xstring) - return string, discrepancy + self.message = ( + "SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format( + depth, layer_blocks + ) + ) + self.num_classes = num_classes + self.channels = [16] + self.layers = nn.ModuleList( + [ + ConvBNReLU( + 3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True + ) + ] + ) + self.InShape = None + self.depth_info = OrderedDict() + self.depth_at_i = OrderedDict() + for stage in range(3): + cur_block_choices = get_depth_choices(layer_blocks, False) + assert ( + cur_block_choices[-1] == layer_blocks + ), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks) + self.message += ( + "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format( + stage, cur_block_choices, layer_blocks + ) + ) + block_choices, xstart = [], len(self.layers) + for iL in range(layer_blocks): + iC = self.channels[-1] + planes = 16 * (2 ** stage) + stride = 2 if stage > 0 and iL == 0 else 1 + module = block(iC, planes, stride) + self.channels.append(module.out_dim) + self.layers.append(module) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( + stage, + iL, + layer_blocks, + len(self.layers) - 1, + iC, + module.out_dim, + stride, + ) + # added for depth + layer_index = len(self.layers) - 1 + if iL + 1 in cur_block_choices: + block_choices.append(layer_index) + if iL + 1 == layer_blocks: + self.depth_info[layer_index] = { + "choices": block_choices, + "stage": stage, + "xstart": xstart, + } + self.depth_info_list = [] + for xend, info in self.depth_info.items(): + self.depth_info_list.append((xend, info)) + xstart, xstage = info["xstart"], info["stage"] + for ilayer in range(xstart, xend + 1): + idx = bisect_right(info["choices"], ilayer - 1) + self.depth_at_i[ilayer] = (xstage, idx) - def set_tau(self, tau_max, tau_min, epoch_ratio): - assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio) - tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 - self.tau = tau + self.avgpool = nn.AvgPool2d(8) + self.classifier = nn.Linear(module.out_dim, num_classes) + self.InShape = None + self.tau = -1 + self.search_mode = "basic" + # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) - def get_message(self): - return self.message + self.register_parameter( + "depth_attentions", + nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))), + ) + nn.init.normal_(self.depth_attentions, 0, 0.01) + self.apply(initialize_resnet) - def forward(self, inputs): - if self.search_mode == 'basic': - return self.basic_forward(inputs) - elif self.search_mode == 'search': - return self.search_forward(inputs) - else: - raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def arch_parameters(self): + return [self.depth_attentions] - def search_forward(self, inputs): - flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) - flop_depth_probs = torch.flip( torch.cumsum( torch.flip(flop_depth_probs, [1]), 1 ), [1] ) - selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) + def base_parameters(self): + return ( + list(self.layers.parameters()) + + list(self.avgpool.parameters()) + + list(self.classifier.parameters()) + ) - x, flops = inputs, [] - feature_maps = [] - for i, layer in enumerate(self.layers): - layer_i = layer( x ) - feature_maps.append( layer_i ) - if i in self.depth_info: # aggregate the information - choices = self.depth_info[i]['choices'] - xstagei = self.depth_info[i]['stage'] - possible_tensors = [] - for tempi, A in enumerate(choices): - xtensor = feature_maps[A] - possible_tensors.append( xtensor ) - weighted_sum = sum( xtensor * W for xtensor, W in zip(possible_tensors, selected_depth_probs[xstagei]) ) - x = weighted_sum - else: - x = layer_i - - if i in self.depth_at_i: - xstagei, xatti = self.depth_at_i[i] - #print ('layer-{:03d}, stage={:}, att={:}, prob={:}, flop={:}'.format(i, xstagei, xatti, flop_depth_probs[xstagei, xatti].item(), layer.get_flops(1e6))) - x_expected_flop = flop_depth_probs[xstagei, xatti] * layer.get_flops(1e6) - else: - x_expected_flop = layer.get_flops(1e6) - flops.append( x_expected_flop ) - flops.append( (self.classifier.in_features * self.classifier.out_features*1.0/1e6) ) + def get_flop(self, mode, config_dict, extra_info): + if config_dict is not None: + config_dict = config_dict.copy() + # select depth + if mode == "genotype": + with torch.no_grad(): + depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) + choices = torch.argmax(depth_probs, dim=1).cpu().tolist() + elif mode == "max": + choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))] + elif mode == "random": + with torch.no_grad(): + depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) + choices = torch.multinomial(depth_probs, 1, False).cpu().tolist() + else: + raise ValueError("invalid mode : {:}".format(mode)) + selected_layers = [] + for choice, xvalue in zip(choices, self.depth_info_list): + xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1 + selected_layers.append(xtemp) + flop = 0 + for i, layer in enumerate(self.layers): + if i in self.depth_at_i: + xstagei, xatti = self.depth_at_i[i] + if xatti <= choices[xstagei]: # leave this depth + flop += layer.get_flops() + else: + flop += 0 # do not use this layer + else: + flop += layer.get_flops() + # the last fc layer + flop += self.classifier.in_features * self.classifier.out_features + if config_dict is None: + return flop / 1e6 + else: + config_dict["xblocks"] = selected_layers + config_dict["super_type"] = "infer-depth" + config_dict["estimated_FLOP"] = flop / 1e6 + return flop / 1e6, config_dict - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = linear_forward(features, self.classifier) - return logits, torch.stack( [sum(flops)] ) + def get_arch_info(self): + string = "for depth, there are {:} attention probabilities.".format( + len(self.depth_attentions) + ) + string += "\n{:}".format(self.depth_info) + discrepancy = [] + with torch.no_grad(): + for i, att in enumerate(self.depth_attentions): + prob = nn.functional.softmax(att, dim=0) + prob = prob.cpu() + selc = prob.argmax().item() + prob = prob.tolist() + prob = ["{:.3f}".format(x) for x in prob] + xstring = "{:03d}/{:03d}-th : {:}".format( + i, len(self.depth_attentions), " ".join(prob) + ) + logt = ["{:.4f}".format(x) for x in att.cpu().tolist()] + xstring += " || {:17s}".format(" ".join(logt)) + prob = sorted([float(x) for x in prob]) + disc = prob[-1] - prob[-2] + xstring += " || discrepancy={:.2f} || select={:}/{:}".format( + disc, selc, len(prob) + ) + discrepancy.append(disc) + string += "\n{:}".format(xstring) + return string, discrepancy - def basic_forward(self, inputs): - if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1)) - x = inputs - for i, layer in enumerate(self.layers): - x = layer( x ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = self.classifier(features) - return features, logits + def set_tau(self, tau_max, tau_min, epoch_ratio): + assert ( + epoch_ratio >= 0 and epoch_ratio <= 1 + ), "invalid epoch-ratio : {:}".format(epoch_ratio) + tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 + self.tau = tau + + def get_message(self): + return self.message + + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) + + def search_forward(self, inputs): + flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) + flop_depth_probs = torch.flip( + torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1] + ) + selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) + + x, flops = inputs, [] + feature_maps = [] + for i, layer in enumerate(self.layers): + layer_i = layer(x) + feature_maps.append(layer_i) + if i in self.depth_info: # aggregate the information + choices = self.depth_info[i]["choices"] + xstagei = self.depth_info[i]["stage"] + possible_tensors = [] + for tempi, A in enumerate(choices): + xtensor = feature_maps[A] + possible_tensors.append(xtensor) + weighted_sum = sum( + xtensor * W + for xtensor, W in zip( + possible_tensors, selected_depth_probs[xstagei] + ) + ) + x = weighted_sum + else: + x = layer_i + + if i in self.depth_at_i: + xstagei, xatti = self.depth_at_i[i] + # print ('layer-{:03d}, stage={:}, att={:}, prob={:}, flop={:}'.format(i, xstagei, xatti, flop_depth_probs[xstagei, xatti].item(), layer.get_flops(1e6))) + x_expected_flop = flop_depth_probs[xstagei, xatti] * layer.get_flops( + 1e6 + ) + else: + x_expected_flop = layer.get_flops(1e6) + flops.append(x_expected_flop) + flops.append( + (self.classifier.in_features * self.classifier.out_features * 1.0 / 1e6) + ) + + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = linear_forward(features, self.classifier) + return logits, torch.stack([sum(flops)]) + + def basic_forward(self, inputs): + if self.InShape is None: + self.InShape = (inputs.size(-2), inputs.size(-1)) + x = inputs + for i, layer in enumerate(self.layers): + x = layer(x) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.classifier(features) + return features, logits diff --git a/lib/models/shape_searchs/SearchCifarResNet_width.py b/lib/models/shape_searchs/SearchCifarResNet_width.py index 5c46e1c..61bee6f 100644 --- a/lib/models/shape_searchs/SearchCifarResNet_width.py +++ b/lib/models/shape_searchs/SearchCifarResNet_width.py @@ -4,390 +4,616 @@ import math, torch import torch.nn as nn from ..initialization import initialize_resnet -from ..SharedUtils import additive_func -from .SoftSelect import select2withP, ChannelWiseInter -from .SoftSelect import linear_forward -from .SoftSelect import get_width_choices as get_choices +from ..SharedUtils import additive_func +from .SoftSelect import select2withP, ChannelWiseInter +from .SoftSelect import linear_forward +from .SoftSelect import get_width_choices as get_choices def conv_forward(inputs, conv, choices): - iC = conv.in_channels - fill_size = list(inputs.size()) - fill_size[1] = iC - fill_size[1] - filled = torch.zeros(fill_size, device=inputs.device) - xinputs = torch.cat((inputs, filled), dim=1) - outputs = conv(xinputs) - selecteds = [outputs[:,:oC] for oC in choices] - return selecteds + iC = conv.in_channels + fill_size = list(inputs.size()) + fill_size[1] = iC - fill_size[1] + filled = torch.zeros(fill_size, device=inputs.device) + xinputs = torch.cat((inputs, filled), dim=1) + outputs = conv(xinputs) + selecteds = [outputs[:, :oC] for oC in choices] + return selecteds class ConvBNReLU(nn.Module): - num_conv = 1 - def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): - super(ConvBNReLU, self).__init__() - self.InShape = None - self.OutShape = None - self.choices = get_choices(nOut) - self.register_buffer('choices_tensor', torch.Tensor( self.choices )) + num_conv = 1 - if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - else : self.avg = None - self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) - #if has_bn : self.bn = nn.BatchNorm2d(nOut) - #else : self.bn = None - self.has_bn = has_bn - self.BNs = nn.ModuleList() - for i, _out in enumerate(self.choices): - self.BNs.append(nn.BatchNorm2d(_out)) - if has_relu: self.relu = nn.ReLU(inplace=True) - else : self.relu = None - self.in_dim = nIn - self.out_dim = nOut - self.search_mode = 'basic' + def __init__( + self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu + ): + super(ConvBNReLU, self).__init__() + self.InShape = None + self.OutShape = None + self.choices = get_choices(nOut) + self.register_buffer("choices_tensor", torch.Tensor(self.choices)) - def get_flops(self, channels, check_range=True, divide=1): - iC, oC = channels - if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels) - assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape) - assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape) - #conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups - conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups) - all_positions = self.OutShape[0] * self.OutShape[1] - flops = (conv_per_position_flops * all_positions / divide) * iC * oC - if self.conv.bias is not None: flops += all_positions / divide - return flops + if has_avg: + self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.avg = None + self.conv = nn.Conv2d( + nIn, + nOut, + kernel_size=kernel, + stride=stride, + padding=padding, + dilation=1, + groups=1, + bias=bias, + ) + # if has_bn : self.bn = nn.BatchNorm2d(nOut) + # else : self.bn = None + self.has_bn = has_bn + self.BNs = nn.ModuleList() + for i, _out in enumerate(self.choices): + self.BNs.append(nn.BatchNorm2d(_out)) + if has_relu: + self.relu = nn.ReLU(inplace=True) + else: + self.relu = None + self.in_dim = nIn + self.out_dim = nOut + self.search_mode = "basic" - def get_range(self): - return [self.choices] + def get_flops(self, channels, check_range=True, divide=1): + iC, oC = channels + if check_range: + assert ( + iC <= self.conv.in_channels and oC <= self.conv.out_channels + ), "{:} vs {:} | {:} vs {:}".format( + iC, self.conv.in_channels, oC, self.conv.out_channels + ) + assert ( + isinstance(self.InShape, tuple) and len(self.InShape) == 2 + ), "invalid in-shape : {:}".format(self.InShape) + assert ( + isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 + ), "invalid out-shape : {:}".format(self.OutShape) + # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups + conv_per_position_flops = ( + self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups + ) + all_positions = self.OutShape[0] * self.OutShape[1] + flops = (conv_per_position_flops * all_positions / divide) * iC * oC + if self.conv.bias is not None: + flops += all_positions / divide + return flops - def forward(self, inputs): - if self.search_mode == 'basic': - return self.basic_forward(inputs) - elif self.search_mode == 'search': - return self.search_forward(inputs) - else: - raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def get_range(self): + return [self.choices] - def search_forward(self, tuple_inputs): - assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) - inputs, expected_inC, probability, index, prob = tuple_inputs - index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) - probability = torch.squeeze(probability) - assert len(index) == 2, 'invalid length : {:}'.format(index) - # compute expected flop - #coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) - expected_outC = (self.choices_tensor * probability).sum() - expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) - if self.avg : out = self.avg( inputs ) - else : out = inputs - # convolutional layer - out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) - out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] - # merge - out_channel = max([x.size(1) for x in out_bns]) - outA = ChannelWiseInter(out_bns[0], out_channel) - outB = ChannelWiseInter(out_bns[1], out_channel) - out = outA * prob[0] + outB * prob[1] - #out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) - if self.relu: out = self.relu( out ) - else : out = out - return out, expected_outC, expected_flop + def search_forward(self, tuple_inputs): + assert ( + isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 + ), "invalid type input : {:}".format(type(tuple_inputs)) + inputs, expected_inC, probability, index, prob = tuple_inputs + index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) + probability = torch.squeeze(probability) + assert len(index) == 2, "invalid length : {:}".format(index) + # compute expected flop + # coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) + expected_outC = (self.choices_tensor * probability).sum() + expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) + if self.avg: + out = self.avg(inputs) + else: + out = inputs + # convolutional layer + out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) + out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] + # merge + out_channel = max([x.size(1) for x in out_bns]) + outA = ChannelWiseInter(out_bns[0], out_channel) + outB = ChannelWiseInter(out_bns[1], out_channel) + out = outA * prob[0] + outB * prob[1] + # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) - def basic_forward(self, inputs): - if self.avg : out = self.avg( inputs ) - else : out = inputs - conv = self.conv( out ) - if self.has_bn:out= self.BNs[-1]( conv ) - else : out = conv - if self.relu: out = self.relu( out ) - else : out = out - if self.InShape is None: - self.InShape = (inputs.size(-2), inputs.size(-1)) - self.OutShape = (out.size(-2) , out.size(-1)) - return out + if self.relu: + out = self.relu(out) + else: + out = out + return out, expected_outC, expected_flop + + def basic_forward(self, inputs): + if self.avg: + out = self.avg(inputs) + else: + out = inputs + conv = self.conv(out) + if self.has_bn: + out = self.BNs[-1](conv) + else: + out = conv + if self.relu: + out = self.relu(out) + else: + out = out + if self.InShape is None: + self.InShape = (inputs.size(-2), inputs.size(-1)) + self.OutShape = (out.size(-2), out.size(-1)) + return out class ResNetBasicblock(nn.Module): - expansion = 1 - num_conv = 2 - def __init__(self, inplanes, planes, stride): - super(ResNetBasicblock, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False) - if stride == 2: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) - elif inplanes != planes: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) - else: - self.downsample = None - self.out_dim = planes - self.search_mode = 'basic' + expansion = 1 + num_conv = 2 - def get_range(self): - return self.conv_a.get_range() + self.conv_b.get_range() + def __init__(self, inplanes, planes, stride): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_a = ConvBNReLU( + inplanes, + planes, + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_b = ConvBNReLU( + planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False + ) + if stride == 2: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + elif inplanes != planes: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + self.out_dim = planes + self.search_mode = "basic" - def get_flops(self, channels): - assert len(channels) == 3, 'invalid channels : {:}'.format(channels) - flop_A = self.conv_a.get_flops([channels[0], channels[1]]) - flop_B = self.conv_b.get_flops([channels[1], channels[2]]) - if hasattr(self.downsample, 'get_flops'): - flop_C = self.downsample.get_flops([channels[0], channels[-1]]) - else: - flop_C = 0 - if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train - flop_C = channels[0] * channels[-1] * self.conv_b.OutShape[0] * self.conv_b.OutShape[1] - return flop_A + flop_B + flop_C + def get_range(self): + return self.conv_a.get_range() + self.conv_b.get_range() - def forward(self, inputs): - if self.search_mode == 'basic' : return self.basic_forward(inputs) - elif self.search_mode == 'search': return self.search_forward(inputs) - else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def get_flops(self, channels): + assert len(channels) == 3, "invalid channels : {:}".format(channels) + flop_A = self.conv_a.get_flops([channels[0], channels[1]]) + flop_B = self.conv_b.get_flops([channels[1], channels[2]]) + if hasattr(self.downsample, "get_flops"): + flop_C = self.downsample.get_flops([channels[0], channels[-1]]) + else: + flop_C = 0 + if ( + channels[0] != channels[-1] and self.downsample is None + ): # this short-cut will be added during the infer-train + flop_C = ( + channels[0] + * channels[-1] + * self.conv_b.OutShape[0] + * self.conv_b.OutShape[1] + ) + return flop_A + flop_B + flop_C - def search_forward(self, tuple_inputs): - assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) - inputs, expected_inC, probability, indexes, probs = tuple_inputs - assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 - out_a, expected_inC_a, expected_flop_a = self.conv_a( (inputs, expected_inC , probability[0], indexes[0], probs[0]) ) - out_b, expected_inC_b, expected_flop_b = self.conv_b( (out_a , expected_inC_a, probability[1], indexes[1], probs[1]) ) - if self.downsample is not None: - residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[1], indexes[1], probs[1]) ) - else: - residual, expected_flop_c = inputs, 0 - out = additive_func(residual, out_b) - return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c]) + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) - def basic_forward(self, inputs): - basicblock = self.conv_a(inputs) - basicblock = self.conv_b(basicblock) - if self.downsample is not None: residual = self.downsample(inputs) - else : residual = inputs - out = additive_func(residual, basicblock) - return nn.functional.relu(out, inplace=True) + def search_forward(self, tuple_inputs): + assert ( + isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 + ), "invalid type input : {:}".format(type(tuple_inputs)) + inputs, expected_inC, probability, indexes, probs = tuple_inputs + assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 + out_a, expected_inC_a, expected_flop_a = self.conv_a( + (inputs, expected_inC, probability[0], indexes[0], probs[0]) + ) + out_b, expected_inC_b, expected_flop_b = self.conv_b( + (out_a, expected_inC_a, probability[1], indexes[1], probs[1]) + ) + if self.downsample is not None: + residual, _, expected_flop_c = self.downsample( + (inputs, expected_inC, probability[1], indexes[1], probs[1]) + ) + else: + residual, expected_flop_c = inputs, 0 + out = additive_func(residual, out_b) + return ( + nn.functional.relu(out, inplace=True), + expected_inC_b, + sum([expected_flop_a, expected_flop_b, expected_flop_c]), + ) + def basic_forward(self, inputs): + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = additive_func(residual, basicblock) + return nn.functional.relu(out, inplace=True) class ResNetBottleneck(nn.Module): - expansion = 4 - num_conv = 3 - def __init__(self, inplanes, planes, stride): - super(ResNetBottleneck, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False) - if stride == 2: - self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) - elif inplanes != planes*self.expansion: - self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) - else: - self.downsample = None - self.out_dim = planes * self.expansion - self.search_mode = 'basic' + expansion = 4 + num_conv = 3 - def get_range(self): - return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range() + def __init__(self, inplanes, planes, stride): + super(ResNetBottleneck, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_1x1 = ConvBNReLU( + inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True + ) + self.conv_3x3 = ConvBNReLU( + planes, + planes, + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_1x4 = ConvBNReLU( + planes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + if stride == 2: + self.downsample = ConvBNReLU( + inplanes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + elif inplanes != planes * self.expansion: + self.downsample = ConvBNReLU( + inplanes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + self.out_dim = planes * self.expansion + self.search_mode = "basic" - def get_flops(self, channels): - assert len(channels) == 4, 'invalid channels : {:}'.format(channels) - flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) - flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) - flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) - if hasattr(self.downsample, 'get_flops'): - flop_D = self.downsample.get_flops([channels[0], channels[-1]]) - else: - flop_D = 0 - if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train - flop_D = channels[0] * channels[-1] * self.conv_1x4.OutShape[0] * self.conv_1x4.OutShape[1] - return flop_A + flop_B + flop_C + flop_D + def get_range(self): + return ( + self.conv_1x1.get_range() + + self.conv_3x3.get_range() + + self.conv_1x4.get_range() + ) - def forward(self, inputs): - if self.search_mode == 'basic' : return self.basic_forward(inputs) - elif self.search_mode == 'search': return self.search_forward(inputs) - else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def get_flops(self, channels): + assert len(channels) == 4, "invalid channels : {:}".format(channels) + flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) + flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) + flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) + if hasattr(self.downsample, "get_flops"): + flop_D = self.downsample.get_flops([channels[0], channels[-1]]) + else: + flop_D = 0 + if ( + channels[0] != channels[-1] and self.downsample is None + ): # this short-cut will be added during the infer-train + flop_D = ( + channels[0] + * channels[-1] + * self.conv_1x4.OutShape[0] + * self.conv_1x4.OutShape[1] + ) + return flop_A + flop_B + flop_C + flop_D - def basic_forward(self, inputs): - bottleneck = self.conv_1x1(inputs) - bottleneck = self.conv_3x3(bottleneck) - bottleneck = self.conv_1x4(bottleneck) - if self.downsample is not None: residual = self.downsample(inputs) - else : residual = inputs - out = additive_func(residual, bottleneck) - return nn.functional.relu(out, inplace=True) + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) - def search_forward(self, tuple_inputs): - assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) - inputs, expected_inC, probability, indexes, probs = tuple_inputs - assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 - out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( (inputs, expected_inC , probability[0], indexes[0], probs[0]) ) - out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( (out_1x1,expected_inC_1x1, probability[1], indexes[1], probs[1]) ) - out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( (out_3x3,expected_inC_3x3, probability[2], indexes[2], probs[2]) ) - if self.downsample is not None: - residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[2], indexes[2], probs[2]) ) - else: - residual, expected_flop_c = inputs, 0 - out = additive_func(residual, out_1x4) - return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c]) + def basic_forward(self, inputs): + bottleneck = self.conv_1x1(inputs) + bottleneck = self.conv_3x3(bottleneck) + bottleneck = self.conv_1x4(bottleneck) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = additive_func(residual, bottleneck) + return nn.functional.relu(out, inplace=True) + + def search_forward(self, tuple_inputs): + assert ( + isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 + ), "invalid type input : {:}".format(type(tuple_inputs)) + inputs, expected_inC, probability, indexes, probs = tuple_inputs + assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 + out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( + (inputs, expected_inC, probability[0], indexes[0], probs[0]) + ) + out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( + (out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1]) + ) + out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( + (out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2]) + ) + if self.downsample is not None: + residual, _, expected_flop_c = self.downsample( + (inputs, expected_inC, probability[2], indexes[2], probs[2]) + ) + else: + residual, expected_flop_c = inputs, 0 + out = additive_func(residual, out_1x4) + return ( + nn.functional.relu(out, inplace=True), + expected_inC_1x4, + sum( + [ + expected_flop_1x1, + expected_flop_3x3, + expected_flop_1x4, + expected_flop_c, + ] + ), + ) class SearchWidthCifarResNet(nn.Module): + def __init__(self, block_name, depth, num_classes): + super(SearchWidthCifarResNet, self).__init__() - def __init__(self, block_name, depth, num_classes): - super(SearchWidthCifarResNet, self).__init__() + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + if block_name == "ResNetBasicblock": + block = ResNetBasicblock + assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" + layer_blocks = (depth - 2) // 6 + elif block_name == "ResNetBottleneck": + block = ResNetBottleneck + assert (depth - 2) % 9 == 0, "depth should be one of 164" + layer_blocks = (depth - 2) // 9 + else: + raise ValueError("invalid block : {:}".format(block_name)) - #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model - if block_name == 'ResNetBasicblock': - block = ResNetBasicblock - assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' - layer_blocks = (depth - 2) // 6 - elif block_name == 'ResNetBottleneck': - block = ResNetBottleneck - assert (depth - 2) % 9 == 0, 'depth should be one of 164' - layer_blocks = (depth - 2) // 9 - else: - raise ValueError('invalid block : {:}'.format(block_name)) + self.message = ( + "SearchWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( + depth, layer_blocks + ) + ) + self.num_classes = num_classes + self.channels = [16] + self.layers = nn.ModuleList( + [ + ConvBNReLU( + 3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True + ) + ] + ) + self.InShape = None + for stage in range(3): + for iL in range(layer_blocks): + iC = self.channels[-1] + planes = 16 * (2 ** stage) + stride = 2 if stage > 0 and iL == 0 else 1 + module = block(iC, planes, stride) + self.channels.append(module.out_dim) + self.layers.append(module) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( + stage, + iL, + layer_blocks, + len(self.layers) - 1, + iC, + module.out_dim, + stride, + ) - self.message = 'SearchWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) - self.num_classes = num_classes - self.channels = [16] - self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) - self.InShape = None - for stage in range(3): - for iL in range(layer_blocks): - iC = self.channels[-1] - planes = 16 * (2**stage) - stride = 2 if stage > 0 and iL == 0 else 1 - module = block(iC, planes, stride) - self.channels.append( module.out_dim ) - self.layers.append ( module ) - self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride) - - self.avgpool = nn.AvgPool2d(8) - self.classifier = nn.Linear(module.out_dim, num_classes) - self.InShape = None - self.tau = -1 - self.search_mode = 'basic' - #assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) - - # parameters for width - self.Ranges = [] - self.layer2indexRange = [] - for i, layer in enumerate(self.layers): - start_index = len(self.Ranges) - self.Ranges += layer.get_range() - self.layer2indexRange.append( (start_index, len(self.Ranges)) ) - assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth) + self.avgpool = nn.AvgPool2d(8) + self.classifier = nn.Linear(module.out_dim, num_classes) + self.InShape = None + self.tau = -1 + self.search_mode = "basic" + # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) - self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None)))) - nn.init.normal_(self.width_attentions, 0, 0.01) - self.apply(initialize_resnet) + # parameters for width + self.Ranges = [] + self.layer2indexRange = [] + for i, layer in enumerate(self.layers): + start_index = len(self.Ranges) + self.Ranges += layer.get_range() + self.layer2indexRange.append((start_index, len(self.Ranges))) + assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format( + len(self.Ranges) + 1, depth + ) - def arch_parameters(self): - return [self.width_attentions] + self.register_parameter( + "width_attentions", + nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))), + ) + nn.init.normal_(self.width_attentions, 0, 0.01) + self.apply(initialize_resnet) - def base_parameters(self): - return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters()) + def arch_parameters(self): + return [self.width_attentions] - def get_flop(self, mode, config_dict, extra_info): - if config_dict is not None: config_dict = config_dict.copy() - #weights = [F.softmax(x, dim=0) for x in self.width_attentions] - channels = [3] - for i, weight in enumerate(self.width_attentions): - if mode == 'genotype': + def base_parameters(self): + return ( + list(self.layers.parameters()) + + list(self.avgpool.parameters()) + + list(self.classifier.parameters()) + ) + + def get_flop(self, mode, config_dict, extra_info): + if config_dict is not None: + config_dict = config_dict.copy() + # weights = [F.softmax(x, dim=0) for x in self.width_attentions] + channels = [3] + for i, weight in enumerate(self.width_attentions): + if mode == "genotype": + with torch.no_grad(): + probe = nn.functional.softmax(weight, dim=0) + C = self.Ranges[i][torch.argmax(probe).item()] + elif mode == "max": + C = self.Ranges[i][-1] + elif mode == "fix": + C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) + elif mode == "random": + assert isinstance(extra_info, float), "invalid extra_info : {:}".format( + extra_info + ) + with torch.no_grad(): + prob = nn.functional.softmax(weight, dim=0) + approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) + for j in range(prob.size(0)): + prob[j] = 1 / ( + abs(j - (approximate_C - self.Ranges[i][j])) + 0.2 + ) + C = self.Ranges[i][torch.multinomial(prob, 1, False).item()] + else: + raise ValueError("invalid mode : {:}".format(mode)) + channels.append(C) + flop = 0 + for i, layer in enumerate(self.layers): + s, e = self.layer2indexRange[i] + xchl = tuple(channels[s : e + 1]) + flop += layer.get_flops(xchl) + # the last fc layer + flop += channels[-1] * self.classifier.out_features + if config_dict is None: + return flop / 1e6 + else: + config_dict["xchannels"] = channels + config_dict["super_type"] = "infer-width" + config_dict["estimated_FLOP"] = flop / 1e6 + return flop / 1e6, config_dict + + def get_arch_info(self): + string = "for width, there are {:} attention probabilities.".format( + len(self.width_attentions) + ) + discrepancy = [] with torch.no_grad(): - probe = nn.functional.softmax(weight, dim=0) - C = self.Ranges[i][ torch.argmax(probe).item() ] - elif mode == 'max': - C = self.Ranges[i][-1] - elif mode == 'fix': - C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] ) - elif mode == 'random': - assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info) + for i, att in enumerate(self.width_attentions): + prob = nn.functional.softmax(att, dim=0) + prob = prob.cpu() + selc = prob.argmax().item() + prob = prob.tolist() + prob = ["{:.3f}".format(x) for x in prob] + xstring = "{:03d}/{:03d}-th : {:}".format( + i, len(self.width_attentions), " ".join(prob) + ) + logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] + xstring += " || {:52s}".format(" ".join(logt)) + prob = sorted([float(x) for x in prob]) + disc = prob[-1] - prob[-2] + xstring += " || dis={:.2f} || select={:}/{:}".format( + disc, selc, len(prob) + ) + discrepancy.append(disc) + string += "\n{:}".format(xstring) + return string, discrepancy + + def set_tau(self, tau_max, tau_min, epoch_ratio): + assert ( + epoch_ratio >= 0 and epoch_ratio <= 1 + ), "invalid epoch-ratio : {:}".format(epoch_ratio) + tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 + self.tau = tau + + def get_message(self): + return self.message + + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) + + def search_forward(self, inputs): + flop_probs = nn.functional.softmax(self.width_attentions, dim=1) + selected_widths, selected_probs = select2withP(self.width_attentions, self.tau) with torch.no_grad(): - prob = nn.functional.softmax(weight, dim=0) - approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] ) - for j in range(prob.size(0)): - prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2) - C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ] - else: - raise ValueError('invalid mode : {:}'.format(mode)) - channels.append( C ) - flop = 0 - for i, layer in enumerate(self.layers): - s, e = self.layer2indexRange[i] - xchl = tuple( channels[s:e+1] ) - flop+= layer.get_flops(xchl) - # the last fc layer - flop += channels[-1] * self.classifier.out_features - if config_dict is None: - return flop / 1e6 - else: - config_dict['xchannels'] = channels - config_dict['super_type'] = 'infer-width' - config_dict['estimated_FLOP'] = flop / 1e6 - return flop / 1e6, config_dict + selected_widths = selected_widths.cpu() - def get_arch_info(self): - string = "for width, there are {:} attention probabilities.".format(len(self.width_attentions)) - discrepancy = [] - with torch.no_grad(): - for i, att in enumerate(self.width_attentions): - prob = nn.functional.softmax(att, dim=0) - prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist() - prob = ['{:.3f}'.format(x) for x in prob] - xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob)) - logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()] - xstring += ' || {:52s}'.format(' '.join(logt)) - prob = sorted( [float(x) for x in prob] ) - disc = prob[-1] - prob[-2] - xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob)) - discrepancy.append( disc ) - string += '\n{:}'.format(xstring) - return string, discrepancy + x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] + for i, layer in enumerate(self.layers): + selected_w_index = selected_widths[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + selected_w_probs = selected_probs[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + layer_prob = flop_probs[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + x, expected_inC, expected_flop = layer( + (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) + ) + last_channel_idx += layer.num_conv + flops.append(expected_flop) + flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = linear_forward(features, self.classifier) + return logits, torch.stack([sum(flops)]) - def set_tau(self, tau_max, tau_min, epoch_ratio): - assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio) - tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 - self.tau = tau - - def get_message(self): - return self.message - - def forward(self, inputs): - if self.search_mode == 'basic': - return self.basic_forward(inputs) - elif self.search_mode == 'search': - return self.search_forward(inputs) - else: - raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) - - def search_forward(self, inputs): - flop_probs = nn.functional.softmax(self.width_attentions, dim=1) - selected_widths, selected_probs = select2withP(self.width_attentions, self.tau) - with torch.no_grad(): - selected_widths = selected_widths.cpu() - - x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] - for i, layer in enumerate(self.layers): - selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv] - selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv] - layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv] - x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) ) - last_channel_idx += layer.num_conv - flops.append( expected_flop ) - flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = linear_forward(features, self.classifier) - return logits, torch.stack( [sum(flops)] ) - - def basic_forward(self, inputs): - if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1)) - x = inputs - for i, layer in enumerate(self.layers): - x = layer( x ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = self.classifier(features) - return features, logits + def basic_forward(self, inputs): + if self.InShape is None: + self.InShape = (inputs.size(-2), inputs.size(-1)) + x = inputs + for i, layer in enumerate(self.layers): + x = layer(x) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.classifier(features) + return features, logits diff --git a/lib/models/shape_searchs/SearchImagenetResNet.py b/lib/models/shape_searchs/SearchImagenetResNet.py index ee01d57..11da09a 100644 --- a/lib/models/shape_searchs/SearchImagenetResNet.py +++ b/lib/models/shape_searchs/SearchImagenetResNet.py @@ -3,480 +3,764 @@ from collections import OrderedDict from bisect import bisect_right import torch.nn as nn from ..initialization import initialize_resnet -from ..SharedUtils import additive_func -from .SoftSelect import select2withP, ChannelWiseInter -from .SoftSelect import linear_forward -from .SoftSelect import get_width_choices +from ..SharedUtils import additive_func +from .SoftSelect import select2withP, ChannelWiseInter +from .SoftSelect import linear_forward +from .SoftSelect import get_width_choices def get_depth_choices(layers): - min_depth = min(layers) - info = {'num': min_depth} - for i, depth in enumerate(layers): - choices = [] - for j in range(1, min_depth+1): - choices.append( int( float(depth)*j/min_depth ) ) - info[i] = choices - return info + min_depth = min(layers) + info = {"num": min_depth} + for i, depth in enumerate(layers): + choices = [] + for j in range(1, min_depth + 1): + choices.append(int(float(depth) * j / min_depth)) + info[i] = choices + return info def conv_forward(inputs, conv, choices): - iC = conv.in_channels - fill_size = list(inputs.size()) - fill_size[1] = iC - fill_size[1] - filled = torch.zeros(fill_size, device=inputs.device) - xinputs = torch.cat((inputs, filled), dim=1) - outputs = conv(xinputs) - selecteds = [outputs[:,:oC] for oC in choices] - return selecteds + iC = conv.in_channels + fill_size = list(inputs.size()) + fill_size[1] = iC - fill_size[1] + filled = torch.zeros(fill_size, device=inputs.device) + xinputs = torch.cat((inputs, filled), dim=1) + outputs = conv(xinputs) + selecteds = [outputs[:, :oC] for oC in choices] + return selecteds class ConvBNReLU(nn.Module): - num_conv = 1 - def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu, last_max_pool=False): - super(ConvBNReLU, self).__init__() - self.InShape = None - self.OutShape = None - self.choices = get_width_choices(nOut) - self.register_buffer('choices_tensor', torch.Tensor( self.choices )) + num_conv = 1 - if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - else : self.avg = None - self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) - #if has_bn : self.bn = nn.BatchNorm2d(nOut) - #else : self.bn = None - self.has_bn = has_bn - self.BNs = nn.ModuleList() - for i, _out in enumerate(self.choices): - self.BNs.append(nn.BatchNorm2d(_out)) - if has_relu: self.relu = nn.ReLU(inplace=True) - else : self.relu = None - - if last_max_pool: self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - else : self.maxpool = None - self.in_dim = nIn - self.out_dim = nOut - self.search_mode = 'basic' + def __init__( + self, + nIn, + nOut, + kernel, + stride, + padding, + bias, + has_avg, + has_bn, + has_relu, + last_max_pool=False, + ): + super(ConvBNReLU, self).__init__() + self.InShape = None + self.OutShape = None + self.choices = get_width_choices(nOut) + self.register_buffer("choices_tensor", torch.Tensor(self.choices)) - def get_flops(self, channels, check_range=True, divide=1): - iC, oC = channels - if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels) - assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape) - assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape) - #conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups - conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups) - all_positions = self.OutShape[0] * self.OutShape[1] - flops = (conv_per_position_flops * all_positions / divide) * iC * oC - if self.conv.bias is not None: flops += all_positions / divide - return flops + if has_avg: + self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.avg = None + self.conv = nn.Conv2d( + nIn, + nOut, + kernel_size=kernel, + stride=stride, + padding=padding, + dilation=1, + groups=1, + bias=bias, + ) + # if has_bn : self.bn = nn.BatchNorm2d(nOut) + # else : self.bn = None + self.has_bn = has_bn + self.BNs = nn.ModuleList() + for i, _out in enumerate(self.choices): + self.BNs.append(nn.BatchNorm2d(_out)) + if has_relu: + self.relu = nn.ReLU(inplace=True) + else: + self.relu = None - def get_range(self): - return [self.choices] + if last_max_pool: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + self.maxpool = None + self.in_dim = nIn + self.out_dim = nOut + self.search_mode = "basic" - def forward(self, inputs): - if self.search_mode == 'basic': - return self.basic_forward(inputs) - elif self.search_mode == 'search': - return self.search_forward(inputs) - else: - raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def get_flops(self, channels, check_range=True, divide=1): + iC, oC = channels + if check_range: + assert ( + iC <= self.conv.in_channels and oC <= self.conv.out_channels + ), "{:} vs {:} | {:} vs {:}".format( + iC, self.conv.in_channels, oC, self.conv.out_channels + ) + assert ( + isinstance(self.InShape, tuple) and len(self.InShape) == 2 + ), "invalid in-shape : {:}".format(self.InShape) + assert ( + isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 + ), "invalid out-shape : {:}".format(self.OutShape) + # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups + conv_per_position_flops = ( + self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups + ) + all_positions = self.OutShape[0] * self.OutShape[1] + flops = (conv_per_position_flops * all_positions / divide) * iC * oC + if self.conv.bias is not None: + flops += all_positions / divide + return flops - def search_forward(self, tuple_inputs): - assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) - inputs, expected_inC, probability, index, prob = tuple_inputs - index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) - probability = torch.squeeze(probability) - assert len(index) == 2, 'invalid length : {:}'.format(index) - # compute expected flop - #coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) - expected_outC = (self.choices_tensor * probability).sum() - expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) - if self.avg : out = self.avg( inputs ) - else : out = inputs - # convolutional layer - out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) - out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] - # merge - out_channel = max([x.size(1) for x in out_bns]) - outA = ChannelWiseInter(out_bns[0], out_channel) - outB = ChannelWiseInter(out_bns[1], out_channel) - out = outA * prob[0] + outB * prob[1] - #out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) + def get_range(self): + return [self.choices] - if self.relu : out = self.relu( out ) - if self.maxpool: out = self.maxpool(out) - return out, expected_outC, expected_flop + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) - def basic_forward(self, inputs): - if self.avg : out = self.avg( inputs ) - else : out = inputs - conv = self.conv( out ) - if self.has_bn:out= self.BNs[-1]( conv ) - else : out = conv - if self.relu: out = self.relu( out ) - else : out = out - if self.InShape is None: - self.InShape = (inputs.size(-2), inputs.size(-1)) - self.OutShape = (out.size(-2) , out.size(-1)) - if self.maxpool: out = self.maxpool(out) - return out + def search_forward(self, tuple_inputs): + assert ( + isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 + ), "invalid type input : {:}".format(type(tuple_inputs)) + inputs, expected_inC, probability, index, prob = tuple_inputs + index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) + probability = torch.squeeze(probability) + assert len(index) == 2, "invalid length : {:}".format(index) + # compute expected flop + # coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) + expected_outC = (self.choices_tensor * probability).sum() + expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) + if self.avg: + out = self.avg(inputs) + else: + out = inputs + # convolutional layer + out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) + out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] + # merge + out_channel = max([x.size(1) for x in out_bns]) + outA = ChannelWiseInter(out_bns[0], out_channel) + outB = ChannelWiseInter(out_bns[1], out_channel) + out = outA * prob[0] + outB * prob[1] + # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) + + if self.relu: + out = self.relu(out) + if self.maxpool: + out = self.maxpool(out) + return out, expected_outC, expected_flop + + def basic_forward(self, inputs): + if self.avg: + out = self.avg(inputs) + else: + out = inputs + conv = self.conv(out) + if self.has_bn: + out = self.BNs[-1](conv) + else: + out = conv + if self.relu: + out = self.relu(out) + else: + out = out + if self.InShape is None: + self.InShape = (inputs.size(-2), inputs.size(-1)) + self.OutShape = (out.size(-2), out.size(-1)) + if self.maxpool: + out = self.maxpool(out) + return out class ResNetBasicblock(nn.Module): - expansion = 1 - num_conv = 2 - def __init__(self, inplanes, planes, stride): - super(ResNetBasicblock, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False) - if stride == 2: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False) - elif inplanes != planes: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True, has_relu=False) - else: - self.downsample = None - self.out_dim = planes - self.search_mode = 'basic' + expansion = 1 + num_conv = 2 - def get_range(self): - return self.conv_a.get_range() + self.conv_b.get_range() + def __init__(self, inplanes, planes, stride): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_a = ConvBNReLU( + inplanes, + planes, + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_b = ConvBNReLU( + planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False + ) + if stride == 2: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=True, + has_bn=True, + has_relu=False, + ) + elif inplanes != planes: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + self.out_dim = planes + self.search_mode = "basic" - def get_flops(self, channels): - assert len(channels) == 3, 'invalid channels : {:}'.format(channels) - flop_A = self.conv_a.get_flops([channels[0], channels[1]]) - flop_B = self.conv_b.get_flops([channels[1], channels[2]]) - if hasattr(self.downsample, 'get_flops'): - flop_C = self.downsample.get_flops([channels[0], channels[-1]]) - else: - flop_C = 0 - if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train - flop_C = channels[0] * channels[-1] * self.conv_b.OutShape[0] * self.conv_b.OutShape[1] - return flop_A + flop_B + flop_C + def get_range(self): + return self.conv_a.get_range() + self.conv_b.get_range() - def forward(self, inputs): - if self.search_mode == 'basic' : return self.basic_forward(inputs) - elif self.search_mode == 'search': return self.search_forward(inputs) - else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def get_flops(self, channels): + assert len(channels) == 3, "invalid channels : {:}".format(channels) + flop_A = self.conv_a.get_flops([channels[0], channels[1]]) + flop_B = self.conv_b.get_flops([channels[1], channels[2]]) + if hasattr(self.downsample, "get_flops"): + flop_C = self.downsample.get_flops([channels[0], channels[-1]]) + else: + flop_C = 0 + if ( + channels[0] != channels[-1] and self.downsample is None + ): # this short-cut will be added during the infer-train + flop_C = ( + channels[0] + * channels[-1] + * self.conv_b.OutShape[0] + * self.conv_b.OutShape[1] + ) + return flop_A + flop_B + flop_C - def search_forward(self, tuple_inputs): - assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) - inputs, expected_inC, probability, indexes, probs = tuple_inputs - assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 - #import pdb; pdb.set_trace() - out_a, expected_inC_a, expected_flop_a = self.conv_a( (inputs, expected_inC , probability[0], indexes[0], probs[0]) ) - out_b, expected_inC_b, expected_flop_b = self.conv_b( (out_a , expected_inC_a, probability[1], indexes[1], probs[1]) ) - if self.downsample is not None: - residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[1], indexes[1], probs[1]) ) - else: - residual, expected_flop_c = inputs, 0 - out = additive_func(residual, out_b) - return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c]) + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) - def basic_forward(self, inputs): - basicblock = self.conv_a(inputs) - basicblock = self.conv_b(basicblock) - if self.downsample is not None: residual = self.downsample(inputs) - else : residual = inputs - out = additive_func(residual, basicblock) - return nn.functional.relu(out, inplace=True) + def search_forward(self, tuple_inputs): + assert ( + isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 + ), "invalid type input : {:}".format(type(tuple_inputs)) + inputs, expected_inC, probability, indexes, probs = tuple_inputs + assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 + # import pdb; pdb.set_trace() + out_a, expected_inC_a, expected_flop_a = self.conv_a( + (inputs, expected_inC, probability[0], indexes[0], probs[0]) + ) + out_b, expected_inC_b, expected_flop_b = self.conv_b( + (out_a, expected_inC_a, probability[1], indexes[1], probs[1]) + ) + if self.downsample is not None: + residual, _, expected_flop_c = self.downsample( + (inputs, expected_inC, probability[1], indexes[1], probs[1]) + ) + else: + residual, expected_flop_c = inputs, 0 + out = additive_func(residual, out_b) + return ( + nn.functional.relu(out, inplace=True), + expected_inC_b, + sum([expected_flop_a, expected_flop_b, expected_flop_c]), + ) + def basic_forward(self, inputs): + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = additive_func(residual, basicblock) + return nn.functional.relu(out, inplace=True) class ResNetBottleneck(nn.Module): - expansion = 4 - num_conv = 3 - def __init__(self, inplanes, planes, stride): - super(ResNetBottleneck, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False) - if stride == 2: - self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False) - elif inplanes != planes*self.expansion: - self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True, has_relu=False) - else: - self.downsample = None - self.out_dim = planes * self.expansion - self.search_mode = 'basic' + expansion = 4 + num_conv = 3 - def get_range(self): - return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range() + def __init__(self, inplanes, planes, stride): + super(ResNetBottleneck, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_1x1 = ConvBNReLU( + inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True + ) + self.conv_3x3 = ConvBNReLU( + planes, + planes, + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + self.conv_1x4 = ConvBNReLU( + planes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + if stride == 2: + self.downsample = ConvBNReLU( + inplanes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=True, + has_bn=True, + has_relu=False, + ) + elif inplanes != planes * self.expansion: + self.downsample = ConvBNReLU( + inplanes, + planes * self.expansion, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + self.out_dim = planes * self.expansion + self.search_mode = "basic" - def get_flops(self, channels): - assert len(channels) == 4, 'invalid channels : {:}'.format(channels) - flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) - flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) - flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) - if hasattr(self.downsample, 'get_flops'): - flop_D = self.downsample.get_flops([channels[0], channels[-1]]) - else: - flop_D = 0 - if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train - flop_D = channels[0] * channels[-1] * self.conv_1x4.OutShape[0] * self.conv_1x4.OutShape[1] - return flop_A + flop_B + flop_C + flop_D + def get_range(self): + return ( + self.conv_1x1.get_range() + + self.conv_3x3.get_range() + + self.conv_1x4.get_range() + ) - def forward(self, inputs): - if self.search_mode == 'basic' : return self.basic_forward(inputs) - elif self.search_mode == 'search': return self.search_forward(inputs) - else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def get_flops(self, channels): + assert len(channels) == 4, "invalid channels : {:}".format(channels) + flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) + flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) + flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) + if hasattr(self.downsample, "get_flops"): + flop_D = self.downsample.get_flops([channels[0], channels[-1]]) + else: + flop_D = 0 + if ( + channels[0] != channels[-1] and self.downsample is None + ): # this short-cut will be added during the infer-train + flop_D = ( + channels[0] + * channels[-1] + * self.conv_1x4.OutShape[0] + * self.conv_1x4.OutShape[1] + ) + return flop_A + flop_B + flop_C + flop_D - def basic_forward(self, inputs): - bottleneck = self.conv_1x1(inputs) - bottleneck = self.conv_3x3(bottleneck) - bottleneck = self.conv_1x4(bottleneck) - if self.downsample is not None: residual = self.downsample(inputs) - else : residual = inputs - out = additive_func(residual, bottleneck) - return nn.functional.relu(out, inplace=True) + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) - def search_forward(self, tuple_inputs): - assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) - inputs, expected_inC, probability, indexes, probs = tuple_inputs - assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 - out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( (inputs, expected_inC , probability[0], indexes[0], probs[0]) ) - out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( (out_1x1,expected_inC_1x1, probability[1], indexes[1], probs[1]) ) - out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( (out_3x3,expected_inC_3x3, probability[2], indexes[2], probs[2]) ) - if self.downsample is not None: - residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[2], indexes[2], probs[2]) ) - else: - residual, expected_flop_c = inputs, 0 - out = additive_func(residual, out_1x4) - return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c]) + def basic_forward(self, inputs): + bottleneck = self.conv_1x1(inputs) + bottleneck = self.conv_3x3(bottleneck) + bottleneck = self.conv_1x4(bottleneck) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = additive_func(residual, bottleneck) + return nn.functional.relu(out, inplace=True) + + def search_forward(self, tuple_inputs): + assert ( + isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 + ), "invalid type input : {:}".format(type(tuple_inputs)) + inputs, expected_inC, probability, indexes, probs = tuple_inputs + assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 + out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( + (inputs, expected_inC, probability[0], indexes[0], probs[0]) + ) + out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( + (out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1]) + ) + out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( + (out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2]) + ) + if self.downsample is not None: + residual, _, expected_flop_c = self.downsample( + (inputs, expected_inC, probability[2], indexes[2], probs[2]) + ) + else: + residual, expected_flop_c = inputs, 0 + out = additive_func(residual, out_1x4) + return ( + nn.functional.relu(out, inplace=True), + expected_inC_1x4, + sum( + [ + expected_flop_1x1, + expected_flop_3x3, + expected_flop_1x4, + expected_flop_c, + ] + ), + ) class SearchShapeImagenetResNet(nn.Module): + def __init__(self, block_name, layers, deep_stem, num_classes): + super(SearchShapeImagenetResNet, self).__init__() - def __init__(self, block_name, layers, deep_stem, num_classes): - super(SearchShapeImagenetResNet, self).__init__() - - #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model - if block_name == 'BasicBlock': - block = ResNetBasicblock - elif block_name == 'Bottleneck': - block = ResNetBottleneck - else: - raise ValueError('invalid block : {:}'.format(block_name)) - - self.message = 'SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}'.format(sum(layers)*block.num_conv, layers) - self.num_classes = num_classes - if not deep_stem: - self.layers = nn.ModuleList( [ ConvBNReLU(3, 64, 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True, last_max_pool=True) ] ) - self.channels = [64] - else: - self.layers = nn.ModuleList( [ ConvBNReLU(3, 32, 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True) - ,ConvBNReLU(32,64, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True, last_max_pool=True) ] ) - self.channels = [32, 64] - - meta_depth_info = get_depth_choices(layers) - self.InShape = None - self.depth_info = OrderedDict() - self.depth_at_i = OrderedDict() - for stage, layer_blocks in enumerate(layers): - cur_block_choices = meta_depth_info[stage] - assert cur_block_choices[-1] == layer_blocks, 'stage={:}, {:} vs {:}'.format(stage, cur_block_choices, layer_blocks) - block_choices, xstart = [], len(self.layers) - for iL in range(layer_blocks): - iC = self.channels[-1] - planes = 64 * (2**stage) - stride = 2 if stage > 0 and iL == 0 else 1 - module = block(iC, planes, stride) - self.channels.append( module.out_dim ) - self.layers.append ( module ) - self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride) - # added for depth - layer_index = len(self.layers) - 1 - if iL + 1 in cur_block_choices: block_choices.append( layer_index ) - if iL + 1 == layer_blocks: - self.depth_info[layer_index] = {'choices': block_choices, - 'stage' : stage, - 'xstart' : xstart} - self.depth_info_list = [] - for xend, info in self.depth_info.items(): - self.depth_info_list.append( (xend, info) ) - xstart, xstage = info['xstart'], info['stage'] - for ilayer in range(xstart, xend+1): - idx = bisect_right(info['choices'], ilayer-1) - self.depth_at_i[ilayer] = (xstage, idx) - - self.avgpool = nn.AdaptiveAvgPool2d((1,1)) - self.classifier = nn.Linear(module.out_dim, num_classes) - self.InShape = None - self.tau = -1 - self.search_mode = 'basic' - #assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) - - # parameters for width - self.Ranges = [] - self.layer2indexRange = [] - for i, layer in enumerate(self.layers): - start_index = len(self.Ranges) - self.Ranges += layer.get_range() - self.layer2indexRange.append( (start_index, len(self.Ranges)) ) - - self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None)))) - self.register_parameter('depth_attentions', nn.Parameter(torch.Tensor(len(layers), meta_depth_info['num']))) - nn.init.normal_(self.width_attentions, 0, 0.01) - nn.init.normal_(self.depth_attentions, 0, 0.01) - self.apply(initialize_resnet) - - def arch_parameters(self, LR=None): - if LR is None: - return [self.width_attentions, self.depth_attentions] - else: - return [ - {"params": self.width_attentions, "lr": LR}, - {"params": self.depth_attentions, "lr": LR}, - ] - - def base_parameters(self): - return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters()) - - def get_flop(self, mode, config_dict, extra_info): - if config_dict is not None: config_dict = config_dict.copy() - # select channels - channels = [3] - for i, weight in enumerate(self.width_attentions): - if mode == 'genotype': - with torch.no_grad(): - probe = nn.functional.softmax(weight, dim=0) - C = self.Ranges[i][ torch.argmax(probe).item() ] - else: - raise ValueError('invalid mode : {:}'.format(mode)) - channels.append( C ) - # select depth - if mode == 'genotype': - with torch.no_grad(): - depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) - choices = torch.argmax(depth_probs, dim=1).cpu().tolist() - else: - raise ValueError('invalid mode : {:}'.format(mode)) - selected_layers = [] - for choice, xvalue in zip(choices, self.depth_info_list): - xtemp = xvalue[1]['choices'][choice] - xvalue[1]['xstart'] + 1 - selected_layers.append(xtemp) - flop = 0 - for i, layer in enumerate(self.layers): - s, e = self.layer2indexRange[i] - xchl = tuple( channels[s:e+1] ) - if i in self.depth_at_i: - xstagei, xatti = self.depth_at_i[i] - if xatti <= choices[xstagei]: # leave this depth - flop+= layer.get_flops(xchl) + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + if block_name == "BasicBlock": + block = ResNetBasicblock + elif block_name == "Bottleneck": + block = ResNetBottleneck else: - flop+= 0 # do not use this layer - else: - flop+= layer.get_flops(xchl) - # the last fc layer - flop += channels[-1] * self.classifier.out_features - if config_dict is None: - return flop / 1e6 - else: - config_dict['xchannels'] = channels - config_dict['xblocks'] = selected_layers - config_dict['super_type'] = 'infer-shape' - config_dict['estimated_FLOP'] = flop / 1e6 - return flop / 1e6, config_dict + raise ValueError("invalid block : {:}".format(block_name)) - def get_arch_info(self): - string = "for depth and width, there are {:} + {:} attention probabilities.".format(len(self.depth_attentions), len(self.width_attentions)) - string+= '\n{:}'.format(self.depth_info) - discrepancy = [] - with torch.no_grad(): - for i, att in enumerate(self.depth_attentions): - prob = nn.functional.softmax(att, dim=0) - prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist() - prob = ['{:.3f}'.format(x) for x in prob] - xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.depth_attentions), ' '.join(prob)) - logt = ['{:.4f}'.format(x) for x in att.cpu().tolist()] - xstring += ' || {:17s}'.format(' '.join(logt)) - prob = sorted( [float(x) for x in prob] ) - disc = prob[-1] - prob[-2] - xstring += ' || discrepancy={:.2f} || select={:}/{:}'.format(disc, selc, len(prob)) - discrepancy.append( disc ) - string += '\n{:}'.format(xstring) - string += '\n-----------------------------------------------' - for i, att in enumerate(self.width_attentions): - prob = nn.functional.softmax(att, dim=0) - prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist() - prob = ['{:.3f}'.format(x) for x in prob] - xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob)) - logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()] - xstring += ' || {:52s}'.format(' '.join(logt)) - prob = sorted( [float(x) for x in prob] ) - disc = prob[-1] - prob[-2] - xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob)) - discrepancy.append( disc ) - string += '\n{:}'.format(xstring) - return string, discrepancy + self.message = ( + "SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format( + sum(layers) * block.num_conv, layers + ) + ) + self.num_classes = num_classes + if not deep_stem: + self.layers = nn.ModuleList( + [ + ConvBNReLU( + 3, + 64, + 7, + 2, + 3, + False, + has_avg=False, + has_bn=True, + has_relu=True, + last_max_pool=True, + ) + ] + ) + self.channels = [64] + else: + self.layers = nn.ModuleList( + [ + ConvBNReLU( + 3, 32, 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True + ), + ConvBNReLU( + 32, + 64, + 3, + 1, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + last_max_pool=True, + ), + ] + ) + self.channels = [32, 64] - def set_tau(self, tau_max, tau_min, epoch_ratio): - assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio) - tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 - self.tau = tau + meta_depth_info = get_depth_choices(layers) + self.InShape = None + self.depth_info = OrderedDict() + self.depth_at_i = OrderedDict() + for stage, layer_blocks in enumerate(layers): + cur_block_choices = meta_depth_info[stage] + assert ( + cur_block_choices[-1] == layer_blocks + ), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks) + block_choices, xstart = [], len(self.layers) + for iL in range(layer_blocks): + iC = self.channels[-1] + planes = 64 * (2 ** stage) + stride = 2 if stage > 0 and iL == 0 else 1 + module = block(iC, planes, stride) + self.channels.append(module.out_dim) + self.layers.append(module) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( + stage, + iL, + layer_blocks, + len(self.layers) - 1, + iC, + module.out_dim, + stride, + ) + # added for depth + layer_index = len(self.layers) - 1 + if iL + 1 in cur_block_choices: + block_choices.append(layer_index) + if iL + 1 == layer_blocks: + self.depth_info[layer_index] = { + "choices": block_choices, + "stage": stage, + "xstart": xstart, + } + self.depth_info_list = [] + for xend, info in self.depth_info.items(): + self.depth_info_list.append((xend, info)) + xstart, xstage = info["xstart"], info["stage"] + for ilayer in range(xstart, xend + 1): + idx = bisect_right(info["choices"], ilayer - 1) + self.depth_at_i[ilayer] = (xstage, idx) - def get_message(self): - return self.message + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.classifier = nn.Linear(module.out_dim, num_classes) + self.InShape = None + self.tau = -1 + self.search_mode = "basic" + # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) - def forward(self, inputs): - if self.search_mode == 'basic': - return self.basic_forward(inputs) - elif self.search_mode == 'search': - return self.search_forward(inputs) - else: - raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + # parameters for width + self.Ranges = [] + self.layer2indexRange = [] + for i, layer in enumerate(self.layers): + start_index = len(self.Ranges) + self.Ranges += layer.get_range() + self.layer2indexRange.append((start_index, len(self.Ranges))) - def search_forward(self, inputs): - flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1) - flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) - flop_depth_probs = torch.flip( torch.cumsum( torch.flip(flop_depth_probs, [1]), 1 ), [1] ) - selected_widths, selected_width_probs = select2withP(self.width_attentions, self.tau) - selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) - with torch.no_grad(): - selected_widths = selected_widths.cpu() + self.register_parameter( + "width_attentions", + nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))), + ) + self.register_parameter( + "depth_attentions", + nn.Parameter(torch.Tensor(len(layers), meta_depth_info["num"])), + ) + nn.init.normal_(self.width_attentions, 0, 0.01) + nn.init.normal_(self.depth_attentions, 0, 0.01) + self.apply(initialize_resnet) - x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] - feature_maps = [] - for i, layer in enumerate(self.layers): - selected_w_index = selected_widths [last_channel_idx: last_channel_idx+layer.num_conv] - selected_w_probs = selected_width_probs[last_channel_idx: last_channel_idx+layer.num_conv] - layer_prob = flop_width_probs [last_channel_idx: last_channel_idx+layer.num_conv] - x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) ) - feature_maps.append( x ) - last_channel_idx += layer.num_conv - if i in self.depth_info: # aggregate the information - choices = self.depth_info[i]['choices'] - xstagei = self.depth_info[i]['stage'] - #print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist())) - #for A, W in zip(choices, selected_depth_probs[xstagei]): - # print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W)) - possible_tensors = [] - max_C = max( feature_maps[A].size(1) for A in choices ) - for tempi, A in enumerate(choices): - xtensor = ChannelWiseInter(feature_maps[A], max_C) - possible_tensors.append( xtensor ) - weighted_sum = sum( xtensor * W for xtensor, W in zip(possible_tensors, selected_depth_probs[xstagei]) ) - x = weighted_sum - - if i in self.depth_at_i: - xstagei, xatti = self.depth_at_i[i] - x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop - else: - x_expected_flop = expected_flop - flops.append( x_expected_flop ) - flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = linear_forward(features, self.classifier) - return logits, torch.stack( [sum(flops)] ) + def arch_parameters(self, LR=None): + if LR is None: + return [self.width_attentions, self.depth_attentions] + else: + return [ + {"params": self.width_attentions, "lr": LR}, + {"params": self.depth_attentions, "lr": LR}, + ] - def basic_forward(self, inputs): - if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1)) - x = inputs - for i, layer in enumerate(self.layers): - x = layer( x ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = self.classifier(features) - return features, logits + def base_parameters(self): + return ( + list(self.layers.parameters()) + + list(self.avgpool.parameters()) + + list(self.classifier.parameters()) + ) + + def get_flop(self, mode, config_dict, extra_info): + if config_dict is not None: + config_dict = config_dict.copy() + # select channels + channels = [3] + for i, weight in enumerate(self.width_attentions): + if mode == "genotype": + with torch.no_grad(): + probe = nn.functional.softmax(weight, dim=0) + C = self.Ranges[i][torch.argmax(probe).item()] + else: + raise ValueError("invalid mode : {:}".format(mode)) + channels.append(C) + # select depth + if mode == "genotype": + with torch.no_grad(): + depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) + choices = torch.argmax(depth_probs, dim=1).cpu().tolist() + else: + raise ValueError("invalid mode : {:}".format(mode)) + selected_layers = [] + for choice, xvalue in zip(choices, self.depth_info_list): + xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1 + selected_layers.append(xtemp) + flop = 0 + for i, layer in enumerate(self.layers): + s, e = self.layer2indexRange[i] + xchl = tuple(channels[s : e + 1]) + if i in self.depth_at_i: + xstagei, xatti = self.depth_at_i[i] + if xatti <= choices[xstagei]: # leave this depth + flop += layer.get_flops(xchl) + else: + flop += 0 # do not use this layer + else: + flop += layer.get_flops(xchl) + # the last fc layer + flop += channels[-1] * self.classifier.out_features + if config_dict is None: + return flop / 1e6 + else: + config_dict["xchannels"] = channels + config_dict["xblocks"] = selected_layers + config_dict["super_type"] = "infer-shape" + config_dict["estimated_FLOP"] = flop / 1e6 + return flop / 1e6, config_dict + + def get_arch_info(self): + string = ( + "for depth and width, there are {:} + {:} attention probabilities.".format( + len(self.depth_attentions), len(self.width_attentions) + ) + ) + string += "\n{:}".format(self.depth_info) + discrepancy = [] + with torch.no_grad(): + for i, att in enumerate(self.depth_attentions): + prob = nn.functional.softmax(att, dim=0) + prob = prob.cpu() + selc = prob.argmax().item() + prob = prob.tolist() + prob = ["{:.3f}".format(x) for x in prob] + xstring = "{:03d}/{:03d}-th : {:}".format( + i, len(self.depth_attentions), " ".join(prob) + ) + logt = ["{:.4f}".format(x) for x in att.cpu().tolist()] + xstring += " || {:17s}".format(" ".join(logt)) + prob = sorted([float(x) for x in prob]) + disc = prob[-1] - prob[-2] + xstring += " || discrepancy={:.2f} || select={:}/{:}".format( + disc, selc, len(prob) + ) + discrepancy.append(disc) + string += "\n{:}".format(xstring) + string += "\n-----------------------------------------------" + for i, att in enumerate(self.width_attentions): + prob = nn.functional.softmax(att, dim=0) + prob = prob.cpu() + selc = prob.argmax().item() + prob = prob.tolist() + prob = ["{:.3f}".format(x) for x in prob] + xstring = "{:03d}/{:03d}-th : {:}".format( + i, len(self.width_attentions), " ".join(prob) + ) + logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] + xstring += " || {:52s}".format(" ".join(logt)) + prob = sorted([float(x) for x in prob]) + disc = prob[-1] - prob[-2] + xstring += " || dis={:.2f} || select={:}/{:}".format( + disc, selc, len(prob) + ) + discrepancy.append(disc) + string += "\n{:}".format(xstring) + return string, discrepancy + + def set_tau(self, tau_max, tau_min, epoch_ratio): + assert ( + epoch_ratio >= 0 and epoch_ratio <= 1 + ), "invalid epoch-ratio : {:}".format(epoch_ratio) + tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 + self.tau = tau + + def get_message(self): + return self.message + + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) + + def search_forward(self, inputs): + flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1) + flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) + flop_depth_probs = torch.flip( + torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1] + ) + selected_widths, selected_width_probs = select2withP( + self.width_attentions, self.tau + ) + selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) + with torch.no_grad(): + selected_widths = selected_widths.cpu() + + x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] + feature_maps = [] + for i, layer in enumerate(self.layers): + selected_w_index = selected_widths[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + selected_w_probs = selected_width_probs[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + layer_prob = flop_width_probs[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + x, expected_inC, expected_flop = layer( + (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) + ) + feature_maps.append(x) + last_channel_idx += layer.num_conv + if i in self.depth_info: # aggregate the information + choices = self.depth_info[i]["choices"] + xstagei = self.depth_info[i]["stage"] + # print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist())) + # for A, W in zip(choices, selected_depth_probs[xstagei]): + # print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W)) + possible_tensors = [] + max_C = max(feature_maps[A].size(1) for A in choices) + for tempi, A in enumerate(choices): + xtensor = ChannelWiseInter(feature_maps[A], max_C) + possible_tensors.append(xtensor) + weighted_sum = sum( + xtensor * W + for xtensor, W in zip( + possible_tensors, selected_depth_probs[xstagei] + ) + ) + x = weighted_sum + + if i in self.depth_at_i: + xstagei, xatti = self.depth_at_i[i] + x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop + else: + x_expected_flop = expected_flop + flops.append(x_expected_flop) + flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = linear_forward(features, self.classifier) + return logits, torch.stack([sum(flops)]) + + def basic_forward(self, inputs): + if self.InShape is None: + self.InShape = (inputs.size(-2), inputs.size(-1)) + x = inputs + for i, layer in enumerate(self.layers): + x = layer(x) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.classifier(features) + return features, logits diff --git a/lib/models/shape_searchs/SearchSimResNet_width.py b/lib/models/shape_searchs/SearchSimResNet_width.py index 18785f8..584ffef 100644 --- a/lib/models/shape_searchs/SearchSimResNet_width.py +++ b/lib/models/shape_searchs/SearchSimResNet_width.py @@ -4,313 +4,463 @@ import math, torch import torch.nn as nn from ..initialization import initialize_resnet -from ..SharedUtils import additive_func -from .SoftSelect import select2withP, ChannelWiseInter -from .SoftSelect import linear_forward -from .SoftSelect import get_width_choices as get_choices +from ..SharedUtils import additive_func +from .SoftSelect import select2withP, ChannelWiseInter +from .SoftSelect import linear_forward +from .SoftSelect import get_width_choices as get_choices def conv_forward(inputs, conv, choices): - iC = conv.in_channels - fill_size = list(inputs.size()) - fill_size[1] = iC - fill_size[1] - filled = torch.zeros(fill_size, device=inputs.device) - xinputs = torch.cat((inputs, filled), dim=1) - outputs = conv(xinputs) - selecteds = [outputs[:,:oC] for oC in choices] - return selecteds + iC = conv.in_channels + fill_size = list(inputs.size()) + fill_size[1] = iC - fill_size[1] + filled = torch.zeros(fill_size, device=inputs.device) + xinputs = torch.cat((inputs, filled), dim=1) + outputs = conv(xinputs) + selecteds = [outputs[:, :oC] for oC in choices] + return selecteds class ConvBNReLU(nn.Module): - num_conv = 1 - def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): - super(ConvBNReLU, self).__init__() - self.InShape = None - self.OutShape = None - self.choices = get_choices(nOut) - self.register_buffer('choices_tensor', torch.Tensor( self.choices )) + num_conv = 1 - if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - else : self.avg = None - self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) - #if has_bn : self.bn = nn.BatchNorm2d(nOut) - #else : self.bn = None - self.has_bn = has_bn - self.BNs = nn.ModuleList() - for i, _out in enumerate(self.choices): - self.BNs.append(nn.BatchNorm2d(_out)) - if has_relu: self.relu = nn.ReLU(inplace=True) - else : self.relu = None - self.in_dim = nIn - self.out_dim = nOut - self.search_mode = 'basic' + def __init__( + self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu + ): + super(ConvBNReLU, self).__init__() + self.InShape = None + self.OutShape = None + self.choices = get_choices(nOut) + self.register_buffer("choices_tensor", torch.Tensor(self.choices)) - def get_flops(self, channels, check_range=True, divide=1): - iC, oC = channels - if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels) - assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape) - assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape) - #conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups - conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups) - all_positions = self.OutShape[0] * self.OutShape[1] - flops = (conv_per_position_flops * all_positions / divide) * iC * oC - if self.conv.bias is not None: flops += all_positions / divide - return flops + if has_avg: + self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.avg = None + self.conv = nn.Conv2d( + nIn, + nOut, + kernel_size=kernel, + stride=stride, + padding=padding, + dilation=1, + groups=1, + bias=bias, + ) + # if has_bn : self.bn = nn.BatchNorm2d(nOut) + # else : self.bn = None + self.has_bn = has_bn + self.BNs = nn.ModuleList() + for i, _out in enumerate(self.choices): + self.BNs.append(nn.BatchNorm2d(_out)) + if has_relu: + self.relu = nn.ReLU(inplace=True) + else: + self.relu = None + self.in_dim = nIn + self.out_dim = nOut + self.search_mode = "basic" - def get_range(self): - return [self.choices] + def get_flops(self, channels, check_range=True, divide=1): + iC, oC = channels + if check_range: + assert ( + iC <= self.conv.in_channels and oC <= self.conv.out_channels + ), "{:} vs {:} | {:} vs {:}".format( + iC, self.conv.in_channels, oC, self.conv.out_channels + ) + assert ( + isinstance(self.InShape, tuple) and len(self.InShape) == 2 + ), "invalid in-shape : {:}".format(self.InShape) + assert ( + isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 + ), "invalid out-shape : {:}".format(self.OutShape) + # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups + conv_per_position_flops = ( + self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups + ) + all_positions = self.OutShape[0] * self.OutShape[1] + flops = (conv_per_position_flops * all_positions / divide) * iC * oC + if self.conv.bias is not None: + flops += all_positions / divide + return flops - def forward(self, inputs): - if self.search_mode == 'basic': - return self.basic_forward(inputs) - elif self.search_mode == 'search': - return self.search_forward(inputs) - else: - raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def get_range(self): + return [self.choices] - def search_forward(self, tuple_inputs): - assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) - inputs, expected_inC, probability, index, prob = tuple_inputs - index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) - probability = torch.squeeze(probability) - assert len(index) == 2, 'invalid length : {:}'.format(index) - # compute expected flop - #coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) - expected_outC = (self.choices_tensor * probability).sum() - expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) - if self.avg : out = self.avg( inputs ) - else : out = inputs - # convolutional layer - out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) - out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] - # merge - out_channel = max([x.size(1) for x in out_bns]) - outA = ChannelWiseInter(out_bns[0], out_channel) - outB = ChannelWiseInter(out_bns[1], out_channel) - out = outA * prob[0] + outB * prob[1] - #out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) - if self.relu: out = self.relu( out ) - else : out = out - return out, expected_outC, expected_flop + def search_forward(self, tuple_inputs): + assert ( + isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 + ), "invalid type input : {:}".format(type(tuple_inputs)) + inputs, expected_inC, probability, index, prob = tuple_inputs + index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) + probability = torch.squeeze(probability) + assert len(index) == 2, "invalid length : {:}".format(index) + # compute expected flop + # coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) + expected_outC = (self.choices_tensor * probability).sum() + expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) + if self.avg: + out = self.avg(inputs) + else: + out = inputs + # convolutional layer + out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) + out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] + # merge + out_channel = max([x.size(1) for x in out_bns]) + outA = ChannelWiseInter(out_bns[0], out_channel) + outB = ChannelWiseInter(out_bns[1], out_channel) + out = outA * prob[0] + outB * prob[1] + # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) - def basic_forward(self, inputs): - if self.avg : out = self.avg( inputs ) - else : out = inputs - conv = self.conv( out ) - if self.has_bn:out= self.BNs[-1]( conv ) - else : out = conv - if self.relu: out = self.relu( out ) - else : out = out - if self.InShape is None: - self.InShape = (inputs.size(-2), inputs.size(-1)) - self.OutShape = (out.size(-2) , out.size(-1)) - return out + if self.relu: + out = self.relu(out) + else: + out = out + return out, expected_outC, expected_flop + + def basic_forward(self, inputs): + if self.avg: + out = self.avg(inputs) + else: + out = inputs + conv = self.conv(out) + if self.has_bn: + out = self.BNs[-1](conv) + else: + out = conv + if self.relu: + out = self.relu(out) + else: + out = out + if self.InShape is None: + self.InShape = (inputs.size(-2), inputs.size(-1)) + self.OutShape = (out.size(-2), out.size(-1)) + return out class SimBlock(nn.Module): - expansion = 1 - num_conv = 1 - def __init__(self, inplanes, planes, stride): - super(SimBlock, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) - if stride == 2: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) - elif inplanes != planes: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) - else: - self.downsample = None - self.out_dim = planes - self.search_mode = 'basic' + expansion = 1 + num_conv = 1 - def get_range(self): - return self.conv.get_range() + def __init__(self, inplanes, planes, stride): + super(SimBlock, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv = ConvBNReLU( + inplanes, + planes, + 3, + stride, + 1, + False, + has_avg=False, + has_bn=True, + has_relu=True, + ) + if stride == 2: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=True, + has_bn=False, + has_relu=False, + ) + elif inplanes != planes: + self.downsample = ConvBNReLU( + inplanes, + planes, + 1, + 1, + 0, + False, + has_avg=False, + has_bn=True, + has_relu=False, + ) + else: + self.downsample = None + self.out_dim = planes + self.search_mode = "basic" - def get_flops(self, channels): - assert len(channels) == 2, 'invalid channels : {:}'.format(channels) - flop_A = self.conv.get_flops([channels[0], channels[1]]) - if hasattr(self.downsample, 'get_flops'): - flop_C = self.downsample.get_flops([channels[0], channels[-1]]) - else: - flop_C = 0 - if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train - flop_C = channels[0] * channels[-1] * self.conv.OutShape[0] * self.conv.OutShape[1] - return flop_A + flop_C + def get_range(self): + return self.conv.get_range() - def forward(self, inputs): - if self.search_mode == 'basic' : return self.basic_forward(inputs) - elif self.search_mode == 'search': return self.search_forward(inputs) - else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + def get_flops(self, channels): + assert len(channels) == 2, "invalid channels : {:}".format(channels) + flop_A = self.conv.get_flops([channels[0], channels[1]]) + if hasattr(self.downsample, "get_flops"): + flop_C = self.downsample.get_flops([channels[0], channels[-1]]) + else: + flop_C = 0 + if ( + channels[0] != channels[-1] and self.downsample is None + ): # this short-cut will be added during the infer-train + flop_C = ( + channels[0] + * channels[-1] + * self.conv.OutShape[0] + * self.conv.OutShape[1] + ) + return flop_A + flop_C - def search_forward(self, tuple_inputs): - assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) - inputs, expected_inC, probability, indexes, probs = tuple_inputs - assert indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1, 'invalid size : {:}, {:}, {:}'.format(indexes.size(), probs.size(), probability.size()) - out, expected_next_inC, expected_flop = self.conv( (inputs, expected_inC , probability[0], indexes[0], probs[0]) ) - if self.downsample is not None: - residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[-1], indexes[-1], probs[-1]) ) - else: - residual, expected_flop_c = inputs, 0 - out = additive_func(residual, out) - return nn.functional.relu(out, inplace=True), expected_next_inC, sum([expected_flop, expected_flop_c]) + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) - def basic_forward(self, inputs): - basicblock = self.conv(inputs) - if self.downsample is not None: residual = self.downsample(inputs) - else : residual = inputs - out = additive_func(residual, basicblock) - return nn.functional.relu(out, inplace=True) + def search_forward(self, tuple_inputs): + assert ( + isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 + ), "invalid type input : {:}".format(type(tuple_inputs)) + inputs, expected_inC, probability, indexes, probs = tuple_inputs + assert ( + indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1 + ), "invalid size : {:}, {:}, {:}".format( + indexes.size(), probs.size(), probability.size() + ) + out, expected_next_inC, expected_flop = self.conv( + (inputs, expected_inC, probability[0], indexes[0], probs[0]) + ) + if self.downsample is not None: + residual, _, expected_flop_c = self.downsample( + (inputs, expected_inC, probability[-1], indexes[-1], probs[-1]) + ) + else: + residual, expected_flop_c = inputs, 0 + out = additive_func(residual, out) + return ( + nn.functional.relu(out, inplace=True), + expected_next_inC, + sum([expected_flop, expected_flop_c]), + ) + def basic_forward(self, inputs): + basicblock = self.conv(inputs) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = additive_func(residual, basicblock) + return nn.functional.relu(out, inplace=True) class SearchWidthSimResNet(nn.Module): + def __init__(self, depth, num_classes): + super(SearchWidthSimResNet, self).__init__() - def __init__(self, depth, num_classes): - super(SearchWidthSimResNet, self).__init__() + assert ( + depth - 2 + ) % 3 == 0, "depth should be one of 5, 8, 11, 14, ... instead of {:}".format( + depth + ) + layer_blocks = (depth - 2) // 3 + self.message = ( + "SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}".format( + depth, layer_blocks + ) + ) + self.num_classes = num_classes + self.channels = [16] + self.layers = nn.ModuleList( + [ + ConvBNReLU( + 3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True + ) + ] + ) + self.InShape = None + for stage in range(3): + for iL in range(layer_blocks): + iC = self.channels[-1] + planes = 16 * (2 ** stage) + stride = 2 if stage > 0 and iL == 0 else 1 + module = SimBlock(iC, planes, stride) + self.channels.append(module.out_dim) + self.layers.append(module) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( + stage, + iL, + layer_blocks, + len(self.layers) - 1, + iC, + module.out_dim, + stride, + ) - assert (depth - 2) % 3 == 0, 'depth should be one of 5, 8, 11, 14, ... instead of {:}'.format(depth) - layer_blocks = (depth - 2) // 3 - self.message = 'SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) - self.num_classes = num_classes - self.channels = [16] - self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) - self.InShape = None - for stage in range(3): - for iL in range(layer_blocks): - iC = self.channels[-1] - planes = 16 * (2**stage) - stride = 2 if stage > 0 and iL == 0 else 1 - module = SimBlock(iC, planes, stride) - self.channels.append( module.out_dim ) - self.layers.append ( module ) - self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride) - - self.avgpool = nn.AvgPool2d(8) - self.classifier = nn.Linear(module.out_dim, num_classes) - self.InShape = None - self.tau = -1 - self.search_mode = 'basic' - #assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) - - # parameters for width - self.Ranges = [] - self.layer2indexRange = [] - for i, layer in enumerate(self.layers): - start_index = len(self.Ranges) - self.Ranges += layer.get_range() - self.layer2indexRange.append( (start_index, len(self.Ranges)) ) - assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth) + self.avgpool = nn.AvgPool2d(8) + self.classifier = nn.Linear(module.out_dim, num_classes) + self.InShape = None + self.tau = -1 + self.search_mode = "basic" + # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) - self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None)))) - nn.init.normal_(self.width_attentions, 0, 0.01) - self.apply(initialize_resnet) + # parameters for width + self.Ranges = [] + self.layer2indexRange = [] + for i, layer in enumerate(self.layers): + start_index = len(self.Ranges) + self.Ranges += layer.get_range() + self.layer2indexRange.append((start_index, len(self.Ranges))) + assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format( + len(self.Ranges) + 1, depth + ) - def arch_parameters(self): - return [self.width_attentions] + self.register_parameter( + "width_attentions", + nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))), + ) + nn.init.normal_(self.width_attentions, 0, 0.01) + self.apply(initialize_resnet) - def base_parameters(self): - return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters()) + def arch_parameters(self): + return [self.width_attentions] - def get_flop(self, mode, config_dict, extra_info): - if config_dict is not None: config_dict = config_dict.copy() - #weights = [F.softmax(x, dim=0) for x in self.width_attentions] - channels = [3] - for i, weight in enumerate(self.width_attentions): - if mode == 'genotype': + def base_parameters(self): + return ( + list(self.layers.parameters()) + + list(self.avgpool.parameters()) + + list(self.classifier.parameters()) + ) + + def get_flop(self, mode, config_dict, extra_info): + if config_dict is not None: + config_dict = config_dict.copy() + # weights = [F.softmax(x, dim=0) for x in self.width_attentions] + channels = [3] + for i, weight in enumerate(self.width_attentions): + if mode == "genotype": + with torch.no_grad(): + probe = nn.functional.softmax(weight, dim=0) + C = self.Ranges[i][torch.argmax(probe).item()] + elif mode == "max": + C = self.Ranges[i][-1] + elif mode == "fix": + C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) + elif mode == "random": + assert isinstance(extra_info, float), "invalid extra_info : {:}".format( + extra_info + ) + with torch.no_grad(): + prob = nn.functional.softmax(weight, dim=0) + approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) + for j in range(prob.size(0)): + prob[j] = 1 / ( + abs(j - (approximate_C - self.Ranges[i][j])) + 0.2 + ) + C = self.Ranges[i][torch.multinomial(prob, 1, False).item()] + else: + raise ValueError("invalid mode : {:}".format(mode)) + channels.append(C) + flop = 0 + for i, layer in enumerate(self.layers): + s, e = self.layer2indexRange[i] + xchl = tuple(channels[s : e + 1]) + flop += layer.get_flops(xchl) + # the last fc layer + flop += channels[-1] * self.classifier.out_features + if config_dict is None: + return flop / 1e6 + else: + config_dict["xchannels"] = channels + config_dict["super_type"] = "infer-width" + config_dict["estimated_FLOP"] = flop / 1e6 + return flop / 1e6, config_dict + + def get_arch_info(self): + string = "for width, there are {:} attention probabilities.".format( + len(self.width_attentions) + ) + discrepancy = [] with torch.no_grad(): - probe = nn.functional.softmax(weight, dim=0) - C = self.Ranges[i][ torch.argmax(probe).item() ] - elif mode == 'max': - C = self.Ranges[i][-1] - elif mode == 'fix': - C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] ) - elif mode == 'random': - assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info) + for i, att in enumerate(self.width_attentions): + prob = nn.functional.softmax(att, dim=0) + prob = prob.cpu() + selc = prob.argmax().item() + prob = prob.tolist() + prob = ["{:.3f}".format(x) for x in prob] + xstring = "{:03d}/{:03d}-th : {:}".format( + i, len(self.width_attentions), " ".join(prob) + ) + logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] + xstring += " || {:52s}".format(" ".join(logt)) + prob = sorted([float(x) for x in prob]) + disc = prob[-1] - prob[-2] + xstring += " || dis={:.2f} || select={:}/{:}".format( + disc, selc, len(prob) + ) + discrepancy.append(disc) + string += "\n{:}".format(xstring) + return string, discrepancy + + def set_tau(self, tau_max, tau_min, epoch_ratio): + assert ( + epoch_ratio >= 0 and epoch_ratio <= 1 + ), "invalid epoch-ratio : {:}".format(epoch_ratio) + tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 + self.tau = tau + + def get_message(self): + return self.message + + def forward(self, inputs): + if self.search_mode == "basic": + return self.basic_forward(inputs) + elif self.search_mode == "search": + return self.search_forward(inputs) + else: + raise ValueError("invalid search_mode = {:}".format(self.search_mode)) + + def search_forward(self, inputs): + flop_probs = nn.functional.softmax(self.width_attentions, dim=1) + selected_widths, selected_probs = select2withP(self.width_attentions, self.tau) with torch.no_grad(): - prob = nn.functional.softmax(weight, dim=0) - approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] ) - for j in range(prob.size(0)): - prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2) - C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ] - else: - raise ValueError('invalid mode : {:}'.format(mode)) - channels.append( C ) - flop = 0 - for i, layer in enumerate(self.layers): - s, e = self.layer2indexRange[i] - xchl = tuple( channels[s:e+1] ) - flop+= layer.get_flops(xchl) - # the last fc layer - flop += channels[-1] * self.classifier.out_features - if config_dict is None: - return flop / 1e6 - else: - config_dict['xchannels'] = channels - config_dict['super_type'] = 'infer-width' - config_dict['estimated_FLOP'] = flop / 1e6 - return flop / 1e6, config_dict + selected_widths = selected_widths.cpu() - def get_arch_info(self): - string = "for width, there are {:} attention probabilities.".format(len(self.width_attentions)) - discrepancy = [] - with torch.no_grad(): - for i, att in enumerate(self.width_attentions): - prob = nn.functional.softmax(att, dim=0) - prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist() - prob = ['{:.3f}'.format(x) for x in prob] - xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob)) - logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()] - xstring += ' || {:52s}'.format(' '.join(logt)) - prob = sorted( [float(x) for x in prob] ) - disc = prob[-1] - prob[-2] - xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob)) - discrepancy.append( disc ) - string += '\n{:}'.format(xstring) - return string, discrepancy + x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] + for i, layer in enumerate(self.layers): + selected_w_index = selected_widths[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + selected_w_probs = selected_probs[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + layer_prob = flop_probs[ + last_channel_idx : last_channel_idx + layer.num_conv + ] + x, expected_inC, expected_flop = layer( + (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) + ) + last_channel_idx += layer.num_conv + flops.append(expected_flop) + flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = linear_forward(features, self.classifier) + return logits, torch.stack([sum(flops)]) - def set_tau(self, tau_max, tau_min, epoch_ratio): - assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio) - tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 - self.tau = tau - - def get_message(self): - return self.message - - def forward(self, inputs): - if self.search_mode == 'basic': - return self.basic_forward(inputs) - elif self.search_mode == 'search': - return self.search_forward(inputs) - else: - raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) - - def search_forward(self, inputs): - flop_probs = nn.functional.softmax(self.width_attentions, dim=1) - selected_widths, selected_probs = select2withP(self.width_attentions, self.tau) - with torch.no_grad(): - selected_widths = selected_widths.cpu() - - x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] - for i, layer in enumerate(self.layers): - selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv] - selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv] - layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv] - x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) ) - last_channel_idx += layer.num_conv - flops.append( expected_flop ) - flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = linear_forward(features, self.classifier) - return logits, torch.stack( [sum(flops)] ) - - def basic_forward(self, inputs): - if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1)) - x = inputs - for i, layer in enumerate(self.layers): - x = layer( x ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = self.classifier(features) - return features, logits + def basic_forward(self, inputs): + if self.InShape is None: + self.InShape = (inputs.size(-2), inputs.size(-1)) + x = inputs + for i, layer in enumerate(self.layers): + x = layer(x) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.classifier(features) + return features, logits diff --git a/lib/models/shape_searchs/SoftSelect.py b/lib/models/shape_searchs/SoftSelect.py index 802dcb6..3cdfa45 100644 --- a/lib/models/shape_searchs/SoftSelect.py +++ b/lib/models/shape_searchs/SoftSelect.py @@ -6,106 +6,123 @@ import torch.nn as nn def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7): - if tau <= 0: - new_logits = logits - probs = nn.functional.softmax(new_logits, dim=1) - else : - while True: # a trick to avoid the gumbels bug - gumbels = -torch.empty_like(logits).exponential_().log() - new_logits = (logits.log_softmax(dim=1) + gumbels) / tau - probs = nn.functional.softmax(new_logits, dim=1) - if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break + if tau <= 0: + new_logits = logits + probs = nn.functional.softmax(new_logits, dim=1) + else: + while True: # a trick to avoid the gumbels bug + gumbels = -torch.empty_like(logits).exponential_().log() + new_logits = (logits.log_softmax(dim=1) + gumbels) / tau + probs = nn.functional.softmax(new_logits, dim=1) + if ( + (not torch.isinf(gumbels).any()) + and (not torch.isinf(probs).any()) + and (not torch.isnan(probs).any()) + ): + break - if just_prob: return probs + if just_prob: + return probs - #with torch.no_grad(): # add eps for unexpected torch error - # probs = nn.functional.softmax(new_logits, dim=1) - # selected_index = torch.multinomial(probs + eps, 2, False) - with torch.no_grad(): # add eps for unexpected torch error - probs = probs.cpu() - selected_index = torch.multinomial(probs + eps, num, False).to(logits.device) - selected_logit = torch.gather(new_logits, 1, selected_index) - selcted_probs = nn.functional.softmax(selected_logit, dim=1) - return selected_index, selcted_probs + # with torch.no_grad(): # add eps for unexpected torch error + # probs = nn.functional.softmax(new_logits, dim=1) + # selected_index = torch.multinomial(probs + eps, 2, False) + with torch.no_grad(): # add eps for unexpected torch error + probs = probs.cpu() + selected_index = torch.multinomial(probs + eps, num, False).to(logits.device) + selected_logit = torch.gather(new_logits, 1, selected_index) + selcted_probs = nn.functional.softmax(selected_logit, dim=1) + return selected_index, selcted_probs -def ChannelWiseInter(inputs, oC, mode='v2'): - if mode == 'v1': - return ChannelWiseInterV1(inputs, oC) - elif mode == 'v2': - return ChannelWiseInterV2(inputs, oC) - else: - raise ValueError('invalid mode : {:}'.format(mode)) +def ChannelWiseInter(inputs, oC, mode="v2"): + if mode == "v1": + return ChannelWiseInterV1(inputs, oC) + elif mode == "v2": + return ChannelWiseInterV2(inputs, oC) + else: + raise ValueError("invalid mode : {:}".format(mode)) def ChannelWiseInterV1(inputs, oC): - assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size()) - def start_index(a, b, c): - return int( math.floor(float(a * c) / b) ) - def end_index(a, b, c): - return int( math.ceil(float((a + 1) * c) / b) ) - batch, iC, H, W = inputs.size() - outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device) - if iC == oC: return inputs - for ot in range(oC): - istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC) - values = inputs[:, istartT:iendT].mean(dim=1) - outputs[:, ot, :, :] = values - return outputs + assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size()) + + def start_index(a, b, c): + return int(math.floor(float(a * c) / b)) + + def end_index(a, b, c): + return int(math.ceil(float((a + 1) * c) / b)) + + batch, iC, H, W = inputs.size() + outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device) + if iC == oC: + return inputs + for ot in range(oC): + istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC) + values = inputs[:, istartT:iendT].mean(dim=1) + outputs[:, ot, :, :] = values + return outputs def ChannelWiseInterV2(inputs, oC): - assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size()) - batch, C, H, W = inputs.size() - if C == oC: return inputs - else : return nn.functional.adaptive_avg_pool3d(inputs, (oC,H,W)) - #inputs_5D = inputs.view(batch, 1, C, H, W) - #otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None) - #otputs = otputs_5D.view(batch, oC, H, W) - #otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False) - #return otputs + assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size()) + batch, C, H, W = inputs.size() + if C == oC: + return inputs + else: + return nn.functional.adaptive_avg_pool3d(inputs, (oC, H, W)) + # inputs_5D = inputs.view(batch, 1, C, H, W) + # otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None) + # otputs = otputs_5D.view(batch, oC, H, W) + # otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False) + # return otputs def linear_forward(inputs, linear): - if linear is None: return inputs - iC = inputs.size(1) - weight = linear.weight[:, :iC] - if linear.bias is None: bias = None - else : bias = linear.bias - return nn.functional.linear(inputs, weight, bias) + if linear is None: + return inputs + iC = inputs.size(1) + weight = linear.weight[:, :iC] + if linear.bias is None: + bias = None + else: + bias = linear.bias + return nn.functional.linear(inputs, weight, bias) def get_width_choices(nOut): - xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] - if nOut is None: - return len(xsrange) - else: - Xs = [int(nOut * i) for i in xsrange] - #xs = [ int(nOut * i // 10) for i in range(2, 11)] - #Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1] - Xs = sorted( list( set(Xs) ) ) - return tuple(Xs) + xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + if nOut is None: + return len(xsrange) + else: + Xs = [int(nOut * i) for i in xsrange] + # xs = [ int(nOut * i // 10) for i in range(2, 11)] + # Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1] + Xs = sorted(list(set(Xs))) + return tuple(Xs) def get_depth_choices(nDepth): - if nDepth is None: - return 3 - else: - assert nDepth >= 3, 'nDepth should be greater than 2 vs {:}'.format(nDepth) - if nDepth == 1 : return (1, 1, 1) - elif nDepth == 2: return (1, 1, 2) - elif nDepth >= 3: - return (nDepth//3, nDepth*2//3, nDepth) + if nDepth is None: + return 3 else: - raise ValueError('invalid Depth : {:}'.format(nDepth)) + assert nDepth >= 3, "nDepth should be greater than 2 vs {:}".format(nDepth) + if nDepth == 1: + return (1, 1, 1) + elif nDepth == 2: + return (1, 1, 2) + elif nDepth >= 3: + return (nDepth // 3, nDepth * 2 // 3, nDepth) + else: + raise ValueError("invalid Depth : {:}".format(nDepth)) def drop_path(x, drop_prob): - if drop_prob > 0.: - keep_prob = 1. - drop_prob - mask = x.new_zeros(x.size(0), 1, 1, 1) - mask = mask.bernoulli_(keep_prob) - x = x * (mask / keep_prob) - #x.div_(keep_prob) - #x.mul_(mask) - return x + if drop_prob > 0.0: + keep_prob = 1.0 - drop_prob + mask = x.new_zeros(x.size(0), 1, 1, 1) + mask = mask.bernoulli_(keep_prob) + x = x * (mask / keep_prob) + # x.div_(keep_prob) + # x.mul_(mask) + return x diff --git a/lib/models/shape_searchs/__init__.py b/lib/models/shape_searchs/__init__.py index 500167d..15e2260 100644 --- a/lib/models/shape_searchs/__init__.py +++ b/lib/models/shape_searchs/__init__.py @@ -3,7 +3,7 @@ ################################################## from .SearchCifarResNet_width import SearchWidthCifarResNet from .SearchCifarResNet_depth import SearchDepthCifarResNet -from .SearchCifarResNet import SearchShapeCifarResNet -from .SearchSimResNet_width import SearchWidthSimResNet -from .SearchImagenetResNet import SearchShapeImagenetResNet +from .SearchCifarResNet import SearchShapeCifarResNet +from .SearchSimResNet_width import SearchWidthSimResNet +from .SearchImagenetResNet import SearchShapeImagenetResNet from .generic_size_tiny_cell_model import GenericNAS301Model diff --git a/lib/models/shape_searchs/generic_size_tiny_cell_model.py b/lib/models/shape_searchs/generic_size_tiny_cell_model.py index ee887cc..3a3a354 100644 --- a/lib/models/shape_searchs/generic_size_tiny_cell_model.py +++ b/lib/models/shape_searchs/generic_size_tiny_cell_model.py @@ -15,152 +15,195 @@ from models.shape_searchs.SoftSelect import select2withP, ChannelWiseInter class GenericNAS301Model(nn.Module): + def __init__( + self, + candidate_Cs: List[int], + max_num_Cs: int, + genotype: Any, + num_classes: int, + affine: bool, + track_running_stats: bool, + ): + super(GenericNAS301Model, self).__init__() + self._max_num_Cs = max_num_Cs + self._candidate_Cs = candidate_Cs + if max_num_Cs % 3 != 2: + raise ValueError("invalid number of layers : {:}".format(max_num_Cs)) + self._num_stage = N = max_num_Cs // 3 + self._max_C = max(candidate_Cs) - def __init__(self, candidate_Cs: List[int], max_num_Cs: int, genotype: Any, num_classes: int, affine: bool, track_running_stats: bool): - super(GenericNAS301Model, self).__init__() - self._max_num_Cs = max_num_Cs - self._candidate_Cs = candidate_Cs - if max_num_Cs % 3 != 2: - raise ValueError('invalid number of layers : {:}'.format(max_num_Cs)) - self._num_stage = N = max_num_Cs // 3 - self._max_C = max(candidate_Cs) + stem = nn.Sequential( + nn.Conv2d(3, self._max_C, kernel_size=3, padding=1, bias=not affine), + nn.BatchNorm2d( + self._max_C, affine=affine, track_running_stats=track_running_stats + ), + ) - stem = nn.Sequential( - nn.Conv2d(3, self._max_C, kernel_size=3, padding=1, bias=not affine), - nn.BatchNorm2d(self._max_C, affine=affine, track_running_stats=track_running_stats)) + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + c_prev = self._max_C + self._cells = nn.ModuleList() + self._cells.append(stem) + for index, reduction in enumerate(layer_reductions): + if reduction: + cell = ResNetBasicblock(c_prev, self._max_C, 2, True) + else: + cell = InferCell( + genotype, c_prev, self._max_C, 1, affine, track_running_stats + ) + self._cells.append(cell) + c_prev = cell.out_dim + self._num_layer = len(self._cells) - c_prev = self._max_C - self._cells = nn.ModuleList() - self._cells.append(stem) - for index, reduction in enumerate(layer_reductions): - if reduction : cell = ResNetBasicblock(c_prev, self._max_C, 2, True) - else : cell = InferCell(genotype, c_prev, self._max_C, 1, affine, track_running_stats) - self._cells.append(cell) - c_prev = cell.out_dim - self._num_layer = len(self._cells) + self.lastact = nn.Sequential( + nn.BatchNorm2d( + c_prev, affine=affine, track_running_stats=track_running_stats + ), + nn.ReLU(inplace=True), + ) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(c_prev, num_classes) + # algorithm related + self.register_buffer("_tau", torch.zeros(1)) + self._algo = None + self._warmup_ratio = None - self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev, affine=affine, track_running_stats=track_running_stats), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(c_prev, num_classes) - # algorithm related - self.register_buffer('_tau', torch.zeros(1)) - self._algo = None - self._warmup_ratio = None + def set_algo(self, algo: Text): + # used for searching + assert self._algo is None, "This functioin can only be called once." + assert algo in ["mask_gumbel", "mask_rl", "tas"], "invalid algo : {:}".format( + algo + ) + self._algo = algo + self._arch_parameters = nn.Parameter( + 1e-3 * torch.randn(self._max_num_Cs, len(self._candidate_Cs)) + ) + # if algo == 'mask_gumbel' or algo == 'mask_rl': + self.register_buffer( + "_masks", torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs)) + ) + for i in range(len(self._candidate_Cs)): + self._masks.data[i, : self._candidate_Cs[i]] = 1 - def set_algo(self, algo: Text): - # used for searching - assert self._algo is None, 'This functioin can only be called once.' - assert algo in ['mask_gumbel', 'mask_rl', 'tas'], 'invalid algo : {:}'.format(algo) - self._algo = algo - self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs))) - # if algo == 'mask_gumbel' or algo == 'mask_rl': - self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))) - for i in range(len(self._candidate_Cs)): - self._masks.data[i, :self._candidate_Cs[i]] = 1 - - @property - def tau(self): - return self._tau + @property + def tau(self): + return self._tau - def set_tau(self, tau): - self._tau.data[:] = tau + def set_tau(self, tau): + self._tau.data[:] = tau - @property - def warmup_ratio(self): - return self._warmup_ratio + @property + def warmup_ratio(self): + return self._warmup_ratio - def set_warmup_ratio(self, ratio: float): - self._warmup_ratio = ratio + def set_warmup_ratio(self, ratio: float): + self._warmup_ratio = ratio - @property - def weights(self): - xlist = list(self._cells.parameters()) - xlist+= list(self.lastact.parameters()) - xlist+= list(self.global_pooling.parameters()) - xlist+= list(self.classifier.parameters()) - return xlist + @property + def weights(self): + xlist = list(self._cells.parameters()) + xlist += list(self.lastact.parameters()) + xlist += list(self.global_pooling.parameters()) + xlist += list(self.classifier.parameters()) + return xlist - @property - def alphas(self): - return [self._arch_parameters] + @property + def alphas(self): + return [self._arch_parameters] - def show_alphas(self): - with torch.no_grad(): - return 'arch-parameters :\n{:}'.format(nn.functional.softmax(self._arch_parameters, dim=-1).cpu()) - - @property - def random(self): - cs = [] - for i in range(self._max_num_Cs): - index = random.randint(0, len(self._candidate_Cs)-1) - cs.append(str(self._candidate_Cs[index])) - return ':'.join(cs) - - @property - def genotype(self): - cs = [] - for i in range(self._max_num_Cs): - with torch.no_grad(): - index = self._arch_parameters[i].argmax().item() - cs.append(str(self._candidate_Cs[index])) - return ':'.join(cs) - - def get_message(self) -> Text: - string = self.extra_repr() - for i, cell in enumerate(self._cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self._cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(candidates={_candidate_Cs}, num={_max_num_Cs}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def forward(self, inputs): - feature = inputs - - log_probs = [] - for i, cell in enumerate(self._cells): - feature = cell(feature) - # apply different searching algorithms - idx = max(0, i-1) - if self._warmup_ratio is not None: - if random.random() < self._warmup_ratio: - mask = self._masks[-1] - else: - mask = self._masks[random.randint(0, len(self._masks)-1)] - feature = feature * mask.view(1, -1, 1, 1) - elif self._algo == 'mask_gumbel': - weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1) - mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) - feature = feature * mask - elif self._algo == 'tas': - selected_cs, selected_probs = select2withP(self._arch_parameters[idx:idx+1], self.tau, num=2) + def show_alphas(self): with torch.no_grad(): - i1, i2 = selected_cs.cpu().view(-1).tolist() - c1, c2 = self._candidate_Cs[i1], self._candidate_Cs[i2] - out_channel = max(c1, c2) - out1 = ChannelWiseInter(feature[:, :c1], out_channel) - out2 = ChannelWiseInter(feature[:, :c2], out_channel) - out = out1 * selected_probs[0, 0] + out2 * selected_probs[0, 1] - if feature.shape[1] == out.shape[1]: - feature = out - else: - miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device) - feature = torch.cat((out, miss), dim=1) - elif self._algo == 'mask_rl': - prob = nn.functional.softmax(self._arch_parameters[idx:idx+1], dim=-1) - dist = torch.distributions.Categorical(prob) - action = dist.sample() - log_probs.append(dist.log_prob(action)) - mask = self._masks[action.item()].view(1, -1, 1, 1) - feature = feature * mask - else: - raise ValueError('invalid algorithm : {:}'.format(self._algo)) + return "arch-parameters :\n{:}".format( + nn.functional.softmax(self._arch_parameters, dim=-1).cpu() + ) - out = self.lastact(feature) - out = self.global_pooling(out) - out = out.view(out.size(0), -1) - logits = self.classifier(out) + @property + def random(self): + cs = [] + for i in range(self._max_num_Cs): + index = random.randint(0, len(self._candidate_Cs) - 1) + cs.append(str(self._candidate_Cs[index])) + return ":".join(cs) - return out, logits, log_probs + @property + def genotype(self): + cs = [] + for i in range(self._max_num_Cs): + with torch.no_grad(): + index = self._arch_parameters[i].argmax().item() + cs.append(str(self._candidate_Cs[index])) + return ":".join(cs) + + def get_message(self) -> Text: + string = self.extra_repr() + for i, cell in enumerate(self._cells): + string += "\n {:02d}/{:02d} :: {:}".format( + i, len(self._cells), cell.extra_repr() + ) + return string + + def extra_repr(self): + return "{name}(candidates={_candidate_Cs}, num={_max_num_Cs}, N={_num_stage}, L={_num_layer})".format( + name=self.__class__.__name__, **self.__dict__ + ) + + def forward(self, inputs): + feature = inputs + + log_probs = [] + for i, cell in enumerate(self._cells): + feature = cell(feature) + # apply different searching algorithms + idx = max(0, i - 1) + if self._warmup_ratio is not None: + if random.random() < self._warmup_ratio: + mask = self._masks[-1] + else: + mask = self._masks[random.randint(0, len(self._masks) - 1)] + feature = feature * mask.view(1, -1, 1, 1) + elif self._algo == "mask_gumbel": + weights = nn.functional.gumbel_softmax( + self._arch_parameters[idx : idx + 1], tau=self.tau, dim=-1 + ) + mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) + feature = feature * mask + elif self._algo == "tas": + selected_cs, selected_probs = select2withP( + self._arch_parameters[idx : idx + 1], self.tau, num=2 + ) + with torch.no_grad(): + i1, i2 = selected_cs.cpu().view(-1).tolist() + c1, c2 = self._candidate_Cs[i1], self._candidate_Cs[i2] + out_channel = max(c1, c2) + out1 = ChannelWiseInter(feature[:, :c1], out_channel) + out2 = ChannelWiseInter(feature[:, :c2], out_channel) + out = out1 * selected_probs[0, 0] + out2 * selected_probs[0, 1] + if feature.shape[1] == out.shape[1]: + feature = out + else: + miss = torch.zeros( + feature.shape[0], + feature.shape[1] - out.shape[1], + feature.shape[2], + feature.shape[3], + device=feature.device, + ) + feature = torch.cat((out, miss), dim=1) + elif self._algo == "mask_rl": + prob = nn.functional.softmax( + self._arch_parameters[idx : idx + 1], dim=-1 + ) + dist = torch.distributions.Categorical(prob) + action = dist.sample() + log_probs.append(dist.log_prob(action)) + mask = self._masks[action.item()].view(1, -1, 1, 1) + feature = feature * mask + else: + raise ValueError("invalid algorithm : {:}".format(self._algo)) + + out = self.lastact(feature) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits, log_probs diff --git a/lib/models/shape_searchs/test.py b/lib/models/shape_searchs/test.py index 4f77f4c..5bcf46e 100644 --- a/lib/models/shape_searchs/test.py +++ b/lib/models/shape_searchs/test.py @@ -6,15 +6,15 @@ import torch.nn as nn from SoftSelect import ChannelWiseInter -if __name__ == '__main__': +if __name__ == "__main__": - tensors = torch.rand((16, 128, 7, 7)) - - for oc in range(200, 210): - out_v1 = ChannelWiseInter(tensors, oc, 'v1') - out_v2 = ChannelWiseInter(tensors, oc, 'v2') - assert (out_v1 == out_v2).any().item() == 1 - for oc in range(48, 160): - out_v1 = ChannelWiseInter(tensors, oc, 'v1') - out_v2 = ChannelWiseInter(tensors, oc, 'v2') - assert (out_v1 == out_v2).any().item() == 1 + tensors = torch.rand((16, 128, 7, 7)) + + for oc in range(200, 210): + out_v1 = ChannelWiseInter(tensors, oc, "v1") + out_v2 = ChannelWiseInter(tensors, oc, "v2") + assert (out_v1 == out_v2).any().item() == 1 + for oc in range(48, 160): + out_v1 = ChannelWiseInter(tensors, oc, "v1") + out_v2 = ChannelWiseInter(tensors, oc, "v2") + assert (out_v1 == out_v2).any().item() == 1 diff --git a/lib/models/xcore.py b/lib/models/xcore.py index b547554..08c03a6 100644 --- a/lib/models/xcore.py +++ b/lib/models/xcore.py @@ -35,6 +35,22 @@ def get_model(config: Dict[Text, Any], **kwargs): act_cls(), SuperLinear(hidden_dim2, kwargs["output_dim"]), ) + elif model_type == "norm_mlp": + act_cls = super_name2activation[kwargs["act_cls"]] + norm_cls = super_name2norm[kwargs["norm_cls"]] + sub_layers, last_dim = [], kwargs["input_dim"] + for i, hidden_dim in enumerate(kwargs["hidden_dims"]): + sub_layers.extend( + [ + norm_cls(last_dim, elementwise_affine=False), + SuperLinear(last_dim, hidden_dim), + act_cls(), + ] + ) + last_dim = hidden_dim + sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) + model = SuperSequential(*sub_layers) + else: raise TypeError("Unkonwn model type: {:}".format(model_type)) - return model \ No newline at end of file + return model