Update docs of NATS-Bench

This commit is contained in:
D-X-Y
2020-09-16 09:04:22 +00:00
parent 9db28392c2
commit 7052265501
14 changed files with 99 additions and 95 deletions

View File

@@ -15,9 +15,9 @@ from .api_utils import pickle_load
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names
from .api_utils import PICKLE_EXT
PICKLE_EXT = 'pickle.pbz2'
ALL_BASE_NAMES = ['NATS-sss-v1_0-50262']
@@ -58,6 +58,7 @@ class NATSsize(NASBenchMetaAPI):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True):
self.ALL_BASE_NAMES = ALL_BASE_NAMES
self.filename = None
self._search_space_name = 'size'
self._fast_mode = fast_mode
@@ -120,39 +121,6 @@ class NATSsize(NASBenchMetaAPI):
print('{:} Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
if not os.path.isdir(archive_root):
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
archive_root = self.archive_dir
if archive_root is None or not os.path.isdir(archive_root):
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
if not os.path.isfile(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = pickle_load(xfile_path)
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
self.evaluated_indexes.add(idx)
hp2archres = OrderedDict()
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key)
self.arch2infos_dict[idx] = hp2archres
def query_info_str_by_arch(self, arch, hp: Text='12'):
""" This function is used to query the information of a specific architecture
'arch' can be an architecture index or an architecture string

View File

@@ -16,9 +16,9 @@ from .api_utils import pickle_load
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names
from .api_utils import PICKLE_EXT
PICKLE_EXT = 'pickle.pbz2'
ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9']
@@ -55,6 +55,7 @@ class NATStopology(NASBenchMetaAPI):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True):
self.ALL_BASE_NAMES = ALL_BASE_NAMES
self.filename = None
self._search_space_name = 'topology'
self._fast_mode = fast_mode
@@ -117,39 +118,6 @@ class NATStopology(NASBenchMetaAPI):
print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures avaliable.'.format(
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
if not os.path.isdir(archive_root):
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
archive_root = self.archive_dir
if archive_root is None or not os.path.isdir(archive_root):
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
if not os.path.isfile(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = pickle_load(xfile_path)
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
self.evaluated_indexes.add(idx)
hp2archres = OrderedDict()
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key)
self.arch2infos_dict[idx] = hp2archres
def query_info_str_by_arch(self, arch, hp: Text='12'):
""" This function is used to query the information of a specific architecture
'arch' can be an architecture index or an architecture string

View File

@@ -17,6 +17,9 @@ from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
PICKLE_EXT = 'pickle.pbz2'
def pickle_save(obj, file_path, ext='.pbz2', protocol=4):
"""Use pickle to save data (obj) into file_path.
According to https://docs.python.org/3/library/pickle.html#data-stream-format, Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. It is the default protocol starting with Python 3.8.
@@ -132,6 +135,41 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
"""Return a random index of all architectures."""
return random.randint(0, len(self.meta_archs)-1)
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space,
where the data will be loaded from 'archive_root'.
If archive_root is None, it will try to load from the default path os.environ['TORCH_HOME'] / 'BASE_NAME'-full.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(self.ALL_BASE_NAMES[-1]))
if not os.path.isdir(archive_root):
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
archive_root = self.archive_dir
if archive_root is None or not os.path.isdir(archive_root):
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
if not os.path.isfile(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = pickle_load(xfile_path)
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
self.evaluated_indexes.add(idx)
hp2archres = OrderedDict()
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key)
self.arch2infos_dict[idx] = hp2archres
def query_index_by_arch(self, arch):
""" This function is used to query the index of an architecture in the search space.
In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|';
@@ -176,12 +214,6 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
if self.verbose:
print('{:} Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(time_string(), index))
@abc.abstractmethod
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
If index is None, overwrite all ckps.
"""
def clear_params(self, index: int, hp: Optional[Text]=None):
"""Remove the architecture's weights to save memory.
:arg