Update NATS-Bench (sss version 1.2)

This commit is contained in:
D-X-Y
2020-08-30 08:04:52 +00:00
parent 469a207945
commit 5f151d1970
15 changed files with 317 additions and 229 deletions

View File

@@ -1,25 +1,36 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
#####################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
#####################################################
#
#
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 ##########################
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
# The official Application Programming Interface (API) for NATS-Bench. #
##############################################################################
from .api_utils import pickle_save, pickle_load
from .api_utils import ArchResults, ResultsCount
from .api_topology import NATStopology
from .api_size import NATSsize
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.07.30]
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.28]
def version():
return NATS_BENCH_API_VERSIONs[-1]
def create(file_path_or_dict, search_space, verbose=True):
def create(file_path_or_dict, search_space, fast_mode=False, verbose=True):
"""Create the instead for NATS API.
Args:
file_path_or_dict: None or a file path or a directory path.
search_space: This is a string indicates the search space in NATS-Bench.
fast_mode: If True, we will not load all the data at initialization, instead, the data for each candidate architecture will be loaded when quering it;
If False, we will load all the data during initialization.
verbose: This is a flag to indicate whether log additional information.
"""
if search_space in ['tss', 'topology']:
return NATStopology(file_path_or_dict, verbose)
return NATStopology(file_path_or_dict, fast_mode, verbose)
elif search_space in ['sss', 'size']:
return NATSsize(file_path_or_dict, verbose)
return NATSsize(file_path_or_dict, fast_mode, verbose)
else:
raise ValueError('invalid search space : {:}'.format(search_space))

View File

@@ -1,21 +1,23 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
############################################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
############################################################################################
# The history of benchmark files:
#
import os, copy, random, torch, numpy as np
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
#####################################################################################
# The history of benchmark files (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
# [2020.08.28] NATS-tss-v1_0-50262.pickle.pbz2 #
#####################################################################################
import os, copy, random, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
from .api_utils import pickle_load
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names
ALL_BENCHMARK_FILES = ['NAS-Bench-301-v1_0-363be7.pth']
ALL_ARCHIVE_DIRS = ['NAS-Bench-301-v1_0-archive']
PICKLE_EXT = 'pickle.pbz2'
ALL_BASE_NAMES = ['NATS-tss-v1_0-50262']
def print_information(information, extra_info=None, show=False):
@@ -54,42 +56,65 @@ This is the class for the API of size search space in NATS-Bench.
class NATSsize(NASBenchMetaAPI):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True):
self.filename = None
self._search_space_name = 'size'
self._fast_mode = fast_mode
self._archive_dir = None
self.reset_time()
if file_path_or_dict is None:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
print ('Try to use the default NATS-Bench (size) path from {:}.'.format(file_path_or_dict))
if self._fast_mode:
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
print ('Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(self._fast_mode, file_path_or_dict))
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose: print('try to create the NATS-Bench (size) api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
if verbose:
print('Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(file_path_or_dict, fast_mode))
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
self.filename = Path(file_path_or_dict).name
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
if fast_mode:
if os.path.isfile(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict))
else:
self._archive_dir = file_path_or_dict
else:
if os.path.isdir(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict))
else:
file_path_or_dict = pickle_load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy( file_path_or_dict )
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
file_path_or_dict = copy.deepcopy(file_path_or_dict)
self.verbose = verbose
if isinstance(file_path_or_dict, dict):
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
elif self.archive_dir is not None:
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT))
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
self.evaluated_indexes = set()
else:
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict)))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx
self.archstr2index[arch] = idx
if self.verbose:
print('Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(len(self.evaluated_indexes), len(self.meta_archs)))
@@ -100,7 +125,7 @@ class NATSsize(NASBenchMetaAPI):
if self.verbose:
print('Call clear_params with archive_root={:} and index={:}'.format(archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
if index is None:
indexes = list(range(len(self)))
@@ -108,16 +133,17 @@ class NATSsize(NASBenchMetaAPI):
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
if not os.path.isfile(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}-FULL.pth'.format(idx))
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path, map_location='cpu')
xdata = pickle_load(xfile_path)
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
self.evaluated_indexes.add(idx)
hp2archres = OrderedDict()
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key)
self.arch2infos_dict[idx] = hp2archres
def query_info_str_by_arch(self, arch, hp: Text='12'):
@@ -153,6 +179,7 @@ class NATSsize(NASBenchMetaAPI):
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index)
if index not in self.arch2infos_dict:
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
archresult = self.arch2infos_dict[index][str(hp)]

View File

@@ -3,7 +3,7 @@
############################################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
############################################################################################
import os, copy, random, torch, numpy as np
import os, copy, random, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
@@ -62,7 +62,7 @@ class NATStopology(NASBenchMetaAPI):
if verbose: print('try to create the NATS-Bench (topology) api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
self.filename = Path(file_path_or_dict).name
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
file_path_or_dict = np.load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy(file_path_or_dict)
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))

View File

@@ -10,15 +10,30 @@
# History:
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
#
import abc, copy, random, numpy as np
import os, abc, copy, random, numpy as np
import bz2, pickle
import importlib, warnings
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
USE_TORCH = importlib.find_loader('torch') is not None
if USE_TORCH:
import torch
else:
warnings.warn('Can not find PyTorch, and thus some features maybe invalid.')
def pickle_save(obj, file_path, ext='.pbz2', protocol=4):
"""Use pickle to save data (obj) into file_path.
According to https://docs.python.org/3/library/pickle.html#data-stream-format, Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. It is the default protocol starting with Python 3.8.
"""
# with open(file_path, 'wb') as cfile:
with bz2.BZ2File(str(file_path) + ext, 'wb') as cfile:
pickle.dump(obj, cfile, protocol=protocol)
def pickle_load(file_path, ext='.pbz2'):
# return pickle.load(open(file_path, "rb"))
if os.path.isfile(str(file_path)):
xfile_path = str(file_path)
else:
xfile_path = str(file_path) + ext
with bz2.BZ2File(xfile_path, 'rb') as cfile:
return pickle.load(cfile)
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
@@ -60,7 +75,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
return len(self.meta_archs)
def __repr__(self):
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
return ('{name}({num}/{total} architectures, fast_mode={fast_mode}, file={filename})'.format(
name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs),
fast_mode=self.fast_mode, filename=self.filename))
@property
def avaliable_hps(self):
@@ -74,6 +91,20 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def search_space_name(self):
return self._search_space_name
@property
def fast_mode(self):
return self._fast_mode
@property
def archive_dir(self):
return self._archive_dir
def reset_archive_dir(self, archive_dir):
self._archive_dir = archive_dir
def reset_fast_mode(self, fast_mode):
self._fast_mode = fast_mode
def reset_time(self):
self._used_time = 0
@@ -121,9 +152,24 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
return arch_index
def query_by_arch(self, arch, hp):
# This is to make the current version be compatible with the old version.
"""This is to make the current version be compatible with the old version."""
return self.query_info_str_by_arch(arch, hp)
def _prepare_info(self, index):
"""This is a function to load the data from disk when using fast mode."""
if index not in self.arch2infos_dict:
if self.fast_mode and self.archive_dir is not None:
self.reload(self.archive_dir, index)
elif not self.fast_mode:
if self.verbose:
print('Call _prepare_info with index={:} skip because it is not the fast mode.'.format(index))
else:
raise ValueError('Invalid status: fast_mode={:} and archive_dir={:}'.format(self.fast_mode, self.archive_dir))
else:
assert index in self.evaluated_indexes, 'The index of {:} is not in self.evaluated_indexes, there must be something wrong.'.format(index)
if self.verbose:
print('Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(index))
@abc.abstractmethod
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
@@ -140,7 +186,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
"""
if self.verbose:
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
if hp is None:
if index not in self.arch2infos_dict:
warnings.warn('The {:}-th architecture is not in the benchmark data yet, no need to clear params.'.format(index))
elif hp is None:
for key, result in self.arch2infos_dict[index].items():
result.clear_params()
else:
@@ -154,6 +202,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None):
arch_index = self.query_index_by_arch(arch)
self._prepare_info(arch_index)
if arch_index in self.arch2infos_dict:
if hp not in self.arch2infos_dict[arch_index]:
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp))
@@ -161,13 +210,14 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
strings = print_information(info, 'arch-index={:}'.format(arch_index))
return '\n'.join(strings)
else:
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
warnings.warn('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
return None
def query_meta_info_by_index(self, arch_index, hp: Text = '12'):
"""Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index."""
if self.verbose:
print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp))
self._prepare_info(arch_index)
if arch_index in self.arch2infos_dict:
if hp not in self.arch2infos_dict[arch_index]:
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp))
@@ -207,7 +257,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
best_index, highest_accuracy = -1, None
for i, arch_index in enumerate(self.evaluated_indexes):
evaluated_indexes = sorted(list(self.evaluated_indexes))
for i, arch_index in enumerate(evaluated_indexes):
arch_info = self.arch2infos_dict[arch_index][hp]
info = arch_info.get_compute_costs(dataset) # the information of costs
flop, param, latency = info['flops'], info['params'], info['latency']
@@ -254,10 +305,11 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
"""
if self.verbose:
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
self._prepare_info(index)
if index in self.arch2infos_dict:
info = self.arch2infos_dict[index]
else:
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(arch_index))
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(index))
info = next(iter(info.values()))
results = info.query(dataset, None)
results = next(iter(results.values()))
@@ -267,6 +319,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
if self.verbose:
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
self._prepare_info(index)
info = self.query_meta_info_by_index(index, hp)
return info.get_compute_costs(dataset)
@@ -296,8 +349,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
"""
if index < 0: # show all architectures
print(self)
for i, idx in enumerate(self.evaluated_indexes):
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
evaluated_indexes = sorted(list(self.evaluated_indexes))
for i, idx in enumerate(evaluated_indexes):
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(evaluated_indexes), idx) + '-'*10)
print('arch : {:}'.format(self.meta_archs[idx]))
for key, result in self.arch2infos_dict[index].items():
strings = print_information(result)
@@ -325,7 +379,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
if dataset not in valid_datasets:
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
nums, hp = defaultdict(lambda: 0), str(hp)
for index in range(len(self)):
# for index in range(len(self)):
for index in self.evaluated_indexes:
archInfo = self.arch2infos_dict[index][hp]
dataset_seed = archInfo.dataset_seed
if dataset not in dataset_seed:
@@ -550,9 +605,7 @@ class ArchResults(object):
def create_from_state_dict(state_dict_or_file):
x = ArchResults(-1, -1)
if isinstance(state_dict_or_file, str): # a file path
if not USE_TORCH:
raise ValueError('Since torch is not imported, this logic can not be used.')
state_dict = torch.load(state_dict_or_file, map_location='cpu')
state_dict = pickle_load(state_dict_or_file)
elif isinstance(state_dict_or_file, dict):
state_dict = state_dict_or_file
else: