update NAS-Bench-102 baselines
This commit is contained in:
@@ -20,7 +20,10 @@ def get_cell_based_tiny_net(config):
|
||||
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
|
||||
if super_type == 'basic' and config.name in group_names:
|
||||
from .cell_searchs import nas_super_nets
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
try:
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats)
|
||||
except:
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
elif super_type == 'l2s-base' and config.name in group_names:
|
||||
from .l2s_cell_searchs import nas_super_nets
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space \
|
||||
|
@@ -11,7 +11,8 @@ from .genotypes import Structure
|
||||
|
||||
class TinyNetworkGDAS(nn.Module):
|
||||
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
|
||||
#def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
|
||||
super(TinyNetworkGDAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
|
@@ -13,7 +13,7 @@ from .genotypes import Structure
|
||||
|
||||
class TinyNetworkSETN(nn.Module):
|
||||
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
|
||||
super(TinyNetworkSETN, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
@@ -31,7 +31,7 @@ class TinyNetworkSETN(nn.Module):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
|
||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||
self.cells.append( cell )
|
||||
|
Reference in New Issue
Block a user