simplify DARTS codes and update affine/track

This commit is contained in:
D-X-Y
2020-01-11 18:46:31 +11:00
parent c66afa4df8
commit 654015bf9d
15 changed files with 30 additions and 110 deletions

View File

@@ -14,7 +14,7 @@ from .search_model_enas_utils import Controller
class TinyNetworkENAS(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, search_space):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkENAS, self).__init__()
self._C = C
self._layerN = N
@@ -32,7 +32,7 @@ class TinyNetworkENAS(nn.Module):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )