Update NATS-Bench (tss version 0.99)

This commit is contained in:
D-X-Y
2020-09-05 10:40:29 +00:00
parent 8d64afd4a3
commit bd9288f45d
9 changed files with 379 additions and 56 deletions

View File

@@ -10,6 +10,7 @@ import os, copy, random, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
from .api_utils import time_string
from .api_utils import pickle_load
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
@@ -71,7 +72,7 @@ class NATSsize(NASBenchMetaAPI):
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose:
print('Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(file_path_or_dict, fast_mode))
print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
self.filename = Path(file_path_or_dict).name
@@ -116,14 +117,15 @@ class NATSsize(NASBenchMetaAPI):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[arch] = idx
if self.verbose:
print('Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(len(self.evaluated_indexes), len(self.meta_archs)))
print('{:} Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('Call clear_params with archive_root={:} and index={:}'.format(archive_root, index))
print('{:} Call clear_params with archive_root={:} and index={:}'.format(time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
@@ -155,7 +157,7 @@ class NATSsize(NASBenchMetaAPI):
The difference between these three configurations are the number of training epochs.
"""
if self.verbose:
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
@@ -177,7 +179,8 @@ class NATSsize(NASBenchMetaAPI):
When is_random=False, the performanceo of all trials will be averaged.
"""
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(
time_string(), index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index)
if index not in self.arch2infos_dict:

View File

@@ -10,6 +10,8 @@ import os, copy, random, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
import warnings
from .api_utils import time_string
from .api_utils import pickle_load
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
@@ -60,58 +62,89 @@ class NATStopology(NASBenchMetaAPI):
self.reset_time()
if file_path_or_dict is None:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
print ('Try to use the default NATS-Bench (topology) path from {:}.'.format(file_path_or_dict))
print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict))
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose: print('try to create the NATS-Bench (topology) api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
if verbose:
print('{:} Try to create the NATS-Bench (topology) api from {:}'.format(time_string(), file_path_or_dict))
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
self.filename = Path(file_path_or_dict).name
file_path_or_dict = np.load(file_path_or_dict)
if fast_mode:
if os.path.isfile(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict))
else:
self._archive_dir = file_path_or_dict
else:
if os.path.isdir(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict))
else:
file_path_or_dict = pickle_load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy(file_path_or_dict)
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set(['12', '200'])
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_info = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
# self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] )
# self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] )
hp2archres['12'] = ArchResults.create_from_state_dict(all_info['less'])
hp2archres['200'] = ArchResults.create_from_state_dict(all_info['full'])
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
if isinstance(file_path_or_dict, dict):
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_info = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = list(file_path_or_dict['evaluated_indexes'])
elif self.archive_dir is not None:
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT))
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
self.evaluated_indexes = set()
else:
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict)))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx
self.archstr2index[arch] = idx
if self.verbose:
print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures avaliable.'.format(
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space.
It will load its data from 'archive_root'.
"""
if self.verbose:
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
if not os.path.isdir(archive_root):
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
archive_root = self.archive_dir
if archive_root is None or not os.path.isdir(archive_root):
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
if not os.path.isfile(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path, map_location='cpu')
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
xdata = pickle_load(xfile_path)
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
self.evaluated_indexes.add(idx)
hp2archres = OrderedDict()
hp2archres['12'] = ArchResults.create_from_state_dict(xdata['less'])
hp2archres['200'] = ArchResults.create_from_state_dict(xdata['full'])
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key)
self.arch2infos_dict[idx] = hp2archres
def query_info_str_by_arch(self, arch, hp: Text='12'):
@@ -122,7 +155,7 @@ class NATStopology(NASBenchMetaAPI):
The difference between these three configurations are the number of training epochs.
"""
if self.verbose:
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
# obtain the metric for the `index`-th architecture
@@ -142,8 +175,10 @@ class NATStopology(NASBenchMetaAPI):
# When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(
time_string(), index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index)
if index not in self.arch2infos_dict:
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
archresult = self.arch2infos_dict[index][str(hp)]

