update NAS-Bench-102 baselines / support track_running_stats
This commit is contained in:
@@ -11,7 +11,7 @@ from ..cell_operations import OPS
|
||||
|
||||
class SearchCell(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, max_nodes, op_names):
|
||||
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
|
||||
super(SearchCell, self).__init__()
|
||||
|
||||
self.op_names = deepcopy(op_names)
|
||||
@@ -23,9 +23,9 @@ class SearchCell(nn.Module):
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
if j == 0:
|
||||
xlists = [OPS[op_name](C_in , C_out, stride, False) for op_name in op_names]
|
||||
xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names]
|
||||
else:
|
||||
xlists = [OPS[op_name](C_in , C_out, 1, False) for op_name in op_names]
|
||||
xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names]
|
||||
self.edges[ node_str ] = nn.ModuleList( xlists )
|
||||
self.edge_keys = sorted(list(self.edges.keys()))
|
||||
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||
|
Reference in New Issue
Block a user