Update docs of NATS-Bench
This commit is contained in:
@@ -17,6 +17,9 @@ from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
|
||||
PICKLE_EXT = 'pickle.pbz2'
|
||||
|
||||
|
||||
def pickle_save(obj, file_path, ext='.pbz2', protocol=4):
|
||||
"""Use pickle to save data (obj) into file_path.
|
||||
According to https://docs.python.org/3/library/pickle.html#data-stream-format, Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. It is the default protocol starting with Python 3.8.
|
||||
@@ -132,6 +135,41 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
"""Return a random index of all architectures."""
|
||||
return random.randint(0, len(self.meta_archs)-1)
|
||||
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space,
|
||||
where the data will be loaded from 'archive_root'.
|
||||
If archive_root is None, it will try to load from the default path os.environ['TORCH_HOME'] / 'BASE_NAME'-full.
|
||||
If index is None, overwrite all ckps.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
|
||||
time_string(), archive_root, index))
|
||||
if archive_root is None:
|
||||
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(self.ALL_BASE_NAMES[-1]))
|
||||
if not os.path.isdir(archive_root):
|
||||
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
|
||||
archive_root = self.archive_dir
|
||||
if archive_root is None or not os.path.isdir(archive_root):
|
||||
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
|
||||
if index is None:
|
||||
indexes = list(range(len(self)))
|
||||
else:
|
||||
indexes = [index]
|
||||
for idx in indexes:
|
||||
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
|
||||
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
|
||||
if not os.path.isfile(xfile_path):
|
||||
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
|
||||
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
||||
xdata = pickle_load(xfile_path)
|
||||
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
|
||||
self.evaluated_indexes.add(idx)
|
||||
hp2archres = OrderedDict()
|
||||
for hp_key, results in xdata.items():
|
||||
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
||||
self._avaliable_hps.add(hp_key)
|
||||
self.arch2infos_dict[idx] = hp2archres
|
||||
|
||||
def query_index_by_arch(self, arch):
|
||||
""" This function is used to query the index of an architecture in the search space.
|
||||
In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|';
|
||||
@@ -176,12 +214,6 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
if self.verbose:
|
||||
print('{:} Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(time_string(), index))
|
||||
|
||||
@abc.abstractmethod
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
|
||||
If index is None, overwrite all ckps.
|
||||
"""
|
||||
|
||||
def clear_params(self, index: int, hp: Optional[Text]=None):
|
||||
"""Remove the architecture's weights to save memory.
|
||||
:arg
|
||||
|
Reference in New Issue
Block a user