update CVPR-2019-GDAS re-train NASNet-search-space searched models

This commit is contained in:
D-X-Y
2020-03-06 19:29:07 +11:00
parent 8b6df42f1f
commit 9a83814a46
17 changed files with 278 additions and 21 deletions

View File

@@ -12,6 +12,7 @@ def obtain_basic_args():
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.')
parser.add_argument('--extra_model_path', type=str, default=None, help='The extra model ckp file (help to indicate the searched architecture).')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')

View File

@@ -29,7 +29,8 @@ def convert_param(original_lists):
elif ctype == 'float':
x = float(x)
elif ctype == 'none':
assert x == 'None', 'for none type, the value must be None instead of {:}'.format(x)
if x.lower() != 'none':
raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
x = None
else:
raise TypeError('Does not know this type : {:}'.format(ctype))

View File

@@ -3,6 +3,7 @@
##################################################
from os import path as osp
from typing import List, Text
import torch
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \
@@ -38,6 +39,9 @@ def get_cell_based_tiny_net(config):
genotype = CellStructure.str2structure(config.arch_str)
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
elif config.name == 'infer.nasnet-cifar':
from .cell_infers import NASNetonCIFAR
raise NotImplementedError
else:
raise ValueError('invalid network name : {:}'.format(config.name))
@@ -52,13 +56,12 @@ def get_search_spaces(xtype, name) -> List[Text]:
raise ValueError('invalid search-space type is {:}'.format(xtype))
def get_cifar_models(config):
from .CifarResNet import CifarResNet
from .CifarDenseNet import DenseNet
from .CifarWideResNet import CifarWideResNet
def get_cifar_models(config, extra_path=None):
super_type = getattr(config, 'super_type', 'basic')
if super_type == 'basic':
from .CifarResNet import CifarResNet
from .CifarDenseNet import DenseNet
from .CifarWideResNet import CifarWideResNet
if config.arch == 'resnet':
return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual)
elif config.arch == 'densenet':
@@ -71,6 +74,7 @@ def get_cifar_models(config):
from .shape_infers import InferWidthCifarResNet
from .shape_infers import InferDepthCifarResNet
from .shape_infers import InferCifarResNet
from .cell_infers import NASNetonCIFAR
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
infer_mode = super_type.split('-')[1]
if infer_mode == 'width':
@@ -79,6 +83,16 @@ def get_cifar_models(config):
return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual)
elif infer_mode == 'shape':
return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual)
elif infer_mode == 'nasnet.cifar':
genotype = config.genotype
if extra_path is not None: # reload genotype by extra_path
if not osp.isfile(extra_path): raise ValueError('invalid extra_path : {:}'.format(extra_path))
xdata = torch.load(extra_path)
current_epoch = xdata['epoch']
genotype = xdata['genotypes'][current_epoch-1]
C = config.C if hasattr(config, 'C') else config.ichannel
N = config.N if hasattr(config, 'N') else config.layers
return NASNetonCIFAR(C, N, config.stem_multi, config.class_num, genotype, config.auxiliary)
else:
raise ValueError('invalid infer-mode : {:}'.format(infer_mode))
else:
@@ -111,9 +125,10 @@ def get_imagenet_models(config):
raise ValueError('invalid super-type : {:}'.format(super_type))
def obtain_model(config):
# Try to obtain the network by config.
def obtain_model(config, extra_path=None):
if config.dataset == 'cifar':
return get_cifar_models(config)
return get_cifar_models(config, extra_path)
elif config.dataset == 'imagenet':
return get_imagenet_models(config)
else:
@@ -152,7 +167,6 @@ def obtain_search_model(config):
def load_net_from_checkpoint(checkpoint):
import torch
assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint)
checkpoint = torch.load(checkpoint)
model_config = dict2config(checkpoint['model-config'], None)

View File

@@ -2,3 +2,4 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .tiny_network import TinyNetwork
from .nasnet_cifar import NASNetonCIFAR

View File

