import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .genotypes import STEPS
from .utils import mask2d, LockedDropout, embedded_dropout


INITRANGE = 0.04

def none_func(x):
  return x * 0


class DARTSCell(nn.Module):

  def __init__(self, ninp, nhid, dropouth, dropoutx, genotype):
    super(DARTSCell, self).__init__()
    self.nhid = nhid
    self.dropouth = dropouth
    self.dropoutx = dropoutx
    self.genotype = genotype

    # genotype is None when doing arch search
    steps = len(self.genotype.recurrent) if self.genotype is not None else STEPS
    self._W0 = nn.Parameter(torch.Tensor(ninp+nhid, 2*nhid).uniform_(-INITRANGE, INITRANGE))
    self._Ws = nn.ParameterList([
        nn.Parameter(torch.Tensor(nhid, 2*nhid).uniform_(-INITRANGE, INITRANGE)) for i in range(steps)
    ])

  def forward(self, inputs, hidden, arch_probs):
    T, B = inputs.size(0), inputs.size(1)

    if self.training:
      x_mask = mask2d(B, inputs.size(2), keep_prob=1.-self.dropoutx)
      h_mask = mask2d(B, hidden.size(2), keep_prob=1.-self.dropouth)
    else:
      x_mask = h_mask = None

    hidden = hidden[0]
    hiddens = []
    for t in range(T):
      hidden = self.cell(inputs[t], hidden, x_mask, h_mask, arch_probs)
      hiddens.append(hidden)
    hiddens = torch.stack(hiddens)
    return hiddens, hiddens[-1].unsqueeze(0)

  def _compute_init_state(self, x, h_prev, x_mask, h_mask):
    if self.training:
      xh_prev = torch.cat([x * x_mask, h_prev * h_mask], dim=-1)
    else:
      xh_prev = torch.cat([x, h_prev], dim=-1)
    c0, h0 = torch.split(xh_prev.mm(self._W0), self.nhid, dim=-1)
    c0 = c0.sigmoid()
    h0 = h0.tanh()
    s0 = h_prev + c0 * (h0-h_prev)
    return s0

  def _get_activation(self, name):
    if name == 'tanh':
      f = torch.tanh
    elif name == 'relu':
      f = torch.relu
    elif name == 'sigmoid':
      f = torch.sigmoid
    elif name == 'identity':
      f = lambda x: x
    elif name == 'none':
      f = none_func
    else:
      raise NotImplementedError
    return f

  def cell(self, x, h_prev, x_mask, h_mask, _):
    s0 = self._compute_init_state(x, h_prev, x_mask, h_mask)

    states = [s0]
    for i, (name, pred) in enumerate(self.genotype.recurrent):
      s_prev = states[pred]
      if self.training:
        ch = (s_prev * h_mask).mm(self._Ws[i])
      else:
        ch = s_prev.mm(self._Ws[i])
      c, h = torch.split(ch, self.nhid, dim=-1)
      c = c.sigmoid()
      fn = self._get_activation(name)
      h = fn(h)
      s = s_prev + c * (h-s_prev)
      states += [s]
    output = torch.mean(torch.stack([states[i] for i in self.genotype.concat], -1), -1)
    return output


class RNNModel(nn.Module):
  """Container module with an encoder, a recurrent module, and a decoder."""
  def __init__(self, ntoken, ninp, nhid, nhidlast, 
                 dropout=0.5, dropouth=0.5, dropoutx=0.5, dropouti=0.5, dropoute=0.1,
                 cell_cls=None, genotype=None):
    super(RNNModel, self).__init__()
    self.lockdrop = LockedDropout()
    self.encoder = nn.Embedding(ntoken, ninp)
        
    assert ninp == nhid == nhidlast
    if cell_cls == DARTSCell:
      assert genotype is not None
      rnns = [cell_cls(ninp, nhid, dropouth, dropoutx, genotype)]
    else:
      assert genotype is None
      rnns = [cell_cls(ninp, nhid, dropouth, dropoutx)]

    self.rnns    = torch.nn.ModuleList(rnns)
    self.decoder = nn.Linear(ninp, ntoken)
    self.decoder.weight = self.encoder.weight
    self.init_weights()
    self.arch_weights = None

    self.ninp = ninp
    self.nhid = nhid
    self.nhidlast = nhidlast
    self.dropout = dropout
    self.dropouti = dropouti
    self.dropoute = dropoute
    self.ntoken = ntoken
    self.cell_cls = cell_cls
    # acceleration
    self.tau = None
    self.use_gumbel = False

  def set_gumbel(self, use_gumbel, set_check):
    self.use_gumbel = use_gumbel
    for i, rnn in enumerate(self.rnns):
      rnn.set_check(set_check)

  def set_tau(self, tau):
    self.tau = tau
  
  def get_tau(self):
    return self.tau

  def init_weights(self):
    self.encoder.weight.data.uniform_(-INITRANGE, INITRANGE)
    self.decoder.bias.data.fill_(0)
    self.decoder.weight.data.uniform_(-INITRANGE, INITRANGE)

  def forward(self, input, hidden, return_h=False):
    batch_size = input.size(1)

    emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0)
    emb = self.lockdrop(emb, self.dropouti)

    raw_output = emb
    new_hidden = []
    raw_outputs = []
    outputs = []
    if self.arch_weights is None:
      arch_probs = None
    else:
      if self.use_gumbel: arch_probs = F.gumbel_softmax(self.arch_weights, self.tau, False)
      else              : arch_probs = F.softmax(self.arch_weights, dim=-1)

    for l, rnn in enumerate(self.rnns):
      current_input = raw_output
      raw_output, new_h = rnn(raw_output, hidden[l], arch_probs)
      new_hidden.append(new_h)
      raw_outputs.append(raw_output)
    hidden = new_hidden

    output = self.lockdrop(raw_output, self.dropout)
    outputs.append(output)

    logit = self.decoder(output.view(-1, self.ninp))
    log_prob = nn.functional.log_softmax(logit, dim=-1)
    model_output = log_prob
    model_output = model_output.view(-1, batch_size, self.ntoken)

    if return_h: return model_output, hidden, raw_outputs, outputs
    else       : return model_output, hidden

  def init_hidden(self, bsz):
    weight = next(self.parameters()).clone()
    return [weight.new(1, bsz, self.nhid).zero_()]