Prototype generic nas model (cont.).
This commit is contained in:
@@ -67,6 +67,14 @@ class GenericNAS201Model(nn.Module):
|
||||
if mode == 'dynamic': self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else : self.dynamic_cell = None
|
||||
|
||||
def set_drop_path(self, progress, drop_path_rate):
|
||||
if drop_path_rate is None:
|
||||
self._drop_path = None
|
||||
elif progress is None:
|
||||
self._drop_path = drop_path_rate
|
||||
else:
|
||||
self._drop_path = progress * drop_path_rate
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
return self._mode
|
||||
@@ -210,6 +218,8 @@ class GenericNAS201Model(nn.Module):
|
||||
feature = cell.forward_gdas(feature, alphas, index)
|
||||
else: raise ValueError('invalid mode={:}'.format(self.mode))
|
||||
else: feature = cell(feature)
|
||||
if self.drop_path is not None:
|
||||
feature = drop_path(feature, self.drop_path)
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
|
Reference in New Issue
Block a user