update GDAS and SETN
This commit is contained in:
@@ -6,9 +6,9 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from .infer_cells import ResNetBasicblock
|
||||
from .search_cells import SearchCell
|
||||
from .genotypes import Structure
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkGDAS(nn.Module):
|
||||
@@ -44,7 +44,6 @@ class TinyNetworkGDAS(nn.Module):
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||
self.tau = 10
|
||||
self.nan_count = 0
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
||||
@@ -52,9 +51,8 @@ class TinyNetworkGDAS(nn.Module):
|
||||
xlist+= list( self.classifier.parameters() )
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau, _nan_count=0):
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
self.nan_count = _nan_count
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
@@ -85,27 +83,10 @@ class TinyNetworkGDAS(nn.Module):
|
||||
return Structure( genotypes )
|
||||
|
||||
def forward(self, inputs):
|
||||
def gumbel_softmax(_logits, _tau):
|
||||
while True: # a trick to avoid the gumbels bug
|
||||
gumbels = -torch.empty_like(_logits).exponential_().log()
|
||||
new_logits = (_logits.log_softmax(dim=1) + gumbels) / _tau
|
||||
probs = nn.functional.softmax(new_logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
if index[0].item() == self.op_names.index('none') and index[3].item() == self.op_names.index('none') and index[5].item() == self.op_names.index('none'): continue
|
||||
if index[1].item() == self.op_names.index('none') and index[2].item() == self.op_names.index('none') and index[3].item() == self.op_names.index('none') and index[4].item() == self.op_names.index('none'): continue
|
||||
if index[3].item() == self.op_names.index('none') and index[4].item() == self.op_names.index('none') and index[5].item() == self.op_names.index('none'): continue
|
||||
if index[3].item() == self.op_names.index('none') and index[0].item() == self.op_names.index('none') and index[1].item() == self.op_names.index('none'): continue
|
||||
one_h = torch.zeros_like(_logits).scatter_(-1, index, 1.0)
|
||||
xres = one_h - probs.detach() + probs
|
||||
if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break
|
||||
self.nan_count += 1
|
||||
return xres, index
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
alphas, IDX = gumbel_softmax(self.arch_parameters, self.tau)
|
||||
feature = cell.forward_gdas(feature, alphas, IDX.cpu())
|
||||
feature = cell.forward_gdas(feature, self.arch_parameters, self.tau)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
|
Reference in New Issue
Block a user