102->201 / NAS->autoDL / more configs of TAS / reorganize docs / fix bugs in NAS baselines

This commit is contained in:
D-X-Y
2020-01-15 00:52:06 +11:00
parent 33384a78af
commit bb2f405961
62 changed files with 789 additions and 412 deletions

View File

@@ -19,7 +19,7 @@ def get_cell_based_tiny_net(config):
super_type = getattr(config, 'super_type', 'basic')
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
if super_type == 'basic' and config.name in group_names:
from .cell_searchs import nas102_super_nets as nas_super_nets
from .cell_searchs import nas201_super_nets as nas_super_nets
try:
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats)
except:

View File

@@ -1,8 +1,13 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##################################################
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import OPS
# Cell for NAS-Bench-201
class InferCell(nn.Module):
def __init__(self, genotype, C_in, C_out, stride):

View File

@@ -1,9 +1,13 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##################################################
import torch
import torch.nn as nn
from ..cell_operations import ResNetBasicblock
from .cells import InferCell
# The macro structure for architectures in NAS-Bench-201
class TinyNetwork(nn.Module):
def __init__(self, C, N, genotype, num_classes):

View File

@@ -21,12 +21,11 @@ OPS = {
}
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'aa-nas' : NAS_BENCH_102,
'nas-bench-102': NAS_BENCH_102,
'nas-bench-201': NAS_BENCH_201,
'darts' : DARTS_SPACE}

View File

@@ -1,7 +1,7 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# The macro structure is defined in NAS-Bench-102
# The macro structure is defined in NAS-Bench-201
from .search_model_darts import TinyNetworkDarts
from .search_model_gdas import TinyNetworkGDAS
from .search_model_setn import TinyNetworkSETN
@@ -12,7 +12,7 @@ from .genotypes import Structure as CellStructure, architectures as
from .search_model_gdas_nasnet import NASNetworkGDAS
nas102_super_nets = {'DARTS-V1': TinyNetworkDarts,
nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
'DARTS-V2': TinyNetworkDarts,
'GDAS' : TinyNetworkGDAS,
'SETN' : TinyNetworkSETN,

View File

@@ -9,11 +9,11 @@ from copy import deepcopy
from ..cell_operations import OPS
# This module is used for NAS-Bench-102, represents a small search space with a complete DAG
class NAS102SearchCell(nn.Module):
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
class NAS201SearchCell(nn.Module):
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
super(NAS102SearchCell, self).__init__()
super(NAS201SearchCell, self).__init__()
self.op_names = deepcopy(op_names)
self.edges = nn.ModuleDict()

View File

@@ -7,7 +7,7 @@ import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS102SearchCell as SearchCell
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure

View File

@@ -7,7 +7,7 @@ import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS102SearchCell as SearchCell
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
from .search_model_enas_utils import Controller

View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS102SearchCell as SearchCell
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure

View File

@@ -7,7 +7,7 @@ import torch, random
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS102SearchCell as SearchCell
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure

View File

@@ -7,7 +7,7 @@ import torch, random
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS102SearchCell as SearchCell
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure