This commit is contained in:
Jack Turner
2021-02-26 16:12:51 +00:00
parent c895924c99
commit b74255e1f3
74 changed files with 11326 additions and 537 deletions

0
pycls/core/__init__.py Normal file
View File

136
pycls/core/benchmark.py Normal file
View File

@@ -0,0 +1,136 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Benchmarking functions."""
import pycls.core.logging as logging
import pycls.datasets.loader as loader
import torch
from pycls.core.config import cfg
from pycls.core.timer import Timer
logger = logging.get_logger(__name__)
@torch.no_grad()
def compute_time_eval(model):
"""Computes precise model forward test time using dummy data."""
# Use eval mode
model.eval()
# Generate a dummy mini-batch and copy data to GPU
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
if cfg.TASK == "jig":
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
else:
inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
# Compute precise forward pass time
timer = Timer()
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
for cur_iter in range(total_iter):
# Reset the timers after the warmup phase
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
timer.reset()
# Forward
timer.tic()
model(inputs)
torch.cuda.synchronize()
timer.toc()
return timer.average_time
def compute_time_train(model, loss_fun):
"""Computes precise model forward + backward time using dummy data."""
# Use train mode
model.train()
# Generate a dummy mini-batch and copy data to GPU
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
if cfg.TASK == "jig":
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
else:
inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
if cfg.TASK in ['col', 'seg']:
labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False)
else:
labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
# Cache BatchNorm2D running stats
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
# Compute precise forward backward pass time
fw_timer, bw_timer = Timer(), Timer()
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
for cur_iter in range(total_iter):
# Reset the timers after the warmup phase
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
fw_timer.reset()
bw_timer.reset()
# Forward
fw_timer.tic()
preds = model(inputs)
if isinstance(preds, tuple):
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
preds = preds[0]
else:
loss = loss_fun(preds, labels)
torch.cuda.synchronize()
fw_timer.toc()
# Backward
bw_timer.tic()
loss.backward()
torch.cuda.synchronize()
bw_timer.toc()
# Restore BatchNorm2D running stats
for bn, (mean, var) in zip(bns, bn_stats):
bn.running_mean, bn.running_var = mean, var
return fw_timer.average_time, bw_timer.average_time
def compute_time_loader(data_loader):
"""Computes loader time."""
timer = Timer()
loader.shuffle(data_loader, 0)
data_loader_iterator = iter(data_loader)
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
total_iter = min(total_iter, len(data_loader))
for cur_iter in range(total_iter):
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
timer.reset()
timer.tic()
next(data_loader_iterator)
timer.toc()
return timer.average_time
def compute_time_full(model, loss_fun, train_loader, test_loader):
"""Times model and data loader."""
logger.info("Computing model and loader timings...")
# Compute timings
test_fw_time = compute_time_eval(model)
train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
train_fw_bw_time = train_fw_time + train_bw_time
train_loader_time = compute_time_loader(train_loader)
# Output iter timing
iter_times = {
"test_fw_time": test_fw_time,
"train_fw_time": train_fw_time,
"train_bw_time": train_bw_time,
"train_fw_bw_time": train_fw_bw_time,
"train_loader_time": train_loader_time,
}
logger.info(logging.dump_log_data(iter_times, "iter_times"))
# Output epoch timing
epoch_times = {
"test_fw_time": test_fw_time * len(test_loader),
"train_fw_time": train_fw_time * len(train_loader),
"train_bw_time": train_bw_time * len(train_loader),
"train_fw_bw_time": train_fw_bw_time * len(train_loader),
"train_loader_time": train_loader_time * len(train_loader),
}
logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
# Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))

88
pycls/core/builders.py Normal file
View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Model and loss construction functions."""
import torch
from pycls.core.config import cfg
from pycls.models.anynet import AnyNet
from pycls.models.effnet import EffNet
from pycls.models.regnet import RegNet
from pycls.models.resnet import ResNet
from pycls.models.nas.nas import NAS
from pycls.models.nas.nas_search import NAS_Search
from pycls.models.nas_bench.model_builder import NAS_Bench
class LabelSmoothedCrossEntropyLoss(torch.nn.Module):
"""CrossEntropyLoss with label smoothing."""
def __init__(self):
super(LabelSmoothedCrossEntropyLoss, self).__init__()
self.eps = cfg.MODEL.LABEL_SMOOTHING_EPS
self.num_classes = cfg.MODEL.NUM_CLASSES
def forward(self, logits, target):
pred = logits.log_softmax(dim=-1)
with torch.no_grad():
target_dist = torch.ones_like(pred) * self.eps / (self.num_classes - 1)
target_dist.scatter_(-1, target.unsqueeze(-1), 1 - self.eps)
return (-target_dist * pred).sum(dim=-1).mean()
# Supported models
_models = {
"anynet": AnyNet,
"effnet": EffNet,
"resnet": ResNet,
"regnet": RegNet,
"nas": NAS,
"nas_search": NAS_Search,
"nas_bench": NAS_Bench,
}
# Supported loss functions
_loss_funs = {
"cross_entropy": torch.nn.CrossEntropyLoss,
"label_smoothed_cross_entropy": LabelSmoothedCrossEntropyLoss,
}
def get_model():
"""Gets the model class specified in the config."""
err_str = "Model type '{}' not supported"
assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
return _models[cfg.MODEL.TYPE]
def get_loss_fun():
"""Gets the loss function class specified in the config."""
err_str = "Loss function type '{}' not supported"
assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
return _loss_funs[cfg.MODEL.LOSS_FUN]
def build_model():
"""Builds the model."""
return get_model()()
def build_loss_fun():
"""Build the loss function."""
if cfg.TASK == "seg":
return get_loss_fun()(ignore_index=255)
else:
return get_loss_fun()()
def register_model(name, ctor):
"""Registers a model dynamically."""
_models[name] = ctor
def register_loss_fun(name, ctor):
"""Registers a loss function dynamically."""
_loss_funs[name] = ctor

