update 10 NAS algs

This commit is contained in:
D-X-Y
2019-11-15 17:15:07 +11:00
parent 672a9ef0db
commit c72e66b66c
139 changed files with 5863 additions and 368 deletions

View File

@@ -0,0 +1,9 @@
# utils
from .utils import batchify, get_batch, repackage_hidden
# models
from .model_search import RNNModelSearch
from .model_search import DARTSCellSearch
from .basemodel import DARTSCell, RNNModel
# architecture
from .genotypes import DARTS_V1, DARTS_V2
from .genotypes import GDAS

View File

@@ -0,0 +1,181 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .genotypes import STEPS
from .utils import mask2d, LockedDropout, embedded_dropout
INITRANGE = 0.04
def none_func(x):
return x * 0
class DARTSCell(nn.Module):
def __init__(self, ninp, nhid, dropouth, dropoutx, genotype):
super(DARTSCell, self).__init__()
self.nhid = nhid
self.dropouth = dropouth
self.dropoutx = dropoutx
self.genotype = genotype
# genotype is None when doing arch search
steps = len(self.genotype.recurrent) if self.genotype is not None else STEPS
self._W0 = nn.Parameter(torch.Tensor(ninp+nhid, 2*nhid).uniform_(-INITRANGE, INITRANGE))
self._Ws = nn.ParameterList([
nn.Parameter(torch.Tensor(nhid, 2*nhid).uniform_(-INITRANGE, INITRANGE)) for i in range(steps)
])
def forward(self, inputs, hidden, arch_probs):
T, B = inputs.size(0), inputs.size(1)
if self.training:
x_mask = mask2d(B, inputs.size(2), keep_prob=1.-self.dropoutx)
h_mask = mask2d(B, hidden.size(2), keep_prob=1.-self.dropouth)
else:
x_mask = h_mask = None
hidden = hidden[0]
hiddens = []
for t in range(T):
hidden = self.cell(inputs[t], hidden, x_mask, h_mask, arch_probs)
hiddens.append(hidden)
hiddens = torch.stack(hiddens)
return hiddens, hiddens[-1].unsqueeze(0)
def _compute_init_state(self, x, h_prev, x_mask, h_mask):
if self.training:
xh_prev = torch.cat([x * x_mask, h_prev * h_mask], dim=-1)
else:
xh_prev = torch.cat([x, h_prev], dim=-1)
c0, h0 = torch.split(xh_prev.mm(self._W0), self.nhid, dim=-1)
c0 = c0.sigmoid()
h0 = h0.tanh()
s0 = h_prev + c0 * (h0-h_prev)
return s0
def _get_activation(self, name):
if name == 'tanh':
f = torch.tanh
elif name == 'relu':
f = torch.relu
elif name == 'sigmoid':
f = torch.sigmoid
elif name == 'identity':
f = lambda x: x
elif name == 'none':
f = none_func
else:
raise NotImplementedError
return f
def cell(self, x, h_prev, x_mask, h_mask, _):
s0 = self._compute_init_state(x, h_prev, x_mask, h_mask)
states = [s0]
for i, (name, pred) in enumerate(self.genotype.recurrent):
s_prev = states[pred]
if self.training:
ch = (s_prev * h_mask).mm(self._Ws[i])
else:
ch = s_prev.mm(self._Ws[i])
c, h = torch.split(ch, self.nhid, dim=-1)
c = c.sigmoid()
fn = self._get_activation(name)
h = fn(h)
s = s_prev + c * (h-s_prev)
states += [s]
output = torch.mean(torch.stack([states[i] for i in self.genotype.concat], -1), -1)
return output
class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, ntoken, ninp, nhid, nhidlast,
dropout=0.5, dropouth=0.5, dropoutx=0.5, dropouti=0.5, dropoute=0.1,
cell_cls=None, genotype=None):
super(RNNModel, self).__init__()
self.lockdrop = LockedDropout()
self.encoder = nn.Embedding(ntoken, ninp)
assert ninp == nhid == nhidlast
if cell_cls == DARTSCell:
assert genotype is not None
rnns = [cell_cls(ninp, nhid, dropouth, dropoutx, genotype)]
else:
assert genotype is None
rnns = [cell_cls(ninp, nhid, dropouth, dropoutx)]
self.rnns = torch.nn.ModuleList(rnns)
self.decoder = nn.Linear(ninp, ntoken)
self.decoder.weight = self.encoder.weight
self.init_weights()
self.arch_weights = None
self.ninp = ninp
self.nhid = nhid
self.nhidlast = nhidlast
self.dropout = dropout
self.dropouti = dropouti
self.dropoute = dropoute
self.ntoken = ntoken
self.cell_cls = cell_cls
# acceleration
self.tau = None
self.use_gumbel = False
def set_gumbel(self, use_gumbel, set_check):
self.use_gumbel = use_gumbel
for i, rnn in enumerate(self.rnns):
rnn.set_check(set_check)
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def init_weights(self):
self.encoder.weight.data.uniform_(-INITRANGE, INITRANGE)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-INITRANGE, INITRANGE)
def forward(self, input, hidden, return_h=False):
batch_size = input.size(1)
emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0)
emb = self.lockdrop(emb, self.dropouti)
raw_output = emb
new_hidden = []
raw_outputs = []
outputs = []
if self.arch_weights is None:
arch_probs = None
else:
if self.use_gumbel: arch_probs = F.gumbel_softmax(self.arch_weights, self.tau, False)
else : arch_probs = F.softmax(self.arch_weights, dim=-1)
for l, rnn in enumerate(self.rnns):
current_input = raw_output
raw_output, new_h = rnn(raw_output, hidden[l], arch_probs)
new_hidden.append(new_h)
raw_outputs.append(raw_output)
hidden = new_hidden
output = self.lockdrop(raw_output, self.dropout)
outputs.append(output)
logit = self.decoder(output.view(-1, self.ninp))
log_prob = nn.functional.log_softmax(logit, dim=-1)
model_output = log_prob
model_output = model_output.view(-1, batch_size, self.ntoken)
if return_h: return model_output, hidden, raw_outputs, outputs
else : return model_output, hidden
def init_hidden(self, bsz):
weight = next(self.parameters()).clone()
return [weight.new(1, bsz, self.nhid).zero_()]

