Fix bugs in TAS: missing ReLU in the end of each searching block
This commit is contained in:
@@ -15,8 +15,8 @@ OPS = {
|
||||
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
|
||||
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
|
||||
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
|
||||
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
|
||||
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
|
||||
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
|
||||
}
|
||||
|
||||
|
@@ -172,7 +172,7 @@ class ResNetBasicblock(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
@@ -244,8 +244,7 @@ class ResNetBottleneck(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
|
||||
class SearchShapeCifarResNet(nn.Module):
|
||||
|
@@ -156,7 +156,7 @@ class ResNetBasicblock(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
@@ -228,8 +228,7 @@ class ResNetBottleneck(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
|
||||
class SearchWidthCifarResNet(nn.Module):
|
||||
|
@@ -171,7 +171,7 @@ class ResNetBasicblock(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
@@ -243,8 +243,7 @@ class ResNetBottleneck(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
|
||||
class SearchShapeImagenetResNet(nn.Module):
|
||||
|
@@ -153,7 +153,7 @@ class SimBlock(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out)
|
||||
return out, expected_next_inC, sum([expected_flop, expected_flop_c])
|
||||
return nn.functional.relu(out, inplace=True), expected_next_inC, sum([expected_flop, expected_flop_c])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv(inputs)
|
||||
|
Reference in New Issue
Block a user