#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Tools for training and testing a model."""

import os
from thop import profile

import numpy as np
import pycls.core.benchmark as benchmark
import pycls.core.builders as builders
import pycls.core.checkpoint as checkpoint
import pycls.core.config as config
import pycls.core.distributed as dist
import pycls.core.logging as logging
import pycls.core.meters as meters
import pycls.core.net as net
import pycls.core.optimizer as optim
import pycls.datasets.loader as loader
import torch
import torch.nn.functional as F
from pycls.core.config import cfg


logger = logging.get_logger(__name__)


def setup_env():
    """Sets up environment for training or testing."""
    if dist.is_master_proc():
        # Ensure that the output dir exists
        os.makedirs(cfg.OUT_DIR, exist_ok=True)
        # Save the config
        config.dump_cfg()
    # Setup logging
    logging.setup_logging()
    # Log the config as both human readable and as a json
    logger.info("Config:\n{}".format(cfg))
    logger.info(logging.dump_log_data(cfg, "cfg"))
    # Fix the RNG seeds (see RNG comment in core/config.py for discussion)
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)
    # Configure the CUDNN backend
    torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK


def setup_model():
    """Sets up a model for training or testing and log the results."""
    # Build the model
    model = builders.build_model()
    logger.info("Model:\n{}".format(model))
    # Log model complexity
    # logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
    if cfg.TASK == "seg" and cfg.TRAIN.DATASET == "cityscapes":
        h, w = 1025, 2049
    else:
        h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
    if cfg.TASK == "jig":
        x = torch.randn(1, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, h, w)
    else:
        x = torch.randn(1, cfg.MODEL.INPUT_CHANNELS, h, w)
    macs, params = profile(model, inputs=(x, ), verbose=False)
    logger.info("Params: {:,}".format(params))
    logger.info("Flops: {:,}".format(macs))
    # Transfer the model to the current GPU device
    err_str = "Cannot use more GPU devices than available"
    assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str
    cur_device = torch.cuda.current_device()
    model = model.cuda(device=cur_device)
    # Use multi-process data parallel model in the multi-gpu setting
    if cfg.NUM_GPUS > 1:
        # Make model replica operate on the current device
        model = torch.nn.parallel.DistributedDataParallel(
            module=model, device_ids=[cur_device], output_device=cur_device
        )
        # Set complexity function to be module's complexity function
        # model.complexity = model.module.complexity
    return model


def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
    """Performs one epoch of training."""
    # Update drop path prob for NAS
    if cfg.MODEL.TYPE == "nas":
        m = model.module if cfg.NUM_GPUS > 1 else model
        m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH)
    # Shuffle the data
    loader.shuffle(train_loader, cur_epoch)
    # Update the learning rate per epoch
    if not cfg.OPTIM.ITER_LR:
        lr = optim.get_epoch_lr(cur_epoch)
        optim.set_lr(optimizer, lr)
    # Enable training mode
    model.train()
    train_meter.iter_tic()
    for cur_iter, (inputs, labels) in enumerate(train_loader):
        # Update the learning rate per iter
        if cfg.OPTIM.ITER_LR:
            lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader))
            optim.set_lr(optimizer, lr)
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Perform the forward pass
        preds = model(inputs)
        # Compute the loss
        if isinstance(preds, tuple):
            loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
            preds = preds[0]
        else:
            loss = loss_fun(preds, labels)
        # Perform the backward pass
        optimizer.zero_grad()
        loss.backward()
        # Update the parameters
        optimizer.step()
        # Compute the errors
        if cfg.TASK == "col":
            preds = preds.permute(0, 2, 3, 1)
            preds = preds.reshape(-1, preds.size(3))
            labels = labels.reshape(-1)
            mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
        else:
            mb_size = inputs.size(0) * cfg.NUM_GPUS
        if cfg.TASK == "seg":
            # top1_err is in fact inter; top5_err is in fact union
            top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
        else:
            ks = [1, min(5, cfg.MODEL.NUM_CLASSES)]  # rot only has 4 classes
            top1_err, top5_err = meters.topk_errors(preds, labels, ks)
        # Combine the stats across the GPUs (no reduction if 1 GPU used)
        loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
        # Copy the stats from GPU to CPU (sync point)
        loss = loss.item()
        if cfg.TASK == "seg":
            top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
        else:
            top1_err, top5_err = top1_err.item(), top5_err.item()
        train_meter.iter_toc()
        # Update and log stats
        train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
        train_meter.log_iter_stats(cur_epoch, cur_iter)
        train_meter.iter_tic()
    # Log epoch stats
    train_meter.log_epoch_stats(cur_epoch)
    train_meter.reset()


