v2
This commit is contained in:
136
pycls/core/benchmark.py
Normal file
136
pycls/core/benchmark.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Benchmarking functions."""
|
||||
|
||||
import pycls.core.logging as logging
|
||||
import pycls.datasets.loader as loader
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
from pycls.core.timer import Timer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_time_eval(model):
|
||||
"""Computes precise model forward test time using dummy data."""
|
||||
# Use eval mode
|
||||
model.eval()
|
||||
# Generate a dummy mini-batch and copy data to GPU
|
||||
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
|
||||
if cfg.TASK == "jig":
|
||||
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
else:
|
||||
inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
# Compute precise forward pass time
|
||||
timer = Timer()
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
for cur_iter in range(total_iter):
|
||||
# Reset the timers after the warmup phase
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
timer.reset()
|
||||
# Forward
|
||||
timer.tic()
|
||||
model(inputs)
|
||||
torch.cuda.synchronize()
|
||||
timer.toc()
|
||||
return timer.average_time
|
||||
|
||||
|
||||
def compute_time_train(model, loss_fun):
|
||||
"""Computes precise model forward + backward time using dummy data."""
|
||||
# Use train mode
|
||||
model.train()
|
||||
# Generate a dummy mini-batch and copy data to GPU
|
||||
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
|
||||
if cfg.TASK == "jig":
|
||||
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
else:
|
||||
inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
if cfg.TASK in ['col', 'seg']:
|
||||
labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False)
|
||||
else:
|
||||
labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
|
||||
# Cache BatchNorm2D running stats
|
||||
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
||||
bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
|
||||
# Compute precise forward backward pass time
|
||||
fw_timer, bw_timer = Timer(), Timer()
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
for cur_iter in range(total_iter):
|
||||
# Reset the timers after the warmup phase
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
fw_timer.reset()
|
||||
bw_timer.reset()
|
||||
# Forward
|
||||
fw_timer.tic()
|
||||
preds = model(inputs)
|
||||
if isinstance(preds, tuple):
|
||||
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
|
||||
preds = preds[0]
|
||||
else:
|
||||
loss = loss_fun(preds, labels)
|
||||
torch.cuda.synchronize()
|
||||
fw_timer.toc()
|
||||
# Backward
|
||||
bw_timer.tic()
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
bw_timer.toc()
|
||||
# Restore BatchNorm2D running stats
|
||||
for bn, (mean, var) in zip(bns, bn_stats):
|
||||
bn.running_mean, bn.running_var = mean, var
|
||||
return fw_timer.average_time, bw_timer.average_time
|
||||
|
||||
|
||||
def compute_time_loader(data_loader):
|
||||
"""Computes loader time."""
|
||||
timer = Timer()
|
||||
loader.shuffle(data_loader, 0)
|
||||
data_loader_iterator = iter(data_loader)
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
total_iter = min(total_iter, len(data_loader))
|
||||
for cur_iter in range(total_iter):
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
timer.reset()
|
||||
timer.tic()
|
||||
next(data_loader_iterator)
|
||||
timer.toc()
|
||||
return timer.average_time
|
||||
|
||||
|
||||
def compute_time_full(model, loss_fun, train_loader, test_loader):
|
||||
"""Times model and data loader."""
|
||||
logger.info("Computing model and loader timings...")
|
||||
# Compute timings
|
||||
test_fw_time = compute_time_eval(model)
|
||||
train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
|
||||
train_fw_bw_time = train_fw_time + train_bw_time
|
||||
train_loader_time = compute_time_loader(train_loader)
|
||||
# Output iter timing
|
||||
iter_times = {
|
||||
"test_fw_time": test_fw_time,
|
||||
"train_fw_time": train_fw_time,
|
||||
"train_bw_time": train_bw_time,
|
||||
"train_fw_bw_time": train_fw_bw_time,
|
||||
"train_loader_time": train_loader_time,
|
||||
}
|
||||
logger.info(logging.dump_log_data(iter_times, "iter_times"))
|
||||
# Output epoch timing
|
||||
epoch_times = {
|
||||
"test_fw_time": test_fw_time * len(test_loader),
|
||||
"train_fw_time": train_fw_time * len(train_loader),
|
||||
"train_bw_time": train_bw_time * len(train_loader),
|
||||
"train_fw_bw_time": train_fw_bw_time * len(train_loader),
|
||||
"train_loader_time": train_loader_time * len(train_loader),
|
||||
}
|
||||
logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
|
||||
# Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
|
||||
overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
|
||||
logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))
|
Reference in New Issue
Block a user