Prototype generic nas model (cont.).

This commit is contained in:
D-X-Y
2020-07-19 08:11:29 +00:00
parent 7ca2ca70b4
commit c34620ab1b
3 changed files with 91 additions and 6 deletions

View File

@@ -67,6 +67,14 @@ class GenericNAS201Model(nn.Module):
if mode == 'dynamic': self.dynamic_cell = deepcopy(dynamic_cell)
else : self.dynamic_cell = None
def set_drop_path(self, progress, drop_path_rate):
if drop_path_rate is None:
self._drop_path = None
elif progress is None:
self._drop_path = drop_path_rate
else:
self._drop_path = progress * drop_path_rate
@property
def mode(self):
return self._mode
@@ -210,6 +218,8 @@ class GenericNAS201Model(nn.Module):
feature = cell.forward_gdas(feature, alphas, index)
else: raise ValueError('invalid mode={:}'.format(self.mode))
else: feature = cell(feature)
if self.drop_path is not None:
feature = drop_path(feature, self.drop_path)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)