def search_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
    """Performs one epoch of differentiable architecture search."""
    m = model.module if cfg.NUM_GPUS > 1 else model
    # Shuffle the data
    loader.shuffle(train_loader[0], cur_epoch)
    loader.shuffle(train_loader[1], cur_epoch)
    # Update the learning rate per epoch
    if not cfg.OPTIM.ITER_LR:
        lr = optim.get_epoch_lr(cur_epoch)
        optim.set_lr(optimizer[0], lr)
    # Enable training mode
    model.train()
    train_meter.iter_tic()
    trainB_iter = iter(train_loader[1])
    for cur_iter, (inputs, labels) in enumerate(train_loader[0]):
        # Update the learning rate per iter
        if cfg.OPTIM.ITER_LR:
            lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader[0]))
            optim.set_lr(optimizer[0], lr)
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Update architecture
        if cur_epoch + cur_iter / len(train_loader[0]) >= cfg.OPTIM.ARCH_EPOCH:
            try:
                inputsB, labelsB = next(trainB_iter)
            except StopIteration:
                trainB_iter = iter(train_loader[1])
                inputsB, labelsB = next(trainB_iter)
            inputsB, labelsB = inputsB.cuda(), labelsB.cuda(non_blocking=True)
            optimizer[1].zero_grad()
            loss = m._loss(inputsB, labelsB)
            loss.backward()
            optimizer[1].step()
        # Perform the forward pass
        preds = model(inputs)
        # Compute the loss
        loss = loss_fun(preds, labels)
        # Perform the backward pass
        optimizer[0].zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
        # Update the parameters
        optimizer[0].step()
        # Compute the errors
        if cfg.TASK == "col":
            preds = preds.permute(0, 2, 3, 1)
            preds = preds.reshape(-1, preds.size(3))
            labels = labels.reshape(-1)
            mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
        else:
            mb_size = inputs.size(0) * cfg.NUM_GPUS
        if cfg.TASK == "seg":
            # top1_err is in fact inter; top5_err is in fact union
            top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
        else:
            ks = [1, min(5, cfg.MODEL.NUM_CLASSES)]  # rot only has 4 classes
            top1_err, top5_err = meters.topk_errors(preds, labels, ks)
        # Combine the stats across the GPUs (no reduction if 1 GPU used)
        loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
        # Copy the stats from GPU to CPU (sync point)
        loss = loss.item()
        if cfg.TASK == "seg":
            top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
        else:
            top1_err, top5_err = top1_err.item(), top5_err.item()
        train_meter.iter_toc()
        # Update and log stats
        train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
        train_meter.log_iter_stats(cur_epoch, cur_iter)
        train_meter.iter_tic()
    # Log epoch stats
    train_meter.log_epoch_stats(cur_epoch)
    train_meter.reset()
    # Log genotype
    genotype = m.genotype()
    logger.info("genotype = %s", genotype)
    logger.info(F.softmax(m.net_.alphas_normal, dim=-1))
    logger.info(F.softmax(m.net_.alphas_reduce, dim=-1))


@torch.no_grad()
def test_epoch(test_loader, model, test_meter, cur_epoch):
    """Evaluates the model on the test set."""
    # Enable eval mode
    model.eval()
    test_meter.iter_tic()
    for cur_iter, (inputs, labels) in enumerate(test_loader):
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Compute the predictions
        preds = model(inputs)
        # Compute the errors
        if cfg.TASK == "col":
            preds = preds.permute(0, 2, 3, 1)
            preds = preds.reshape(-1, preds.size(3))
            labels = labels.reshape(-1)
            mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
        else:
            mb_size = inputs.size(0) * cfg.NUM_GPUS
        if cfg.TASK == "seg":
            # top1_err is in fact inter; top5_err is in fact union
            top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
        else:
            ks = [1, min(5, cfg.MODEL.NUM_CLASSES)]  # rot only has 4 classes
            top1_err, top5_err = meters.topk_errors(preds, labels, ks)
        # Combine the errors across the GPUs  (no reduction if 1 GPU used)
        top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
        # Copy the errors from GPU to CPU (sync point)
        if cfg.TASK == "seg":
            top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
        else:
            top1_err, top5_err = top1_err.item(), top5_err.item()
        test_meter.iter_toc()
        # Update and log stats
        test_meter.update_stats(top1_err, top5_err, mb_size)
        test_meter.log_iter_stats(cur_epoch, cur_iter)
        test_meter.iter_tic()
    # Log epoch stats
    test_meter.log_epoch_stats(cur_epoch)
    test_meter.reset()


