Clean unnecessary files
This commit is contained in:
@@ -1,104 +0,0 @@
|
||||
import copy, torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections import namedtuple
|
||||
from .genotypes import PRIMITIVES, STEPS, CONCAT, Genotype
|
||||
from .basemodel import DARTSCell, RNNModel
|
||||
|
||||
|
||||
class DARTSCellSearch(DARTSCell):
|
||||
|
||||
def __init__(self, ninp, nhid, dropouth, dropoutx):
|
||||
super(DARTSCellSearch, self).__init__(ninp, nhid, dropouth, dropoutx, genotype=None)
|
||||
self.bn = nn.BatchNorm1d(nhid, affine=False)
|
||||
self.check_zero = False
|
||||
|
||||
def set_check(self, check_zero):
|
||||
self.check_zero = check_zero
|
||||
|
||||
def cell(self, x, h_prev, x_mask, h_mask, arch_probs):
|
||||
s0 = self._compute_init_state(x, h_prev, x_mask, h_mask)
|
||||
s0 = self.bn(s0)
|
||||
if self.check_zero:
|
||||
arch_probs_cpu = arch_probs.cpu().tolist()
|
||||
#arch_probs = F.softmax(self.weights, dim=-1)
|
||||
|
||||
offset = 0
|
||||
states = s0.unsqueeze(0)
|
||||
for i in range(STEPS):
|
||||
if self.training:
|
||||
masked_states = states * h_mask.unsqueeze(0)
|
||||
else:
|
||||
masked_states = states
|
||||
ch = masked_states.view(-1, self.nhid).mm(self._Ws[i]).view(i+1, -1, 2*self.nhid)
|
||||
c, h = torch.split(ch, self.nhid, dim=-1)
|
||||
c = c.sigmoid()
|
||||
|
||||
s = torch.zeros_like(s0)
|
||||
for k, name in enumerate(PRIMITIVES):
|
||||
if name == 'none':
|
||||
continue
|
||||
fn = self._get_activation(name)
|
||||
unweighted = states + c * (fn(h) - states)
|
||||
if self.check_zero:
|
||||
INDEX, INDDX = [], []
|
||||
for jj in range(offset, offset+i+1):
|
||||
if arch_probs_cpu[jj][k] > 0:
|
||||
INDEX.append(jj)
|
||||
INDDX.append(jj-offset)
|
||||
if len(INDEX) == 0: continue
|
||||
s += torch.sum(arch_probs[INDEX, k].unsqueeze(-1).unsqueeze(-1) * unweighted[INDDX, :, :], dim=0)
|
||||
else:
|
||||
s += torch.sum(arch_probs[offset:offset+i+1, k].unsqueeze(-1).unsqueeze(-1) * unweighted, dim=0)
|
||||
s = self.bn(s)
|
||||
states = torch.cat([states, s.unsqueeze(0)], 0)
|
||||
offset += i+1
|
||||
output = torch.mean(states[-CONCAT:], dim=0)
|
||||
return output
|
||||
|
||||
|
||||
class RNNModelSearch(RNNModel):
|
||||
|
||||
def __init__(self, *args):
|
||||
super(RNNModelSearch, self).__init__(*args)
|
||||
self._args = copy.deepcopy( args )
|
||||
|
||||
k = sum(i for i in range(1, STEPS+1))
|
||||
self.arch_weights = nn.Parameter(torch.Tensor(k, len(PRIMITIVES)))
|
||||
nn.init.normal_(self.arch_weights, 0, 0.001)
|
||||
|
||||
def base_parameters(self):
|
||||
lists = list(self.lockdrop.parameters())
|
||||
lists += list(self.encoder.parameters())
|
||||
lists += list(self.rnns.parameters())
|
||||
lists += list(self.decoder.parameters())
|
||||
return lists
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.arch_weights]
|
||||
|
||||
def genotype(self):
|
||||
|
||||
def _parse(probs):
|
||||
gene = []
|
||||
start = 0
|
||||
for i in range(STEPS):
|
||||
end = start + i + 1
|
||||
W = probs[start:end].copy()
|
||||
#j = sorted(range(i + 1), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[0]
|
||||
j = sorted(range(i + 1), key=lambda x: -max(W[x][k] for k in range(len(W[x])) ))[0]
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
#if k != PRIMITIVES.index('none'):
|
||||
# if k_best is None or W[j][k] > W[j][k_best]:
|
||||
# k_best = k
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[k_best], j))
|
||||
start = end
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene = _parse(F.softmax(self.arch_weights, dim=-1).cpu().numpy())
|
||||
genotype = Genotype(recurrent=gene, concat=list(range(STEPS+1)[-CONCAT:]))
|
||||
return genotype
|
Reference in New Issue
Block a user