Prototype generic nas model (cont.) for ENAS.
This commit is contained in:
@@ -5,11 +5,75 @@ import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import Text
|
||||
from torch.distributions.categorical import Categorical
|
||||
|
||||
from ..cell_operations import ResNetBasicblock, drop_path
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
from .search_model_enas_utils import Controller
|
||||
|
||||
|
||||
class Controller(nn.Module):
|
||||
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
|
||||
def __init__(self, edge2index, op_names, max_nodes, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0):
|
||||
super(Controller, self).__init__()
|
||||
# assign the attributes
|
||||
self.max_nodes = max_nodes
|
||||
self.num_edge = len(edge2index)
|
||||
self.edge2index = edge2index
|
||||
self.num_ops = len(op_names)
|
||||
self.op_names = op_names
|
||||
self.lstm_size = lstm_size
|
||||
self.lstm_N = lstm_num_layers
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
# create parameters
|
||||
self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size)))
|
||||
self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N)
|
||||
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
|
||||
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
|
||||
|
||||
nn.init.uniform_(self.input_vars , -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_embd.weight , -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_pred.weight , -0.1, 0.1)
|
||||
|
||||
def convert_structure(self, _arch):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
op_index = _arch[self.edge2index[node_str]]
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return Structure(genotypes)
|
||||
|
||||
def forward(self):
|
||||
|
||||
inputs, h0 = self.input_vars, None
|
||||
log_probs, entropys, sampled_arch = [], [], []
|
||||
for iedge in range(self.num_edge):
|
||||
outputs, h0 = self.w_lstm(inputs, h0)
|
||||
|
||||
logits = self.w_pred(outputs)
|
||||
logits = logits / self.temperature
|
||||
logits = self.tanh_constant * torch.tanh(logits)
|
||||
# distribution
|
||||
op_distribution = Categorical(logits=logits)
|
||||
op_index = op_distribution.sample()
|
||||
sampled_arch.append( op_index.item() )
|
||||
|
||||
op_log_prob = op_distribution.log_prob(op_index)
|
||||
log_probs.append( op_log_prob.view(-1) )
|
||||
op_entropy = op_distribution.entropy()
|
||||
entropys.append( op_entropy.view(-1) )
|
||||
|
||||
# obtain the input embedding for the next step
|
||||
inputs = self.w_embd(op_index)
|
||||
return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), self.convert_structure(sampled_arch)
|
||||
|
||||
|
||||
|
||||
class GenericNAS201Model(nn.Module):
|
||||
@@ -55,7 +119,7 @@ class GenericNAS201Model(nn.Module):
|
||||
assert self._algo is None, 'This functioin can only be called once.'
|
||||
self._algo = algo
|
||||
if algo == 'enas':
|
||||
self.controller = Controller(len(self.edge2index), len(self._op_names))
|
||||
self.controller = Controller(self.edge2index, self._op_names, self._max_nodes)
|
||||
else:
|
||||
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(self._num_edge, len(self._op_names)) )
|
||||
if algo == 'gdas':
|
||||
@@ -116,10 +180,9 @@ class GenericNAS201Model(nn.Module):
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
if self._algo == 'enas':
|
||||
import pdb; pdb.set_trace()
|
||||
print('-')
|
||||
return 'w_pred :\n{:}'.format(self.controller.w_pred.weight)
|
||||
else:
|
||||
return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() )
|
||||
return 'arch-parameters :\n{:}'.format(nn.functional.softmax(self.arch_parameters, dim=-1).cpu())
|
||||
|
||||
|
||||
def extra_repr(self):
|
||||
|
Reference in New Issue
Block a user