update NAS-Bench

This commit is contained in:
D-X-Y
2020-03-09 19:38:00 +11:00
parent 9a83814a46
commit e59eb804cb
35 changed files with 693 additions and 64 deletions

View File

@@ -1,5 +1,5 @@
####################
# DARTS, ICLR 2019 #
# DARTS, ICLR 2019 #
####################
import torch
import torch.nn as nn
@@ -11,7 +11,8 @@ from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkDARTS(nn.Module):
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int,
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
super(NASNetworkDARTS, self).__init__()
self._C = C
self._layerN = N