View File

@@ -0,0 +1,55 @@
from collections import namedtuple
Genotype = namedtuple('Genotype', 'recurrent concat')
PRIMITIVES = [
'none',
'tanh',
'relu',
'sigmoid',
'identity'
]
STEPS = 8
CONCAT = 8
ENAS = Genotype(
recurrent = [
('tanh', 0),
('tanh', 1),
('relu', 1),
('tanh', 3),
('tanh', 3),
('relu', 3),
('relu', 4),
('relu', 7),
('relu', 8),
('relu', 8),
('relu', 8),
],
concat = [2, 5, 6, 9, 10, 11]
)
DARTS_V1 = Genotype(
recurrent = [
('relu', 0),
('relu', 1),
('tanh', 2),
('relu', 3), ('relu', 4), ('identity', 1), ('relu', 5), ('relu', 1)
],
concat=range(1, 9)
)
DARTS_V2 = Genotype(
recurrent = [
('sigmoid', 0), ('relu', 1), ('relu', 1),
('identity', 1), ('tanh', 2), ('sigmoid', 5),
('tanh', 3), ('relu', 5)
],
concat=range(1, 9)
)
GDAS = Genotype(
recurrent=[('relu', 0), ('relu', 0), ('identity', 1), ('relu', 1), ('tanh', 0), ('relu', 2), ('identity', 4), ('identity', 2)],
concat=range(1, 9)
)

View File

