Prototype generic nas model (cont.) for GDAS.

This commit is contained in:
D-X-Y
2020-07-20 08:45:41 +00:00
parent 5cf66d24a1
commit 8d27050f6f
2 changed files with 22 additions and 9 deletions

View File

@@ -102,17 +102,18 @@ class GenericNAS201Model(nn.Module):
self._op_names = deepcopy(search_space)
self._Layer = len(self._cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev, affine=affine, track_running_stats=track_running_stats), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self._num_edge = num_edge
# algorithm related
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.arch_parameters = nn.Parameter(1e-3*torch.randn(num_edge, len(search_space)))
self._mode = None
self.dynamic_cell = None
self._tau = None
self._algo = None
self._drop_path = None
self.verbose = False
def set_algo(self, algo: Text):
# used for searching
@@ -256,33 +257,45 @@ class GenericNAS201Model(nn.Module):
else: break
with torch.no_grad():
hardwts_cpu = hardwts.detach().cpu()
return hardwts, hardwts_cpu, index
return hardwts, hardwts_cpu, index, 'GUMBEL'
else:
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
index = alphas.max(-1, keepdim=True)[1]
with torch.no_grad():
alphas_cpu = alphas.detach().cpu()
return alphas, alphas_cpu, index
return alphas, alphas_cpu, index, 'SOFTMAX'
def forward(self, inputs):
alphas, alphas_cpu, index = self.normalize_archp()
alphas, alphas_cpu, index, verbose_str = self.normalize_archp()
feature = self._stem(inputs)
for i, cell in enumerate(self._cells):
if isinstance(cell, SearchCell):
if self.mode == 'urs':
feature = cell.forward_urs(feature)
if self.verbose:
verbose_str += '-forward_urs'
elif self.mode == 'select':
feature = cell.forward_select(feature, alphas_cpu)
if self.verbose:
verbose_str += '-forward_select'
elif self.mode == 'joint':
feature = cell.forward_joint(feature, alphas)
if self.verbose:
verbose_str += '-forward_joint'
elif self.mode == 'dynamic':
feature = cell.forward_dynamic(feature, self.dynamic_cell)
if self.verbose:
verbose_str += '-forward_dynamic'
elif self.mode == 'gdas':
feature = cell.forward_gdas(feature, alphas, index)
if self.verbose:
verbose_str += '-forward_gdas'
else: raise ValueError('invalid mode={:}'.format(self.mode))
else: feature = cell(feature)
if self.drop_path is not None:
feature = drop_path(feature, self.drop_path)
if self.verbose and random.random() < 0.001:
print(verbose_str)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)