98
pycls/core/checkpoint.py Normal file
View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Functions that handle saving and loading of checkpoints."""
import os
import pycls.core.distributed as dist
import torch
from pycls.core.config import cfg
# Common prefix for checkpoint file names
_NAME_PREFIX = "model_epoch_"
# Checkpoints directory name
_DIR_NAME = "checkpoints"
def get_checkpoint_dir():
"""Retrieves the location for storing checkpoints."""
return os.path.join(cfg.OUT_DIR, _DIR_NAME)
def get_checkpoint(epoch):
"""Retrieves the path to a checkpoint file."""
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
return os.path.join(get_checkpoint_dir(), name)
def get_last_checkpoint():
"""Retrieves the most recent checkpoint (highest epoch number)."""
checkpoint_dir = get_checkpoint_dir()
# Checkpoint file names are in lexicographic order
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
last_checkpoint_name = sorted(checkpoints)[-1]
return os.path.join(checkpoint_dir, last_checkpoint_name)
def has_checkpoint():
"""Determines if there are checkpoints available."""
checkpoint_dir = get_checkpoint_dir()
if not os.path.exists(checkpoint_dir):
return False
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
def save_checkpoint(model, optimizer, epoch):
"""Saves a checkpoint."""
# Save checkpoints only from the master process
if not dist.is_master_proc():
return
# Ensure that the checkpoint dir exists
os.makedirs(get_checkpoint_dir(), exist_ok=True)
# Omit the DDP wrapper in the multi-gpu setting
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
# Record the state
if isinstance(optimizer, list):
checkpoint = {
"epoch": epoch,
"model_state": sd,
"optimizer_w_state": optimizer[0].state_dict(),
"optimizer_a_state": optimizer[1].state_dict(),
"cfg": cfg.dump(),
}
else:
checkpoint = {
"epoch": epoch,
"model_state": sd,
"optimizer_state": optimizer.state_dict(),
"cfg": cfg.dump(),
}
# Write the checkpoint
checkpoint_file = get_checkpoint(epoch + 1)
torch.save(checkpoint, checkpoint_file)
return checkpoint_file
def load_checkpoint(checkpoint_file, model, optimizer=None):
"""Loads the checkpoint from the given file."""
err_str = "Checkpoint '{}' not found"
assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
# Load the checkpoint on CPU to avoid GPU mem spike
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# Account for the DDP wrapper in the multi-gpu setting
ms = model.module if cfg.NUM_GPUS > 1 else model
ms.load_state_dict(checkpoint["model_state"])
# Load the optimizer state (commonly not done when fine-tuning)
if optimizer:
if isinstance(optimizer, list):
optimizer[0].load_state_dict(checkpoint["optimizer_w_state"])
optimizer[1].load_state_dict(checkpoint["optimizer_a_state"])
else:
optimizer.load_state_dict(checkpoint["optimizer_state"])
return checkpoint["epoch"]

500
pycls/core/config.py Normal file
View File

@@ -0,0 +1,500 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Configuration file (powered by YACS)."""
import argparse
import os
import sys
from pycls.core.io import cache_url
from yacs.config import CfgNode as CfgNode
# Global config object
_C = CfgNode()
# Example usage:
# from core.config import cfg
cfg = _C
# ------------------------------------------------------------------------------------ #
# Model options
# ------------------------------------------------------------------------------------ #
_C.MODEL = CfgNode()
# Model type
_C.MODEL.TYPE = ""
# Number of weight layers
_C.MODEL.DEPTH = 0
# Number of input channels
_C.MODEL.INPUT_CHANNELS = 3
# Number of classes
_C.MODEL.NUM_CLASSES = 10
# Loss function (see pycls/core/builders.py for options)
_C.MODEL.LOSS_FUN = "cross_entropy"
# Label smoothing eps
_C.MODEL.LABEL_SMOOTHING_EPS = 0.0
# ASPP channels
_C.MODEL.ASPP_CHANNELS = 256
# ASPP dilation rates
_C.MODEL.ASPP_RATES = [6, 12, 18]
# ------------------------------------------------------------------------------------ #
# ResNet options
# ------------------------------------------------------------------------------------ #
_C.RESNET = CfgNode()
# Transformation function (see pycls/models/resnet.py for options)
_C.RESNET.TRANS_FUN = "basic_transform"
# Number of groups to use (1 -> ResNet; > 1 -> ResNeXt)
_C.RESNET.NUM_GROUPS = 1
# Width of each group (64 -> ResNet; 4 -> ResNeXt)
_C.RESNET.WIDTH_PER_GROUP = 64
# Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch)
_C.RESNET.STRIDE_1X1 = True
# ------------------------------------------------------------------------------------ #
# AnyNet options
# ------------------------------------------------------------------------------------ #
_C.ANYNET = CfgNode()
# Stem type
_C.ANYNET.STEM_TYPE = "simple_stem_in"
# Stem width
_C.ANYNET.STEM_W = 32
# Block type
_C.ANYNET.BLOCK_TYPE = "res_bottleneck_block"
# Depth for each stage (number of blocks in the stage)
_C.ANYNET.DEPTHS = []
# Width for each stage (width of each block in the stage)
_C.ANYNET.WIDTHS = []
# Strides for each stage (applies to the first block of each stage)
_C.ANYNET.STRIDES = []
# Bottleneck multipliers for each stage (applies to bottleneck block)
_C.ANYNET.BOT_MULS = []
# Group widths for each stage (applies to bottleneck block)
_C.ANYNET.GROUP_WS = []
# Whether SE is enabled for res_bottleneck_block
_C.ANYNET.SE_ON = False
# SE ratio
_C.ANYNET.SE_R = 0.25
# ------------------------------------------------------------------------------------ #
# RegNet options
# ------------------------------------------------------------------------------------ #
_C.REGNET = CfgNode()
# Stem type
_C.REGNET.STEM_TYPE = "simple_stem_in"
# Stem width
_C.REGNET.STEM_W = 32
# Block type
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
# Stride of each stage
_C.REGNET.STRIDE = 2
# Squeeze-and-Excitation (RegNetY)
_C.REGNET.SE_ON = False
_C.REGNET.SE_R = 0.25
# Depth
_C.REGNET.DEPTH = 10
# Initial width
_C.REGNET.W0 = 32
# Slope
_C.REGNET.WA = 5.0
# Quantization
_C.REGNET.WM = 2.5
# Group width
_C.REGNET.GROUP_W = 16
# Bottleneck multiplier (bm = 1 / b from the paper)
_C.REGNET.BOT_MUL = 1.0
# ------------------------------------------------------------------------------------ #
# EfficientNet options
# ------------------------------------------------------------------------------------ #
_C.EN = CfgNode()
# Stem width
_C.EN.STEM_W = 32
# Depth for each stage (number of blocks in the stage)
_C.EN.DEPTHS = []
# Width for each stage (width of each block in the stage)
_C.EN.WIDTHS = []
# Expansion ratios for MBConv blocks in each stage
_C.EN.EXP_RATIOS = []
# Squeeze-and-Excitation (SE) ratio
_C.EN.SE_R = 0.25
# Strides for each stage (applies to the first block of each stage)
_C.EN.STRIDES = []
# Kernel sizes for each stage
_C.EN.KERNELS = []
# Head width
_C.EN.HEAD_W = 1280
# Drop connect ratio
_C.EN.DC_RATIO = 0.0
# Dropout ratio
_C.EN.DROPOUT_RATIO = 0.0
# ---------------------------------------------------------------------------- #
# NAS options
# ---------------------------------------------------------------------------- #
_C.NAS = CfgNode()
# Cell genotype
_C.NAS.GENOTYPE = 'nas'
# Custom genotype
_C.NAS.CUSTOM_GENOTYPE = []
# Base NAS width
_C.NAS.WIDTH = 16
# Total number of cells
_C.NAS.DEPTH = 20
# Auxiliary heads
_C.NAS.AUX = False
# Weight for auxiliary heads
_C.NAS.AUX_WEIGHT = 0.4
# Drop path probability
_C.NAS.DROP_PROB = 0.0
# Matrix in NAS Bench
_C.NAS.MATRIX = []
# Operations in NAS Bench
_C.NAS.OPS = []
# Number of stacks in NAS Bench
_C.NAS.NUM_STACKS = 3
# Number of modules per stack in NAS Bench
_C.NAS.NUM_MODULES_PER_STACK = 3
# ------------------------------------------------------------------------------------ #
# Batch norm options
# ------------------------------------------------------------------------------------ #
_C.BN = CfgNode()
# BN epsilon
_C.BN.EPS = 1e-5
# BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
_C.BN.MOM = 0.1
# Precise BN stats
_C.BN.USE_PRECISE_STATS = False
_C.BN.NUM_SAMPLES_PRECISE = 1024
# Initialize the gamma of the final BN of each block to zero
_C.BN.ZERO_INIT_FINAL_GAMMA = False
# Use a different weight decay for BN layers
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0
# ------------------------------------------------------------------------------------ #
# Optimizer options
# ------------------------------------------------------------------------------------ #
_C.OPTIM = CfgNode()
# Base learning rate
_C.OPTIM.BASE_LR = 0.1
# Learning rate policy select from {'cos', 'exp', 'steps'}
_C.OPTIM.LR_POLICY = "cos"
# Exponential decay factor
_C.OPTIM.GAMMA = 0.1
# Steps for 'steps' policy (in epochs)
_C.OPTIM.STEPS = []
# Learning rate multiplier for 'steps' policy
_C.OPTIM.LR_MULT = 0.1
# Maximal number of epochs
_C.OPTIM.MAX_EPOCH = 200
# Momentum
_C.OPTIM.MOMENTUM = 0.9
# Momentum dampening
_C.OPTIM.DAMPENING = 0.0
# Nesterov momentum
_C.OPTIM.NESTEROV = True
# L2 regularization
_C.OPTIM.WEIGHT_DECAY = 5e-4
# Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR
_C.OPTIM.WARMUP_FACTOR = 0.1
# Gradually warm up the OPTIM.BASE_LR over this number of epochs
_C.OPTIM.WARMUP_EPOCHS = 0
# Update the learning rate per iter
_C.OPTIM.ITER_LR = False
# Base learning rate for arch
_C.OPTIM.ARCH_BASE_LR = 0.0003
# L2 regularization for arch
_C.OPTIM.ARCH_WEIGHT_DECAY = 0.001
# Optimizer for arch
_C.OPTIM.ARCH_OPTIM = 'adam'
# Epoch to start optimizing arch
_C.OPTIM.ARCH_EPOCH = 0.0
# ------------------------------------------------------------------------------------ #
# Training options
# ------------------------------------------------------------------------------------ #
_C.TRAIN = CfgNode()
# Dataset and split
_C.TRAIN.DATASET = ""
_C.TRAIN.SPLIT = "train"
# Total mini-batch size
_C.TRAIN.BATCH_SIZE = 128
# Image size
_C.TRAIN.IM_SIZE = 224
# Evaluate model on test data every eval period epochs
_C.TRAIN.EVAL_PERIOD = 1
# Save model checkpoint every checkpoint period epochs
_C.TRAIN.CHECKPOINT_PERIOD = 1
# Resume training from the latest checkpoint in the output directory
_C.TRAIN.AUTO_RESUME = True
# Weights to start training from
_C.TRAIN.WEIGHTS = ""
# Percentage of gray images in jig
_C.TRAIN.GRAY_PERCENTAGE = 0.0
# Portion to create trainA/trainB split
_C.TRAIN.PORTION = 1.0
# ------------------------------------------------------------------------------------ #
# Testing options
# ------------------------------------------------------------------------------------ #
_C.TEST = CfgNode()
# Dataset and split
_C.TEST.DATASET = ""
_C.TEST.SPLIT = "val"
# Total mini-batch size
_C.TEST.BATCH_SIZE = 200
# Image size
_C.TEST.IM_SIZE = 256
# Weights to use for testing
_C.TEST.WEIGHTS = ""
# ------------------------------------------------------------------------------------ #
# Common train/test data loader options
# ------------------------------------------------------------------------------------ #
_C.DATA_LOADER = CfgNode()
# Number of data loader workers per process
_C.DATA_LOADER.NUM_WORKERS = 8
# Load data to pinned host memory
_C.DATA_LOADER.PIN_MEMORY = True
# ------------------------------------------------------------------------------------ #
# Memory options
# ------------------------------------------------------------------------------------ #
_C.MEM = CfgNode()
# Perform ReLU inplace
_C.MEM.RELU_INPLACE = True
# ------------------------------------------------------------------------------------ #
# CUDNN options
# ------------------------------------------------------------------------------------ #
_C.CUDNN = CfgNode()
# Perform benchmarking to select the fastest CUDNN algorithms to use
# Note that this may increase the memory usage and will likely not result
# in overall speedups when variable size inputs are used (e.g. COCO training)
_C.CUDNN.BENCHMARK = True
# ------------------------------------------------------------------------------------ #
# Precise timing options
# ------------------------------------------------------------------------------------ #
_C.PREC_TIME = CfgNode()
# Number of iterations to warm up the caches
_C.PREC_TIME.WARMUP_ITER = 3
# Number of iterations to compute avg time
_C.PREC_TIME.NUM_ITER = 30
# ------------------------------------------------------------------------------------ #
# Misc options
# ------------------------------------------------------------------------------------ #
# Number of GPUs to use (applies to both training and testing)
_C.NUM_GPUS = 1
# Task (cls, seg, rot, col, jig)
_C.TASK = "cls"
# Grid in Jigsaw (2, 3); no effect if TASK is not jig
_C.JIGSAW_GRID = 3
# Output directory
_C.OUT_DIR = "/tmp"
# Config destination (in OUT_DIR)
_C.CFG_DEST = "config.yaml"
# Note that non-determinism may still be present due to non-deterministic
# operator implementations in GPU operator libraries
_C.RNG_SEED = 1
# Log destination ('stdout' or 'file')
_C.LOG_DEST = "stdout"
# Log period in iters
_C.LOG_PERIOD = 10
# Distributed backend
_C.DIST_BACKEND = "nccl"
# Hostname and port for initializing multi-process groups
_C.HOST = "localhost"
_C.PORT = 10001
# Models weights referred to by URL are downloaded to this local cache
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
# ------------------------------------------------------------------------------------ #
# Deprecated keys
# ------------------------------------------------------------------------------------ #
_C.register_deprecated_key("PREC_TIME.BATCH_SIZE")
_C.register_deprecated_key("PREC_TIME.ENABLED")
def assert_and_infer_cfg(cache_urls=True):
"""Checks config values invariants."""
err_str = "The first lr step must start at 0"
assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str
data_splits = ["train", "val", "test"]
err_str = "Data split '{}' not supported"
assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT)
assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT)
err_str = "Mini-batch size should be a multiple of NUM_GPUS."
assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
err_str = "Precise BN stats computation not verified for > 1 GPU"
assert not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1, err_str
err_str = "Log destination '{}' not supported"
assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
if cache_urls:
cache_cfg_urls()
def cache_cfg_urls():
"""Download URLs in config, cache them, and rewrite cfg to use cached file."""
_C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)
def dump_cfg():
"""Dumps the config to the output directory."""
cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
with open(cfg_file, "w") as f:
_C.dump(stream=f)
def load_cfg(out_dir, cfg_dest="config.yaml"):
"""Loads config from specified output directory."""
cfg_file = os.path.join(out_dir, cfg_dest)
_C.merge_from_file(cfg_file)
def load_cfg_fom_args(description="Config file options."):
"""Load config from command line arguments and set any specified options."""
parser = argparse.ArgumentParser(description=description)
help_s = "Config file location"
parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str)
help_s = "See pycls/core/config.py for all options"
parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
_C.merge_from_file(args.cfg_file)
_C.merge_from_list(args.opts)