@@ -0,0 +1,104 @@
import copy, torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
from .genotypes import PRIMITIVES, STEPS, CONCAT, Genotype
from .basemodel import DARTSCell, RNNModel
class DARTSCellSearch(DARTSCell):
def __init__(self, ninp, nhid, dropouth, dropoutx):
super(DARTSCellSearch, self).__init__(ninp, nhid, dropouth, dropoutx, genotype=None)
self.bn = nn.BatchNorm1d(nhid, affine=False)
self.check_zero = False
def set_check(self, check_zero):
self.check_zero = check_zero
def cell(self, x, h_prev, x_mask, h_mask, arch_probs):
s0 = self._compute_init_state(x, h_prev, x_mask, h_mask)
s0 = self.bn(s0)
if self.check_zero:
arch_probs_cpu = arch_probs.cpu().tolist()
#arch_probs = F.softmax(self.weights, dim=-1)
offset = 0
states = s0.unsqueeze(0)
for i in range(STEPS):
if self.training:
masked_states = states * h_mask.unsqueeze(0)
else:
masked_states = states
ch = masked_states.view(-1, self.nhid).mm(self._Ws[i]).view(i+1, -1, 2*self.nhid)
c, h = torch.split(ch, self.nhid, dim=-1)
c = c.sigmoid()
s = torch.zeros_like(s0)
for k, name in enumerate(PRIMITIVES):
if name == 'none':
continue
fn = self._get_activation(name)
unweighted = states + c * (fn(h) - states)
if self.check_zero:
INDEX, INDDX = [], []
for jj in range(offset, offset+i+1):
if arch_probs_cpu[jj][k] > 0:
INDEX.append(jj)
INDDX.append(jj-offset)
if len(INDEX) == 0: continue
s += torch.sum(arch_probs[INDEX, k].unsqueeze(-1).unsqueeze(-1) * unweighted[INDDX, :, :], dim=0)
else:
s += torch.sum(arch_probs[offset:offset+i+1, k].unsqueeze(-1).unsqueeze(-1) * unweighted, dim=0)
s = self.bn(s)
states = torch.cat([states, s.unsqueeze(0)], 0)
offset += i+1
output = torch.mean(states[-CONCAT:], dim=0)
return output
class RNNModelSearch(RNNModel):
def __init__(self, *args):
super(RNNModelSearch, self).__init__(*args)
self._args = copy.deepcopy( args )
k = sum(i for i in range(1, STEPS+1))
self.arch_weights = nn.Parameter(torch.Tensor(k, len(PRIMITIVES)))
nn.init.normal_(self.arch_weights, 0, 0.001)
def base_parameters(self):
lists = list(self.lockdrop.parameters())
lists += list(self.encoder.parameters())
lists += list(self.rnns.parameters())
lists += list(self.decoder.parameters())
return lists
def arch_parameters(self):
return [self.arch_weights]
def genotype(self):
def _parse(probs):
gene = []
start = 0
for i in range(STEPS):
end = start + i + 1
W = probs[start:end].copy()
#j = sorted(range(i + 1), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[0]
j = sorted(range(i + 1), key=lambda x: -max(W[x][k] for k in range(len(W[x])) ))[0]
k_best = None
for k in range(len(W[j])):
#if k != PRIMITIVES.index('none'):
# if k_best is None or W[j][k] > W[j][k_best]:
# k_best = k
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((PRIMITIVES[k_best], j))
start = end
return gene
with torch.no_grad():
gene = _parse(F.softmax(self.arch_weights, dim=-1).cpu().numpy())
genotype = Genotype(recurrent=gene, concat=list(range(STEPS+1)[-CONCAT:]))
return genotype

View File

@@ -0,0 +1,66 @@
import torch
import torch.nn as nn
import os, shutil
import numpy as np
def repackage_hidden(h):
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
def batchify(data, bsz, use_cuda):
nbatch = data.size(0) // bsz
data = data.narrow(0, 0, nbatch * bsz)
data = data.view(bsz, -1).t().contiguous()
if use_cuda: return data.cuda()
else : return data
def get_batch(source, i, seq_len):
seq_len = min(seq_len, len(source) - 1 - i)
data = source[i:i+seq_len].clone()
target = source[i+1:i+1+seq_len].clone()
return data, target
def embedded_dropout(embed, words, dropout=0.1, scale=None):
if dropout:
mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
mask.requires_grad_(True)
masked_embed_weight = mask * embed.weight
else:
masked_embed_weight = embed.weight
if scale:
masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight
padding_idx = embed.padding_idx
if padding_idx is None:
padding_idx = -1
X = torch.nn.functional.embedding(
words, masked_embed_weight,
padding_idx, embed.max_norm, embed.norm_type,
embed.scale_grad_by_freq, embed.sparse)
return X
class LockedDropout(nn.Module):
def __init__(self):
super(LockedDropout, self).__init__()
def forward(self, x, dropout=0.5):
if not self.training or not dropout:
return x
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
mask = m.div_(1 - dropout).detach()
mask = mask.expand_as(x)
return mask * x
def mask2d(B, D, keep_prob, cuda=True):
m = torch.floor(torch.rand(B, D) + keep_prob) / keep_prob
if cuda: return m.cuda()
else : return m