View File

@@ -10,9 +10,9 @@
# History:
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
#
import os, abc, copy, random, numpy as np
import os, abc, time, copy, random, numpy as np
import bz2, pickle
import importlib, warnings
import warnings
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
@@ -36,6 +36,12 @@ def pickle_load(file_path, ext='.pbz2'):
return pickle.load(cfile)
def time_string():
ISOTIMEFORMAT='%Y-%m-%d %X'
string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
"""re-map the metric_on_set to internal keys"""
if verbose:
@@ -136,7 +142,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
"""
if self.verbose:
print('Call query_index_by_arch with arch={:}'.format(arch))
print('{:} Call query_index_by_arch with arch={:}'.format(time_string(), arch))
if isinstance(arch, int):
if 0 <= arch < len(self):
return arch
@@ -162,13 +168,13 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
self.reload(self.archive_dir, index)
elif not self.fast_mode:
if self.verbose:
print('Call _prepare_info with index={:} skip because it is not the fast mode.'.format(index))
print('{:} Call _prepare_info with index={:} skip because it is not the fast mode.'.format(time_string(), index))
else:
raise ValueError('Invalid status: fast_mode={:} and archive_dir={:}'.format(self.fast_mode, self.archive_dir))
else:
assert index in self.evaluated_indexes, 'The index of {:} is not in self.evaluated_indexes, there must be something wrong.'.format(index)
if self.verbose:
print('Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(index))
print('{:} Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(time_string(), index))
@abc.abstractmethod
def reload(self, archive_root: Text = None, index: int = None):
@@ -185,7 +191,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
-- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp].
"""
if self.verbose:
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
print('{:} Call clear_params with index={:} and hp={:}'.format(time_string(), index, hp))
if index not in self.arch2infos_dict:
warnings.warn('The {:}-th architecture is not in the benchmark data yet, no need to clear params.'.format(index))
elif hp is None:
@@ -243,7 +249,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
"""
if self.verbose:
print('Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(arch_index, dataname, hp))
print('{:} Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(time_string(), arch_index, dataname, hp))
info = self.query_meta_info_by_index(arch_index, hp)
if dataname is None: return info
else:
@@ -254,7 +260,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, hp: Text = '12'):
"""Find the architecture with the highest accuracy based on some constraints."""
if self.verbose:
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
print('{:} Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(
time_string(), dataset, metric_on_set, hp, FLOP_max, Param_max))
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
best_index, highest_accuracy = -1, None
evaluated_indexes = sorted(list(self.evaluated_indexes))
@@ -287,7 +294,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
-- 200 : train the model by 200 epochs
"""
if self.verbose:
print('Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(index, dataset, seed, hp))
print('{:} Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(time_string(), index, dataset, seed, hp))
info = self.query_meta_info_by_index(index, hp)
return info.get_net_param(dataset, seed)
@@ -304,7 +311,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
config = api.get_net_config(128, 'cifar10')
"""
if self.verbose:
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
print('{:} Call the get_net_config function with index={:}, dataset={:}.'.format(time_string(), index, dataset))
self._prepare_info(index)
if index in self.arch2infos_dict:
info = self.arch2infos_dict[index]
@@ -318,7 +325,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def get_cost_info(self, index: int, dataset: Text, hp: Text = '12') -> Dict[Text, float]:
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
if self.verbose:
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
print('{:} Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(time_string(), index, dataset, hp))
self._prepare_info(index)
info = self.query_meta_info_by_index(index, hp)
return info.get_compute_costs(dataset)
@@ -331,7 +338,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
:return: return a float value in seconds
"""
if self.verbose:
print('Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
print('{:} Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(time_string(), index, dataset, hp))
cost_dict = self.get_cost_info(index, dataset, hp)
return cost_dict['latency']