add naswot
This commit is contained in:
419
graph_dit/naswot/pycls/core/trainer.py
Normal file
419
graph_dit/naswot/pycls/core/trainer.py
Normal file
@@ -0,0 +1,419 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Tools for training and testing a model."""
|
||||
|
||||
import os
|
||||
from thop import profile
|
||||
|
||||
import numpy as np
|
||||
import pycls.core.benchmark as benchmark
|
||||
import pycls.core.builders as builders
|
||||
import pycls.core.checkpoint as checkpoint
|
||||
import pycls.core.config as config
|
||||
import pycls.core.distributed as dist
|
||||
import pycls.core.logging as logging
|
||||
import pycls.core.meters as meters
|
||||
import pycls.core.net as net
|
||||
import pycls.core.optimizer as optim
|
||||
import pycls.datasets.loader as loader
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def setup_env():
|
||||
"""Sets up environment for training or testing."""
|
||||
if dist.is_master_proc():
|
||||
# Ensure that the output dir exists
|
||||
os.makedirs(cfg.OUT_DIR, exist_ok=True)
|
||||
# Save the config
|
||||
config.dump_cfg()
|
||||
# Setup logging
|
||||
logging.setup_logging()
|
||||
# Log the config as both human readable and as a json
|
||||
logger.info("Config:\n{}".format(cfg))
|
||||
logger.info(logging.dump_log_data(cfg, "cfg"))
|
||||
# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
|
||||
np.random.seed(cfg.RNG_SEED)
|
||||
torch.manual_seed(cfg.RNG_SEED)
|
||||
# Configure the CUDNN backend
|
||||
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
|
||||
|
||||
|
||||
def setup_model():
|
||||
"""Sets up a model for training or testing and log the results."""
|
||||
# Build the model
|
||||
model = builders.build_model()
|
||||
logger.info("Model:\n{}".format(model))
|
||||
# Log model complexity
|
||||
# logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
|
||||
if cfg.TASK == "seg" and cfg.TRAIN.DATASET == "cityscapes":
|
||||
h, w = 1025, 2049
|
||||
else:
|
||||
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
|
||||
if cfg.TASK == "jig":
|
||||
x = torch.randn(1, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, h, w)
|
||||
else:
|
||||
x = torch.randn(1, cfg.MODEL.INPUT_CHANNELS, h, w)
|
||||
macs, params = profile(model, inputs=(x, ), verbose=False)
|
||||
logger.info("Params: {:,}".format(params))
|
||||
logger.info("Flops: {:,}".format(macs))
|
||||
# Transfer the model to the current GPU device
|
||||
err_str = "Cannot use more GPU devices than available"
|
||||
assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str
|
||||
cur_device = torch.cuda.current_device()
|
||||
model = model.cuda(device=cur_device)
|
||||
# Use multi-process data parallel model in the multi-gpu setting
|
||||
if cfg.NUM_GPUS > 1:
|
||||
# Make model replica operate on the current device
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
module=model, device_ids=[cur_device], output_device=cur_device
|
||||
)
|
||||
# Set complexity function to be module's complexity function
|
||||
# model.complexity = model.module.complexity
|
||||
return model
|
||||
|
||||
|
||||
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
|
||||
"""Performs one epoch of training."""
|
||||
# Update drop path prob for NAS
|
||||
if cfg.MODEL.TYPE == "nas":
|
||||
m = model.module if cfg.NUM_GPUS > 1 else model
|
||||
m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH)
|
||||
# Shuffle the data
|
||||
loader.shuffle(train_loader, cur_epoch)
|
||||
# Update the learning rate per epoch
|
||||
if not cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch)
|
||||
optim.set_lr(optimizer, lr)
|
||||
# Enable training mode
|
||||
model.train()
|
||||
train_meter.iter_tic()
|
||||
for cur_iter, (inputs, labels) in enumerate(train_loader):
|
||||
# Update the learning rate per iter
|
||||
if cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader))
|
||||
optim.set_lr(optimizer, lr)
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Perform the forward pass
|
||||
preds = model(inputs)
|
||||
# Compute the loss
|
||||
if isinstance(preds, tuple):
|
||||
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
|
||||
preds = preds[0]
|
||||
else:
|
||||
loss = loss_fun(preds, labels)
|
||||
# Perform the backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Update the parameters
|
||||
optimizer.step()
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the stats across the GPUs (no reduction if 1 GPU used)
|
||||
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
|
||||
# Copy the stats from GPU to CPU (sync point)
|
||||
loss = loss.item()
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
train_meter.iter_toc()
|
||||
# Update and log stats
|
||||
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
|
||||
train_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
train_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
train_meter.log_epoch_stats(cur_epoch)
|
||||
train_meter.reset()
|
||||
|
||||
|
||||
def search_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
|
||||
"""Performs one epoch of differentiable architecture search."""
|
||||
m = model.module if cfg.NUM_GPUS > 1 else model
|
||||
# Shuffle the data
|
||||
loader.shuffle(train_loader[0], cur_epoch)
|
||||
loader.shuffle(train_loader[1], cur_epoch)
|
||||
# Update the learning rate per epoch
|
||||
if not cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch)
|
||||
optim.set_lr(optimizer[0], lr)
|
||||
# Enable training mode
|
||||
model.train()
|
||||
train_meter.iter_tic()
|
||||
trainB_iter = iter(train_loader[1])
|
||||
for cur_iter, (inputs, labels) in enumerate(train_loader[0]):
|
||||
# Update the learning rate per iter
|
||||
if cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader[0]))
|
||||
optim.set_lr(optimizer[0], lr)
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Update architecture
|
||||
if cur_epoch + cur_iter / len(train_loader[0]) >= cfg.OPTIM.ARCH_EPOCH:
|
||||
try:
|
||||
inputsB, labelsB = next(trainB_iter)
|
||||
except StopIteration:
|
||||
trainB_iter = iter(train_loader[1])
|
||||
inputsB, labelsB = next(trainB_iter)
|
||||
inputsB, labelsB = inputsB.cuda(), labelsB.cuda(non_blocking=True)
|
||||
optimizer[1].zero_grad()
|
||||
loss = m._loss(inputsB, labelsB)
|
||||
loss.backward()
|
||||
optimizer[1].step()
|
||||
# Perform the forward pass
|
||||
preds = model(inputs)
|
||||
# Compute the loss
|
||||
loss = loss_fun(preds, labels)
|
||||
# Perform the backward pass
|
||||
optimizer[0].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
|
||||
# Update the parameters
|
||||
optimizer[0].step()
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the stats across the GPUs (no reduction if 1 GPU used)
|
||||
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
|
||||
# Copy the stats from GPU to CPU (sync point)
|
||||
loss = loss.item()
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
train_meter.iter_toc()
|
||||
# Update and log stats
|
||||
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
|
||||
train_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
train_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
train_meter.log_epoch_stats(cur_epoch)
|
||||
train_meter.reset()
|
||||
# Log genotype
|
||||
genotype = m.genotype()
|
||||
logger.info("genotype = %s", genotype)
|
||||
logger.info(F.softmax(m.net_.alphas_normal, dim=-1))
|
||||
logger.info(F.softmax(m.net_.alphas_reduce, dim=-1))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_epoch(test_loader, model, test_meter, cur_epoch):
|
||||
"""Evaluates the model on the test set."""
|
||||
# Enable eval mode
|
||||
model.eval()
|
||||
test_meter.iter_tic()
|
||||
for cur_iter, (inputs, labels) in enumerate(test_loader):
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Compute the predictions
|
||||
preds = model(inputs)
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the errors across the GPUs (no reduction if 1 GPU used)
|
||||
top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
|
||||
# Copy the errors from GPU to CPU (sync point)
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
test_meter.iter_toc()
|
||||
# Update and log stats
|
||||
test_meter.update_stats(top1_err, top5_err, mb_size)
|
||||
test_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
test_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
test_meter.log_epoch_stats(cur_epoch)
|
||||
test_meter.reset()
|
||||
|
||||
|
||||
def train_model():
|
||||
"""Trains the model."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model, loss_fun, and optimizer
|
||||
model = setup_model()
|
||||
loss_fun = builders.build_loss_fun().cuda()
|
||||
if "search" in cfg.MODEL.TYPE:
|
||||
params_w = [v for k, v in model.named_parameters() if "alphas" not in k]
|
||||
params_a = [v for k, v in model.named_parameters() if "alphas" in k]
|
||||
optimizer_w = torch.optim.SGD(
|
||||
params=params_w,
|
||||
lr=cfg.OPTIM.BASE_LR,
|
||||
momentum=cfg.OPTIM.MOMENTUM,
|
||||
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
|
||||
dampening=cfg.OPTIM.DAMPENING,
|
||||
nesterov=cfg.OPTIM.NESTEROV
|
||||
)
|
||||
if cfg.OPTIM.ARCH_OPTIM == "adam":
|
||||
optimizer_a = torch.optim.Adam(
|
||||
params=params_a,
|
||||
lr=cfg.OPTIM.ARCH_BASE_LR,
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY
|
||||
)
|
||||
elif cfg.OPTIM.ARCH_OPTIM == "sgd":
|
||||
optimizer_a = torch.optim.SGD(
|
||||
params=params_a,
|
||||
lr=cfg.OPTIM.ARCH_BASE_LR,
|
||||
momentum=cfg.OPTIM.MOMENTUM,
|
||||
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY,
|
||||
dampening=cfg.OPTIM.DAMPENING,
|
||||
nesterov=cfg.OPTIM.NESTEROV
|
||||
)
|
||||
optimizer = [optimizer_w, optimizer_a]
|
||||
else:
|
||||
optimizer = optim.construct_optimizer(model)
|
||||
# Load checkpoint or initial weights
|
||||
start_epoch = 0
|
||||
if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
|
||||
last_checkpoint = checkpoint.get_last_checkpoint()
|
||||
checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer)
|
||||
logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
|
||||
start_epoch = checkpoint_epoch + 1
|
||||
elif cfg.TRAIN.WEIGHTS:
|
||||
checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
|
||||
logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
|
||||
# Create data loaders and meters
|
||||
if cfg.TRAIN.PORTION < 1:
|
||||
if "search" in cfg.MODEL.TYPE:
|
||||
train_loader = [loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="l"
|
||||
),
|
||||
loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="r"
|
||||
)]
|
||||
else:
|
||||
train_loader = loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="l"
|
||||
)
|
||||
test_loader = loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="r"
|
||||
)
|
||||
else:
|
||||
train_loader = loader.construct_train_loader()
|
||||
test_loader = loader.construct_test_loader()
|
||||
train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter
|
||||
test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter
|
||||
l = train_loader[0] if isinstance(train_loader, list) else train_loader
|
||||
train_meter = train_meter_type(len(l))
|
||||
test_meter = test_meter_type(len(test_loader))
|
||||
# Compute model and loader timings
|
||||
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
|
||||
l = train_loader[0] if isinstance(train_loader, list) else train_loader
|
||||
benchmark.compute_time_full(model, loss_fun, l, test_loader)
|
||||
# Perform the training loop
|
||||
logger.info("Start epoch: {}".format(start_epoch + 1))
|
||||
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
|
||||
# Train for one epoch
|
||||
f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch
|
||||
f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
|
||||
# Compute precise BN stats
|
||||
if cfg.BN.USE_PRECISE_STATS:
|
||||
net.compute_precise_bn_stats(model, train_loader)
|
||||
# Save a checkpoint
|
||||
if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
|
||||
checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch)
|
||||
logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
|
||||
# Evaluate the model
|
||||
next_epoch = cur_epoch + 1
|
||||
if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
|
||||
test_epoch(test_loader, model, test_meter, cur_epoch)
|
||||
|
||||
|
||||
def test_model():
|
||||
"""Evaluates a trained model."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model
|
||||
model = setup_model()
|
||||
# Load model weights
|
||||
checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
|
||||
logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
|
||||
# Create data loaders and meters
|
||||
test_loader = loader.construct_test_loader()
|
||||
test_meter = meters.TestMeter(len(test_loader))
|
||||
# Evaluate the model
|
||||
test_epoch(test_loader, model, test_meter, 0)
|
||||
|
||||
|
||||
def time_model():
|
||||
"""Times model and data loader."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model and loss_fun
|
||||
model = setup_model()
|
||||
loss_fun = builders.build_loss_fun().cuda()
|
||||
# Create data loaders
|
||||
train_loader = loader.construct_train_loader()
|
||||
test_loader = loader.construct_test_loader()
|
||||
# Compute model and loader timings
|
||||
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
|
Reference in New Issue
Block a user