102->201 / NAS->autoDL / more configs of TAS / reorganize docs / fix bugs in NAS baselines
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##################################################
|
||||
import os, sys, time, random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
@@ -19,7 +19,7 @@ def get_cell_based_tiny_net(config):
|
||||
super_type = getattr(config, 'super_type', 'basic')
|
||||
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
|
||||
if super_type == 'basic' and config.name in group_names:
|
||||
from .cell_searchs import nas102_super_nets as nas_super_nets
|
||||
from .cell_searchs import nas201_super_nets as nas_super_nets
|
||||
try:
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats)
|
||||
except:
|
||||
|
@@ -1,8 +1,13 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##################################################
|
||||
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
|
||||
# Cell for NAS-Bench-201
|
||||
class InferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_in, C_out, stride):
|
||||
|
@@ -1,9 +1,13 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .cells import InferCell
|
||||
|
||||
|
||||
# The macro structure for architectures in NAS-Bench-201
|
||||
class TinyNetwork(nn.Module):
|
||||
|
||||
def __init__(self, C, N, genotype, num_classes):
|
||||
|
@@ -21,12 +21,11 @@ OPS = {
|
||||
}
|
||||
|
||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
|
||||
|
||||
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
||||
'aa-nas' : NAS_BENCH_102,
|
||||
'nas-bench-102': NAS_BENCH_102,
|
||||
'nas-bench-201': NAS_BENCH_201,
|
||||
'darts' : DARTS_SPACE}
|
||||
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# The macro structure is defined in NAS-Bench-102
|
||||
# The macro structure is defined in NAS-Bench-201
|
||||
from .search_model_darts import TinyNetworkDarts
|
||||
from .search_model_gdas import TinyNetworkGDAS
|
||||
from .search_model_setn import TinyNetworkSETN
|
||||
@@ -12,7 +12,7 @@ from .genotypes import Structure as CellStructure, architectures as
|
||||
from .search_model_gdas_nasnet import NASNetworkGDAS
|
||||
|
||||
|
||||
nas102_super_nets = {'DARTS-V1': TinyNetworkDarts,
|
||||
nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
|
||||
'DARTS-V2': TinyNetworkDarts,
|
||||
'GDAS' : TinyNetworkGDAS,
|
||||
'SETN' : TinyNetworkSETN,
|
||||
|
@@ -9,11 +9,11 @@ from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
|
||||
# This module is used for NAS-Bench-102, represents a small search space with a complete DAG
|
||||
class NAS102SearchCell(nn.Module):
|
||||
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
|
||||
class NAS201SearchCell(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
|
||||
super(NAS102SearchCell, self).__init__()
|
||||
super(NAS201SearchCell, self).__init__()
|
||||
|
||||
self.op_names = deepcopy(op_names)
|
||||
self.edges = nn.ModuleDict()
|
||||
|
@@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS102SearchCell as SearchCell
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
|
@@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS102SearchCell as SearchCell
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
from .search_model_enas_utils import Controller
|
||||
|
||||
|
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS102SearchCell as SearchCell
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
|
@@ -7,7 +7,7 @@ import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS102SearchCell as SearchCell
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
|
@@ -7,7 +7,7 @@ import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS102SearchCell as SearchCell
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .api import NASBench102API
|
||||
from .api import NASBench201API
|
||||
from .api import ArchResults, ResultsCount
|
||||
|
||||
NAS_BENCH_102_API_VERSION="v1.0"
|
||||
NAS_BENCH_201_API_VERSION="v1.0"
|
@@ -1,9 +1,9 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
############################################################################################
|
||||
# NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||
############################################################################################
|
||||
# NAS-Bench-102-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
||||
# NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
||||
#
|
||||
#
|
||||
#
|
||||
@@ -38,11 +38,11 @@ def print_information(information, extra_info=None, show=False):
|
||||
return strings
|
||||
|
||||
|
||||
class NASBench102API(object):
|
||||
class NASBench201API(object):
|
||||
|
||||
def __init__(self, file_path_or_dict, verbose=True):
|
||||
if isinstance(file_path_or_dict, str):
|
||||
if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict))
|
||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
file_path_or_dict = torch.load(file_path_or_dict)
|
||||
elif isinstance(file_path_or_dict, dict):
|
@@ -2,3 +2,4 @@
|
||||
from .CifarNet import NetworkCIFAR as CifarNet
|
||||
from .ImageNet import NetworkImageNet as ImageNet
|
||||
from .genotypes import Networks
|
||||
from .genotypes import build_genotype_from_dict
|
||||
|
@@ -167,3 +167,6 @@ Networks = {'DARTS_V1': DARTS_V1,
|
||||
'PNASNet' : PNASNet,
|
||||
'SETN' : SETN,
|
||||
}
|
||||
|
||||
def build_genotype_from_dict(xdict):
|
||||
import pdb; pdb.set_trace()
|
||||
|
@@ -1,6 +1,11 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##################################################
|
||||
# I write this package to make AutoDL-Projects to be compatible with the old GDAS projects.
|
||||
# Ideally, this package will be merged into lib/models/cell_infers in future.
|
||||
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
|
||||
##################################################
|
||||
|
||||
import torch
|
||||
|
||||
def obtain_nas_infer_model(config):
|
||||
|
@@ -14,10 +14,10 @@ OPS = {
|
||||
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride)
|
||||
}
|
||||
|
||||
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
|
||||
SearchSpaceNames = {
|
||||
'nas-bench-102': NAS_BENCH_102,
|
||||
'nas-bench-201': NAS_BENCH_201,
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user