#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

from pycls.core.config import cfg


def Preprocess(x):
    if cfg.TASK == 'jig':
        assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw'
        assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw'
        x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]])
    return x


class Classifier(nn.Module):
    def __init__(self, channels, num_classes):
        super(Classifier, self).__init__()
        if cfg.TASK == 'jig':
            self.jig_sq = cfg.JIGSAW_GRID ** 2
            self.pooling = nn.AdaptiveAvgPool2d(1)
            self.classifier = nn.Linear(channels * self.jig_sq, num_classes)
        elif cfg.TASK == 'col':
            self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)
        elif cfg.TASK == 'seg':
            self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES)
        else:
            self.pooling = nn.AdaptiveAvgPool2d(1)
            self.classifier = nn.Linear(channels, num_classes)

    def forward(self, x, shape):
        if cfg.TASK == 'jig':
            x = self.pooling(x)
            x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]])
            x = self.classifier(x.view(x.size(0), -1))
        elif cfg.TASK in ['col', 'seg']:
            x = self.classifier(x)
            x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x)
        else:
            x = self.pooling(x)
            x = self.classifier(x.view(x.size(0), -1))
        return x


class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes, rates):
        super(ASPP, self).__init__()
        assert len(rates) in [1, 3]
        self.rates = rates
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.aspp1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.aspp2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0],
                padding=rates[0], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        if len(self.rates) == 3:
            self.aspp3 = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1],
                    padding=rates[1], bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
            self.aspp4 = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2],
                    padding=rates[2], bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        self.aspp5 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.classifier = nn.Sequential(
            nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1,
                bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, num_classes, 1)
        )

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x5 = self.global_pooling(x)
        x5 = self.aspp5(x5)
        x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
                align_corners=True)(x5)
        if len(self.rates) == 3:
            x3 = self.aspp3(x)
            x4 = self.aspp4(x)
            x = torch.cat((x1, x2, x3, x4, x5), 1)
        else:
            x = torch.cat((x1, x2, x5), 1)
        x = self.classifier(x)
        return x