Prototype generic nas model (cont.).

This commit is contained in:
D-X-Y
2020-07-18 22:49:35 +00:00
parent 68f9d037eb
commit 7ca2ca70b4
3 changed files with 115 additions and 52 deletions

View File

@@ -242,6 +242,16 @@ class PartAwareOp(nn.Module):
return outputs
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob)
x = torch.div(x, keep_prob)
x.mul_(mask)
return x
# Searching for A Robust Neural Architecture in Four GPU Hours
class GDAS_Reduction_Cell(nn.Module):