@@ -2,6 +2,7 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import OPS
@@ -50,3 +51,70 @@ class InferCell(nn.Module):
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
nodes.append( node_feature )
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetInferCell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
super(NASNetInferCell, self).__init__()
self.reduction = reduction
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
if not reduction:
nodes, concats = genotype['normal'], genotype['normal_concat']
else:
nodes, concats = genotype['reduce'], genotype['reduce_concat']
self._multiplier = len(concats)
self._concats = concats
self._steps = len(nodes)
self._nodes = nodes
self.edges = nn.ModuleDict()
for i, node in enumerate(nodes):
for in_node in node:
name, j = in_node[0], in_node[1]
stride = 2 if reduction and j < 2 else 1
node_str = '{:}<-{:}'.format(i+2, j)
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
# [TODO] to support drop_prob in this function..
def forward(self, s0, s1, unused_drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i, node in enumerate(self._nodes):
clist = []
for in_node in node:
name, j = in_node[0], in_node[1]
node_str = '{:}<-{:}'.format(i+2, j)
op = self.edges[ node_str ]
clist.append( op(states[j]) )
states.append( sum(clist) )
return torch.cat([states[x] for x in self._concats], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x

View File

@@ -0,0 +1,71 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
from copy import deepcopy
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
# The macro structure is based on NASNet
class NASNetonCIFAR(nn.Module):
def __init__(self, C, N, stem_multiplier, num_classes, genotype, auxiliary, affine=True, track_running_stats=True):
super(NASNetonCIFAR, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False
self.auxiliary_index = None
self.auxiliary_head = None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
self.cells.append( cell )
C_prev_prev, C_prev, reduction_prev = C_prev, cell._multiplier*C_curr, reduction
if reduction and C_curr == C*4 and auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
self.auxiliary_index = index
self._Layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.drop_path_prob = -1
def update_drop_path(self, drop_path_prob):
self.drop_path_prob = drop_path_prob
def auxiliary_param(self):
if self.auxiliary_head is None: return []
else: return list( self.auxiliary_head.parameters() )
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
stem_feature, logits_aux = self.stem(inputs), None
cell_results = [stem_feature, stem_feature]
for i, cell in enumerate(self.cells):
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
cell_results.append( cell_feature )
if self.auxiliary_index is not None and i == self.auxiliary_index and self.training:
logits_aux = self.auxiliary_head( cell_results[-1] )
out = self.lastact(cell_results[-1])
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
if logits_aux is None: return out, logits
else: return out, [logits, logits_aux]

View File

@@ -155,7 +155,7 @@ class NASNetSearchCell(nn.Module):
self.edges = nn.ModuleDict()
for i in range(self._steps):
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j)
node_str = '{:}<-{:}'.format(i, j) # indicate the edge from node-(j) to node-(i+2)
stride = 2 if reduction and j < 2 else 1
op = MixedOp(space, C, stride, affine, track_running_stats)
self.edges[ node_str ] = op

View File

@@ -5,8 +5,7 @@ import torch
import torch.nn as nn
from copy import deepcopy
from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell
from .genotypes import Structure
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet

View File

@@ -4,8 +4,7 @@
import torch
import torch.nn as nn
from copy import deepcopy
from .search_cells import NASNetSearchCell as SearchCell
from .genotypes import Structure
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet

View File

@@ -168,5 +168,15 @@ Networks = {'DARTS_V1': DARTS_V1,
'SETN' : SETN,
}
# This function will return a Genotype from a dict.
def build_genotype_from_dict(xdict):
import pdb; pdb.set_trace()
def remove_value(nodes):
return [tuple([(x[0], x[1]) for x in node]) for node in nodes]
genotype = Genotype(
normal=remove_value(xdict['normal']),
normal_concat=xdict['normal_concat'],
reduce=remove_value(xdict['reduce']),
reduce_concat=xdict['reduce_concat'],
connectN=None, connects=None
)
return genotype

View File

@@ -6,12 +6,22 @@
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
##################################################
import torch
import os, torch
def obtain_nas_infer_model(config):
def obtain_nas_infer_model(config, extra_model_path=None):
if config.arch == 'dxys':
from .DXYs import CifarNet, ImageNet, Networks
genotype = Networks[config.genotype]
from .DXYs import build_genotype_from_dict
if config.genotype is None:
if extra_model_path is not None and not os.path.isfile(extra_model_path):
raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path))
xdata = torch.load(extra_model_path)
current_epoch = xdata['epoch']
genotype_dict = xdata['genotypes'][current_epoch-1]
genotype = build_genotype_from_dict(genotype_dict)
else:
genotype = Networks[config.genotype]
if config.dataset == 'cifar':
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
elif config.dataset == 'imagenet':