v2
This commit is contained in:
0
pycls/core/__init__.py
Normal file
0
pycls/core/__init__.py
Normal file
136
pycls/core/benchmark.py
Normal file
136
pycls/core/benchmark.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Benchmarking functions."""
|
||||
|
||||
import pycls.core.logging as logging
|
||||
import pycls.datasets.loader as loader
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
from pycls.core.timer import Timer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_time_eval(model):
|
||||
"""Computes precise model forward test time using dummy data."""
|
||||
# Use eval mode
|
||||
model.eval()
|
||||
# Generate a dummy mini-batch and copy data to GPU
|
||||
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
|
||||
if cfg.TASK == "jig":
|
||||
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
else:
|
||||
inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
# Compute precise forward pass time
|
||||
timer = Timer()
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
for cur_iter in range(total_iter):
|
||||
# Reset the timers after the warmup phase
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
timer.reset()
|
||||
# Forward
|
||||
timer.tic()
|
||||
model(inputs)
|
||||
torch.cuda.synchronize()
|
||||
timer.toc()
|
||||
return timer.average_time
|
||||
|
||||
|
||||
def compute_time_train(model, loss_fun):
|
||||
"""Computes precise model forward + backward time using dummy data."""
|
||||
# Use train mode
|
||||
model.train()
|
||||
# Generate a dummy mini-batch and copy data to GPU
|
||||
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
|
||||
if cfg.TASK == "jig":
|
||||
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
else:
|
||||
inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
if cfg.TASK in ['col', 'seg']:
|
||||
labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False)
|
||||
else:
|
||||
labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
|
||||
# Cache BatchNorm2D running stats
|
||||
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
||||
bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
|
||||
# Compute precise forward backward pass time
|
||||
fw_timer, bw_timer = Timer(), Timer()
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
for cur_iter in range(total_iter):
|
||||
# Reset the timers after the warmup phase
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
fw_timer.reset()
|
||||
bw_timer.reset()
|
||||
# Forward
|
||||
fw_timer.tic()
|
||||
preds = model(inputs)
|
||||
if isinstance(preds, tuple):
|
||||
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
|
||||
preds = preds[0]
|
||||
else:
|
||||
loss = loss_fun(preds, labels)
|
||||
torch.cuda.synchronize()
|
||||
fw_timer.toc()
|
||||
# Backward
|
||||
bw_timer.tic()
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
bw_timer.toc()
|
||||
# Restore BatchNorm2D running stats
|
||||
for bn, (mean, var) in zip(bns, bn_stats):
|
||||
bn.running_mean, bn.running_var = mean, var
|
||||
return fw_timer.average_time, bw_timer.average_time
|
||||
|
||||
|
||||
def compute_time_loader(data_loader):
|
||||
"""Computes loader time."""
|
||||
timer = Timer()
|
||||
loader.shuffle(data_loader, 0)
|
||||
data_loader_iterator = iter(data_loader)
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
total_iter = min(total_iter, len(data_loader))
|
||||
for cur_iter in range(total_iter):
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
timer.reset()
|
||||
timer.tic()
|
||||
next(data_loader_iterator)
|
||||
timer.toc()
|
||||
return timer.average_time
|
||||
|
||||
|
||||
def compute_time_full(model, loss_fun, train_loader, test_loader):
|
||||
"""Times model and data loader."""
|
||||
logger.info("Computing model and loader timings...")
|
||||
# Compute timings
|
||||
test_fw_time = compute_time_eval(model)
|
||||
train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
|
||||
train_fw_bw_time = train_fw_time + train_bw_time
|
||||
train_loader_time = compute_time_loader(train_loader)
|
||||
# Output iter timing
|
||||
iter_times = {
|
||||
"test_fw_time": test_fw_time,
|
||||
"train_fw_time": train_fw_time,
|
||||
"train_bw_time": train_bw_time,
|
||||
"train_fw_bw_time": train_fw_bw_time,
|
||||
"train_loader_time": train_loader_time,
|
||||
}
|
||||
logger.info(logging.dump_log_data(iter_times, "iter_times"))
|
||||
# Output epoch timing
|
||||
epoch_times = {
|
||||
"test_fw_time": test_fw_time * len(test_loader),
|
||||
"train_fw_time": train_fw_time * len(train_loader),
|
||||
"train_bw_time": train_bw_time * len(train_loader),
|
||||
"train_fw_bw_time": train_fw_bw_time * len(train_loader),
|
||||
"train_loader_time": train_loader_time * len(train_loader),
|
||||
}
|
||||
logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
|
||||
# Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
|
||||
overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
|
||||
logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))
|
88
pycls/core/builders.py
Normal file
88
pycls/core/builders.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Model and loss construction functions."""
|
||||
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.anynet import AnyNet
|
||||
from pycls.models.effnet import EffNet
|
||||
from pycls.models.regnet import RegNet
|
||||
from pycls.models.resnet import ResNet
|
||||
from pycls.models.nas.nas import NAS
|
||||
from pycls.models.nas.nas_search import NAS_Search
|
||||
from pycls.models.nas_bench.model_builder import NAS_Bench
|
||||
|
||||
|
||||
class LabelSmoothedCrossEntropyLoss(torch.nn.Module):
|
||||
"""CrossEntropyLoss with label smoothing."""
|
||||
def __init__(self):
|
||||
super(LabelSmoothedCrossEntropyLoss, self).__init__()
|
||||
self.eps = cfg.MODEL.LABEL_SMOOTHING_EPS
|
||||
self.num_classes = cfg.MODEL.NUM_CLASSES
|
||||
|
||||
def forward(self, logits, target):
|
||||
pred = logits.log_softmax(dim=-1)
|
||||
with torch.no_grad():
|
||||
target_dist = torch.ones_like(pred) * self.eps / (self.num_classes - 1)
|
||||
target_dist.scatter_(-1, target.unsqueeze(-1), 1 - self.eps)
|
||||
return (-target_dist * pred).sum(dim=-1).mean()
|
||||
|
||||
|
||||
# Supported models
|
||||
_models = {
|
||||
"anynet": AnyNet,
|
||||
"effnet": EffNet,
|
||||
"resnet": ResNet,
|
||||
"regnet": RegNet,
|
||||
"nas": NAS,
|
||||
"nas_search": NAS_Search,
|
||||
"nas_bench": NAS_Bench,
|
||||
}
|
||||
|
||||
# Supported loss functions
|
||||
_loss_funs = {
|
||||
"cross_entropy": torch.nn.CrossEntropyLoss,
|
||||
"label_smoothed_cross_entropy": LabelSmoothedCrossEntropyLoss,
|
||||
}
|
||||
|
||||
|
||||
def get_model():
|
||||
"""Gets the model class specified in the config."""
|
||||
err_str = "Model type '{}' not supported"
|
||||
assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
|
||||
return _models[cfg.MODEL.TYPE]
|
||||
|
||||
|
||||
def get_loss_fun():
|
||||
"""Gets the loss function class specified in the config."""
|
||||
err_str = "Loss function type '{}' not supported"
|
||||
assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
|
||||
return _loss_funs[cfg.MODEL.LOSS_FUN]
|
||||
|
||||
|
||||
def build_model():
|
||||
"""Builds the model."""
|
||||
return get_model()()
|
||||
|
||||
|
||||
def build_loss_fun():
|
||||
"""Build the loss function."""
|
||||
if cfg.TASK == "seg":
|
||||
return get_loss_fun()(ignore_index=255)
|
||||
else:
|
||||
return get_loss_fun()()
|
||||
|
||||
|
||||
def register_model(name, ctor):
|
||||
"""Registers a model dynamically."""
|
||||
_models[name] = ctor
|
||||
|
||||
|
||||
def register_loss_fun(name, ctor):
|
||||
"""Registers a loss function dynamically."""
|
||||
_loss_funs[name] = ctor
|
98
pycls/core/checkpoint.py
Normal file
98
pycls/core/checkpoint.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Functions that handle saving and loading of checkpoints."""
|
||||
|
||||
import os
|
||||
|
||||
import pycls.core.distributed as dist
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
# Common prefix for checkpoint file names
|
||||
_NAME_PREFIX = "model_epoch_"
|
||||
# Checkpoints directory name
|
||||
_DIR_NAME = "checkpoints"
|
||||
|
||||
|
||||
def get_checkpoint_dir():
|
||||
"""Retrieves the location for storing checkpoints."""
|
||||
return os.path.join(cfg.OUT_DIR, _DIR_NAME)
|
||||
|
||||
|
||||
def get_checkpoint(epoch):
|
||||
"""Retrieves the path to a checkpoint file."""
|
||||
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
|
||||
return os.path.join(get_checkpoint_dir(), name)
|
||||
|
||||
|
||||
def get_last_checkpoint():
|
||||
"""Retrieves the most recent checkpoint (highest epoch number)."""
|
||||
checkpoint_dir = get_checkpoint_dir()
|
||||
# Checkpoint file names are in lexicographic order
|
||||
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
|
||||
last_checkpoint_name = sorted(checkpoints)[-1]
|
||||
return os.path.join(checkpoint_dir, last_checkpoint_name)
|
||||
|
||||
|
||||
def has_checkpoint():
|
||||
"""Determines if there are checkpoints available."""
|
||||
checkpoint_dir = get_checkpoint_dir()
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
return False
|
||||
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch):
|
||||
"""Saves a checkpoint."""
|
||||
# Save checkpoints only from the master process
|
||||
if not dist.is_master_proc():
|
||||
return
|
||||
# Ensure that the checkpoint dir exists
|
||||
os.makedirs(get_checkpoint_dir(), exist_ok=True)
|
||||
# Omit the DDP wrapper in the multi-gpu setting
|
||||
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
|
||||
# Record the state
|
||||
if isinstance(optimizer, list):
|
||||
checkpoint = {
|
||||
"epoch": epoch,
|
||||
"model_state": sd,
|
||||
"optimizer_w_state": optimizer[0].state_dict(),
|
||||
"optimizer_a_state": optimizer[1].state_dict(),
|
||||
"cfg": cfg.dump(),
|
||||
}
|
||||
else:
|
||||
checkpoint = {
|
||||
"epoch": epoch,
|
||||
"model_state": sd,
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
"cfg": cfg.dump(),
|
||||
}
|
||||
# Write the checkpoint
|
||||
checkpoint_file = get_checkpoint(epoch + 1)
|
||||
torch.save(checkpoint, checkpoint_file)
|
||||
return checkpoint_file
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_file, model, optimizer=None):
|
||||
"""Loads the checkpoint from the given file."""
|
||||
err_str = "Checkpoint '{}' not found"
|
||||
assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
|
||||
# Load the checkpoint on CPU to avoid GPU mem spike
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
||||
# Account for the DDP wrapper in the multi-gpu setting
|
||||
ms = model.module if cfg.NUM_GPUS > 1 else model
|
||||
ms.load_state_dict(checkpoint["model_state"])
|
||||
# Load the optimizer state (commonly not done when fine-tuning)
|
||||
if optimizer:
|
||||
if isinstance(optimizer, list):
|
||||
optimizer[0].load_state_dict(checkpoint["optimizer_w_state"])
|
||||
optimizer[1].load_state_dict(checkpoint["optimizer_a_state"])
|
||||
else:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
return checkpoint["epoch"]
|
500
pycls/core/config.py
Normal file
500
pycls/core/config.py
Normal file
@@ -0,0 +1,500 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Configuration file (powered by YACS)."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pycls.core.io import cache_url
|
||||
from yacs.config import CfgNode as CfgNode
|
||||
|
||||
|
||||
# Global config object
|
||||
_C = CfgNode()
|
||||
|
||||
# Example usage:
|
||||
# from core.config import cfg
|
||||
cfg = _C
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Model options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.MODEL = CfgNode()
|
||||
|
||||
# Model type
|
||||
_C.MODEL.TYPE = ""
|
||||
|
||||
# Number of weight layers
|
||||
_C.MODEL.DEPTH = 0
|
||||
|
||||
# Number of input channels
|
||||
_C.MODEL.INPUT_CHANNELS = 3
|
||||
|
||||
# Number of classes
|
||||
_C.MODEL.NUM_CLASSES = 10
|
||||
|
||||
# Loss function (see pycls/core/builders.py for options)
|
||||
_C.MODEL.LOSS_FUN = "cross_entropy"
|
||||
|
||||
# Label smoothing eps
|
||||
_C.MODEL.LABEL_SMOOTHING_EPS = 0.0
|
||||
|
||||
# ASPP channels
|
||||
_C.MODEL.ASPP_CHANNELS = 256
|
||||
|
||||
# ASPP dilation rates
|
||||
_C.MODEL.ASPP_RATES = [6, 12, 18]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# ResNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.RESNET = CfgNode()
|
||||
|
||||
# Transformation function (see pycls/models/resnet.py for options)
|
||||
_C.RESNET.TRANS_FUN = "basic_transform"
|
||||
|
||||
# Number of groups to use (1 -> ResNet; > 1 -> ResNeXt)
|
||||
_C.RESNET.NUM_GROUPS = 1
|
||||
|
||||
# Width of each group (64 -> ResNet; 4 -> ResNeXt)
|
||||
_C.RESNET.WIDTH_PER_GROUP = 64
|
||||
|
||||
# Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch)
|
||||
_C.RESNET.STRIDE_1X1 = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# AnyNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.ANYNET = CfgNode()
|
||||
|
||||
# Stem type
|
||||
_C.ANYNET.STEM_TYPE = "simple_stem_in"
|
||||
|
||||
# Stem width
|
||||
_C.ANYNET.STEM_W = 32
|
||||
|
||||
# Block type
|
||||
_C.ANYNET.BLOCK_TYPE = "res_bottleneck_block"
|
||||
|
||||
# Depth for each stage (number of blocks in the stage)
|
||||
_C.ANYNET.DEPTHS = []
|
||||
|
||||
# Width for each stage (width of each block in the stage)
|
||||
_C.ANYNET.WIDTHS = []
|
||||
|
||||
# Strides for each stage (applies to the first block of each stage)
|
||||
_C.ANYNET.STRIDES = []
|
||||
|
||||
# Bottleneck multipliers for each stage (applies to bottleneck block)
|
||||
_C.ANYNET.BOT_MULS = []
|
||||
|
||||
# Group widths for each stage (applies to bottleneck block)
|
||||
_C.ANYNET.GROUP_WS = []
|
||||
|
||||
# Whether SE is enabled for res_bottleneck_block
|
||||
_C.ANYNET.SE_ON = False
|
||||
|
||||
# SE ratio
|
||||
_C.ANYNET.SE_R = 0.25
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# RegNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.REGNET = CfgNode()
|
||||
|
||||
# Stem type
|
||||
_C.REGNET.STEM_TYPE = "simple_stem_in"
|
||||
|
||||
# Stem width
|
||||
_C.REGNET.STEM_W = 32
|
||||
|
||||
# Block type
|
||||
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
|
||||
|
||||
# Stride of each stage
|
||||
_C.REGNET.STRIDE = 2
|
||||
|
||||
# Squeeze-and-Excitation (RegNetY)
|
||||
_C.REGNET.SE_ON = False
|
||||
_C.REGNET.SE_R = 0.25
|
||||
|
||||
# Depth
|
||||
_C.REGNET.DEPTH = 10
|
||||
|
||||
# Initial width
|
||||
_C.REGNET.W0 = 32
|
||||
|
||||
# Slope
|
||||
_C.REGNET.WA = 5.0
|
||||
|
||||
# Quantization
|
||||
_C.REGNET.WM = 2.5
|
||||
|
||||
# Group width
|
||||
_C.REGNET.GROUP_W = 16
|
||||
|
||||
# Bottleneck multiplier (bm = 1 / b from the paper)
|
||||
_C.REGNET.BOT_MUL = 1.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# EfficientNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.EN = CfgNode()
|
||||
|
||||
# Stem width
|
||||
_C.EN.STEM_W = 32
|
||||
|
||||
# Depth for each stage (number of blocks in the stage)
|
||||
_C.EN.DEPTHS = []
|
||||
|
||||
# Width for each stage (width of each block in the stage)
|
||||
_C.EN.WIDTHS = []
|
||||
|
||||
# Expansion ratios for MBConv blocks in each stage
|
||||
_C.EN.EXP_RATIOS = []
|
||||
|
||||
# Squeeze-and-Excitation (SE) ratio
|
||||
_C.EN.SE_R = 0.25
|
||||
|
||||
# Strides for each stage (applies to the first block of each stage)
|
||||
_C.EN.STRIDES = []
|
||||
|
||||
# Kernel sizes for each stage
|
||||
_C.EN.KERNELS = []
|
||||
|
||||
# Head width
|
||||
_C.EN.HEAD_W = 1280
|
||||
|
||||
# Drop connect ratio
|
||||
_C.EN.DC_RATIO = 0.0
|
||||
|
||||
# Dropout ratio
|
||||
_C.EN.DROPOUT_RATIO = 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# NAS options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.NAS = CfgNode()
|
||||
|
||||
# Cell genotype
|
||||
_C.NAS.GENOTYPE = 'nas'
|
||||
|
||||
# Custom genotype
|
||||
_C.NAS.CUSTOM_GENOTYPE = []
|
||||
|
||||
# Base NAS width
|
||||
_C.NAS.WIDTH = 16
|
||||
|
||||
# Total number of cells
|
||||
_C.NAS.DEPTH = 20
|
||||
|
||||
# Auxiliary heads
|
||||
_C.NAS.AUX = False
|
||||
|
||||
# Weight for auxiliary heads
|
||||
_C.NAS.AUX_WEIGHT = 0.4
|
||||
|
||||
# Drop path probability
|
||||
_C.NAS.DROP_PROB = 0.0
|
||||
|
||||
# Matrix in NAS Bench
|
||||
_C.NAS.MATRIX = []
|
||||
|
||||
# Operations in NAS Bench
|
||||
_C.NAS.OPS = []
|
||||
|
||||
# Number of stacks in NAS Bench
|
||||
_C.NAS.NUM_STACKS = 3
|
||||
|
||||
# Number of modules per stack in NAS Bench
|
||||
_C.NAS.NUM_MODULES_PER_STACK = 3
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Batch norm options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.BN = CfgNode()
|
||||
|
||||
# BN epsilon
|
||||
_C.BN.EPS = 1e-5
|
||||
|
||||
# BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
|
||||
_C.BN.MOM = 0.1
|
||||
|
||||
# Precise BN stats
|
||||
_C.BN.USE_PRECISE_STATS = False
|
||||
_C.BN.NUM_SAMPLES_PRECISE = 1024
|
||||
|
||||
# Initialize the gamma of the final BN of each block to zero
|
||||
_C.BN.ZERO_INIT_FINAL_GAMMA = False
|
||||
|
||||
# Use a different weight decay for BN layers
|
||||
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
|
||||
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Optimizer options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.OPTIM = CfgNode()
|
||||
|
||||
# Base learning rate
|
||||
_C.OPTIM.BASE_LR = 0.1
|
||||
|
||||
# Learning rate policy select from {'cos', 'exp', 'steps'}
|
||||
_C.OPTIM.LR_POLICY = "cos"
|
||||
|
||||
# Exponential decay factor
|
||||
_C.OPTIM.GAMMA = 0.1
|
||||
|
||||
# Steps for 'steps' policy (in epochs)
|
||||
_C.OPTIM.STEPS = []
|
||||
|
||||
# Learning rate multiplier for 'steps' policy
|
||||
_C.OPTIM.LR_MULT = 0.1
|
||||
|
||||
# Maximal number of epochs
|
||||
_C.OPTIM.MAX_EPOCH = 200
|
||||
|
||||
# Momentum
|
||||
_C.OPTIM.MOMENTUM = 0.9
|
||||
|
||||
# Momentum dampening
|
||||
_C.OPTIM.DAMPENING = 0.0
|
||||
|
||||
# Nesterov momentum
|
||||
_C.OPTIM.NESTEROV = True
|
||||
|
||||
# L2 regularization
|
||||
_C.OPTIM.WEIGHT_DECAY = 5e-4
|
||||
|
||||
# Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR
|
||||
_C.OPTIM.WARMUP_FACTOR = 0.1
|
||||
|
||||
# Gradually warm up the OPTIM.BASE_LR over this number of epochs
|
||||
_C.OPTIM.WARMUP_EPOCHS = 0
|
||||
|
||||
# Update the learning rate per iter
|
||||
_C.OPTIM.ITER_LR = False
|
||||
|
||||
# Base learning rate for arch
|
||||
_C.OPTIM.ARCH_BASE_LR = 0.0003
|
||||
|
||||
# L2 regularization for arch
|
||||
_C.OPTIM.ARCH_WEIGHT_DECAY = 0.001
|
||||
|
||||
# Optimizer for arch
|
||||
_C.OPTIM.ARCH_OPTIM = 'adam'
|
||||
|
||||
# Epoch to start optimizing arch
|
||||
_C.OPTIM.ARCH_EPOCH = 0.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Training options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.TRAIN = CfgNode()
|
||||
|
||||
# Dataset and split
|
||||
_C.TRAIN.DATASET = ""
|
||||
_C.TRAIN.SPLIT = "train"
|
||||
|
||||
# Total mini-batch size
|
||||
_C.TRAIN.BATCH_SIZE = 128
|
||||
|
||||
# Image size
|
||||
_C.TRAIN.IM_SIZE = 224
|
||||
|
||||
# Evaluate model on test data every eval period epochs
|
||||
_C.TRAIN.EVAL_PERIOD = 1
|
||||
|
||||
# Save model checkpoint every checkpoint period epochs
|
||||
_C.TRAIN.CHECKPOINT_PERIOD = 1
|
||||
|
||||
# Resume training from the latest checkpoint in the output directory
|
||||
_C.TRAIN.AUTO_RESUME = True
|
||||
|
||||
# Weights to start training from
|
||||
_C.TRAIN.WEIGHTS = ""
|
||||
|
||||
# Percentage of gray images in jig
|
||||
_C.TRAIN.GRAY_PERCENTAGE = 0.0
|
||||
|
||||
# Portion to create trainA/trainB split
|
||||
_C.TRAIN.PORTION = 1.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Testing options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.TEST = CfgNode()
|
||||
|
||||
# Dataset and split
|
||||
_C.TEST.DATASET = ""
|
||||
_C.TEST.SPLIT = "val"
|
||||
|
||||
# Total mini-batch size
|
||||
_C.TEST.BATCH_SIZE = 200
|
||||
|
||||
# Image size
|
||||
_C.TEST.IM_SIZE = 256
|
||||
|
||||
# Weights to use for testing
|
||||
_C.TEST.WEIGHTS = ""
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Common train/test data loader options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.DATA_LOADER = CfgNode()
|
||||
|
||||
# Number of data loader workers per process
|
||||
_C.DATA_LOADER.NUM_WORKERS = 8
|
||||
|
||||
# Load data to pinned host memory
|
||||
_C.DATA_LOADER.PIN_MEMORY = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Memory options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.MEM = CfgNode()
|
||||
|
||||
# Perform ReLU inplace
|
||||
_C.MEM.RELU_INPLACE = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# CUDNN options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.CUDNN = CfgNode()
|
||||
|
||||
# Perform benchmarking to select the fastest CUDNN algorithms to use
|
||||
# Note that this may increase the memory usage and will likely not result
|
||||
# in overall speedups when variable size inputs are used (e.g. COCO training)
|
||||
_C.CUDNN.BENCHMARK = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Precise timing options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.PREC_TIME = CfgNode()
|
||||
|
||||
# Number of iterations to warm up the caches
|
||||
_C.PREC_TIME.WARMUP_ITER = 3
|
||||
|
||||
# Number of iterations to compute avg time
|
||||
_C.PREC_TIME.NUM_ITER = 30
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Misc options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
|
||||
# Number of GPUs to use (applies to both training and testing)
|
||||
_C.NUM_GPUS = 1
|
||||
|
||||
# Task (cls, seg, rot, col, jig)
|
||||
_C.TASK = "cls"
|
||||
|
||||
# Grid in Jigsaw (2, 3); no effect if TASK is not jig
|
||||
_C.JIGSAW_GRID = 3
|
||||
|
||||
# Output directory
|
||||
_C.OUT_DIR = "/tmp"
|
||||
|
||||
# Config destination (in OUT_DIR)
|
||||
_C.CFG_DEST = "config.yaml"
|
||||
|
||||
# Note that non-determinism may still be present due to non-deterministic
|
||||
# operator implementations in GPU operator libraries
|
||||
_C.RNG_SEED = 1
|
||||
|
||||
# Log destination ('stdout' or 'file')
|
||||
_C.LOG_DEST = "stdout"
|
||||
|
||||
# Log period in iters
|
||||
_C.LOG_PERIOD = 10
|
||||
|
||||
# Distributed backend
|
||||
_C.DIST_BACKEND = "nccl"
|
||||
|
||||
# Hostname and port for initializing multi-process groups
|
||||
_C.HOST = "localhost"
|
||||
_C.PORT = 10001
|
||||
|
||||
# Models weights referred to by URL are downloaded to this local cache
|
||||
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Deprecated keys
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
|
||||
_C.register_deprecated_key("PREC_TIME.BATCH_SIZE")
|
||||
_C.register_deprecated_key("PREC_TIME.ENABLED")
|
||||
|
||||
|
||||
def assert_and_infer_cfg(cache_urls=True):
|
||||
"""Checks config values invariants."""
|
||||
err_str = "The first lr step must start at 0"
|
||||
assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str
|
||||
data_splits = ["train", "val", "test"]
|
||||
err_str = "Data split '{}' not supported"
|
||||
assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT)
|
||||
assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT)
|
||||
err_str = "Mini-batch size should be a multiple of NUM_GPUS."
|
||||
assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
|
||||
assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
|
||||
err_str = "Precise BN stats computation not verified for > 1 GPU"
|
||||
assert not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1, err_str
|
||||
err_str = "Log destination '{}' not supported"
|
||||
assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
|
||||
if cache_urls:
|
||||
cache_cfg_urls()
|
||||
|
||||
|
||||
def cache_cfg_urls():
|
||||
"""Download URLs in config, cache them, and rewrite cfg to use cached file."""
|
||||
_C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
|
||||
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)
|
||||
|
||||
|
||||
def dump_cfg():
|
||||
"""Dumps the config to the output directory."""
|
||||
cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
|
||||
with open(cfg_file, "w") as f:
|
||||
_C.dump(stream=f)
|
||||
|
||||
|
||||
def load_cfg(out_dir, cfg_dest="config.yaml"):
|
||||
"""Loads config from specified output directory."""
|
||||
cfg_file = os.path.join(out_dir, cfg_dest)
|
||||
_C.merge_from_file(cfg_file)
|
||||
|
||||
|
||||
def load_cfg_fom_args(description="Config file options."):
|
||||
"""Load config from command line arguments and set any specified options."""
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
help_s = "Config file location"
|
||||
parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str)
|
||||
help_s = "See pycls/core/config.py for all options"
|
||||
parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
|
||||
if len(sys.argv) == 1:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
args = parser.parse_args()
|
||||
_C.merge_from_file(args.cfg_file)
|
||||
_C.merge_from_list(args.opts)
|
157
pycls/core/distributed.py
Normal file
157
pycls/core/distributed.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Distributed helpers."""
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def is_master_proc():
|
||||
"""Determines if the current process is the master process.
|
||||
|
||||
Master process is responsible for logging, writing and loading checkpoints. In
|
||||
the multi GPU setting, we assign the master role to the rank 0 process. When
|
||||
training using a single GPU, there is a single process which is considered master.
|
||||
"""
|
||||
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
|
||||
|
||||
|
||||
def init_process_group(proc_rank, world_size):
|
||||
"""Initializes the default process group."""
|
||||
# Set the GPU to use
|
||||
torch.cuda.set_device(proc_rank)
|
||||
# Initialize the process group
|
||||
torch.distributed.init_process_group(
|
||||
backend=cfg.DIST_BACKEND,
|
||||
init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT),
|
||||
world_size=world_size,
|
||||
rank=proc_rank,
|
||||
)
|
||||
|
||||
|
||||
def destroy_process_group():
|
||||
"""Destroys the default process group."""
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def scaled_all_reduce(tensors):
|
||||
"""Performs the scaled all_reduce operation on the provided tensors.
|
||||
|
||||
The input tensors are modified in-place. Currently supports only the sum
|
||||
reduction operator. The reduced values are scaled by the inverse size of the
|
||||
process group (equivalent to cfg.NUM_GPUS).
|
||||
"""
|
||||
# There is no need for reduction in the single-proc case
|
||||
if cfg.NUM_GPUS == 1:
|
||||
return tensors
|
||||
# Queue the reductions
|
||||
reductions = []
|
||||
for tensor in tensors:
|
||||
reduction = torch.distributed.all_reduce(tensor, async_op=True)
|
||||
reductions.append(reduction)
|
||||
# Wait for reductions to finish
|
||||
for reduction in reductions:
|
||||
reduction.wait()
|
||||
# Scale the results
|
||||
for tensor in tensors:
|
||||
tensor.mul_(1.0 / cfg.NUM_GPUS)
|
||||
return tensors
|
||||
|
||||
|
||||
class ChildException(Exception):
|
||||
"""Wraps an exception from a child process."""
|
||||
|
||||
def __init__(self, child_trace):
|
||||
super(ChildException, self).__init__(child_trace)
|
||||
|
||||
|
||||
class ErrorHandler(object):
|
||||
"""Multiprocessing error handler (based on fairseq's).
|
||||
|
||||
Listens for errors in child processes and propagates the tracebacks to the parent.
|
||||
"""
|
||||
|
||||
def __init__(self, error_queue):
|
||||
# Shared error queue
|
||||
self.error_queue = error_queue
|
||||
# Children processes sharing the error queue
|
||||
self.children_pids = []
|
||||
# Start a thread listening to errors
|
||||
self.error_listener = threading.Thread(target=self.listen, daemon=True)
|
||||
self.error_listener.start()
|
||||
# Register the signal handler
|
||||
signal.signal(signal.SIGUSR1, self.signal_handler)
|
||||
|
||||
def add_child(self, pid):
|
||||
"""Registers a child process."""
|
||||
self.children_pids.append(pid)
|
||||
|
||||
def listen(self):
|
||||
"""Listens for errors in the error queue."""
|
||||
# Wait until there is an error in the queue
|
||||
child_trace = self.error_queue.get()
|
||||
# Put the error back for the signal handler
|
||||
self.error_queue.put(child_trace)
|
||||
# Invoke the signal handler
|
||||
os.kill(os.getpid(), signal.SIGUSR1)
|
||||
|
||||
def signal_handler(self, _sig_num, _stack_frame):
|
||||
"""Signal handler."""
|
||||
# Kill children processes
|
||||
for pid in self.children_pids:
|
||||
os.kill(pid, signal.SIGINT)
|
||||
# Propagate the error from the child process
|
||||
raise ChildException(self.error_queue.get())
|
||||
|
||||
|
||||
def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
|
||||
"""Runs a function from a child process."""
|
||||
try:
|
||||
# Initialize the process group
|
||||
init_process_group(proc_rank, world_size)
|
||||
# Run the function
|
||||
fun(*fun_args, **fun_kwargs)
|
||||
except KeyboardInterrupt:
|
||||
# Killed by the parent process
|
||||
pass
|
||||
except Exception:
|
||||
# Propagate exception to the parent process
|
||||
error_queue.put(traceback.format_exc())
|
||||
finally:
|
||||
# Destroy the process group
|
||||
destroy_process_group()
|
||||
|
||||
|
||||
def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
|
||||
"""Runs a function in a multi-proc setting (unless num_proc == 1)."""
|
||||
# There is no need for multi-proc in the single-proc case
|
||||
fun_kwargs = fun_kwargs if fun_kwargs else {}
|
||||
if num_proc == 1:
|
||||
fun(*fun_args, **fun_kwargs)
|
||||
return
|
||||
# Handle errors from training subprocesses
|
||||
error_queue = multiprocessing.SimpleQueue()
|
||||
error_handler = ErrorHandler(error_queue)
|
||||
# Run each training subprocess
|
||||
ps = []
|
||||
for i in range(num_proc):
|
||||
p_i = multiprocessing.Process(
|
||||
target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs)
|
||||
)
|
||||
ps.append(p_i)
|
||||
p_i.start()
|
||||
error_handler.add_child(p_i.pid)
|
||||
# Wait for each subprocess to finish
|
||||
for p in ps:
|
||||
p.join()
|
77
pycls/core/io.py
Normal file
77
pycls/core/io.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""IO utilities (adapted from Detectron)"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from urllib import request as urlrequest
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
|
||||
|
||||
|
||||
def cache_url(url_or_file, cache_dir):
|
||||
"""Download the file specified by the URL to the cache_dir and return the path to
|
||||
the cached file. If the argument is not a URL, simply return it as is.
|
||||
"""
|
||||
is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
|
||||
if not is_url:
|
||||
return url_or_file
|
||||
url = url_or_file
|
||||
err_str = "pycls only automatically caches URLs in the pycls S3 bucket: {}"
|
||||
assert url.startswith(_PYCLS_BASE_URL), err_str.format(_PYCLS_BASE_URL)
|
||||
cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir)
|
||||
if os.path.exists(cache_file_path):
|
||||
return cache_file_path
|
||||
cache_file_dir = os.path.dirname(cache_file_path)
|
||||
if not os.path.exists(cache_file_dir):
|
||||
os.makedirs(cache_file_dir)
|
||||
logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
|
||||
download_url(url, cache_file_path)
|
||||
return cache_file_path
|
||||
|
||||
|
||||
def _progress_bar(count, total):
|
||||
"""Report download progress. Credit:
|
||||
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
|
||||
"""
|
||||
bar_len = 60
|
||||
filled_len = int(round(bar_len * count / float(total)))
|
||||
percents = round(100.0 * count / float(total), 1)
|
||||
bar = "=" * filled_len + "-" * (bar_len - filled_len)
|
||||
sys.stdout.write(
|
||||
" [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
if count >= total:
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
|
||||
"""Download url and write it to dst_file_path. Credit:
|
||||
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
|
||||
"""
|
||||
req = urlrequest.Request(url)
|
||||
response = urlrequest.urlopen(req)
|
||||
total_size = response.info().get("Content-Length").strip()
|
||||
total_size = int(total_size)
|
||||
bytes_so_far = 0
|
||||
with open(dst_file_path, "wb") as f:
|
||||
while 1:
|
||||
chunk = response.read(chunk_size)
|
||||
bytes_so_far += len(chunk)
|
||||
if not chunk:
|
||||
break
|
||||
if progress_hook:
|
||||
progress_hook(bytes_so_far, total_size)
|
||||
f.write(chunk)
|
||||
return bytes_so_far
|
138
pycls/core/logging.py
Normal file
138
pycls/core/logging.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Logging."""
|
||||
|
||||
import builtins
|
||||
import decimal
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pycls.core.distributed as dist
|
||||
import simplejson
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
# Show filename and line number in logs
|
||||
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
|
||||
|
||||
# Log file name (for cfg.LOG_DEST = 'file')
|
||||
_LOG_FILE = "stdout.log"
|
||||
|
||||
# Data output with dump_log_data(data, data_type) will be tagged w/ this
|
||||
_TAG = "json_stats: "
|
||||
|
||||
# Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
|
||||
_TYPE = "_type"
|
||||
|
||||
|
||||
def _suppress_print():
|
||||
"""Suppresses printing from the current process."""
|
||||
|
||||
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
|
||||
pass
|
||||
|
||||
builtins.print = ignore
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""Sets up the logging."""
|
||||
# Enable logging only for the master process
|
||||
if dist.is_master_proc():
|
||||
# Clear the root logger to prevent any existing logging config
|
||||
# (e.g. set by another module) from messing with our setup
|
||||
logging.root.handlers = []
|
||||
# Construct logging configuration
|
||||
logging_config = {"level": logging.INFO, "format": _FORMAT}
|
||||
# Log either to stdout or to a file
|
||||
if cfg.LOG_DEST == "stdout":
|
||||
logging_config["stream"] = sys.stdout
|
||||
else:
|
||||
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
|
||||
# Configure logging
|
||||
logging.basicConfig(**logging_config)
|
||||
else:
|
||||
_suppress_print()
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
"""Retrieves the logger."""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def dump_log_data(data, data_type, prec=4):
|
||||
"""Covert data (a dictionary) into tagged json string for logging."""
|
||||
data[_TYPE] = data_type
|
||||
data = float_to_decimal(data, prec)
|
||||
data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
|
||||
return "{:s}{:s}".format(_TAG, data_json)
|
||||
|
||||
|
||||
def float_to_decimal(data, prec=4):
|
||||
"""Convert floats to decimals which allows for fixed width json."""
|
||||
if isinstance(data, dict):
|
||||
return {k: float_to_decimal(v, prec) for k, v in data.items()}
|
||||
if isinstance(data, float):
|
||||
return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
|
||||
"""Get all log files in directory containing subdirs of trained models."""
|
||||
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
|
||||
files = [os.path.join(log_dir, n, log_file) for n in names]
|
||||
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
|
||||
files, names = zip(*f_n_ps) if f_n_ps else ([], [])
|
||||
return files, names
|
||||
|
||||
|
||||
def load_log_data(log_file, data_types_to_skip=()):
|
||||
"""Loads log data into a dictionary of the form data[data_type][metric][index]."""
|
||||
# Load log_file
|
||||
assert os.path.exists(log_file), "Log file not found: {}".format(log_file)
|
||||
with open(log_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
# Extract and parse lines that start with _TAG and have a type specified
|
||||
lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
|
||||
lines = [simplejson.loads(l) for l in lines]
|
||||
lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip]
|
||||
# Generate data structure accessed by data[data_type][index][metric]
|
||||
data_types = [l[_TYPE] for l in lines]
|
||||
data = {t: [] for t in data_types}
|
||||
for t, line in zip(data_types, lines):
|
||||
del line[_TYPE]
|
||||
data[t].append(line)
|
||||
# Generate data structure accessed by data[data_type][metric][index]
|
||||
for t in data:
|
||||
metrics = sorted(data[t][0].keys())
|
||||
err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics)
|
||||
assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
|
||||
data[t] = {m: [d[m] for d in data[t]] for m in metrics}
|
||||
return data
|
||||
|
||||
|
||||
def sort_log_data(data):
|
||||
"""Sort each data[data_type][metric] by epoch or keep only first instance."""
|
||||
for t in data:
|
||||
if "epoch" in data[t]:
|
||||
assert "epoch_ind" not in data[t] and "epoch_max" not in data[t]
|
||||
data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]]
|
||||
data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]]
|
||||
epoch = data[t]["epoch_ind"]
|
||||
if "iter" in data[t]:
|
||||
assert "iter_ind" not in data[t] and "iter_max" not in data[t]
|
||||
data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]]
|
||||
data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]]
|
||||
itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"])
|
||||
epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr]
|
||||
for m in data[t]:
|
||||
data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))]
|
||||
else:
|
||||
data[t] = {m: d[0] for m, d in data[t].items()}
|
||||
return data
|
435
pycls/core/meters.py
Normal file
435
pycls/core/meters.py
Normal file
@@ -0,0 +1,435 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Meters."""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import pycls.core.logging as logging
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
from pycls.core.timer import Timer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def time_string(seconds):
|
||||
"""Converts time in seconds to a fixed-width string format."""
|
||||
days, rem = divmod(int(seconds), 24 * 3600)
|
||||
hrs, rem = divmod(rem, 3600)
|
||||
mins, secs = divmod(rem, 60)
|
||||
return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs)
|
||||
|
||||
|
||||
def inter_union(preds, labels, num_classes):
|
||||
_, preds = torch.max(preds, 1)
|
||||
preds = preds.type(torch.uint8) + 1
|
||||
labels = labels.type(torch.uint8) + 1
|
||||
preds = preds * (labels > 0).type(torch.uint8)
|
||||
|
||||
inter = preds * (preds == labels).type(torch.uint8)
|
||||
area_inter = torch.histc(inter.type(torch.int64), bins=num_classes, min=1, max=num_classes)
|
||||
area_preds = torch.histc(preds.type(torch.int64), bins=num_classes, min=1, max=num_classes)
|
||||
area_labels = torch.histc(labels.type(torch.int64), bins=num_classes, min=1, max=num_classes)
|
||||
area_union = area_preds + area_labels - area_inter
|
||||
|
||||
return [area_inter.type(torch.float64) / labels.size(0), area_union.type(torch.float64) / labels.size(0)]
|
||||
|
||||
|
||||
def topk_errors(preds, labels, ks):
|
||||
"""Computes the top-k error for each k."""
|
||||
err_str = "Batch dim of predictions and labels must match"
|
||||
assert preds.size(0) == labels.size(0), err_str
|
||||
# Find the top max_k predictions for each sample
|
||||
_top_max_k_vals, top_max_k_inds = torch.topk(
|
||||
preds, max(ks), dim=1, largest=True, sorted=True
|
||||
)
|
||||
# (batch_size, max_k) -> (max_k, batch_size)
|
||||
top_max_k_inds = top_max_k_inds.t()
|
||||
# (batch_size, ) -> (max_k, batch_size)
|
||||
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
|
||||
# (i, j) = 1 if top i-th prediction for the j-th sample is correct
|
||||
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
|
||||
# Compute the number of topk correct predictions for each k
|
||||
topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks]
|
||||
return [(1.0 - x / preds.size(0)) * 100.0 for x in topks_correct]
|
||||
|
||||
|
||||
def gpu_mem_usage():
|
||||
"""Computes the GPU memory usage for the current device (MB)."""
|
||||
mem_usage_bytes = torch.cuda.max_memory_allocated()
|
||||
return mem_usage_bytes / 1024 / 1024
|
||||
|
||||
|
||||
class ScalarMeter(object):
|
||||
"""Measures a scalar value (adapted from Detectron)."""
|
||||
|
||||
def __init__(self, window_size):
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.deque.clear()
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
|
||||
def add_value(self, value):
|
||||
self.deque.append(value)
|
||||
self.count += 1
|
||||
self.total += value
|
||||
|
||||
def get_win_median(self):
|
||||
return np.median(self.deque)
|
||||
|
||||
def get_win_avg(self):
|
||||
return np.mean(self.deque)
|
||||
|
||||
def get_global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
|
||||
class TrainMeter(object):
|
||||
"""Measures training stats."""
|
||||
|
||||
def __init__(self, epoch_iters):
|
||||
self.epoch_iters = epoch_iters
|
||||
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
|
||||
self.iter_timer = Timer()
|
||||
self.loss = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
# Current minibatch errors (smoothed over a window)
|
||||
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
# Number of misclassified examples
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, timer=False):
|
||||
if timer:
|
||||
self.iter_timer.reset()
|
||||
self.loss.reset()
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
self.mb_top1_err.reset()
|
||||
self.mb_top5_err.reset()
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
|
||||
# Current minibatch stats
|
||||
self.mb_top1_err.add_value(top1_err)
|
||||
self.mb_top5_err.add_value(top5_err)
|
||||
self.loss.add_value(loss)
|
||||
self.lr = lr
|
||||
# Aggregate stats
|
||||
self.num_top1_mis += top1_err * mb_size
|
||||
self.num_top5_mis += top5_err * mb_size
|
||||
self.loss_total += loss * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"eta": time_string(eta_sec),
|
||||
"top1_err": self.mb_top1_err.get_win_median(),
|
||||
"top5_err": self.mb_top5_err.get_win_median(),
|
||||
"loss": self.loss.get_win_median(),
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "train_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
top1_err = self.num_top1_mis / self.num_samples
|
||||
top5_err = self.num_top5_mis / self.num_samples
|
||||
avg_loss = self.loss_total / self.num_samples
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"eta": time_string(eta_sec),
|
||||
"top1_err": top1_err,
|
||||
"top5_err": top5_err,
|
||||
"loss": avg_loss,
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "train_epoch"))
|
||||
|
||||
|
||||
class TestMeter(object):
|
||||
"""Measures testing stats."""
|
||||
|
||||
def __init__(self, max_iter):
|
||||
self.max_iter = max_iter
|
||||
self.iter_timer = Timer()
|
||||
# Current minibatch errors (smoothed over a window)
|
||||
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
# Min errors (over the full test set)
|
||||
self.min_top1_err = 100.0
|
||||
self.min_top5_err = 100.0
|
||||
# Number of misclassified examples
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, min_errs=False):
|
||||
if min_errs:
|
||||
self.min_top1_err = 100.0
|
||||
self.min_top5_err = 100.0
|
||||
self.iter_timer.reset()
|
||||
self.mb_top1_err.reset()
|
||||
self.mb_top5_err.reset()
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, top1_err, top5_err, mb_size):
|
||||
self.mb_top1_err.add_value(top1_err)
|
||||
self.mb_top5_err.add_value(top5_err)
|
||||
self.num_top1_mis += top1_err * mb_size
|
||||
self.num_top5_mis += top5_err * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
mem_usage = gpu_mem_usage()
|
||||
iter_stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"top1_err": self.mb_top1_err.get_win_median(),
|
||||
"top5_err": self.mb_top5_err.get_win_median(),
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return iter_stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "test_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
top1_err = self.num_top1_mis / self.num_samples
|
||||
top5_err = self.num_top5_mis / self.num_samples
|
||||
self.min_top1_err = min(self.min_top1_err, top1_err)
|
||||
self.min_top5_err = min(self.min_top5_err, top5_err)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"top1_err": top1_err,
|
||||
"top5_err": top5_err,
|
||||
"min_top1_err": self.min_top1_err,
|
||||
"min_top5_err": self.min_top5_err,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "test_epoch"))
|
||||
|
||||
|
||||
class TrainMeterIoU(object):
|
||||
"""Measures training stats."""
|
||||
|
||||
def __init__(self, epoch_iters):
|
||||
self.epoch_iters = epoch_iters
|
||||
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
|
||||
self.iter_timer = Timer()
|
||||
self.loss = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
|
||||
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
|
||||
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, timer=False):
|
||||
if timer:
|
||||
self.iter_timer.reset()
|
||||
self.loss.reset()
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
self.mb_miou.reset()
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, inter, union, loss, lr, mb_size):
|
||||
# Current minibatch stats
|
||||
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
|
||||
self.loss.add_value(loss)
|
||||
self.lr = lr
|
||||
# Aggregate stats
|
||||
self.num_inter += inter * mb_size
|
||||
self.num_union += union * mb_size
|
||||
self.loss_total += loss * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"eta": time_string(eta_sec),
|
||||
"miou": self.mb_miou.get_win_median(),
|
||||
"loss": self.loss.get_win_median(),
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "train_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
|
||||
avg_loss = self.loss_total / self.num_samples
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"eta": time_string(eta_sec),
|
||||
"miou": miou,
|
||||
"loss": avg_loss,
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "train_epoch"))
|
||||
|
||||
|
||||
class TestMeterIoU(object):
|
||||
"""Measures testing stats."""
|
||||
|
||||
def __init__(self, max_iter):
|
||||
self.max_iter = max_iter
|
||||
self.iter_timer = Timer()
|
||||
|
||||
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
|
||||
|
||||
self.max_miou = 0.0
|
||||
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, min_errs=False):
|
||||
if min_errs:
|
||||
self.max_miou = 0.0
|
||||
self.iter_timer.reset()
|
||||
self.mb_miou.reset()
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, inter, union, mb_size):
|
||||
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
|
||||
self.num_inter += inter * mb_size
|
||||
self.num_union += union * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
mem_usage = gpu_mem_usage()
|
||||
iter_stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"miou": self.mb_miou.get_win_median(),
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return iter_stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "test_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
|
||||
self.max_miou = max(self.max_miou, miou)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"miou": miou,
|
||||
"max_miou": self.max_miou,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "test_epoch"))
|
129
pycls/core/net.py
Normal file
129
pycls/core/net.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Functions for manipulating networks."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
"""Performs ResNet-style weight initialization."""
|
||||
if isinstance(m, nn.Conv2d):
|
||||
# Note that there is no bias due to BN
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA
|
||||
zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma
|
||||
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(mean=0.0, std=0.01)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_precise_bn_stats(model, loader):
|
||||
"""Computes precise BN stats on training data."""
|
||||
# Compute the number of minibatches to use
|
||||
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
|
||||
# Retrieve the BN layers
|
||||
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
||||
# Initialize stats storage
|
||||
mus = [torch.zeros_like(bn.running_mean) for bn in bns]
|
||||
sqs = [torch.zeros_like(bn.running_var) for bn in bns]
|
||||
# Remember momentum values
|
||||
moms = [bn.momentum for bn in bns]
|
||||
# Disable momentum
|
||||
for bn in bns:
|
||||
bn.momentum = 1.0
|
||||
# Accumulate the stats across the data samples
|
||||
for inputs, _labels in itertools.islice(loader, num_iter):
|
||||
model(inputs.cuda())
|
||||
# Accumulate the stats for each BN layer
|
||||
for i, bn in enumerate(bns):
|
||||
m, v = bn.running_mean, bn.running_var
|
||||
sqs[i] += (v + m * m) / num_iter
|
||||
mus[i] += m / num_iter
|
||||
# Set the stats and restore momentum values
|
||||
for i, bn in enumerate(bns):
|
||||
bn.running_var = sqs[i] - mus[i] * mus[i]
|
||||
bn.running_mean = mus[i]
|
||||
bn.momentum = moms[i]
|
||||
|
||||
|
||||
def reset_bn_stats(model):
|
||||
"""Resets running BN stats."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, torch.nn.BatchNorm2d):
|
||||
m.reset_running_stats()
|
||||
|
||||
|
||||
def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False):
|
||||
"""Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
h = (h + 2 * padding - k) // stride + 1
|
||||
w = (w + 2 * padding - k) // stride + 1
|
||||
flops += k * k * w_in * w_out * h * w // groups
|
||||
params += k * k * w_in * w_out // groups
|
||||
flops += w_out if bias else 0
|
||||
params += w_out if bias else 0
|
||||
acts += w_out * h * w
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity_batchnorm2d(cx, w_in):
|
||||
"""Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
params += 2 * w_in
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity_maxpool2d(cx, k, stride, padding):
|
||||
"""Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
h = (h + 2 * padding - k) // stride + 1
|
||||
w = (w + 2 * padding - k) // stride + 1
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity(model):
|
||||
"""Compute model complexity (model can be model instance or model class)."""
|
||||
size = cfg.TRAIN.IM_SIZE
|
||||
cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
|
||||
cx = model.complexity(cx)
|
||||
return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}
|
||||
|
||||
|
||||
def drop_connect(x, drop_ratio):
|
||||
"""Drop connect (adapted from DARTS)."""
|
||||
keep_ratio = 1.0 - drop_ratio
|
||||
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
|
||||
mask.bernoulli_(keep_ratio)
|
||||
x.div_(keep_ratio)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
def get_flat_weights(model):
|
||||
"""Gets all model weights as a single flat vector."""
|
||||
return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
|
||||
|
||||
|
||||
def set_flat_weights(model, flat_weights):
|
||||
"""Sets all model weights from a single flat vector."""
|
||||
k = 0
|
||||
for p in model.parameters():
|
||||
n = p.data.numel()
|
||||
p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
|
||||
k += n
|
||||
assert k == flat_weights.numel()
|
95
pycls/core/optimizer.py
Normal file
95
pycls/core/optimizer.py
Normal file
@@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Optimizer."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def construct_optimizer(model):
|
||||
"""Constructs the optimizer.
|
||||
|
||||
Note that the momentum update in PyTorch differs from the one in Caffe2.
|
||||
In particular,
|
||||
|
||||
Caffe2:
|
||||
V := mu * V + lr * g
|
||||
p := p - V
|
||||
|
||||
PyTorch:
|
||||
V := mu * V + g
|
||||
p := p - lr * V
|
||||
|
||||
where V is the velocity, mu is the momentum factor, lr is the learning rate,
|
||||
g is the gradient and p are the parameters.
|
||||
|
||||
Since V is defined independently of the learning rate in PyTorch,
|
||||
when the learning rate is changed there is no need to perform the
|
||||
momentum correction by scaling V (unlike in the Caffe2 case).
|
||||
"""
|
||||
if cfg.BN.USE_CUSTOM_WEIGHT_DECAY:
|
||||
# Apply different weight decay to Batchnorm and non-batchnorm parameters.
|
||||
p_bn = [p for n, p in model.named_parameters() if "bn" in n]
|
||||
p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n]
|
||||
optim_params = [
|
||||
{"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY},
|
||||
{"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
|
||||
]
|
||||
else:
|
||||
optim_params = model.parameters()
|
||||
return torch.optim.SGD(
|
||||
optim_params,
|
||||
lr=cfg.OPTIM.BASE_LR,
|
||||
momentum=cfg.OPTIM.MOMENTUM,
|
||||
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
|
||||
dampening=cfg.OPTIM.DAMPENING,
|
||||
nesterov=cfg.OPTIM.NESTEROV,
|
||||
)
|
||||
|
||||
|
||||
def lr_fun_steps(cur_epoch):
|
||||
"""Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
|
||||
ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
|
||||
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind)
|
||||
|
||||
|
||||
def lr_fun_exp(cur_epoch):
|
||||
"""Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
|
||||
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch)
|
||||
|
||||
|
||||
def lr_fun_cos(cur_epoch):
|
||||
"""Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
|
||||
base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH
|
||||
return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch))
|
||||
|
||||
|
||||
def get_lr_fun():
|
||||
"""Retrieves the specified lr policy function"""
|
||||
lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
|
||||
if lr_fun not in globals():
|
||||
raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY)
|
||||
return globals()[lr_fun]
|
||||
|
||||
|
||||
def get_epoch_lr(cur_epoch):
|
||||
"""Retrieves the lr for the given epoch according to the policy."""
|
||||
lr = get_lr_fun()(cur_epoch)
|
||||
# Linear warmup
|
||||
if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS:
|
||||
alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
|
||||
warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
|
||||
lr *= warmup_factor
|
||||
return lr
|
||||
|
||||
|
||||
def set_lr(optimizer, new_lr):
|
||||
"""Sets the optimizer lr to the specified value."""
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = new_lr
|
132
pycls/core/plotting.py
Normal file
132
pycls/core/plotting.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Plotting functions."""
|
||||
|
||||
import colorlover as cl
|
||||
import matplotlib.pyplot as plt
|
||||
import plotly.graph_objs as go
|
||||
import plotly.offline as offline
|
||||
import pycls.core.logging as logging
|
||||
|
||||
|
||||
def get_plot_colors(max_colors, color_format="pyplot"):
|
||||
"""Generate colors for plotting."""
|
||||
colors = cl.scales["11"]["qual"]["Paired"]
|
||||
if max_colors > len(colors):
|
||||
colors = cl.to_rgb(cl.interp(colors, max_colors))
|
||||
if color_format == "pyplot":
|
||||
return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
|
||||
return colors
|
||||
|
||||
|
||||
def prepare_plot_data(log_files, names, metric="top1_err"):
|
||||
"""Load logs and extract data for plotting error curves."""
|
||||
plot_data = []
|
||||
for file, name in zip(log_files, names):
|
||||
d, data = {}, logging.sort_log_data(logging.load_log_data(file))
|
||||
for phase in ["train", "test"]:
|
||||
x = data[phase + "_epoch"]["epoch_ind"]
|
||||
y = data[phase + "_epoch"][metric]
|
||||
d["x_" + phase], d["y_" + phase] = x, y
|
||||
d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
|
||||
plot_data.append(d)
|
||||
assert len(plot_data) > 0, "No data to plot"
|
||||
return plot_data
|
||||
|
||||
|
||||
def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"):
|
||||
"""Plot error curves using plotly and save to file."""
|
||||
plot_data = prepare_plot_data(log_files, names, metric)
|
||||
colors = get_plot_colors(len(plot_data), "plotly")
|
||||
# Prepare data for plots (3 sets, train duplicated w and w/o legend)
|
||||
data = []
|
||||
for i, d in enumerate(plot_data):
|
||||
s = str(i)
|
||||
line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
|
||||
line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_train"],
|
||||
y=d["y_train"],
|
||||
mode="lines",
|
||||
name=d["train_label"],
|
||||
line=line_train,
|
||||
legendgroup=s,
|
||||
visible=True,
|
||||
showlegend=False,
|
||||
)
|
||||
)
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_test"],
|
||||
y=d["y_test"],
|
||||
mode="lines",
|
||||
name=d["test_label"],
|
||||
line=line_test,
|
||||
legendgroup=s,
|
||||
visible=True,
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_train"],
|
||||
y=d["y_train"],
|
||||
mode="lines",
|
||||
name=d["train_label"],
|
||||
line=line_train,
|
||||
legendgroup=s,
|
||||
visible=False,
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
# Prepare layout w ability to toggle 'all', 'train', 'test'
|
||||
titlefont = {"size": 18, "color": "#7f7f7f"}
|
||||
vis = [[True, True, False], [False, False, True], [False, True, False]]
|
||||
buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
|
||||
buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons]
|
||||
layout = go.Layout(
|
||||
title=metric + " vs. epoch<br>[dash=train, solid=test]",
|
||||
xaxis={"title": "epoch", "titlefont": titlefont},
|
||||
yaxis={"title": metric, "titlefont": titlefont},
|
||||
showlegend=True,
|
||||
hoverlabel={"namelength": -1},
|
||||
updatemenus=[
|
||||
{
|
||||
"buttons": buttons,
|
||||
"direction": "down",
|
||||
"showactive": True,
|
||||
"x": 1.02,
|
||||
"xanchor": "left",
|
||||
"y": 1.08,
|
||||
"yanchor": "top",
|
||||
}
|
||||
],
|
||||
)
|
||||
# Create plotly plot
|
||||
offline.plot({"data": data, "layout": layout}, filename=filename)
|
||||
|
||||
|
||||
def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"):
|
||||
"""Plot error curves using matplotlib.pyplot and save to file."""
|
||||
plot_data = prepare_plot_data(log_files, names, metric)
|
||||
colors = get_plot_colors(len(names))
|
||||
for ind, d in enumerate(plot_data):
|
||||
c, lbl = colors[ind], d["test_label"]
|
||||
plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
|
||||
plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
|
||||
plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
|
||||
plt.xlabel("epoch", fontsize=14)
|
||||
plt.ylabel(metric, fontsize=14)
|
||||
plt.grid(alpha=0.4)
|
||||
plt.legend()
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
plt.clf()
|
||||
else:
|
||||
plt.show()
|
39
pycls/core/timer.py
Normal file
39
pycls/core/timer.py
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Timer."""
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class Timer(object):
|
||||
"""A simple timer (adapted from Detectron)."""
|
||||
|
||||
def __init__(self):
|
||||
self.total_time = None
|
||||
self.calls = None
|
||||
self.start_time = None
|
||||
self.diff = None
|
||||
self.average_time = None
|
||||
self.reset()
|
||||
|
||||
def tic(self):
|
||||
# using time.time as time.clock does not normalize for multithreading
|
||||
self.start_time = time.time()
|
||||
|
||||
def toc(self):
|
||||
self.diff = time.time() - self.start_time
|
||||
self.total_time += self.diff
|
||||
self.calls += 1
|
||||
self.average_time = self.total_time / self.calls
|
||||
|
||||
def reset(self):
|
||||
self.total_time = 0.0
|
||||
self.calls = 0
|
||||
self.start_time = 0.0
|
||||
self.diff = 0.0
|
||||
self.average_time = 0.0
|
419
pycls/core/trainer.py
Normal file
419
pycls/core/trainer.py
Normal file
@@ -0,0 +1,419 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Tools for training and testing a model."""
|
||||
|
||||
import os
|
||||
from thop import profile
|
||||
|
||||
import numpy as np
|
||||
import pycls.core.benchmark as benchmark
|
||||
import pycls.core.builders as builders
|
||||
import pycls.core.checkpoint as checkpoint
|
||||
import pycls.core.config as config
|
||||
import pycls.core.distributed as dist
|
||||
import pycls.core.logging as logging
|
||||
import pycls.core.meters as meters
|
||||
import pycls.core.net as net
|
||||
import pycls.core.optimizer as optim
|
||||
import pycls.datasets.loader as loader
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def setup_env():
|
||||
"""Sets up environment for training or testing."""
|
||||
if dist.is_master_proc():
|
||||
# Ensure that the output dir exists
|
||||
os.makedirs(cfg.OUT_DIR, exist_ok=True)
|
||||
# Save the config
|
||||
config.dump_cfg()
|
||||
# Setup logging
|
||||
logging.setup_logging()
|
||||
# Log the config as both human readable and as a json
|
||||
logger.info("Config:\n{}".format(cfg))
|
||||
logger.info(logging.dump_log_data(cfg, "cfg"))
|
||||
# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
|
||||
np.random.seed(cfg.RNG_SEED)
|
||||
torch.manual_seed(cfg.RNG_SEED)
|
||||
# Configure the CUDNN backend
|
||||
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
|
||||
|
||||
|
||||
def setup_model():
|
||||
"""Sets up a model for training or testing and log the results."""
|
||||
# Build the model
|
||||
model = builders.build_model()
|
||||
logger.info("Model:\n{}".format(model))
|
||||
# Log model complexity
|
||||
# logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
|
||||
if cfg.TASK == "seg" and cfg.TRAIN.DATASET == "cityscapes":
|
||||
h, w = 1025, 2049
|
||||
else:
|
||||
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
|
||||
if cfg.TASK == "jig":
|
||||
x = torch.randn(1, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, h, w)
|
||||
else:
|
||||
x = torch.randn(1, cfg.MODEL.INPUT_CHANNELS, h, w)
|
||||
macs, params = profile(model, inputs=(x, ), verbose=False)
|
||||
logger.info("Params: {:,}".format(params))
|
||||
logger.info("Flops: {:,}".format(macs))
|
||||
# Transfer the model to the current GPU device
|
||||
err_str = "Cannot use more GPU devices than available"
|
||||
assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str
|
||||
cur_device = torch.cuda.current_device()
|
||||
model = model.cuda(device=cur_device)
|
||||
# Use multi-process data parallel model in the multi-gpu setting
|
||||
if cfg.NUM_GPUS > 1:
|
||||
# Make model replica operate on the current device
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
module=model, device_ids=[cur_device], output_device=cur_device
|
||||
)
|
||||
# Set complexity function to be module's complexity function
|
||||
# model.complexity = model.module.complexity
|
||||
return model
|
||||
|
||||
|
||||
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
|
||||
"""Performs one epoch of training."""
|
||||
# Update drop path prob for NAS
|
||||
if cfg.MODEL.TYPE == "nas":
|
||||
m = model.module if cfg.NUM_GPUS > 1 else model
|
||||
m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH)
|
||||
# Shuffle the data
|
||||
loader.shuffle(train_loader, cur_epoch)
|
||||
# Update the learning rate per epoch
|
||||
if not cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch)
|
||||
optim.set_lr(optimizer, lr)
|
||||
# Enable training mode
|
||||
model.train()
|
||||
train_meter.iter_tic()
|
||||
for cur_iter, (inputs, labels) in enumerate(train_loader):
|
||||
# Update the learning rate per iter
|
||||
if cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader))
|
||||
optim.set_lr(optimizer, lr)
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Perform the forward pass
|
||||
preds = model(inputs)
|
||||
# Compute the loss
|
||||
if isinstance(preds, tuple):
|
||||
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
|
||||
preds = preds[0]
|
||||
else:
|
||||
loss = loss_fun(preds, labels)
|
||||
# Perform the backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Update the parameters
|
||||
optimizer.step()
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the stats across the GPUs (no reduction if 1 GPU used)
|
||||
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
|
||||
# Copy the stats from GPU to CPU (sync point)
|
||||
loss = loss.item()
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
train_meter.iter_toc()
|
||||
# Update and log stats
|
||||
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
|
||||
train_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
train_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
train_meter.log_epoch_stats(cur_epoch)
|
||||
train_meter.reset()
|
||||
|
||||
|
||||
def search_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
|
||||
"""Performs one epoch of differentiable architecture search."""
|
||||
m = model.module if cfg.NUM_GPUS > 1 else model
|
||||
# Shuffle the data
|
||||
loader.shuffle(train_loader[0], cur_epoch)
|
||||
loader.shuffle(train_loader[1], cur_epoch)
|
||||
# Update the learning rate per epoch
|
||||
if not cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch)
|
||||
optim.set_lr(optimizer[0], lr)
|
||||
# Enable training mode
|
||||
model.train()
|
||||
train_meter.iter_tic()
|
||||
trainB_iter = iter(train_loader[1])
|
||||
for cur_iter, (inputs, labels) in enumerate(train_loader[0]):
|
||||
# Update the learning rate per iter
|
||||
if cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader[0]))
|
||||
optim.set_lr(optimizer[0], lr)
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Update architecture
|
||||
if cur_epoch + cur_iter / len(train_loader[0]) >= cfg.OPTIM.ARCH_EPOCH:
|
||||
try:
|
||||
inputsB, labelsB = next(trainB_iter)
|
||||
except StopIteration:
|
||||
trainB_iter = iter(train_loader[1])
|
||||
inputsB, labelsB = next(trainB_iter)
|
||||
inputsB, labelsB = inputsB.cuda(), labelsB.cuda(non_blocking=True)
|
||||
optimizer[1].zero_grad()
|
||||
loss = m._loss(inputsB, labelsB)
|
||||
loss.backward()
|
||||
optimizer[1].step()
|
||||
# Perform the forward pass
|
||||
preds = model(inputs)
|
||||
# Compute the loss
|
||||
loss = loss_fun(preds, labels)
|
||||
# Perform the backward pass
|
||||
optimizer[0].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
|
||||
# Update the parameters
|
||||
optimizer[0].step()
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the stats across the GPUs (no reduction if 1 GPU used)
|
||||
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
|
||||
# Copy the stats from GPU to CPU (sync point)
|
||||
loss = loss.item()
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
train_meter.iter_toc()
|
||||
# Update and log stats
|
||||
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
|
||||
train_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
train_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
train_meter.log_epoch_stats(cur_epoch)
|
||||
train_meter.reset()
|
||||
# Log genotype
|
||||
genotype = m.genotype()
|
||||
logger.info("genotype = %s", genotype)
|
||||
logger.info(F.softmax(m.net_.alphas_normal, dim=-1))
|
||||
logger.info(F.softmax(m.net_.alphas_reduce, dim=-1))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_epoch(test_loader, model, test_meter, cur_epoch):
|
||||
"""Evaluates the model on the test set."""
|
||||
# Enable eval mode
|
||||
model.eval()
|
||||
test_meter.iter_tic()
|
||||
for cur_iter, (inputs, labels) in enumerate(test_loader):
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Compute the predictions
|
||||
preds = model(inputs)
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the errors across the GPUs (no reduction if 1 GPU used)
|
||||
top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
|
||||
# Copy the errors from GPU to CPU (sync point)
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
test_meter.iter_toc()
|
||||
# Update and log stats
|
||||
test_meter.update_stats(top1_err, top5_err, mb_size)
|
||||
test_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
test_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
test_meter.log_epoch_stats(cur_epoch)
|
||||
test_meter.reset()
|
||||
|
||||
|
||||
def train_model():
|
||||
"""Trains the model."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model, loss_fun, and optimizer
|
||||
model = setup_model()
|
||||
loss_fun = builders.build_loss_fun().cuda()
|
||||
if "search" in cfg.MODEL.TYPE:
|
||||
params_w = [v for k, v in model.named_parameters() if "alphas" not in k]
|
||||
params_a = [v for k, v in model.named_parameters() if "alphas" in k]
|
||||
optimizer_w = torch.optim.SGD(
|
||||
params=params_w,
|
||||
lr=cfg.OPTIM.BASE_LR,
|
||||
momentum=cfg.OPTIM.MOMENTUM,
|
||||
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
|
||||
dampening=cfg.OPTIM.DAMPENING,
|
||||
nesterov=cfg.OPTIM.NESTEROV
|
||||
)
|
||||
if cfg.OPTIM.ARCH_OPTIM == "adam":
|
||||
optimizer_a = torch.optim.Adam(
|
||||
params=params_a,
|
||||
lr=cfg.OPTIM.ARCH_BASE_LR,
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY
|
||||
)
|
||||
elif cfg.OPTIM.ARCH_OPTIM == "sgd":
|
||||
optimizer_a = torch.optim.SGD(
|
||||
params=params_a,
|
||||
lr=cfg.OPTIM.ARCH_BASE_LR,
|
||||
momentum=cfg.OPTIM.MOMENTUM,
|
||||
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY,
|
||||
dampening=cfg.OPTIM.DAMPENING,
|
||||
nesterov=cfg.OPTIM.NESTEROV
|
||||
)
|
||||
optimizer = [optimizer_w, optimizer_a]
|
||||
else:
|
||||
optimizer = optim.construct_optimizer(model)
|
||||
# Load checkpoint or initial weights
|
||||
start_epoch = 0
|
||||
if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
|
||||
last_checkpoint = checkpoint.get_last_checkpoint()
|
||||
checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer)
|
||||
logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
|
||||
start_epoch = checkpoint_epoch + 1
|
||||
elif cfg.TRAIN.WEIGHTS:
|
||||
checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
|
||||
logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
|
||||
# Create data loaders and meters
|
||||
if cfg.TRAIN.PORTION < 1:
|
||||
if "search" in cfg.MODEL.TYPE:
|
||||
train_loader = [loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="l"
|
||||
),
|
||||
loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="r"
|
||||
)]
|
||||
else:
|
||||
train_loader = loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="l"
|
||||
)
|
||||
test_loader = loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="r"
|
||||
)
|
||||
else:
|
||||
train_loader = loader.construct_train_loader()
|
||||
test_loader = loader.construct_test_loader()
|
||||
train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter
|
||||
test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter
|
||||
l = train_loader[0] if isinstance(train_loader, list) else train_loader
|
||||
train_meter = train_meter_type(len(l))
|
||||
test_meter = test_meter_type(len(test_loader))
|
||||
# Compute model and loader timings
|
||||
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
|
||||
l = train_loader[0] if isinstance(train_loader, list) else train_loader
|
||||
benchmark.compute_time_full(model, loss_fun, l, test_loader)
|
||||
# Perform the training loop
|
||||
logger.info("Start epoch: {}".format(start_epoch + 1))
|
||||
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
|
||||
# Train for one epoch
|
||||
f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch
|
||||
f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
|
||||
# Compute precise BN stats
|
||||
if cfg.BN.USE_PRECISE_STATS:
|
||||
net.compute_precise_bn_stats(model, train_loader)
|
||||
# Save a checkpoint
|
||||
if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
|
||||
checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch)
|
||||
logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
|
||||
# Evaluate the model
|
||||
next_epoch = cur_epoch + 1
|
||||
if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
|
||||
test_epoch(test_loader, model, test_meter, cur_epoch)
|
||||
|
||||
|
||||
def test_model():
|
||||
"""Evaluates a trained model."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model
|
||||
model = setup_model()
|
||||
# Load model weights
|
||||
checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
|
||||
logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
|
||||
# Create data loaders and meters
|
||||
test_loader = loader.construct_test_loader()
|
||||
test_meter = meters.TestMeter(len(test_loader))
|
||||
# Evaluate the model
|
||||
test_epoch(test_loader, model, test_meter, 0)
|
||||
|
||||
|
||||
def time_model():
|
||||
"""Times model and data loader."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model and loss_fun
|
||||
model = setup_model()
|
||||
loss_fun = builders.build_loss_fun().cuda()
|
||||
# Create data loaders
|
||||
train_loader = loader.construct_train_loader()
|
||||
test_loader = loader.construct_test_loader()
|
||||
# Compute model and loader timings
|
||||
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
|
0
pycls/models/__init__.py
Normal file
0
pycls/models/__init__.py
Normal file
406
pycls/models/anynet.py
Normal file
406
pycls/models/anynet.py
Normal file
@@ -0,0 +1,406 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""AnyNet models."""
|
||||
|
||||
import pycls.core.net as net
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def get_stem_fun(stem_type):
|
||||
"""Retrieves the stem function by name."""
|
||||
stem_funs = {
|
||||
"res_stem_cifar": ResStemCifar,
|
||||
"res_stem_in": ResStemIN,
|
||||
"simple_stem_in": SimpleStemIN,
|
||||
}
|
||||
err_str = "Stem type '{}' not supported"
|
||||
assert stem_type in stem_funs.keys(), err_str.format(stem_type)
|
||||
return stem_funs[stem_type]
|
||||
|
||||
|
||||
def get_block_fun(block_type):
|
||||
"""Retrieves the block function by name."""
|
||||
block_funs = {
|
||||
"vanilla_block": VanillaBlock,
|
||||
"res_basic_block": ResBasicBlock,
|
||||
"res_bottleneck_block": ResBottleneckBlock,
|
||||
}
|
||||
err_str = "Block type '{}' not supported"
|
||||
assert block_type in block_funs.keys(), err_str.format(block_type)
|
||||
return block_funs[block_type]
|
||||
|
||||
|
||||
class AnyHead(nn.Module):
|
||||
"""AnyNet head: AvgPool, 1x1."""
|
||||
|
||||
def __init__(self, w_in, nc):
|
||||
super(AnyHead, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(w_in, nc, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, nc):
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
|
||||
return cx
|
||||
|
||||
|
||||
class VanillaBlock(nn.Module):
|
||||
"""Vanilla block: [3x3 conv, BN, Relu] x2."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Vanilla block does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
super(VanillaBlock, self).__init__()
|
||||
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Vanilla block does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class BasicTransform(nn.Module):
|
||||
"""Basic transformation: [3x3 conv, BN, Relu] x2."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride):
|
||||
super(BasicTransform, self).__init__()
|
||||
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResBasicBlock(nn.Module):
|
||||
"""Residual basic block: x + F(x), F = basic transform."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Basic transform does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
super(ResBasicBlock, self).__init__()
|
||||
self.proj_block = (w_in != w_out) or (stride != 1)
|
||||
if self.proj_block:
|
||||
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.f = BasicTransform(w_in, w_out, stride)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
if self.proj_block:
|
||||
x = self.bn(self.proj(x)) + self.f(x)
|
||||
else:
|
||||
x = x + self.f(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Basic transform does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
proj_block = (w_in != w_out) or (stride != 1)
|
||||
if proj_block:
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = h, w # parallel branch
|
||||
cx = BasicTransform.complexity(cx, w_in, w_out, stride)
|
||||
return cx
|
||||
|
||||
|
||||
class SE(nn.Module):
|
||||
"""Squeeze-and-Excitation (SE) block: AvgPool, FC, ReLU, FC, Sigmoid."""
|
||||
|
||||
def __init__(self, w_in, w_se):
|
||||
super(SE, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.f_ex = nn.Sequential(
|
||||
nn.Conv2d(w_in, w_se, 1, bias=True),
|
||||
nn.ReLU(inplace=cfg.MEM.RELU_INPLACE),
|
||||
nn.Conv2d(w_se, w_in, 1, bias=True),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.f_ex(self.avg_pool(x))
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_se):
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
|
||||
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
|
||||
cx["h"], cx["w"] = h, w
|
||||
return cx
|
||||
|
||||
|
||||
class BottleneckTransform(nn.Module):
|
||||
"""Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm, gw, se_r):
|
||||
super(BottleneckTransform, self).__init__()
|
||||
w_b = int(round(w_out * bm))
|
||||
g = w_b // gw
|
||||
self.a = nn.Conv2d(w_in, w_b, 1, stride=1, padding=0, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_b, w_b, 3, stride=stride, padding=1, groups=g, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
if se_r:
|
||||
w_se = int(round(w_in * se_r))
|
||||
self.se = SE(w_b, w_se)
|
||||
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.c_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm, gw, se_r):
|
||||
w_b = int(round(w_out * bm))
|
||||
g = w_b // gw
|
||||
cx = net.complexity_conv2d(cx, w_in, w_b, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_b, 3, stride, 1, g)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
if se_r:
|
||||
w_se = int(round(w_in * se_r))
|
||||
cx = SE.complexity(cx, w_b, w_se)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResBottleneckBlock(nn.Module):
|
||||
"""Residual bottleneck block: x + F(x), F = bottleneck transform."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
|
||||
super(ResBottleneckBlock, self).__init__()
|
||||
# Use skip connection with projection if shape changes
|
||||
self.proj_block = (w_in != w_out) or (stride != 1)
|
||||
if self.proj_block:
|
||||
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
if self.proj_block:
|
||||
x = self.bn(self.proj(x)) + self.f(x)
|
||||
else:
|
||||
x = x + self.f(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
|
||||
proj_block = (w_in != w_out) or (stride != 1)
|
||||
if proj_block:
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = h, w # parallel branch
|
||||
cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, bm, gw, se_r)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemCifar(nn.Module):
|
||||
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemCifar, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemIN(nn.Module):
|
||||
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
|
||||
return cx
|
||||
|
||||
|
||||
class SimpleStemIN(nn.Module):
|
||||
"""Simple stem for ImageNet: 3x3, BN, ReLU."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(SimpleStemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class AnyStage(nn.Module):
|
||||
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
|
||||
super(AnyStage, self).__init__()
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
name = "b{}".format(i + 1)
|
||||
self.add_module(name, block_fun(b_w_in, w_out, b_stride, bm, gw, se_r))
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.children():
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
cx = block_fun.complexity(cx, b_w_in, w_out, b_stride, bm, gw, se_r)
|
||||
return cx
|
||||
|
||||
|
||||
class AnyNet(nn.Module):
|
||||
"""AnyNet model."""
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
return {
|
||||
"stem_type": cfg.ANYNET.STEM_TYPE,
|
||||
"stem_w": cfg.ANYNET.STEM_W,
|
||||
"block_type": cfg.ANYNET.BLOCK_TYPE,
|
||||
"ds": cfg.ANYNET.DEPTHS,
|
||||
"ws": cfg.ANYNET.WIDTHS,
|
||||
"ss": cfg.ANYNET.STRIDES,
|
||||
"bms": cfg.ANYNET.BOT_MULS,
|
||||
"gws": cfg.ANYNET.GROUP_WS,
|
||||
"se_r": cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None,
|
||||
"nc": cfg.MODEL.NUM_CLASSES,
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(AnyNet, self).__init__()
|
||||
kwargs = self.get_args() if not kwargs else kwargs
|
||||
#print(kwargs)
|
||||
self._construct(**kwargs)
|
||||
self.apply(net.init_weights)
|
||||
|
||||
def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
|
||||
# Generate dummy bot muls and gs for models that do not use them
|
||||
bms = bms if bms else [None for _d in ds]
|
||||
gws = gws if gws else [None for _d in ds]
|
||||
stage_params = list(zip(ds, ws, ss, bms, gws))
|
||||
stem_fun = get_stem_fun(stem_type)
|
||||
self.stem = stem_fun(3, stem_w)
|
||||
block_fun = get_block_fun(block_type)
|
||||
prev_w = stem_w
|
||||
for i, (d, w, s, bm, gw) in enumerate(stage_params):
|
||||
name = "s{}".format(i + 1)
|
||||
self.add_module(name, AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r))
|
||||
prev_w = w
|
||||
self.head = AnyHead(w_in=prev_w, nc=nc)
|
||||
|
||||
def forward(self, x, get_ints=False):
|
||||
for module in self.children():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, **kwargs):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
kwargs = AnyNet.get_args() if not kwargs else kwargs
|
||||
return AnyNet._complexity(cx, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _complexity(cx, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
|
||||
bms = bms if bms else [None for _d in ds]
|
||||
gws = gws if gws else [None for _d in ds]
|
||||
stage_params = list(zip(ds, ws, ss, bms, gws))
|
||||
stem_fun = get_stem_fun(stem_type)
|
||||
cx = stem_fun.complexity(cx, 3, stem_w)
|
||||
block_fun = get_block_fun(block_type)
|
||||
prev_w = stem_w
|
||||
for d, w, s, bm, gw in stage_params:
|
||||
cx = AnyStage.complexity(cx, prev_w, w, s, d, block_fun, bm, gw, se_r)
|
||||
prev_w = w
|
||||
cx = AnyHead.complexity(cx, prev_w, nc)
|
||||
return cx
|
108
pycls/models/common.py
Normal file
108
pycls/models/common.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def Preprocess(x):
|
||||
if cfg.TASK == 'jig':
|
||||
assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw'
|
||||
assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw'
|
||||
x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]])
|
||||
return x
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self, channels, num_classes):
|
||||
super(Classifier, self).__init__()
|
||||
if cfg.TASK == 'jig':
|
||||
self.jig_sq = cfg.JIGSAW_GRID ** 2
|
||||
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(channels * self.jig_sq, num_classes)
|
||||
elif cfg.TASK == 'col':
|
||||
self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)
|
||||
elif cfg.TASK == 'seg':
|
||||
self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES)
|
||||
else:
|
||||
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(channels, num_classes)
|
||||
|
||||
def forward(self, x, shape):
|
||||
if cfg.TASK == 'jig':
|
||||
x = self.pooling(x)
|
||||
x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]])
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
elif cfg.TASK in ['col', 'seg']:
|
||||
x = self.classifier(x)
|
||||
x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x)
|
||||
else:
|
||||
x = self.pooling(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
class ASPP(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_classes, rates):
|
||||
super(ASPP, self).__init__()
|
||||
assert len(rates) in [1, 3]
|
||||
self.rates = rates
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.aspp1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp2 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0],
|
||||
padding=rates[0], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
if len(self.rates) == 3:
|
||||
self.aspp3 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1],
|
||||
padding=rates[1], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp4 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2],
|
||||
padding=rates[2], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp5 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, num_classes, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.aspp1(x)
|
||||
x2 = self.aspp2(x)
|
||||
x5 = self.global_pooling(x)
|
||||
x5 = self.aspp5(x5)
|
||||
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
|
||||
align_corners=True)(x5)
|
||||
if len(self.rates) == 3:
|
||||
x3 = self.aspp3(x)
|
||||
x4 = self.aspp4(x)
|
||||
x = torch.cat((x1, x2, x3, x4, x5), 1)
|
||||
else:
|
||||
x = torch.cat((x1, x2, x5), 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
232
pycls/models/effnet.py
Normal file
232
pycls/models/effnet.py
Normal file
@@ -0,0 +1,232 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""EfficientNet models."""
|
||||
|
||||
import pycls.core.net as net
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
class EffHead(nn.Module):
|
||||
"""EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
|
||||
|
||||
def __init__(self, w_in, w_out, nc):
|
||||
super(EffHead, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.conv_swish = Swish()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
if cfg.EN.DROPOUT_RATIO > 0.0:
|
||||
self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO)
|
||||
self.fc = nn.Linear(w_out, nc, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_swish(self.conv_bn(self.conv(x)))
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.dropout(x) if hasattr(self, "dropout") else x
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, nc):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_out, nc, 1, 1, 0, bias=True)
|
||||
return cx
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
"""Swish activation function: x * sigmoid(x)."""
|
||||
|
||||
def __init__(self):
|
||||
super(Swish, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SE(nn.Module):
|
||||
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
|
||||
|
||||
def __init__(self, w_in, w_se):
|
||||
super(SE, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.f_ex = nn.Sequential(
|
||||
nn.Conv2d(w_in, w_se, 1, bias=True),
|
||||
Swish(),
|
||||
nn.Conv2d(w_se, w_in, 1, bias=True),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.f_ex(self.avg_pool(x))
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_se):
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
|
||||
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
|
||||
cx["h"], cx["w"] = h, w
|
||||
return cx
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
"""Mobile inverted bottleneck block w/ SE (MBConv)."""
|
||||
|
||||
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out):
|
||||
# expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
|
||||
super(MBConv, self).__init__()
|
||||
self.exp = None
|
||||
w_exp = int(w_in * exp_r)
|
||||
if w_exp != w_in:
|
||||
self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
|
||||
self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.exp_swish = Swish()
|
||||
dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
|
||||
self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
|
||||
self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.dwise_swish = Swish()
|
||||
self.se = SE(w_exp, int(w_in * se_r))
|
||||
self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
# Skip connection if in and out shapes are the same (MN-V2 style)
|
||||
self.has_skip = stride == 1 and w_in == w_out
|
||||
|
||||
def forward(self, x):
|
||||
f_x = x
|
||||
if self.exp:
|
||||
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
|
||||
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
|
||||
f_x = self.se(f_x)
|
||||
f_x = self.lin_proj_bn(self.lin_proj(f_x))
|
||||
if self.has_skip:
|
||||
if self.training and cfg.EN.DC_RATIO > 0.0:
|
||||
f_x = net.drop_connect(f_x, cfg.EN.DC_RATIO)
|
||||
f_x = x + f_x
|
||||
return f_x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out):
|
||||
w_exp = int(w_in * exp_r)
|
||||
if w_exp != w_in:
|
||||
cx = net.complexity_conv2d(cx, w_in, w_exp, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_exp)
|
||||
padding = (kernel - 1) // 2
|
||||
cx = net.complexity_conv2d(cx, w_exp, w_exp, kernel, stride, padding, w_exp)
|
||||
cx = net.complexity_batchnorm2d(cx, w_exp)
|
||||
cx = SE.complexity(cx, w_exp, int(w_in * se_r))
|
||||
cx = net.complexity_conv2d(cx, w_exp, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class EffStage(nn.Module):
|
||||
"""EfficientNet stage."""
|
||||
|
||||
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
|
||||
super(EffStage, self).__init__()
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
name = "b{}".format(i + 1)
|
||||
self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out))
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.children():
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out, d):
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
cx = MBConv.complexity(cx, b_w_in, exp_r, kernel, b_stride, se_r, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class StemIN(nn.Module):
|
||||
"""EfficientNet stem for ImageNet: 3x3, BN, Swish."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(StemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.swish = Swish()
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class EffNet(nn.Module):
|
||||
"""EfficientNet model."""
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
return {
|
||||
"stem_w": cfg.EN.STEM_W,
|
||||
"ds": cfg.EN.DEPTHS,
|
||||
"ws": cfg.EN.WIDTHS,
|
||||
"exp_rs": cfg.EN.EXP_RATIOS,
|
||||
"se_r": cfg.EN.SE_R,
|
||||
"ss": cfg.EN.STRIDES,
|
||||
"ks": cfg.EN.KERNELS,
|
||||
"head_w": cfg.EN.HEAD_W,
|
||||
"nc": cfg.MODEL.NUM_CLASSES,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
err_str = "Dataset {} is not supported"
|
||||
assert cfg.TRAIN.DATASET in ["imagenet"], err_str.format(cfg.TRAIN.DATASET)
|
||||
assert cfg.TEST.DATASET in ["imagenet"], err_str.format(cfg.TEST.DATASET)
|
||||
super(EffNet, self).__init__()
|
||||
self._construct(**EffNet.get_args())
|
||||
self.apply(net.init_weights)
|
||||
|
||||
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
|
||||
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
|
||||
self.stem = StemIN(3, stem_w)
|
||||
prev_w = stem_w
|
||||
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
|
||||
name = "s{}".format(i + 1)
|
||||
self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d))
|
||||
prev_w = w
|
||||
self.head = EffHead(prev_w, head_w, nc)
|
||||
|
||||
def forward(self, x):
|
||||
for module in self.children():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
return EffNet._complexity(cx, **EffNet.get_args())
|
||||
|
||||
@staticmethod
|
||||
def _complexity(cx, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
|
||||
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
|
||||
cx = StemIN.complexity(cx, 3, stem_w)
|
||||
prev_w = stem_w
|
||||
for d, w, exp_r, stride, kernel in stage_params:
|
||||
cx = EffStage.complexity(cx, prev_w, exp_r, kernel, stride, se_r, w, d)
|
||||
prev_w = w
|
||||
cx = EffHead.complexity(cx, prev_w, head_w, nc)
|
||||
return cx
|
634
pycls/models/nas/genotypes.py
Normal file
634
pycls/models/nas/genotypes.py
Normal file
@@ -0,0 +1,634 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""NAS genotypes (adopted from DARTS)."""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
|
||||
# NASNet ops
|
||||
NASNET_OPS = [
|
||||
'skip_connect',
|
||||
'conv_3x1_1x3',
|
||||
'conv_7x1_1x7',
|
||||
'dil_conv_3x3',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'max_pool_5x5',
|
||||
'max_pool_7x7',
|
||||
'conv_1x1',
|
||||
'conv_3x3',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
]
|
||||
|
||||
# ENAS ops
|
||||
ENAS_OPS = [
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
]
|
||||
|
||||
# AmoebaNet ops
|
||||
AMOEBA_OPS = [
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'dil_sep_conv_3x3',
|
||||
'conv_7x1_1x7',
|
||||
]
|
||||
|
||||
# NAO ops
|
||||
NAO_OPS = [
|
||||
'skip_connect',
|
||||
'conv_1x1',
|
||||
'conv_3x3',
|
||||
'conv_3x1_1x3',
|
||||
'conv_7x1_1x7',
|
||||
'max_pool_2x2',
|
||||
'max_pool_3x3',
|
||||
'max_pool_5x5',
|
||||
'avg_pool_2x2',
|
||||
'avg_pool_3x3',
|
||||
'avg_pool_5x5',
|
||||
]
|
||||
|
||||
# PNAS ops
|
||||
PNAS_OPS = [
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'conv_7x1_1x7',
|
||||
'skip_connect',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'dil_conv_3x3',
|
||||
]
|
||||
|
||||
# DARTS ops
|
||||
DARTS_OPS = [
|
||||
'none',
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5',
|
||||
]
|
||||
|
||||
|
||||
NASNet = Genotype(
|
||||
normal=[
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5, 6],
|
||||
reduce=[
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_7x7', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_7x7', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('skip_connect', 3),
|
||||
('avg_pool_3x3', 2),
|
||||
('sep_conv_3x3', 2),
|
||||
('max_pool_3x3', 1),
|
||||
],
|
||||
reduce_concat=[4, 5, 6],
|
||||
)
|
||||
|
||||
|
||||
PNASNet = Genotype(
|
||||
normal=[
|
||||
('sep_conv_5x5', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 1),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 4),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5, 6],
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 1),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 4),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5, 6],
|
||||
)
|
||||
|
||||
|
||||
AmoebaNet = Genotype(
|
||||
normal=[
|
||||
('avg_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 1),
|
||||
('skip_connect', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
],
|
||||
normal_concat=[4, 5, 6],
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 2),
|
||||
('sep_conv_7x7', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('conv_7x1_1x7', 0),
|
||||
('sep_conv_3x3', 5),
|
||||
],
|
||||
reduce_concat=[3, 4, 6]
|
||||
)
|
||||
|
||||
|
||||
DARTS_V1 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 2)
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 2),
|
||||
('skip_connect', 2),
|
||||
('avg_pool_3x3', 0)
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5]
|
||||
)
|
||||
|
||||
|
||||
DARTS_V2 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('skip_connect', 0),
|
||||
('dil_conv_3x3', 2)
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 2),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 1)
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5]
|
||||
)
|
||||
|
||||
PDARTS = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('dil_conv_3x3', 1),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
PCDARTS_C10 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('dil_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
PCDARTS_IN1K = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('dil_conv_5x5', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_CLS = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_COL = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 3),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_JIG = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_CLS = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 4),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_3x3', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_COL = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 1)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_JIG = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('skip_connect', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_5x5', 3),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_5x5', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_SEG = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_5x5', 0)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_COL = Genotype(
|
||||
normal=[
|
||||
('dil_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_5x5', 2),
|
||||
('dil_conv_3x3', 3),
|
||||
('skip_connect', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('skip_connect', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_JIG = Genotype(
|
||||
normal=[
|
||||
('dil_conv_5x5', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
|
||||
# Supported genotypes
|
||||
GENOTYPES = {
|
||||
'nas': NASNet,
|
||||
'pnas': PNASNet,
|
||||
'amoeba': AmoebaNet,
|
||||
'darts_v1': DARTS_V1,
|
||||
'darts_v2': DARTS_V2,
|
||||
'pdarts': PDARTS,
|
||||
'pcdarts_c10': PCDARTS_C10,
|
||||
'pcdarts_in1k': PCDARTS_IN1K,
|
||||
'unnas_imagenet_cls': UNNAS_IMAGENET_CLS,
|
||||
'unnas_imagenet_rot': UNNAS_IMAGENET_ROT,
|
||||
'unnas_imagenet_col': UNNAS_IMAGENET_COL,
|
||||
'unnas_imagenet_jig': UNNAS_IMAGENET_JIG,
|
||||
'unnas_imagenet22k_cls': UNNAS_IMAGENET22K_CLS,
|
||||
'unnas_imagenet22k_rot': UNNAS_IMAGENET22K_ROT,
|
||||
'unnas_imagenet22k_col': UNNAS_IMAGENET22K_COL,
|
||||
'unnas_imagenet22k_jig': UNNAS_IMAGENET22K_JIG,
|
||||
'unnas_cityscapes_seg': UNNAS_CITYSCAPES_SEG,
|
||||
'unnas_cityscapes_rot': UNNAS_CITYSCAPES_ROT,
|
||||
'unnas_cityscapes_col': UNNAS_CITYSCAPES_COL,
|
||||
'unnas_cityscapes_jig': UNNAS_CITYSCAPES_JIG,
|
||||
'custom': None,
|
||||
}
|
299
pycls/models/nas/nas.py
Normal file
299
pycls/models/nas/nas.py
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""NAS network (adopted from DARTS)."""
|
||||
|
||||
from torch.autograd import Variable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import pycls.core.logging as logging
|
||||
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.common import Preprocess
|
||||
from pycls.models.common import Classifier
|
||||
from pycls.models.nas.genotypes import GENOTYPES
|
||||
from pycls.models.nas.genotypes import Genotype
|
||||
from pycls.models.nas.operations import FactorizedReduce
|
||||
from pycls.models.nas.operations import OPS
|
||||
from pycls.models.nas.operations import ReLUConvBN
|
||||
from pycls.models.nas.operations import Identity
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
"""Drop path (ported from DARTS)."""
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1.-drop_prob
|
||||
mask = Variable(
|
||||
torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
|
||||
)
|
||||
x.div_(keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
"""NAS cell (ported from DARTS)."""
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C))
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
if reduction:
|
||||
op_names, indices = zip(*genotype.reduce)
|
||||
concat = genotype.reduce_concat
|
||||
else:
|
||||
op_names, indices = zip(*genotype.normal)
|
||||
concat = genotype.normal_concat
|
||||
self._compile(C, op_names, indices, concat, reduction)
|
||||
|
||||
def _compile(self, C, op_names, indices, concat, reduction):
|
||||
assert len(op_names) == len(indices)
|
||||
self._steps = len(op_names) // 2
|
||||
self._concat = concat
|
||||
self.multiplier = len(concat)
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for name, index in zip(op_names, indices):
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
op = OPS[name](C, stride, True)
|
||||
self._ops += [op]
|
||||
self._indices = indices
|
||||
|
||||
def forward(self, s0, s1, drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
s = h1 + h2
|
||||
states += [s]
|
||||
return torch.cat([states[i] for i in self._concat], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
class AuxiliaryHeadImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
|
||||
# Commenting it out for consistency with the experiments in the paper.
|
||||
# nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
class NetworkCIFAR(nn.Module):
|
||||
"""CIFAR network (ported from DARTS)."""
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkCIFAR, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
|
||||
stem_multiplier = 3
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = False
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
|
||||
if i == 2*layers//3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
|
||||
self.classifier = Classifier(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
input = Preprocess(input)
|
||||
logits_aux = None
|
||||
s0 = s1 = self.stem(input)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2*self._layers//3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
logits = self.classifier(s1, input.shape[2:])
|
||||
if self._auxiliary and self.training:
|
||||
return logits, logits_aux
|
||||
return logits
|
||||
|
||||
|
||||
class NetworkImageNet(nn.Module):
|
||||
"""ImageNet network (ported from DARTS)."""
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkImageNet, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C, C, C
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = True
|
||||
reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3]
|
||||
for i in range(layers):
|
||||
if i in reduction_layers:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
|
||||
if i == 2 * layers // 3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
|
||||
self.classifier = Classifier(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
input = Preprocess(input)
|
||||
logits_aux = None
|
||||
s0 = self.stem0(input)
|
||||
s1 = self.stem1(s0)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2 * self._layers // 3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
logits = self.classifier(s1, input.shape[2:])
|
||||
if self._auxiliary and self.training:
|
||||
return logits, logits_aux
|
||||
return logits
|
||||
|
||||
|
||||
class NAS(nn.Module):
|
||||
"""NAS net wrapper (delegates to nets from DARTS)."""
|
||||
|
||||
def __init__(self):
|
||||
assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
|
||||
'Training on {} is not supported'.format(cfg.TRAIN.DATASET)
|
||||
assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
|
||||
'Testing on {} is not supported'.format(cfg.TEST.DATASET)
|
||||
assert cfg.NAS.GENOTYPE in GENOTYPES, \
|
||||
'Genotype {} not supported'.format(cfg.NAS.GENOTYPE)
|
||||
super(NAS, self).__init__()
|
||||
logger.info('Constructing NAS: {}'.format(cfg.NAS))
|
||||
# Use a custom or predefined genotype
|
||||
if cfg.NAS.GENOTYPE == 'custom':
|
||||
genotype = Genotype(
|
||||
normal=cfg.NAS.CUSTOM_GENOTYPE[0],
|
||||
normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1],
|
||||
reduce=cfg.NAS.CUSTOM_GENOTYPE[2],
|
||||
reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3],
|
||||
)
|
||||
else:
|
||||
genotype = GENOTYPES[cfg.NAS.GENOTYPE]
|
||||
# Determine the network constructor for dataset
|
||||
if 'cifar' in cfg.TRAIN.DATASET:
|
||||
net_ctor = NetworkCIFAR
|
||||
else:
|
||||
net_ctor = NetworkImageNet
|
||||
# Construct the network
|
||||
self.net_ = net_ctor(
|
||||
C=cfg.NAS.WIDTH,
|
||||
num_classes=cfg.MODEL.NUM_CLASSES,
|
||||
layers=cfg.NAS.DEPTH,
|
||||
auxiliary=cfg.NAS.AUX,
|
||||
genotype=genotype
|
||||
)
|
||||
# Drop path probability (set / annealed based on epoch)
|
||||
self.net_.drop_path_prob = 0.0
|
||||
|
||||
def set_drop_path_prob(self, drop_path_prob):
|
||||
self.net_.drop_path_prob = drop_path_prob
|
||||
|
||||
def forward(self, x):
|
||||
return self.net_.forward(x)
|
201
pycls/models/nas/operations.py
Normal file
201
pycls/models/nas/operations.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
"""NAS ops (adopted from DARTS)."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
OPS = {
|
||||
'none': lambda C, stride, affine:
|
||||
Zero(stride),
|
||||
'avg_pool_2x2': lambda C, stride, affine:
|
||||
nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False),
|
||||
'avg_pool_3x3': lambda C, stride, affine:
|
||||
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
|
||||
'avg_pool_5x5': lambda C, stride, affine:
|
||||
nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False),
|
||||
'max_pool_2x2': lambda C, stride, affine:
|
||||
nn.MaxPool2d(2, stride=stride, padding=0),
|
||||
'max_pool_3x3': lambda C, stride, affine:
|
||||
nn.MaxPool2d(3, stride=stride, padding=1),
|
||||
'max_pool_5x5': lambda C, stride, affine:
|
||||
nn.MaxPool2d(5, stride=stride, padding=2),
|
||||
'max_pool_7x7': lambda C, stride, affine:
|
||||
nn.MaxPool2d(7, stride=stride, padding=3),
|
||||
'skip_connect': lambda C, stride, affine:
|
||||
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
|
||||
'conv_1x1': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'conv_3x3': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'sep_conv_3x3': lambda C, stride, affine:
|
||||
SepConv(C, C, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5': lambda C, stride, affine:
|
||||
SepConv(C, C, 5, stride, 2, affine=affine),
|
||||
'sep_conv_7x7': lambda C, stride, affine:
|
||||
SepConv(C, C, 7, stride, 3, affine=affine),
|
||||
'dil_conv_3x3': lambda C, stride, affine:
|
||||
DilConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5': lambda C, stride, affine:
|
||||
DilConv(C, C, 5, stride, 4, 2, affine=affine),
|
||||
'dil_sep_conv_3x3': lambda C, stride, affine:
|
||||
DilSepConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'conv_3x1_1x3': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
|
||||
nn.Conv2d(C, C, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'conv_7x1_1x7': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
|
||||
nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_out, kernel_size, stride=stride,
|
||||
padding=padding, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilConv(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
|
||||
):
|
||||
super(DilConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=1,
|
||||
padding=padding, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilSepConv(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
|
||||
):
|
||||
super(DilSepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=1,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, affine=True):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
assert C_out % 2 == 0
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
89
pycls/models/regnet.py
Normal file
89
pycls/models/regnet.py
Normal file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""RegNet models."""
|
||||
|
||||
import numpy as np
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.anynet import AnyNet
|
||||
|
||||
|
||||
def quantize_float(f, q):
|
||||
"""Converts a float to closest non-zero int divisible by q."""
|
||||
return int(round(f / q) * q)
|
||||
|
||||
|
||||
def adjust_ws_gs_comp(ws, bms, gs):
|
||||
"""Adjusts the compatibility of widths and groups."""
|
||||
ws_bot = [int(w * b) for w, b in zip(ws, bms)]
|
||||
gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
|
||||
ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
|
||||
ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
|
||||
return ws, gs
|
||||
|
||||
|
||||
def get_stages_from_blocks(ws, rs):
|
||||
"""Gets ws/ds of network at each stage from per block values."""
|
||||
ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
|
||||
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
|
||||
s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
|
||||
s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
|
||||
return s_ws, s_ds
|
||||
|
||||
|
||||
def generate_regnet(w_a, w_0, w_m, d, q=8):
|
||||
"""Generates per block ws from RegNet parameters."""
|
||||
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
|
||||
ws_cont = np.arange(d) * w_a + w_0
|
||||
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
|
||||
ws = w_0 * np.power(w_m, ks)
|
||||
ws = np.round(np.divide(ws, q)) * q
|
||||
num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
|
||||
ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
|
||||
return ws, num_stages, max_stage, ws_cont
|
||||
|
||||
|
||||
class RegNet(AnyNet):
|
||||
"""RegNet model."""
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
"""Convert RegNet to AnyNet parameter format."""
|
||||
# Generate RegNet ws per block
|
||||
w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH
|
||||
ws, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
|
||||
# Convert to per stage format
|
||||
s_ws, s_ds = get_stages_from_blocks(ws, ws)
|
||||
# Use the same gw, bm and ss for each stage
|
||||
s_gs = [cfg.REGNET.GROUP_W for _ in range(num_stages)]
|
||||
s_bs = [cfg.REGNET.BOT_MUL for _ in range(num_stages)]
|
||||
s_ss = [cfg.REGNET.STRIDE for _ in range(num_stages)]
|
||||
# Adjust the compatibility of ws and gws
|
||||
s_ws, s_gs = adjust_ws_gs_comp(s_ws, s_bs, s_gs)
|
||||
# Get AnyNet arguments defining the RegNet
|
||||
return {
|
||||
"stem_type": cfg.REGNET.STEM_TYPE,
|
||||
"stem_w": cfg.REGNET.STEM_W,
|
||||
"block_type": cfg.REGNET.BLOCK_TYPE,
|
||||
"ds": s_ds,
|
||||
"ws": s_ws,
|
||||
"ss": s_ss,
|
||||
"bms": s_bs,
|
||||
"gws": s_gs,
|
||||
"se_r": cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None,
|
||||
"nc": cfg.MODEL.NUM_CLASSES,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
kwargs = RegNet.get_args()
|
||||
super(RegNet, self).__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, **kwargs):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
kwargs = RegNet.get_args() if not kwargs else kwargs
|
||||
return AnyNet.complexity(cx, **kwargs)
|
280
pycls/models/resnet.py
Normal file
280
pycls/models/resnet.py
Normal file
@@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""ResNe(X)t models."""
|
||||
|
||||
import pycls.core.net as net
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
# Stage depths for ImageNet models
|
||||
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
|
||||
|
||||
|
||||
def get_trans_fun(name):
|
||||
"""Retrieves the transformation function by name."""
|
||||
trans_funs = {
|
||||
"basic_transform": BasicTransform,
|
||||
"bottleneck_transform": BottleneckTransform,
|
||||
}
|
||||
err_str = "Transformation function '{}' not supported"
|
||||
assert name in trans_funs.keys(), err_str.format(name)
|
||||
return trans_funs[name]
|
||||
|
||||
|
||||
class ResHead(nn.Module):
|
||||
"""ResNet head: AvgPool, 1x1."""
|
||||
|
||||
def __init__(self, w_in, nc):
|
||||
super(ResHead, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(w_in, nc, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, nc):
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
|
||||
return cx
|
||||
|
||||
|
||||
class BasicTransform(nn.Module):
|
||||
"""Basic transformation: 3x3, BN, ReLU, 3x3, BN."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1):
|
||||
err_str = "Basic transform does not support w_b and num_gs options"
|
||||
assert w_b is None and num_gs == 1, err_str
|
||||
super(BasicTransform, self).__init__()
|
||||
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, w_b=None, num_gs=1):
|
||||
err_str = "Basic transform does not support w_b and num_gs options"
|
||||
assert w_b is None and num_gs == 1, err_str
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class BottleneckTransform(nn.Module):
|
||||
"""Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, w_b, num_gs):
|
||||
super(BottleneckTransform, self).__init__()
|
||||
# MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
|
||||
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
|
||||
self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.c_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, w_b, num_gs):
|
||||
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
|
||||
cx = net.complexity_conv2d(cx, w_in, w_b, 1, s1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_b, 3, s3, 1, num_gs)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block: x + F(x)."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
|
||||
super(ResBlock, self).__init__()
|
||||
# Use skip connection with projection if shape changes
|
||||
self.proj_block = (w_in != w_out) or (stride != 1)
|
||||
if self.proj_block:
|
||||
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.f = trans_fun(w_in, w_out, stride, w_b, num_gs)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
if self.proj_block:
|
||||
x = self.bn(self.proj(x)) + self.f(x)
|
||||
else:
|
||||
x = x + self.f(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, trans_fun, w_b, num_gs):
|
||||
proj_block = (w_in != w_out) or (stride != 1)
|
||||
if proj_block:
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = h, w # parallel branch
|
||||
cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, num_gs)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStage(nn.Module):
|
||||
"""Stage of ResNet."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
|
||||
super(ResStage, self).__init__()
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN)
|
||||
res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs)
|
||||
self.add_module("b{}".format(i + 1), res_block)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.children():
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, d, w_b=None, num_gs=1):
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
trans_f = get_trans_fun(cfg.RESNET.TRANS_FUN)
|
||||
cx = ResBlock.complexity(cx, b_w_in, w_out, b_stride, trans_f, w_b, num_gs)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemCifar(nn.Module):
|
||||
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemCifar, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemIN(nn.Module):
|
||||
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
|
||||
return cx
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
"""ResNet model."""
|
||||
|
||||
def __init__(self):
|
||||
datasets = ["cifar10", "imagenet"]
|
||||
err_str = "Dataset {} is not supported"
|
||||
assert cfg.TRAIN.DATASET in datasets, err_str.format(cfg.TRAIN.DATASET)
|
||||
assert cfg.TEST.DATASET in datasets, err_str.format(cfg.TEST.DATASET)
|
||||
super(ResNet, self).__init__()
|
||||
if "cifar" in cfg.TRAIN.DATASET:
|
||||
self._construct_cifar()
|
||||
else:
|
||||
self._construct_imagenet()
|
||||
self.apply(net.init_weights)
|
||||
|
||||
def _construct_cifar(self):
|
||||
err_str = "Model depth should be of the format 6n + 2 for cifar"
|
||||
assert (cfg.MODEL.DEPTH - 2) % 6 == 0, err_str
|
||||
d = int((cfg.MODEL.DEPTH - 2) / 6)
|
||||
self.stem = ResStemCifar(3, 16)
|
||||
self.s1 = ResStage(16, 16, stride=1, d=d)
|
||||
self.s2 = ResStage(16, 32, stride=2, d=d)
|
||||
self.s3 = ResStage(32, 64, stride=2, d=d)
|
||||
self.head = ResHead(64, nc=cfg.MODEL.NUM_CLASSES)
|
||||
|
||||
def _construct_imagenet(self):
|
||||
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
|
||||
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
|
||||
w_b = gw * g
|
||||
self.stem = ResStemIN(3, 64)
|
||||
self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g)
|
||||
self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g)
|
||||
self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g)
|
||||
self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g)
|
||||
self.head = ResHead(2048, nc=cfg.MODEL.NUM_CLASSES)
|
||||
|
||||
def forward(self, x):
|
||||
for module in self.children():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
if "cifar" in cfg.TRAIN.DATASET:
|
||||
d = int((cfg.MODEL.DEPTH - 2) / 6)
|
||||
cx = ResStemCifar.complexity(cx, 3, 16)
|
||||
cx = ResStage.complexity(cx, 16, 16, stride=1, d=d)
|
||||
cx = ResStage.complexity(cx, 16, 32, stride=2, d=d)
|
||||
cx = ResStage.complexity(cx, 32, 64, stride=2, d=d)
|
||||
cx = ResHead.complexity(cx, 64, nc=cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
|
||||
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
|
||||
w_b = gw * g
|
||||
cx = ResStemIN.complexity(cx, 3, 64)
|
||||
cx = ResStage.complexity(cx, 64, 256, 1, d=d1, w_b=w_b, num_gs=g)
|
||||
cx = ResStage.complexity(cx, 256, 512, 2, d=d2, w_b=w_b * 2, num_gs=g)
|
||||
cx = ResStage.complexity(cx, 512, 1024, 2, d=d3, w_b=w_b * 4, num_gs=g)
|
||||
cx = ResStage.complexity(cx, 1024, 2048, 2, d=d4, w_b=w_b * 8, num_gs=g)
|
||||
cx = ResHead.complexity(cx, 2048, nc=cfg.MODEL.NUM_CLASSES)
|
||||
return cx
|
Reference in New Issue
Block a user