update NAS-Bench-102 baselines / support track_running_stats

This commit is contained in:
D-X-Y
2019-12-23 13:32:20 +11:00
parent 729ce136db
commit 2dc8dce6d3
9 changed files with 56 additions and 57 deletions

View File

@@ -11,7 +11,7 @@ from ..cell_operations import OPS
class SearchCell(nn.Module):
def __init__(self, C_in, C_out, stride, max_nodes, op_names):
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
super(SearchCell, self).__init__()
self.op_names = deepcopy(op_names)
@@ -23,9 +23,9 @@ class SearchCell(nn.Module):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if j == 0:
xlists = [OPS[op_name](C_in , C_out, stride, False) for op_name in op_names]
xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names]
else:
xlists = [OPS[op_name](C_in , C_out, 1, False) for op_name in op_names]
xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names]
self.edges[ node_str ] = nn.ModuleList( xlists )
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}