Update TAS abd FBV2 for NAS-Bench

This commit is contained in:
D-X-Y
2020-07-24 12:56:34 +00:00
parent b9fbe5577c
commit 4a2292a863
8 changed files with 491 additions and 12 deletions

View File

@@ -74,17 +74,17 @@ class DualSepConv(nn.Module):
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, affine=True):
def __init__(self, inplanes, planes, stride, affine=True, track_running_stats=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine, track_running_stats)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine, track_running_stats)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine, track_running_stats)
else:
self.downsample = None
self.in_dim = inplanes