update NAS-Bench
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
####################
|
||||
# DARTS, ICLR 2019 #
|
||||
# DARTS, ICLR 2019 #
|
||||
####################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -11,7 +11,8 @@ from .search_cells import NASNetSearchCell as SearchCell
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkDARTS(nn.Module):
|
||||
|
||||
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
|
||||
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int,
|
||||
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
|
||||
super(NASNetworkDARTS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
|
Reference in New Issue
Block a user