Update NATS-Bench (sss version 1.2)
This commit is contained in:
@@ -10,15 +10,30 @@
|
||||
# History:
|
||||
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
|
||||
#
|
||||
import abc, copy, random, numpy as np
|
||||
import os, abc, copy, random, numpy as np
|
||||
import bz2, pickle
|
||||
import importlib, warnings
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
USE_TORCH = importlib.find_loader('torch') is not None
|
||||
if USE_TORCH:
|
||||
import torch
|
||||
else:
|
||||
warnings.warn('Can not find PyTorch, and thus some features maybe invalid.')
|
||||
|
||||
|
||||
def pickle_save(obj, file_path, ext='.pbz2', protocol=4):
|
||||
"""Use pickle to save data (obj) into file_path.
|
||||
According to https://docs.python.org/3/library/pickle.html#data-stream-format, Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. It is the default protocol starting with Python 3.8.
|
||||
"""
|
||||
# with open(file_path, 'wb') as cfile:
|
||||
with bz2.BZ2File(str(file_path) + ext, 'wb') as cfile:
|
||||
pickle.dump(obj, cfile, protocol=protocol)
|
||||
|
||||
|
||||
def pickle_load(file_path, ext='.pbz2'):
|
||||
# return pickle.load(open(file_path, "rb"))
|
||||
if os.path.isfile(str(file_path)):
|
||||
xfile_path = str(file_path)
|
||||
else:
|
||||
xfile_path = str(file_path) + ext
|
||||
with bz2.BZ2File(xfile_path, 'rb') as cfile:
|
||||
return pickle.load(cfile)
|
||||
|
||||
|
||||
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
|
||||
@@ -60,7 +75,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
return len(self.meta_archs)
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
|
||||
return ('{name}({num}/{total} architectures, fast_mode={fast_mode}, file={filename})'.format(
|
||||
name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs),
|
||||
fast_mode=self.fast_mode, filename=self.filename))
|
||||
|
||||
@property
|
||||
def avaliable_hps(self):
|
||||
@@ -74,6 +91,20 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
def search_space_name(self):
|
||||
return self._search_space_name
|
||||
|
||||
@property
|
||||
def fast_mode(self):
|
||||
return self._fast_mode
|
||||
|
||||
@property
|
||||
def archive_dir(self):
|
||||
return self._archive_dir
|
||||
|
||||
def reset_archive_dir(self, archive_dir):
|
||||
self._archive_dir = archive_dir
|
||||
|
||||
def reset_fast_mode(self, fast_mode):
|
||||
self._fast_mode = fast_mode
|
||||
|
||||
def reset_time(self):
|
||||
self._used_time = 0
|
||||
|
||||
@@ -121,9 +152,24 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
return arch_index
|
||||
|
||||
def query_by_arch(self, arch, hp):
|
||||
# This is to make the current version be compatible with the old version.
|
||||
"""This is to make the current version be compatible with the old version."""
|
||||
return self.query_info_str_by_arch(arch, hp)
|
||||
|
||||
def _prepare_info(self, index):
|
||||
"""This is a function to load the data from disk when using fast mode."""
|
||||
if index not in self.arch2infos_dict:
|
||||
if self.fast_mode and self.archive_dir is not None:
|
||||
self.reload(self.archive_dir, index)
|
||||
elif not self.fast_mode:
|
||||
if self.verbose:
|
||||
print('Call _prepare_info with index={:} skip because it is not the fast mode.'.format(index))
|
||||
else:
|
||||
raise ValueError('Invalid status: fast_mode={:} and archive_dir={:}'.format(self.fast_mode, self.archive_dir))
|
||||
else:
|
||||
assert index in self.evaluated_indexes, 'The index of {:} is not in self.evaluated_indexes, there must be something wrong.'.format(index)
|
||||
if self.verbose:
|
||||
print('Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(index))
|
||||
|
||||
@abc.abstractmethod
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
|
||||
@@ -140,7 +186,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
|
||||
if hp is None:
|
||||
if index not in self.arch2infos_dict:
|
||||
warnings.warn('The {:}-th architecture is not in the benchmark data yet, no need to clear params.'.format(index))
|
||||
elif hp is None:
|
||||
for key, result in self.arch2infos_dict[index].items():
|
||||
result.clear_params()
|
||||
else:
|
||||
@@ -154,6 +202,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
|
||||
def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None):
|
||||
arch_index = self.query_index_by_arch(arch)
|
||||
self._prepare_info(arch_index)
|
||||
if arch_index in self.arch2infos_dict:
|
||||
if hp not in self.arch2infos_dict[arch_index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
||||
@@ -161,13 +210,14 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
strings = print_information(info, 'arch-index={:}'.format(arch_index))
|
||||
return '\n'.join(strings)
|
||||
else:
|
||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||
warnings.warn('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||
return None
|
||||
|
||||
def query_meta_info_by_index(self, arch_index, hp: Text = '12'):
|
||||
"""Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index."""
|
||||
if self.verbose:
|
||||
print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp))
|
||||
self._prepare_info(arch_index)
|
||||
if arch_index in self.arch2infos_dict:
|
||||
if hp not in self.arch2infos_dict[arch_index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
||||
@@ -207,7 +257,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
|
||||
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
|
||||
best_index, highest_accuracy = -1, None
|
||||
for i, arch_index in enumerate(self.evaluated_indexes):
|
||||
evaluated_indexes = sorted(list(self.evaluated_indexes))
|
||||
for i, arch_index in enumerate(evaluated_indexes):
|
||||
arch_info = self.arch2infos_dict[arch_index][hp]
|
||||
info = arch_info.get_compute_costs(dataset) # the information of costs
|
||||
flop, param, latency = info['flops'], info['params'], info['latency']
|
||||
@@ -254,10 +305,11 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
|
||||
self._prepare_info(index)
|
||||
if index in self.arch2infos_dict:
|
||||
info = self.arch2infos_dict[index]
|
||||
else:
|
||||
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(arch_index))
|
||||
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(index))
|
||||
info = next(iter(info.values()))
|
||||
results = info.query(dataset, None)
|
||||
results = next(iter(results.values()))
|
||||
@@ -267,6 +319,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
||||
if self.verbose:
|
||||
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
|
||||
self._prepare_info(index)
|
||||
info = self.query_meta_info_by_index(index, hp)
|
||||
return info.get_compute_costs(dataset)
|
||||
|
||||
@@ -296,8 +349,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
if index < 0: # show all architectures
|
||||
print(self)
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
|
||||
evaluated_indexes = sorted(list(self.evaluated_indexes))
|
||||
for i, idx in enumerate(evaluated_indexes):
|
||||
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(evaluated_indexes), idx) + '-'*10)
|
||||
print('arch : {:}'.format(self.meta_archs[idx]))
|
||||
for key, result in self.arch2infos_dict[index].items():
|
||||
strings = print_information(result)
|
||||
@@ -325,7 +379,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
if dataset not in valid_datasets:
|
||||
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
|
||||
nums, hp = defaultdict(lambda: 0), str(hp)
|
||||
for index in range(len(self)):
|
||||
# for index in range(len(self)):
|
||||
for index in self.evaluated_indexes:
|
||||
archInfo = self.arch2infos_dict[index][hp]
|
||||
dataset_seed = archInfo.dataset_seed
|
||||
if dataset not in dataset_seed:
|
||||
@@ -550,9 +605,7 @@ class ArchResults(object):
|
||||
def create_from_state_dict(state_dict_or_file):
|
||||
x = ArchResults(-1, -1)
|
||||
if isinstance(state_dict_or_file, str): # a file path
|
||||
if not USE_TORCH:
|
||||
raise ValueError('Since torch is not imported, this logic can not be used.')
|
||||
state_dict = torch.load(state_dict_or_file, map_location='cpu')
|
||||
state_dict = pickle_load(state_dict_or_file)
|
||||
elif isinstance(state_dict_or_file, dict):
|
||||
state_dict = state_dict_or_file
|
||||
else:
|
||||
|
Reference in New Issue
Block a user