Update NATS-Bench (sss version 1.3)

This commit is contained in:
D-X-Y
2020-08-30 09:25:45 +00:00
parent 5f151d1970
commit e04808c14e
5 changed files with 89 additions and 25 deletions

View File

@@ -1,5 +1,5 @@
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 ##########################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ##########################
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
@@ -11,7 +11,7 @@ from .api_topology import NATStopology
from .api_size import NATSsize
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.28]
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.31]
def version():

View File

@@ -1,10 +1,10 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
#####################################################################################
# The history of benchmark files (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
# [2020.08.28] NATS-tss-v1_0-50262.pickle.pbz2 #
# The history of benchmark files (the name is NATS-sss-[version]-[md5].pickle.pbz2) #
# [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 #
#####################################################################################
import os, copy, random, numpy as np
from pathlib import Path
@@ -17,7 +17,7 @@ from .api_utils import remap_dataset_set_names
PICKLE_EXT = 'pickle.pbz2'
ALL_BASE_NAMES = ['NATS-tss-v1_0-50262']
ALL_BASE_NAMES = ['NATS-sss-v1_0-50262']
def print_information(information, extra_info=None, show=False):

View File

@@ -1,20 +1,23 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
############################################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
############################################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
#####################################################################################
# The history of benchmark files (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
# [2020.08.31] #
#####################################################################################
import os, copy, random, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
from .api_utils import pickle_load
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names
ALL_BENCHMARK_FILES = ['NAS-Bench-201-v1_0-e61699.pth', 'NAS-Bench-201-v1_1-096897.pth']
ALL_ARCHIVE_DIRS = ['NAS-Bench-201-v1_1-archive']
PICKLE_EXT = 'pickle.pbz2'
ALL_BASE_NAMES = ['NATS-tss-v1_0-xxxxx']
def print_information(information, extra_info=None, show=False):
@@ -49,10 +52,11 @@ This is the class for the API of topology search space in NATS-Bench.
class NATStopology(NASBenchMetaAPI):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None,
verbose: bool=True):
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True):
self.filename = None
self._search_space_name = 'topology'
self._fast_mode = fast_mode
self._archive_dir = None
self.reset_time()
if file_path_or_dict is None:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])