157
pycls/core/distributed.py Normal file
View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Distributed helpers."""
import multiprocessing
import os
import signal
import threading
import traceback
import torch
from pycls.core.config import cfg
def is_master_proc():
"""Determines if the current process is the master process.
Master process is responsible for logging, writing and loading checkpoints. In
the multi GPU setting, we assign the master role to the rank 0 process. When
training using a single GPU, there is a single process which is considered master.
"""
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
def init_process_group(proc_rank, world_size):
"""Initializes the default process group."""
# Set the GPU to use
torch.cuda.set_device(proc_rank)
# Initialize the process group
torch.distributed.init_process_group(
backend=cfg.DIST_BACKEND,
init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT),
world_size=world_size,
rank=proc_rank,
)
def destroy_process_group():
"""Destroys the default process group."""
torch.distributed.destroy_process_group()
def scaled_all_reduce(tensors):
"""Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place. Currently supports only the sum
reduction operator. The reduced values are scaled by the inverse size of the
process group (equivalent to cfg.NUM_GPUS).
"""
# There is no need for reduction in the single-proc case
if cfg.NUM_GPUS == 1:
return tensors
# Queue the reductions
reductions = []
for tensor in tensors:
reduction = torch.distributed.all_reduce(tensor, async_op=True)
reductions.append(reduction)
# Wait for reductions to finish
for reduction in reductions:
reduction.wait()
# Scale the results
for tensor in tensors:
tensor.mul_(1.0 / cfg.NUM_GPUS)
return tensors
class ChildException(Exception):
"""Wraps an exception from a child process."""
def __init__(self, child_trace):
super(ChildException, self).__init__(child_trace)
class ErrorHandler(object):
"""Multiprocessing error handler (based on fairseq's).
Listens for errors in child processes and propagates the tracebacks to the parent.
"""
def __init__(self, error_queue):
# Shared error queue
self.error_queue = error_queue
# Children processes sharing the error queue
self.children_pids = []
# Start a thread listening to errors
self.error_listener = threading.Thread(target=self.listen, daemon=True)
self.error_listener.start()
# Register the signal handler
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
"""Registers a child process."""
self.children_pids.append(pid)
def listen(self):
"""Listens for errors in the error queue."""
# Wait until there is an error in the queue
child_trace = self.error_queue.get()
# Put the error back for the signal handler
self.error_queue.put(child_trace)
# Invoke the signal handler
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, _sig_num, _stack_frame):
"""Signal handler."""
# Kill children processes
for pid in self.children_pids:
os.kill(pid, signal.SIGINT)
# Propagate the error from the child process
raise ChildException(self.error_queue.get())
def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
"""Runs a function from a child process."""
try:
# Initialize the process group
init_process_group(proc_rank, world_size)
# Run the function
fun(*fun_args, **fun_kwargs)
except KeyboardInterrupt:
# Killed by the parent process
pass
except Exception:
# Propagate exception to the parent process
error_queue.put(traceback.format_exc())
finally:
# Destroy the process group
destroy_process_group()
def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
"""Runs a function in a multi-proc setting (unless num_proc == 1)."""
# There is no need for multi-proc in the single-proc case
fun_kwargs = fun_kwargs if fun_kwargs else {}
if num_proc == 1:
fun(*fun_args, **fun_kwargs)
return
# Handle errors from training subprocesses
error_queue = multiprocessing.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Run each training subprocess
ps = []
for i in range(num_proc):
p_i = multiprocessing.Process(
target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs)
)
ps.append(p_i)
p_i.start()
error_handler.add_child(p_i.pid)
# Wait for each subprocess to finish
for p in ps:
p.join()

77
pycls/core/io.py Normal file
View File

@@ -0,0 +1,77 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""IO utilities (adapted from Detectron)"""
import logging
import os
import re
import sys
from urllib import request as urlrequest
logger = logging.getLogger(__name__)
_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
def cache_url(url_or_file, cache_dir):
"""Download the file specified by the URL to the cache_dir and return the path to
the cached file. If the argument is not a URL, simply return it as is.
"""
is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
if not is_url:
return url_or_file
url = url_or_file
err_str = "pycls only automatically caches URLs in the pycls S3 bucket: {}"
assert url.startswith(_PYCLS_BASE_URL), err_str.format(_PYCLS_BASE_URL)
cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir)
if os.path.exists(cache_file_path):
return cache_file_path
cache_file_dir = os.path.dirname(cache_file_path)
if not os.path.exists(cache_file_dir):
os.makedirs(cache_file_dir)
logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
download_url(url, cache_file_path)
return cache_file_path
def _progress_bar(count, total):
"""Report download progress. Credit:
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
"""
bar_len = 60
filled_len = int(round(bar_len * count / float(total)))
percents = round(100.0 * count / float(total), 1)
bar = "=" * filled_len + "-" * (bar_len - filled_len)
sys.stdout.write(
" [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
)
sys.stdout.flush()
if count >= total:
sys.stdout.write("\n")
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
"""Download url and write it to dst_file_path. Credit:
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
"""
req = urlrequest.Request(url)
response = urlrequest.urlopen(req)
total_size = response.info().get("Content-Length").strip()
total_size = int(total_size)
bytes_so_far = 0
with open(dst_file_path, "wb") as f:
while 1:
chunk = response.read(chunk_size)
bytes_so_far += len(chunk)
if not chunk:
break
if progress_hook:
progress_hook(bytes_so_far, total_size)
f.write(chunk)
return bytes_so_far

138
pycls/core/logging.py Normal file
View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Logging."""
import builtins
import decimal
import logging
import os
import sys
import pycls.core.distributed as dist
import simplejson
from pycls.core.config import cfg
# Show filename and line number in logs
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
# Log file name (for cfg.LOG_DEST = 'file')
_LOG_FILE = "stdout.log"
# Data output with dump_log_data(data, data_type) will be tagged w/ this
_TAG = "json_stats: "
# Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
_TYPE = "_type"
def _suppress_print():
"""Suppresses printing from the current process."""
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
pass
builtins.print = ignore
def setup_logging():
"""Sets up the logging."""
# Enable logging only for the master process
if dist.is_master_proc():
# Clear the root logger to prevent any existing logging config
# (e.g. set by another module) from messing with our setup
logging.root.handlers = []
# Construct logging configuration
logging_config = {"level": logging.INFO, "format": _FORMAT}
# Log either to stdout or to a file
if cfg.LOG_DEST == "stdout":
logging_config["stream"] = sys.stdout
else:
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
# Configure logging
logging.basicConfig(**logging_config)
else:
_suppress_print()
def get_logger(name):
"""Retrieves the logger."""
return logging.getLogger(name)
def dump_log_data(data, data_type, prec=4):
"""Covert data (a dictionary) into tagged json string for logging."""
data[_TYPE] = data_type
data = float_to_decimal(data, prec)
data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
return "{:s}{:s}".format(_TAG, data_json)
def float_to_decimal(data, prec=4):
"""Convert floats to decimals which allows for fixed width json."""
if isinstance(data, dict):
return {k: float_to_decimal(v, prec) for k, v in data.items()}
if isinstance(data, float):
return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
else:
return data
def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
"""Get all log files in directory containing subdirs of trained models."""
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
files = [os.path.join(log_dir, n, log_file) for n in names]
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
files, names = zip(*f_n_ps) if f_n_ps else ([], [])
return files, names
def load_log_data(log_file, data_types_to_skip=()):
"""Loads log data into a dictionary of the form data[data_type][metric][index]."""
# Load log_file
assert os.path.exists(log_file), "Log file not found: {}".format(log_file)
with open(log_file, "r") as f:
lines = f.readlines()
# Extract and parse lines that start with _TAG and have a type specified
lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
lines = [simplejson.loads(l) for l in lines]
lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip]
# Generate data structure accessed by data[data_type][index][metric]
data_types = [l[_TYPE] for l in lines]
data = {t: [] for t in data_types}
for t, line in zip(data_types, lines):
del line[_TYPE]
data[t].append(line)
# Generate data structure accessed by data[data_type][metric][index]
for t in data:
metrics = sorted(data[t][0].keys())
err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics)
assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
data[t] = {m: [d[m] for d in data[t]] for m in metrics}
return data
def sort_log_data(data):
"""Sort each data[data_type][metric] by epoch or keep only first instance."""
for t in data:
if "epoch" in data[t]:
assert "epoch_ind" not in data[t] and "epoch_max" not in data[t]
data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]]
data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]]
epoch = data[t]["epoch_ind"]
if "iter" in data[t]:
assert "iter_ind" not in data[t] and "iter_max" not in data[t]
data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]]
data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]]
itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"])
epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr]
for m in data[t]:
data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))]
else:
data[t] = {m: d[0] for m, d in data[t].items()}
return data

435
pycls/core/meters.py Normal file
View File

@@ -0,0 +1,435 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Meters."""
from collections import deque
import numpy as np
import pycls.core.logging as logging
import torch
from pycls.core.config import cfg
from pycls.core.timer import Timer
logger = logging.get_logger(__name__)
def time_string(seconds):
"""Converts time in seconds to a fixed-width string format."""
days, rem = divmod(int(seconds), 24 * 3600)
hrs, rem = divmod(rem, 3600)
mins, secs = divmod(rem, 60)
return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs)
def inter_union(preds, labels, num_classes):
_, preds = torch.max(preds, 1)
preds = preds.type(torch.uint8) + 1
labels = labels.type(torch.uint8) + 1
preds = preds * (labels > 0).type(torch.uint8)
inter = preds * (preds == labels).type(torch.uint8)
area_inter = torch.histc(inter.type(torch.int64), bins=num_classes, min=1, max=num_classes)
area_preds = torch.histc(preds.type(torch.int64), bins=num_classes, min=1, max=num_classes)
area_labels = torch.histc(labels.type(torch.int64), bins=num_classes, min=1, max=num_classes)
area_union = area_preds + area_labels - area_inter
return [area_inter.type(torch.float64) / labels.size(0), area_union.type(torch.float64) / labels.size(0)]
def topk_errors(preds, labels, ks):
"""Computes the top-k error for each k."""
err_str = "Batch dim of predictions and labels must match"
assert preds.size(0) == labels.size(0), err_str
# Find the top max_k predictions for each sample
_top_max_k_vals, top_max_k_inds = torch.topk(
preds, max(ks), dim=1, largest=True, sorted=True
)
# (batch_size, max_k) -> (max_k, batch_size)
top_max_k_inds = top_max_k_inds.t()
# (batch_size, ) -> (max_k, batch_size)
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
# (i, j) = 1 if top i-th prediction for the j-th sample is correct
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
# Compute the number of topk correct predictions for each k
topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks]
return [(1.0 - x / preds.size(0)) * 100.0 for x in topks_correct]
def gpu_mem_usage():
"""Computes the GPU memory usage for the current device (MB)."""
mem_usage_bytes = torch.cuda.max_memory_allocated()
return mem_usage_bytes / 1024 / 1024
class ScalarMeter(object):
"""Measures a scalar value (adapted from Detectron)."""
def __init__(self, window_size):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
def reset(self):
self.deque.clear()
self.total = 0.0
self.count = 0
def add_value(self, value):
self.deque.append(value)
self.count += 1
self.total += value
def get_win_median(self):
return np.median(self.deque)
def get_win_avg(self):
return np.mean(self.deque)
def get_global_avg(self):
return self.total / self.count
class TrainMeter(object):
"""Measures training stats."""
def __init__(self, epoch_iters):
self.epoch_iters = epoch_iters
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
self.iter_timer = Timer()
self.loss = ScalarMeter(cfg.LOG_PERIOD)
self.loss_total = 0.0
self.lr = None
# Current minibatch errors (smoothed over a window)
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Number of misclassified examples
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def reset(self, timer=False):
if timer:
self.iter_timer.reset()
self.loss.reset()
self.loss_total = 0.0
self.lr = None
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
# Current minibatch stats
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
self.loss.add_value(loss)
self.lr = lr
# Aggregate stats
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
self.loss_total += loss * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"eta": time_string(eta_sec),
"top1_err": self.mb_top1_err.get_win_median(),
"top5_err": self.mb_top5_err.get_win_median(),
"loss": self.loss.get_win_median(),
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "train_iter"))
def get_epoch_stats(self, cur_epoch):
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
avg_loss = self.loss_total / self.num_samples
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"eta": time_string(eta_sec),
"top1_err": top1_err,
"top5_err": top5_err,
"loss": avg_loss,
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "train_epoch"))
class TestMeter(object):
"""Measures testing stats."""
def __init__(self, max_iter):
self.max_iter = max_iter
self.iter_timer = Timer()
# Current minibatch errors (smoothed over a window)
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Min errors (over the full test set)
self.min_top1_err = 100.0
self.min_top5_err = 100.0
# Number of misclassified examples
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def reset(self, min_errs=False):
if min_errs:
self.min_top1_err = 100.0
self.min_top5_err = 100.0
self.iter_timer.reset()
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, top1_err, top5_err, mb_size):
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
mem_usage = gpu_mem_usage()
iter_stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"top1_err": self.mb_top1_err.get_win_median(),
"top5_err": self.mb_top5_err.get_win_median(),
"mem": int(np.ceil(mem_usage)),
}
return iter_stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "test_iter"))
def get_epoch_stats(self, cur_epoch):
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
self.min_top1_err = min(self.min_top1_err, top1_err)
self.min_top5_err = min(self.min_top5_err, top5_err)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"top1_err": top1_err,
"top5_err": top5_err,
"min_top1_err": self.min_top1_err,
"min_top5_err": self.min_top5_err,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "test_epoch"))
class TrainMeterIoU(object):
"""Measures training stats."""
def __init__(self, epoch_iters):
self.epoch_iters = epoch_iters
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
self.iter_timer = Timer()
self.loss = ScalarMeter(cfg.LOG_PERIOD)
self.loss_total = 0.0
self.lr = None
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def reset(self, timer=False):
if timer:
self.iter_timer.reset()
self.loss.reset()
self.loss_total = 0.0
self.lr = None
self.mb_miou.reset()
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, inter, union, loss, lr, mb_size):
# Current minibatch stats
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
self.loss.add_value(loss)
self.lr = lr
# Aggregate stats
self.num_inter += inter * mb_size
self.num_union += union * mb_size
self.loss_total += loss * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"eta": time_string(eta_sec),
"miou": self.mb_miou.get_win_median(),
"loss": self.loss.get_win_median(),
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "train_iter"))
def get_epoch_stats(self, cur_epoch):
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
avg_loss = self.loss_total / self.num_samples
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"eta": time_string(eta_sec),
"miou": miou,
"loss": avg_loss,
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "train_epoch"))
class TestMeterIoU(object):
"""Measures testing stats."""
def __init__(self, max_iter):
self.max_iter = max_iter
self.iter_timer = Timer()
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
self.max_miou = 0.0
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def reset(self, min_errs=False):
if min_errs:
self.max_miou = 0.0
self.iter_timer.reset()
self.mb_miou.reset()
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, inter, union, mb_size):
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
self.num_inter += inter * mb_size
self.num_union += union * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
mem_usage = gpu_mem_usage()
iter_stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"miou": self.mb_miou.get_win_median(),
"mem": int(np.ceil(mem_usage)),
}
return iter_stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "test_iter"))
def get_epoch_stats(self, cur_epoch):
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
self.max_miou = max(self.max_miou, miou)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"miou": miou,
"max_miou": self.max_miou,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "test_epoch"))

