Update TAS abd FBV2 for NAS-Bench
This commit is contained in:
@@ -74,17 +74,17 @@ class DualSepConv(nn.Module):
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, stride, affine=True):
|
||||
def __init__(self, inplanes, planes, stride, affine=True, track_running_stats=True):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
|
||||
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
|
||||
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine, track_running_stats)
|
||||
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine, track_running_stats)
|
||||
if stride == 2:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
|
||||
elif inplanes != planes:
|
||||
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
|
||||
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine, track_running_stats)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.in_dim = inplanes
|
||||
|
Reference in New Issue
Block a user