def train_model():
    """Trains the model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model, loss_fun, and optimizer
    model = setup_model()
    loss_fun = builders.build_loss_fun().cuda()
    if "search" in cfg.MODEL.TYPE:
        params_w = [v for k, v in model.named_parameters() if "alphas" not in k]
        params_a = [v for k, v in model.named_parameters() if "alphas" in k]
        optimizer_w = torch.optim.SGD(
            params=params_w,
            lr=cfg.OPTIM.BASE_LR,
            momentum=cfg.OPTIM.MOMENTUM,
            weight_decay=cfg.OPTIM.WEIGHT_DECAY,
            dampening=cfg.OPTIM.DAMPENING,
            nesterov=cfg.OPTIM.NESTEROV
        )
        if cfg.OPTIM.ARCH_OPTIM == "adam":
            optimizer_a = torch.optim.Adam(
                params=params_a,
                lr=cfg.OPTIM.ARCH_BASE_LR,
                betas=(0.5, 0.999),
                weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY
            )
        elif cfg.OPTIM.ARCH_OPTIM == "sgd":
            optimizer_a = torch.optim.SGD(
                params=params_a,
                lr=cfg.OPTIM.ARCH_BASE_LR,
                momentum=cfg.OPTIM.MOMENTUM,
                weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY,
                dampening=cfg.OPTIM.DAMPENING,
                nesterov=cfg.OPTIM.NESTEROV
            )
        optimizer = [optimizer_w, optimizer_a]
    else:
        optimizer = optim.construct_optimizer(model)
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
        logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
    # Create data loaders and meters
    if cfg.TRAIN.PORTION < 1:
        if "search" in cfg.MODEL.TYPE:
            train_loader = [loader._construct_loader(
                dataset_name=cfg.TRAIN.DATASET,
                split=cfg.TRAIN.SPLIT,
                batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
                shuffle=True,
                drop_last=True,
                portion=cfg.TRAIN.PORTION,
                side="l"
            ),
            loader._construct_loader(
                dataset_name=cfg.TRAIN.DATASET,
                split=cfg.TRAIN.SPLIT,
                batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
                shuffle=True,
                drop_last=True,
                portion=cfg.TRAIN.PORTION,
                side="r"
            )]
        else:
            train_loader = loader._construct_loader(
                dataset_name=cfg.TRAIN.DATASET,
                split=cfg.TRAIN.SPLIT,
                batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
                shuffle=True,
                drop_last=True,
                portion=cfg.TRAIN.PORTION,
                side="l"
            )
        test_loader = loader._construct_loader(
            dataset_name=cfg.TRAIN.DATASET,
            split=cfg.TRAIN.SPLIT,
            batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
            shuffle=False,
            drop_last=False,
            portion=cfg.TRAIN.PORTION,
            side="r"
        )
    else:
        train_loader = loader.construct_train_loader()
        test_loader = loader.construct_test_loader()
    train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter
    test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter
    l = train_loader[0] if isinstance(train_loader, list) else train_loader
    train_meter = train_meter_type(len(l))
    test_meter = test_meter_type(len(test_loader))
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        l = train_loader[0] if isinstance(train_loader, list) else train_loader
        benchmark.compute_time_full(model, loss_fun, l, test_loader)
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train for one epoch
        f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch
        f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            net.compute_precise_bn_stats(model, train_loader)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            test_epoch(test_loader, model, test_meter, cur_epoch)


def test_model():
    """Evaluates a trained model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model
    model = setup_model()
    # Load model weights
    checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
    logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
    # Create data loaders and meters
    test_loader = loader.construct_test_loader()
    test_meter = meters.TestMeter(len(test_loader))
    # Evaluate the model
    test_epoch(test_loader, model, test_meter, 0)


def time_model():
    """Times model and data loader."""
    # Setup training/testing environment
    setup_env()
    # Construct the model and loss_fun
    model = setup_model()
    loss_fun = builders.build_loss_fun().cuda()
    # Create data loaders
    train_loader = loader.construct_train_loader()
    test_loader = loader.construct_test_loader()
    # Compute model and loader timings
    benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)