129
pycls/core/net.py Normal file
View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Functions for manipulating networks."""
import itertools
import math
import torch
import torch.nn as nn
from pycls.core.config import cfg
def init_weights(m):
"""Performs ResNet-style weight initialization."""
if isinstance(m, nn.Conv2d):
# Note that there is no bias due to BN
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
elif isinstance(m, nn.BatchNorm2d):
zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA
zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=0.01)
m.bias.data.zero_()
@torch.no_grad()
def compute_precise_bn_stats(model, loader):
"""Computes precise BN stats on training data."""
# Compute the number of minibatches to use
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
# Retrieve the BN layers
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
# Initialize stats storage
mus = [torch.zeros_like(bn.running_mean) for bn in bns]
sqs = [torch.zeros_like(bn.running_var) for bn in bns]
# Remember momentum values
moms = [bn.momentum for bn in bns]
# Disable momentum
for bn in bns:
bn.momentum = 1.0
# Accumulate the stats across the data samples
for inputs, _labels in itertools.islice(loader, num_iter):
model(inputs.cuda())
# Accumulate the stats for each BN layer
for i, bn in enumerate(bns):
m, v = bn.running_mean, bn.running_var
sqs[i] += (v + m * m) / num_iter
mus[i] += m / num_iter
# Set the stats and restore momentum values
for i, bn in enumerate(bns):
bn.running_var = sqs[i] - mus[i] * mus[i]
bn.running_mean = mus[i]
bn.momentum = moms[i]
def reset_bn_stats(model):
"""Resets running BN stats."""
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.reset_running_stats()
def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False):
"""Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
h = (h + 2 * padding - k) // stride + 1
w = (w + 2 * padding - k) // stride + 1
flops += k * k * w_in * w_out * h * w // groups
params += k * k * w_in * w_out // groups
flops += w_out if bias else 0
params += w_out if bias else 0
acts += w_out * h * w
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def complexity_batchnorm2d(cx, w_in):
"""Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
params += 2 * w_in
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def complexity_maxpool2d(cx, k, stride, padding):
"""Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
h = (h + 2 * padding - k) // stride + 1
w = (w + 2 * padding - k) // stride + 1
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def complexity(model):
"""Compute model complexity (model can be model instance or model class)."""
size = cfg.TRAIN.IM_SIZE
cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
cx = model.complexity(cx)
return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}
def drop_connect(x, drop_ratio):
"""Drop connect (adapted from DARTS)."""
keep_ratio = 1.0 - drop_ratio
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
mask.bernoulli_(keep_ratio)
x.div_(keep_ratio)
x.mul_(mask)
return x
def get_flat_weights(model):
"""Gets all model weights as a single flat vector."""
return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
def set_flat_weights(model, flat_weights):
"""Sets all model weights from a single flat vector."""
k = 0
for p in model.parameters():
n = p.data.numel()
p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
k += n
assert k == flat_weights.numel()

95
pycls/core/optimizer.py Normal file
View File

@@ -0,0 +1,95 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Optimizer."""
import numpy as np
import torch
from pycls.core.config import cfg
def construct_optimizer(model):
"""Constructs the optimizer.
Note that the momentum update in PyTorch differs from the one in Caffe2.
In particular,
Caffe2:
V := mu * V + lr * g
p := p - V
PyTorch:
V := mu * V + g
p := p - lr * V
where V is the velocity, mu is the momentum factor, lr is the learning rate,
g is the gradient and p are the parameters.
Since V is defined independently of the learning rate in PyTorch,
when the learning rate is changed there is no need to perform the
momentum correction by scaling V (unlike in the Caffe2 case).
"""
if cfg.BN.USE_CUSTOM_WEIGHT_DECAY:
# Apply different weight decay to Batchnorm and non-batchnorm parameters.
p_bn = [p for n, p in model.named_parameters() if "bn" in n]
p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n]
optim_params = [
{"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY},
{"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
]
else:
optim_params = model.parameters()
return torch.optim.SGD(
optim_params,
lr=cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV,
)
def lr_fun_steps(cur_epoch):
"""Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind)
def lr_fun_exp(cur_epoch):
"""Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch)
def lr_fun_cos(cur_epoch):
"""Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH
return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch))
def get_lr_fun():
"""Retrieves the specified lr policy function"""
lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
if lr_fun not in globals():
raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY)
return globals()[lr_fun]
def get_epoch_lr(cur_epoch):
"""Retrieves the lr for the given epoch according to the policy."""
lr = get_lr_fun()(cur_epoch)
# Linear warmup
if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS:
alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
lr *= warmup_factor
return lr
def set_lr(optimizer, new_lr):
"""Sets the optimizer lr to the specified value."""
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr

132
pycls/core/plotting.py Normal file
View File

@@ -0,0 +1,132 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Plotting functions."""
import colorlover as cl
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.offline as offline
import pycls.core.logging as logging
def get_plot_colors(max_colors, color_format="pyplot"):
"""Generate colors for plotting."""
colors = cl.scales["11"]["qual"]["Paired"]
if max_colors > len(colors):
colors = cl.to_rgb(cl.interp(colors, max_colors))
if color_format == "pyplot":
return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
return colors
def prepare_plot_data(log_files, names, metric="top1_err"):
"""Load logs and extract data for plotting error curves."""
plot_data = []
for file, name in zip(log_files, names):
d, data = {}, logging.sort_log_data(logging.load_log_data(file))
for phase in ["train", "test"]:
x = data[phase + "_epoch"]["epoch_ind"]
y = data[phase + "_epoch"][metric]
d["x_" + phase], d["y_" + phase] = x, y
d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
plot_data.append(d)
assert len(plot_data) > 0, "No data to plot"
return plot_data
def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"):
"""Plot error curves using plotly and save to file."""
plot_data = prepare_plot_data(log_files, names, metric)
colors = get_plot_colors(len(plot_data), "plotly")
# Prepare data for plots (3 sets, train duplicated w and w/o legend)
data = []
for i, d in enumerate(plot_data):
s = str(i)
line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=True,
showlegend=False,
)
)
data.append(
go.Scatter(
x=d["x_test"],
y=d["y_test"],
mode="lines",
name=d["test_label"],
line=line_test,
legendgroup=s,
visible=True,
showlegend=True,
)
)
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=False,
showlegend=True,
)
)
# Prepare layout w ability to toggle 'all', 'train', 'test'
titlefont = {"size": 18, "color": "#7f7f7f"}
vis = [[True, True, False], [False, False, True], [False, True, False]]
buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons]
layout = go.Layout(
title=metric + " vs. epoch<br>[dash=train, solid=test]",
xaxis={"title": "epoch", "titlefont": titlefont},
yaxis={"title": metric, "titlefont": titlefont},
showlegend=True,
hoverlabel={"namelength": -1},
updatemenus=[
{
"buttons": buttons,
"direction": "down",
"showactive": True,
"x": 1.02,
"xanchor": "left",
"y": 1.08,
"yanchor": "top",
}
],
)
# Create plotly plot
offline.plot({"data": data, "layout": layout}, filename=filename)
def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"):
"""Plot error curves using matplotlib.pyplot and save to file."""
plot_data = prepare_plot_data(log_files, names, metric)
colors = get_plot_colors(len(names))
for ind, d in enumerate(plot_data):
c, lbl = colors[ind], d["test_label"]
plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
plt.xlabel("epoch", fontsize=14)
plt.ylabel(metric, fontsize=14)
plt.grid(alpha=0.4)
plt.legend()
if filename:
plt.savefig(filename)
plt.clf()
else:
plt.show()

39
pycls/core/timer.py Normal file
View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Timer."""
import time
class Timer(object):
"""A simple timer (adapted from Detectron)."""
def __init__(self):
self.total_time = None
self.calls = None
self.start_time = None
self.diff = None
self.average_time = None
self.reset()
def tic(self):
# using time.time as time.clock does not normalize for multithreading
self.start_time = time.time()
def toc(self):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
def reset(self):
self.total_time = 0.0
self.calls = 0
self.start_time = 0.0
self.diff = 0.0
self.average_time = 0.0

419
pycls/core/trainer.py Normal file
View File

@@ -0,0 +1,419 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Tools for training and testing a model."""
import os
from thop import profile
import numpy as np
import pycls.core.benchmark as benchmark
import pycls.core.builders as builders
import pycls.core.checkpoint as checkpoint
import pycls.core.config as config
import pycls.core.distributed as dist
import pycls.core.logging as logging
import pycls.core.meters as meters
import pycls.core.net as net
import pycls.core.optimizer as optim
import pycls.datasets.loader as loader
import torch
import torch.nn.functional as F
from pycls.core.config import cfg
logger = logging.get_logger(__name__)
def setup_env():
"""Sets up environment for training or testing."""
if dist.is_master_proc():
# Ensure that the output dir exists
os.makedirs(cfg.OUT_DIR, exist_ok=True)
# Save the config
config.dump_cfg()
# Setup logging
logging.setup_logging()
# Log the config as both human readable and as a json
logger.info("Config:\n{}".format(cfg))
logger.info(logging.dump_log_data(cfg, "cfg"))
# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
# Configure the CUDNN backend
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
def setup_model():
"""Sets up a model for training or testing and log the results."""
# Build the model
model = builders.build_model()
logger.info("Model:\n{}".format(model))
# Log model complexity
# logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
if cfg.TASK == "seg" and cfg.TRAIN.DATASET == "cityscapes":
h, w = 1025, 2049
else:
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
if cfg.TASK == "jig":
x = torch.randn(1, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, h, w)
else:
x = torch.randn(1, cfg.MODEL.INPUT_CHANNELS, h, w)
macs, params = profile(model, inputs=(x, ), verbose=False)
logger.info("Params: {:,}".format(params))
logger.info("Flops: {:,}".format(macs))
# Transfer the model to the current GPU device
err_str = "Cannot use more GPU devices than available"
assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str
cur_device = torch.cuda.current_device()
model = model.cuda(device=cur_device)
# Use multi-process data parallel model in the multi-gpu setting
if cfg.NUM_GPUS > 1:
# Make model replica operate on the current device
model = torch.nn.parallel.DistributedDataParallel(
module=model, device_ids=[cur_device], output_device=cur_device
)
# Set complexity function to be module's complexity function
# model.complexity = model.module.complexity
return model
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
"""Performs one epoch of training."""
# Update drop path prob for NAS
if cfg.MODEL.TYPE == "nas":
m = model.module if cfg.NUM_GPUS > 1 else model
m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH)
# Shuffle the data
loader.shuffle(train_loader, cur_epoch)
# Update the learning rate per epoch
if not cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch)
optim.set_lr(optimizer, lr)
# Enable training mode
model.train()
train_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(train_loader):
# Update the learning rate per iter
if cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader))
optim.set_lr(optimizer, lr)
# Transfer the data to the current GPU device
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# Perform the forward pass
preds = model(inputs)
# Compute the loss
if isinstance(preds, tuple):
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
preds = preds[0]
else:
loss = loss_fun(preds, labels)
# Perform the backward pass
optimizer.zero_grad()
loss.backward()
# Update the parameters
optimizer.step()
# Compute the errors
if cfg.TASK == "col":
preds = preds.permute(0, 2, 3, 1)
preds = preds.reshape(-1, preds.size(3))
labels = labels.reshape(-1)
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
else:
mb_size = inputs.size(0) * cfg.NUM_GPUS
if cfg.TASK == "seg":
# top1_err is in fact inter; top5_err is in fact union
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
else:
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
# Combine the stats across the GPUs (no reduction if 1 GPU used)
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
# Copy the stats from GPU to CPU (sync point)
loss = loss.item()
if cfg.TASK == "seg":
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
else:
top1_err, top5_err = top1_err.item(), top5_err.item()
train_meter.iter_toc()
# Update and log stats
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
train_meter.log_iter_stats(cur_epoch, cur_iter)
train_meter.iter_tic()
# Log epoch stats
train_meter.log_epoch_stats(cur_epoch)
train_meter.reset()
def search_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
"""Performs one epoch of differentiable architecture search."""
m = model.module if cfg.NUM_GPUS > 1 else model
# Shuffle the data
loader.shuffle(train_loader[0], cur_epoch)
loader.shuffle(train_loader[1], cur_epoch)
# Update the learning rate per epoch
if not cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch)
optim.set_lr(optimizer[0], lr)
# Enable training mode
model.train()
train_meter.iter_tic()
trainB_iter = iter(train_loader[1])
for cur_iter, (inputs, labels) in enumerate(train_loader[0]):
# Update the learning rate per iter
if cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader[0]))
optim.set_lr(optimizer[0], lr)
# Transfer the data to the current GPU device
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# Update architecture
if cur_epoch + cur_iter / len(train_loader[0]) >= cfg.OPTIM.ARCH_EPOCH:
try:
inputsB, labelsB = next(trainB_iter)
except StopIteration:
trainB_iter = iter(train_loader[1])
inputsB, labelsB = next(trainB_iter)
inputsB, labelsB = inputsB.cuda(), labelsB.cuda(non_blocking=True)
optimizer[1].zero_grad()
loss = m._loss(inputsB, labelsB)
loss.backward()
optimizer[1].step()
# Perform the forward pass
preds = model(inputs)
# Compute the loss
loss = loss_fun(preds, labels)
# Perform the backward pass
optimizer[0].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
# Update the parameters
optimizer[0].step()
# Compute the errors
if cfg.TASK == "col":
preds = preds.permute(0, 2, 3, 1)
preds = preds.reshape(-1, preds.size(3))
labels = labels.reshape(-1)
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
else:
mb_size = inputs.size(0) * cfg.NUM_GPUS
if cfg.TASK == "seg":
# top1_err is in fact inter; top5_err is in fact union
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
else:
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
# Combine the stats across the GPUs (no reduction if 1 GPU used)
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
# Copy the stats from GPU to CPU (sync point)
loss = loss.item()
if cfg.TASK == "seg":
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
else:
top1_err, top5_err = top1_err.item(), top5_err.item()
train_meter.iter_toc()
# Update and log stats
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
train_meter.log_iter_stats(cur_epoch, cur_iter)
train_meter.iter_tic()
# Log epoch stats
train_meter.log_epoch_stats(cur_epoch)
train_meter.reset()
# Log genotype
genotype = m.genotype()
logger.info("genotype = %s", genotype)
logger.info(F.softmax(m.net_.alphas_normal, dim=-1))
logger.info(F.softmax(m.net_.alphas_reduce, dim=-1))
@torch.no_grad()
def test_epoch(test_loader, model, test_meter, cur_epoch):
"""Evaluates the model on the test set."""
# Enable eval mode
model.eval()
test_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(test_loader):
# Transfer the data to the current GPU device
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# Compute the predictions
preds = model(inputs)
# Compute the errors
if cfg.TASK == "col":
preds = preds.permute(0, 2, 3, 1)
preds = preds.reshape(-1, preds.size(3))
labels = labels.reshape(-1)
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
else:
mb_size = inputs.size(0) * cfg.NUM_GPUS
if cfg.TASK == "seg":
# top1_err is in fact inter; top5_err is in fact union
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
else:
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
# Combine the errors across the GPUs (no reduction if 1 GPU used)
top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
# Copy the errors from GPU to CPU (sync point)
if cfg.TASK == "seg":
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
else:
top1_err, top5_err = top1_err.item(), top5_err.item()
test_meter.iter_toc()
# Update and log stats
test_meter.update_stats(top1_err, top5_err, mb_size)
test_meter.log_iter_stats(cur_epoch, cur_iter)
test_meter.iter_tic()
# Log epoch stats
test_meter.log_epoch_stats(cur_epoch)
test_meter.reset()
def train_model():
"""Trains the model."""
# Setup training/testing environment
setup_env()
# Construct the model, loss_fun, and optimizer
model = setup_model()
loss_fun = builders.build_loss_fun().cuda()
if "search" in cfg.MODEL.TYPE:
params_w = [v for k, v in model.named_parameters() if "alphas" not in k]
params_a = [v for k, v in model.named_parameters() if "alphas" in k]
optimizer_w = torch.optim.SGD(
params=params_w,
lr=cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV
)
if cfg.OPTIM.ARCH_OPTIM == "adam":
optimizer_a = torch.optim.Adam(
params=params_a,
lr=cfg.OPTIM.ARCH_BASE_LR,
betas=(0.5, 0.999),
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY
)
elif cfg.OPTIM.ARCH_OPTIM == "sgd":
optimizer_a = torch.optim.SGD(
params=params_a,
lr=cfg.OPTIM.ARCH_BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV
)
optimizer = [optimizer_w, optimizer_a]
else:
optimizer = optim.construct_optimizer(model)
# Load checkpoint or initial weights
start_epoch = 0
if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
last_checkpoint = checkpoint.get_last_checkpoint()
checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer)
logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
start_epoch = checkpoint_epoch + 1
elif cfg.TRAIN.WEIGHTS:
checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
# Create data loaders and meters
if cfg.TRAIN.PORTION < 1:
if "search" in cfg.MODEL.TYPE:
train_loader = [loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=True,
drop_last=True,
portion=cfg.TRAIN.PORTION,
side="l"
),
loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=True,
drop_last=True,
portion=cfg.TRAIN.PORTION,
side="r"
)]
else:
train_loader = loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=True,
drop_last=True,
portion=cfg.TRAIN.PORTION,
side="l"
)
test_loader = loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=False,
drop_last=False,
portion=cfg.TRAIN.PORTION,
side="r"
)
else:
train_loader = loader.construct_train_loader()
test_loader = loader.construct_test_loader()
train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter
test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter
l = train_loader[0] if isinstance(train_loader, list) else train_loader
train_meter = train_meter_type(len(l))
test_meter = test_meter_type(len(test_loader))
# Compute model and loader timings
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
l = train_loader[0] if isinstance(train_loader, list) else train_loader
benchmark.compute_time_full(model, loss_fun, l, test_loader)
# Perform the training loop
logger.info("Start epoch: {}".format(start_epoch + 1))
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
# Train for one epoch
f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch
f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
# Compute precise BN stats
if cfg.BN.USE_PRECISE_STATS:
net.compute_precise_bn_stats(model, train_loader)
# Save a checkpoint
if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch)
logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
# Evaluate the model
next_epoch = cur_epoch + 1
if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
test_epoch(test_loader, model, test_meter, cur_epoch)
def test_model():
"""Evaluates a trained model."""
# Setup training/testing environment
setup_env()
# Construct the model
model = setup_model()
# Load model weights
checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
# Create data loaders and meters
test_loader = loader.construct_test_loader()
test_meter = meters.TestMeter(len(test_loader))
# Evaluate the model
test_epoch(test_loader, model, test_meter, 0)
def time_model():
"""Times model and data loader."""
# Setup training/testing environment
setup_env()
# Construct the model and loss_fun
model = setup_model()
loss_fun = builders.build_loss_fun().cuda()
# Create data loaders
train_loader = loader.construct_train_loader()
test_loader = loader.construct_test_loader()
# Compute model and loader timings
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)

0
pycls/models/__init__.py Normal file
View File

406
pycls/models/anynet.py Normal file
View File

@@ -0,0 +1,406 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""AnyNet models."""
import pycls.core.net as net
import torch.nn as nn
from pycls.core.config import cfg
def get_stem_fun(stem_type):
"""Retrieves the stem function by name."""
stem_funs = {
"res_stem_cifar": ResStemCifar,
"res_stem_in": ResStemIN,
"simple_stem_in": SimpleStemIN,
}
err_str = "Stem type '{}' not supported"
assert stem_type in stem_funs.keys(), err_str.format(stem_type)
return stem_funs[stem_type]
def get_block_fun(block_type):
"""Retrieves the block function by name."""
block_funs = {
"vanilla_block": VanillaBlock,
"res_basic_block": ResBasicBlock,
"res_bottleneck_block": ResBottleneckBlock,
}
err_str = "Block type '{}' not supported"
assert block_type in block_funs.keys(), err_str.format(block_type)
return block_funs[block_type]
class AnyHead(nn.Module):
"""AnyNet head: AvgPool, 1x1."""
def __init__(self, w_in, nc):
super(AnyHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(w_in, nc, bias=True)
def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
@staticmethod
def complexity(cx, w_in, nc):
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
return cx
class VanillaBlock(nn.Module):
"""Vanilla block: [3x3 conv, BN, Relu] x2."""
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Vanilla block does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
super(VanillaBlock, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Vanilla block does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class BasicTransform(nn.Module):
"""Basic transformation: [3x3 conv, BN, Relu] x2."""
def __init__(self, w_in, w_out, stride):
super(BasicTransform, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBasicBlock(nn.Module):
"""Residual basic block: x + F(x), F = basic transform."""
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Basic transform does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
super(ResBasicBlock, self).__init__()
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = BasicTransform(w_in, w_out, stride)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Basic transform does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = BasicTransform.complexity(cx, w_in, w_out, stride)
return cx
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block: AvgPool, FC, ReLU, FC, Sigmoid."""
def __init__(self, w_in, w_se):
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, 1, bias=True),
nn.ReLU(inplace=cfg.MEM.RELU_INPLACE),
nn.Conv2d(w_se, w_in, 1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
@staticmethod
def complexity(cx, w_in, w_se):
h, w = cx["h"], cx["w"]
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
cx["h"], cx["w"] = h, w
return cx
class BottleneckTransform(nn.Module):
"""Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
def __init__(self, w_in, w_out, stride, bm, gw, se_r):
super(BottleneckTransform, self).__init__()
w_b = int(round(w_out * bm))
g = w_b // gw
self.a = nn.Conv2d(w_in, w_b, 1, stride=1, padding=0, bias=False)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_b, w_b, 3, stride=stride, padding=1, groups=g, bias=False)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
if se_r:
w_se = int(round(w_in * se_r))
self.se = SE(w_b, w_se)
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm, gw, se_r):
w_b = int(round(w_out * bm))
g = w_b // gw
cx = net.complexity_conv2d(cx, w_in, w_b, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_b, 3, stride, 1, g)
cx = net.complexity_batchnorm2d(cx, w_b)
if se_r:
w_se = int(round(w_in * se_r))
cx = SE.complexity(cx, w_b, w_se)
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBottleneckBlock(nn.Module):
"""Residual bottleneck block: x + F(x), F = bottleneck transform."""
def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
super(ResBottleneckBlock, self).__init__()
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, bm, gw, se_r)
return cx
class ResStemCifar(nn.Module):
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(ResStemCifar, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResStemIN(nn.Module):
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
def __init__(self, w_in, w_out):
super(ResStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
return cx
class SimpleStemIN(nn.Module):
"""Simple stem for ImageNet: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(SimpleStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class AnyStage(nn.Module):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
super(AnyStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
name = "b{}".format(i + 1)
self.add_module(name, block_fun(b_w_in, w_out, b_stride, bm, gw, se_r))
def forward(self, x):
for block in self.children():
x = block(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
cx = block_fun.complexity(cx, b_w_in, w_out, b_stride, bm, gw, se_r)
return cx
class AnyNet(nn.Module):
"""AnyNet model."""
@staticmethod
def get_args():
return {
"stem_type": cfg.ANYNET.STEM_TYPE,
"stem_w": cfg.ANYNET.STEM_W,
"block_type": cfg.ANYNET.BLOCK_TYPE,
"ds": cfg.ANYNET.DEPTHS,
"ws": cfg.ANYNET.WIDTHS,
"ss": cfg.ANYNET.STRIDES,
"bms": cfg.ANYNET.BOT_MULS,
"gws": cfg.ANYNET.GROUP_WS,
"se_r": cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None,
"nc": cfg.MODEL.NUM_CLASSES,
}
def __init__(self, **kwargs):
super(AnyNet, self).__init__()
kwargs = self.get_args() if not kwargs else kwargs
#print(kwargs)
self._construct(**kwargs)
self.apply(net.init_weights)
def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
# Generate dummy bot muls and gs for models that do not use them
bms = bms if bms else [None for _d in ds]
gws = gws if gws else [None for _d in ds]
stage_params = list(zip(ds, ws, ss, bms, gws))
stem_fun = get_stem_fun(stem_type)
self.stem = stem_fun(3, stem_w)
block_fun = get_block_fun(block_type)
prev_w = stem_w
for i, (d, w, s, bm, gw) in enumerate(stage_params):
name = "s{}".format(i + 1)
self.add_module(name, AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r))
prev_w = w
self.head = AnyHead(w_in=prev_w, nc=nc)
def forward(self, x, get_ints=False):
for module in self.children():
x = module(x)
return x
@staticmethod
def complexity(cx, **kwargs):
"""Computes model complexity. If you alter the model, make sure to update."""
kwargs = AnyNet.get_args() if not kwargs else kwargs
return AnyNet._complexity(cx, **kwargs)
@staticmethod
def _complexity(cx, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
bms = bms if bms else [None for _d in ds]
gws = gws if gws else [None for _d in ds]
stage_params = list(zip(ds, ws, ss, bms, gws))
stem_fun = get_stem_fun(stem_type)
cx = stem_fun.complexity(cx, 3, stem_w)
block_fun = get_block_fun(block_type)
prev_w = stem_w
for d, w, s, bm, gw in stage_params:
cx = AnyStage.complexity(cx, prev_w, w, s, d, block_fun, bm, gw, se_r)
prev_w = w
cx = AnyHead.complexity(cx, prev_w, nc)
return cx

108
pycls/models/common.py Normal file
View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from pycls.core.config import cfg
def Preprocess(x):
if cfg.TASK == 'jig':
assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw'
assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw'
x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]])
return x
class Classifier(nn.Module):
def __init__(self, channels, num_classes):
super(Classifier, self).__init__()
if cfg.TASK == 'jig':
self.jig_sq = cfg.JIGSAW_GRID ** 2
self.pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(channels * self.jig_sq, num_classes)
elif cfg.TASK == 'col':
self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)
elif cfg.TASK == 'seg':
self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES)
else:
self.pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(channels, num_classes)
def forward(self, x, shape):
if cfg.TASK == 'jig':
x = self.pooling(x)
x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]])
x = self.classifier(x.view(x.size(0), -1))
elif cfg.TASK in ['col', 'seg']:
x = self.classifier(x)
x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x)
else:
x = self.pooling(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, num_classes, rates):
super(ASPP, self).__init__()
assert len(rates) in [1, 3]
self.rates = rates
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.aspp1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp2 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0],
padding=rates[0], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
if len(self.rates) == 3:
self.aspp3 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1],
padding=rates[1], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp4 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2],
padding=rates[2], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp5 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.classifier = nn.Sequential(
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1,
bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, num_classes, 1)
)
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x5 = self.global_pooling(x)
x5 = self.aspp5(x5)
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
align_corners=True)(x5)
if len(self.rates) == 3:
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x = torch.cat((x1, x2, x3, x4, x5), 1)
else:
x = torch.cat((x1, x2, x5), 1)
x = self.classifier(x)
return x

232
pycls/models/effnet.py Normal file
View File

@@ -0,0 +1,232 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""EfficientNet models."""
import pycls.core.net as net
import torch
import torch.nn as nn
from pycls.core.config import cfg
class EffHead(nn.Module):
"""EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
def __init__(self, w_in, w_out, nc):
super(EffHead, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.conv_swish = Swish()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
if cfg.EN.DROPOUT_RATIO > 0.0:
self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO)
self.fc = nn.Linear(w_out, nc, bias=True)
def forward(self, x):
x = self.conv_swish(self.conv_bn(self.conv(x)))
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x) if hasattr(self, "dropout") else x
x = self.fc(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, nc):
cx = net.complexity_conv2d(cx, w_in, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_out, nc, 1, 1, 0, bias=True)
return cx
class Swish(nn.Module):
"""Swish activation function: x * sigmoid(x)."""
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
def __init__(self, w_in, w_se):
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, 1, bias=True),
Swish(),
nn.Conv2d(w_se, w_in, 1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
@staticmethod
def complexity(cx, w_in, w_se):
h, w = cx["h"], cx["w"]
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
cx["h"], cx["w"] = h, w
return cx
class MBConv(nn.Module):
"""Mobile inverted bottleneck block w/ SE (MBConv)."""
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out):
# expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
super(MBConv, self).__init__()
self.exp = None
w_exp = int(w_in * exp_r)
if w_exp != w_in:
self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.exp_swish = Swish()
dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.dwise_swish = Swish()
self.se = SE(w_exp, int(w_in * se_r))
self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
# Skip connection if in and out shapes are the same (MN-V2 style)
self.has_skip = stride == 1 and w_in == w_out
def forward(self, x):
f_x = x
if self.exp:
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
f_x = self.se(f_x)
f_x = self.lin_proj_bn(self.lin_proj(f_x))
if self.has_skip:
if self.training and cfg.EN.DC_RATIO > 0.0:
f_x = net.drop_connect(f_x, cfg.EN.DC_RATIO)
f_x = x + f_x
return f_x
@staticmethod
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out):
w_exp = int(w_in * exp_r)
if w_exp != w_in:
cx = net.complexity_conv2d(cx, w_in, w_exp, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_exp)
padding = (kernel - 1) // 2
cx = net.complexity_conv2d(cx, w_exp, w_exp, kernel, stride, padding, w_exp)
cx = net.complexity_batchnorm2d(cx, w_exp)
cx = SE.complexity(cx, w_exp, int(w_in * se_r))
cx = net.complexity_conv2d(cx, w_exp, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class EffStage(nn.Module):
"""EfficientNet stage."""
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
super(EffStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
name = "b{}".format(i + 1)
self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out))
def forward(self, x):
for block in self.children():
x = block(x)
return x
@staticmethod
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out, d):
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
cx = MBConv.complexity(cx, b_w_in, exp_r, kernel, b_stride, se_r, w_out)
return cx
class StemIN(nn.Module):
"""EfficientNet stem for ImageNet: 3x3, BN, Swish."""
def __init__(self, w_in, w_out):
super(StemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.swish = Swish()
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class EffNet(nn.Module):
"""EfficientNet model."""
@staticmethod
def get_args():
return {
"stem_w": cfg.EN.STEM_W,
"ds": cfg.EN.DEPTHS,
"ws": cfg.EN.WIDTHS,
"exp_rs": cfg.EN.EXP_RATIOS,
"se_r": cfg.EN.SE_R,
"ss": cfg.EN.STRIDES,
"ks": cfg.EN.KERNELS,
"head_w": cfg.EN.HEAD_W,
"nc": cfg.MODEL.NUM_CLASSES,
}
def __init__(self):
err_str = "Dataset {} is not supported"
assert cfg.TRAIN.DATASET in ["imagenet"], err_str.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in ["imagenet"], err_str.format(cfg.TEST.DATASET)
super(EffNet, self).__init__()
self._construct(**EffNet.get_args())
self.apply(net.init_weights)
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
self.stem = StemIN(3, stem_w)
prev_w = stem_w
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
name = "s{}".format(i + 1)
self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d))
prev_w = w
self.head = EffHead(prev_w, head_w, nc)
def forward(self, x):
for module in self.children():
x = module(x)
return x
@staticmethod
def complexity(cx):
"""Computes model complexity. If you alter the model, make sure to update."""
return EffNet._complexity(cx, **EffNet.get_args())
@staticmethod
def _complexity(cx, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
cx = StemIN.complexity(cx, 3, stem_w)
prev_w = stem_w
for d, w, exp_r, stride, kernel in stage_params:
cx = EffStage.complexity(cx, prev_w, exp_r, kernel, stride, se_r, w, d)
prev_w = w
cx = EffHead.complexity(cx, prev_w, head_w, nc)
return cx

View File

@@ -0,0 +1,634 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS genotypes (adopted from DARTS)."""
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
# NASNet ops
NASNET_OPS = [
'skip_connect',
'conv_3x1_1x3',
'conv_7x1_1x7',
'dil_conv_3x3',
'avg_pool_3x3',
'max_pool_3x3',
'max_pool_5x5',
'max_pool_7x7',
'conv_1x1',
'conv_3x3',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
]
# ENAS ops
ENAS_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'avg_pool_3x3',
'max_pool_3x3',
]
# AmoebaNet ops
AMOEBA_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'avg_pool_3x3',
'max_pool_3x3',
'dil_sep_conv_3x3',
'conv_7x1_1x7',
]
# NAO ops
NAO_OPS = [
'skip_connect',
'conv_1x1',
'conv_3x3',
'conv_3x1_1x3',
'conv_7x1_1x7',
'max_pool_2x2',
'max_pool_3x3',
'max_pool_5x5',
'avg_pool_2x2',
'avg_pool_3x3',
'avg_pool_5x5',
]
# PNAS ops
PNAS_OPS = [
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'conv_7x1_1x7',
'skip_connect',
'avg_pool_3x3',
'max_pool_3x3',
'dil_conv_3x3',
]
# DARTS ops
DARTS_OPS = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
]
NASNet = Genotype(
normal=[
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 0),
('avg_pool_3x3', 0),
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('skip_connect', 1),
],
normal_concat=[2, 3, 4, 5, 6],
reduce=[
('sep_conv_5x5', 1),
('sep_conv_7x7', 0),
('max_pool_3x3', 1),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('skip_connect', 3),
('avg_pool_3x3', 2),
('sep_conv_3x3', 2),
('max_pool_3x3', 1),
],
reduce_concat=[4, 5, 6],
)
PNASNet = Genotype(
normal=[
('sep_conv_5x5', 0),
('max_pool_3x3', 0),
('sep_conv_7x7', 1),
('max_pool_3x3', 1),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 4),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 1),
],
normal_concat=[2, 3, 4, 5, 6],
reduce=[
('sep_conv_5x5', 0),
('max_pool_3x3', 0),
('sep_conv_7x7', 1),
('max_pool_3x3', 1),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 4),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 1),
],
reduce_concat=[2, 3, 4, 5, 6],
)
AmoebaNet = Genotype(
normal=[
('avg_pool_3x3', 0),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('avg_pool_3x3', 3),
('sep_conv_3x3', 1),
('skip_connect', 1),
('skip_connect', 0),
('avg_pool_3x3', 1),
],
normal_concat=[4, 5, 6],
reduce=[
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_7x7', 2),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('conv_7x1_1x7', 0),
('sep_conv_3x3', 5),
],
reduce_concat=[3, 4, 6]
)
DARTS_V1 = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 0),
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 2)
],
normal_concat=[2, 3, 4, 5],
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('skip_connect', 2),
('max_pool_3x3', 0),
('max_pool_3x3', 0),
('skip_connect', 2),
('skip_connect', 2),
('avg_pool_3x3', 0)
],
reduce_concat=[2, 3, 4, 5]
)
DARTS_V2 = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('skip_connect', 0),
('skip_connect', 0),
('dil_conv_3x3', 2)
],
normal_concat=[2, 3, 4, 5],
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('skip_connect', 2),
('max_pool_3x3', 1),
('max_pool_3x3', 0),
('skip_connect', 2),
('skip_connect', 2),
('max_pool_3x3', 1)
],
reduce_concat=[2, 3, 4, 5]
)
PDARTS = Genotype(
normal=[
('skip_connect', 0),
('dil_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('dil_conv_5x5', 4)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('dil_conv_3x3', 1),
('dil_conv_3x3', 1),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
PCDARTS_C10 = Genotype(
normal=[
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 0),
('dil_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 1),
('avg_pool_3x3', 0),
('dil_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 1),
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2)
],
reduce_concat=range(2, 6)
)
PCDARTS_IN1K = Genotype(
normal=[
('skip_connect', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('skip_connect', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('dil_conv_5x5', 4)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 2),
('max_pool_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_CLS = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('skip_connect', 1),
('max_pool_3x3', 0),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_ROT = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('sep_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_COL = Genotype(
normal=[
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_5x5', 3),
('max_pool_3x3', 0),
('sep_conv_3x3', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_JIG = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 1)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_CLS = Genotype(
normal=[
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('dil_conv_5x5', 3),
('dil_conv_5x5', 2),
('dil_conv_5x5', 4),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_ROT = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('dil_conv_5x5', 2),
('sep_conv_5x5', 0),
('dil_conv_5x5', 3),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('sep_conv_3x3', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_COL = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 2),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('sep_conv_3x3', 4),
('sep_conv_5x5', 1)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_JIG = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 4)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 0),
('skip_connect', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 0),
('sep_conv_5x5', 3),
('sep_conv_5x5', 0),
('sep_conv_5x5', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_SEG = Genotype(
normal=[
('skip_connect', 0),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 0),
('sep_conv_3x3', 4),
('sep_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_ROT = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_5x5', 2),
('sep_conv_5x5', 1),
('sep_conv_5x5', 3),
('dil_conv_5x5', 2),
('sep_conv_5x5', 2),
('sep_conv_5x5', 0)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_COL = Genotype(
normal=[
('dil_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 0),
('sep_conv_5x5', 2),
('dil_conv_3x3', 3),
('skip_connect', 0),
('skip_connect', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_JIG = Genotype(
normal=[
('dil_conv_5x5', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 0),
('dil_conv_5x5', 1)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 1),
('dil_conv_5x5', 2),
('dil_conv_5x5', 2),
('dil_conv_5x5', 0),
('dil_conv_5x5', 3),
('dil_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
# Supported genotypes
GENOTYPES = {
'nas': NASNet,
'pnas': PNASNet,
'amoeba': AmoebaNet,
'darts_v1': DARTS_V1,
'darts_v2': DARTS_V2,
'pdarts': PDARTS,
'pcdarts_c10': PCDARTS_C10,
'pcdarts_in1k': PCDARTS_IN1K,
'unnas_imagenet_cls': UNNAS_IMAGENET_CLS,
'unnas_imagenet_rot': UNNAS_IMAGENET_ROT,
'unnas_imagenet_col': UNNAS_IMAGENET_COL,
'unnas_imagenet_jig': UNNAS_IMAGENET_JIG,
'unnas_imagenet22k_cls': UNNAS_IMAGENET22K_CLS,
'unnas_imagenet22k_rot': UNNAS_IMAGENET22K_ROT,
'unnas_imagenet22k_col': UNNAS_IMAGENET22K_COL,
'unnas_imagenet22k_jig': UNNAS_IMAGENET22K_JIG,
'unnas_cityscapes_seg': UNNAS_CITYSCAPES_SEG,
'unnas_cityscapes_rot': UNNAS_CITYSCAPES_ROT,
'unnas_cityscapes_col': UNNAS_CITYSCAPES_COL,
'unnas_cityscapes_jig': UNNAS_CITYSCAPES_JIG,
'custom': None,
}

299
pycls/models/nas/nas.py Normal file
View File

@@ -0,0 +1,299 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS network (adopted from DARTS)."""
from torch.autograd import Variable
import torch
import torch.nn as nn
import pycls.core.logging as logging
from pycls.core.config import cfg
from pycls.models.common import Preprocess
from pycls.models.common import Classifier
from pycls.models.nas.genotypes import GENOTYPES
from pycls.models.nas.genotypes import Genotype
from pycls.models.nas.operations import FactorizedReduce
from pycls.models.nas.operations import OPS
from pycls.models.nas.operations import ReLUConvBN
from pycls.models.nas.operations import Identity
logger = logging.get_logger(__name__)
def drop_path(x, drop_prob):
"""Drop path (ported from DARTS)."""
if drop_prob > 0.:
keep_prob = 1.-drop_prob
mask = Variable(
torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
)
x.div_(keep_prob)
x.mul_(mask)
return x
class Cell(nn.Module):
"""NAS cell (ported from DARTS)."""
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C))
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
if reduction:
op_names, indices = zip(*genotype.reduce)
concat = genotype.reduce_concat
else:
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
for name, index in zip(op_names, indices):
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True)
self._ops += [op]
self._indices = indices
def forward(self, s0, s1, drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2*i]]
h2 = states[self._indices[2*i+1]]
op1 = self._ops[2*i]
op2 = self._ops[2*i+1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
states += [s]
return torch.cat([states[i] for i in self._concat], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class AuxiliaryHeadImageNet(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 14x14"""
super(AuxiliaryHeadImageNet, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
# Commenting it out for consistency with the experiments in the paper.
# nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class NetworkCIFAR(nn.Module):
"""CIFAR network (ported from DARTS)."""
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkCIFAR, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
stem_multiplier = 3
C_curr = stem_multiplier*C
self.stem = nn.Sequential(
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
if i == 2*layers//3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
self.classifier = Classifier(C_prev, num_classes)
def forward(self, input):
input = Preprocess(input)
logits_aux = None
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2*self._layers//3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1, input.shape[2:])
if self._auxiliary and self.training:
return logits, logits_aux
return logits
class NetworkImageNet(nn.Module):
"""ImageNet network (ported from DARTS)."""
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkImageNet, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
self.stem0 = nn.Sequential(
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_prev_prev, C_prev, C_curr = C, C, C
self.cells = nn.ModuleList()
reduction_prev = True
reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3]
for i in range(layers):
if i in reduction_layers:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if i == 2 * layers // 3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
self.classifier = Classifier(C_prev, num_classes)
def forward(self, input):
input = Preprocess(input)
logits_aux = None
s0 = self.stem0(input)
s1 = self.stem1(s0)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2 * self._layers // 3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1, input.shape[2:])
if self._auxiliary and self.training:
return logits, logits_aux
return logits
class NAS(nn.Module):
"""NAS net wrapper (delegates to nets from DARTS)."""
def __init__(self):
assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
'Training on {} is not supported'.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
'Testing on {} is not supported'.format(cfg.TEST.DATASET)
assert cfg.NAS.GENOTYPE in GENOTYPES, \
'Genotype {} not supported'.format(cfg.NAS.GENOTYPE)
super(NAS, self).__init__()
logger.info('Constructing NAS: {}'.format(cfg.NAS))
# Use a custom or predefined genotype
if cfg.NAS.GENOTYPE == 'custom':
genotype = Genotype(
normal=cfg.NAS.CUSTOM_GENOTYPE[0],
normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1],
reduce=cfg.NAS.CUSTOM_GENOTYPE[2],
reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3],
)
else:
genotype = GENOTYPES[cfg.NAS.GENOTYPE]
# Determine the network constructor for dataset
if 'cifar' in cfg.TRAIN.DATASET:
net_ctor = NetworkCIFAR
else:
net_ctor = NetworkImageNet
# Construct the network
self.net_ = net_ctor(
C=cfg.NAS.WIDTH,
num_classes=cfg.MODEL.NUM_CLASSES,
layers=cfg.NAS.DEPTH,
auxiliary=cfg.NAS.AUX,
genotype=genotype
)
# Drop path probability (set / annealed based on epoch)
self.net_.drop_path_prob = 0.0
def set_drop_path_prob(self, drop_path_prob):
self.net_.drop_path_prob = drop_path_prob
def forward(self, x):
return self.net_.forward(x)

View File

@@ -0,0 +1,201 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS ops (adopted from DARTS)."""
import torch
import torch.nn as nn
OPS = {
'none': lambda C, stride, affine:
Zero(stride),
'avg_pool_2x2': lambda C, stride, affine:
nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False),
'avg_pool_3x3': lambda C, stride, affine:
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
'avg_pool_5x5': lambda C, stride, affine:
nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False),
'max_pool_2x2': lambda C, stride, affine:
nn.MaxPool2d(2, stride=stride, padding=0),
'max_pool_3x3': lambda C, stride, affine:
nn.MaxPool2d(3, stride=stride, padding=1),
'max_pool_5x5': lambda C, stride, affine:
nn.MaxPool2d(5, stride=stride, padding=2),
'max_pool_7x7': lambda C, stride, affine:
nn.MaxPool2d(7, stride=stride, padding=3),
'skip_connect': lambda C, stride, affine:
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'conv_1x1': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_3x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'sep_conv_3x3': lambda C, stride, affine:
SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine:
SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine:
SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine:
DilConv(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5': lambda C, stride, affine:
DilConv(C, C, 5, stride, 4, 2, affine=affine),
'dil_sep_conv_3x3': lambda C, stride, affine:
DilSepConv(C, C, 3, stride, 2, 2, affine=affine),
'conv_3x1_1x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
nn.Conv2d(C, C, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_7x1_1x7': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
}
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_out, kernel_size, stride=stride,
padding=padding, bias=False
),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.op(x)
class DilConv(nn.Module):
def __init__(
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
):
super(DilConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class DilSepConv(nn.Module):
def __init__(
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
):
super(DilSepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:,:,::self.stride,::self.stride].mul(0.)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
out = self.bn(out)
return out

89
pycls/models/regnet.py Normal file
View File

@@ -0,0 +1,89 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""RegNet models."""
import numpy as np
from pycls.core.config import cfg
from pycls.models.anynet import AnyNet
def quantize_float(f, q):
"""Converts a float to closest non-zero int divisible by q."""
return int(round(f / q) * q)
def adjust_ws_gs_comp(ws, bms, gs):
"""Adjusts the compatibility of widths and groups."""
ws_bot = [int(w * b) for w, b in zip(ws, bms)]
gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
return ws, gs
def get_stages_from_blocks(ws, rs):
"""Gets ws/ds of network at each stage from per block values."""
ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
return s_ws, s_ds
def generate_regnet(w_a, w_0, w_m, d, q=8):
"""Generates per block ws from RegNet parameters."""
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
ws_cont = np.arange(d) * w_a + w_0
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
ws = w_0 * np.power(w_m, ks)
ws = np.round(np.divide(ws, q)) * q
num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
return ws, num_stages, max_stage, ws_cont
class RegNet(AnyNet):
"""RegNet model."""
@staticmethod
def get_args():
"""Convert RegNet to AnyNet parameter format."""
# Generate RegNet ws per block
w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH
ws, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
# Convert to per stage format
s_ws, s_ds = get_stages_from_blocks(ws, ws)
# Use the same gw, bm and ss for each stage
s_gs = [cfg.REGNET.GROUP_W for _ in range(num_stages)]
s_bs = [cfg.REGNET.BOT_MUL for _ in range(num_stages)]
s_ss = [cfg.REGNET.STRIDE for _ in range(num_stages)]
# Adjust the compatibility of ws and gws
s_ws, s_gs = adjust_ws_gs_comp(s_ws, s_bs, s_gs)
# Get AnyNet arguments defining the RegNet
return {
"stem_type": cfg.REGNET.STEM_TYPE,
"stem_w": cfg.REGNET.STEM_W,
"block_type": cfg.REGNET.BLOCK_TYPE,
"ds": s_ds,
"ws": s_ws,
"ss": s_ss,
"bms": s_bs,
"gws": s_gs,
"se_r": cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None,
"nc": cfg.MODEL.NUM_CLASSES,
}
def __init__(self):
kwargs = RegNet.get_args()
super(RegNet, self).__init__(**kwargs)
@staticmethod
def complexity(cx, **kwargs):
"""Computes model complexity. If you alter the model, make sure to update."""
kwargs = RegNet.get_args() if not kwargs else kwargs
return AnyNet.complexity(cx, **kwargs)

280
pycls/models/resnet.py Normal file
View File

@@ -0,0 +1,280 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""ResNe(X)t models."""
import pycls.core.net as net
import torch.nn as nn
from pycls.core.config import cfg
# Stage depths for ImageNet models
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
def get_trans_fun(name):
"""Retrieves the transformation function by name."""
trans_funs = {
"basic_transform": BasicTransform,
"bottleneck_transform": BottleneckTransform,
}
err_str = "Transformation function '{}' not supported"
assert name in trans_funs.keys(), err_str.format(name)
return trans_funs[name]
class ResHead(nn.Module):
"""ResNet head: AvgPool, 1x1."""
def __init__(self, w_in, nc):
super(ResHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(w_in, nc, bias=True)
def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
@staticmethod
def complexity(cx, w_in, nc):
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
return cx
class BasicTransform(nn.Module):
"""Basic transformation: 3x3, BN, ReLU, 3x3, BN."""
def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1):
err_str = "Basic transform does not support w_b and num_gs options"
assert w_b is None and num_gs == 1, err_str
super(BasicTransform, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, w_b=None, num_gs=1):
err_str = "Basic transform does not support w_b and num_gs options"
assert w_b is None and num_gs == 1, err_str
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class BottleneckTransform(nn.Module):
"""Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN."""
def __init__(self, w_in, w_out, stride, w_b, num_gs):
super(BottleneckTransform, self).__init__()
# MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, w_b, num_gs):
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
cx = net.complexity_conv2d(cx, w_in, w_b, 1, s1, 0)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_b, 3, s3, 1, num_gs)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBlock(nn.Module):
"""Residual block: x + F(x)."""
def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
super(ResBlock, self).__init__()
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = trans_fun(w_in, w_out, stride, w_b, num_gs)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, trans_fun, w_b, num_gs):
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, num_gs)
return cx
class ResStage(nn.Module):
"""Stage of ResNet."""
def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
super(ResStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN)
res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs)
self.add_module("b{}".format(i + 1), res_block)
def forward(self, x):
for block in self.children():
x = block(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, d, w_b=None, num_gs=1):
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
trans_f = get_trans_fun(cfg.RESNET.TRANS_FUN)
cx = ResBlock.complexity(cx, b_w_in, w_out, b_stride, trans_f, w_b, num_gs)
return cx
class ResStemCifar(nn.Module):
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(ResStemCifar, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResStemIN(nn.Module):
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
def __init__(self, w_in, w_out):
super(ResStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
return cx
class ResNet(nn.Module):
"""ResNet model."""
def __init__(self):
datasets = ["cifar10", "imagenet"]
err_str = "Dataset {} is not supported"
assert cfg.TRAIN.DATASET in datasets, err_str.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in datasets, err_str.format(cfg.TEST.DATASET)
super(ResNet, self).__init__()
if "cifar" in cfg.TRAIN.DATASET:
self._construct_cifar()
else:
self._construct_imagenet()
self.apply(net.init_weights)
def _construct_cifar(self):
err_str = "Model depth should be of the format 6n + 2 for cifar"
assert (cfg.MODEL.DEPTH - 2) % 6 == 0, err_str
d = int((cfg.MODEL.DEPTH - 2) / 6)
self.stem = ResStemCifar(3, 16)
self.s1 = ResStage(16, 16, stride=1, d=d)
self.s2 = ResStage(16, 32, stride=2, d=d)
self.s3 = ResStage(32, 64, stride=2, d=d)
self.head = ResHead(64, nc=cfg.MODEL.NUM_CLASSES)
def _construct_imagenet(self):
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
w_b = gw * g
self.stem = ResStemIN(3, 64)
self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g)
self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g)
self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g)
self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g)
self.head = ResHead(2048, nc=cfg.MODEL.NUM_CLASSES)
def forward(self, x):
for module in self.children():
x = module(x)
return x
@staticmethod
def complexity(cx):
"""Computes model complexity. If you alter the model, make sure to update."""
if "cifar" in cfg.TRAIN.DATASET:
d = int((cfg.MODEL.DEPTH - 2) / 6)
cx = ResStemCifar.complexity(cx, 3, 16)
cx = ResStage.complexity(cx, 16, 16, stride=1, d=d)
cx = ResStage.complexity(cx, 16, 32, stride=2, d=d)
cx = ResStage.complexity(cx, 32, 64, stride=2, d=d)
cx = ResHead.complexity(cx, 64, nc=cfg.MODEL.NUM_CLASSES)
else:
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
w_b = gw * g
cx = ResStemIN.complexity(cx, 3, 64)
cx = ResStage.complexity(cx, 64, 256, 1, d=d1, w_b=w_b, num_gs=g)
cx = ResStage.complexity(cx, 256, 512, 2, d=d2, w_b=w_b * 2, num_gs=g)
cx = ResStage.complexity(cx, 512, 1024, 2, d=d3, w_b=w_b * 4, num_gs=g)
cx = ResStage.complexity(cx, 1024, 2048, 2, d=d4, w_b=w_b * 8, num_gs=g)
cx = ResHead.complexity(cx, 2048, nc=cfg.MODEL.NUM_CLASSES)
return cx