Update NATS-Bench (tss version 0.99)

This commit is contained in:
D-X-Y
2020-09-05 10:40:29 +00:00
parent 8d64afd4a3
commit bd9288f45d
9 changed files with 379 additions and 56 deletions

View File

@@ -10,9 +10,9 @@
# History:
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
#
import os, abc, copy, random, numpy as np
import os, abc, time, copy, random, numpy as np
import bz2, pickle
import importlib, warnings
import warnings
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
@@ -36,6 +36,12 @@ def pickle_load(file_path, ext='.pbz2'):
return pickle.load(cfile)
def time_string():
ISOTIMEFORMAT='%Y-%m-%d %X'
string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
"""re-map the metric_on_set to internal keys"""
if verbose:
@@ -136,7 +142,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
"""
if self.verbose:
print('Call query_index_by_arch with arch={:}'.format(arch))
print('{:} Call query_index_by_arch with arch={:}'.format(time_string(), arch))
if isinstance(arch, int):
if 0 <= arch < len(self):
return arch
@@ -162,13 +168,13 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
self.reload(self.archive_dir, index)
elif not self.fast_mode:
if self.verbose:
print('Call _prepare_info with index={:} skip because it is not the fast mode.'.format(index))
print('{:} Call _prepare_info with index={:} skip because it is not the fast mode.'.format(time_string(), index))
else:
raise ValueError('Invalid status: fast_mode={:} and archive_dir={:}'.format(self.fast_mode, self.archive_dir))
else:
assert index in self.evaluated_indexes, 'The index of {:} is not in self.evaluated_indexes, there must be something wrong.'.format(index)
if self.verbose:
print('Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(index))
print('{:} Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(time_string(), index))
@abc.abstractmethod
def reload(self, archive_root: Text = None, index: int = None):
@@ -185,7 +191,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
-- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp].
"""
if self.verbose:
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
print('{:} Call clear_params with index={:} and hp={:}'.format(time_string(), index, hp))
if index not in self.arch2infos_dict:
warnings.warn('The {:}-th architecture is not in the benchmark data yet, no need to clear params.'.format(index))
elif hp is None:
@@ -243,7 +249,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
"""
if self.verbose:
print('Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(arch_index, dataname, hp))
print('{:} Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(time_string(), arch_index, dataname, hp))
info = self.query_meta_info_by_index(arch_index, hp)
if dataname is None: return info
else:
@@ -254,7 +260,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, hp: Text = '12'):
"""Find the architecture with the highest accuracy based on some constraints."""
if self.verbose:
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
print('{:} Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(
time_string(), dataset, metric_on_set, hp, FLOP_max, Param_max))
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
best_index, highest_accuracy = -1, None
evaluated_indexes = sorted(list(self.evaluated_indexes))
@@ -287,7 +294,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
-- 200 : train the model by 200 epochs
"""
if self.verbose:
print('Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(index, dataset, seed, hp))
print('{:} Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(time_string(), index, dataset, seed, hp))
info = self.query_meta_info_by_index(index, hp)
return info.get_net_param(dataset, seed)
@@ -304,7 +311,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
config = api.get_net_config(128, 'cifar10')
"""
if self.verbose:
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
print('{:} Call the get_net_config function with index={:}, dataset={:}.'.format(time_string(), index, dataset))
self._prepare_info(index)
if index in self.arch2infos_dict:
info = self.arch2infos_dict[index]
@@ -318,7 +325,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def get_cost_info(self, index: int, dataset: Text, hp: Text = '12') -> Dict[Text, float]:
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
if self.verbose:
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
print('{:} Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(time_string(), index, dataset, hp))
self._prepare_info(index)
info = self.query_meta_info_by_index(index, hp)
return info.get_compute_costs(dataset)
@@ -331,7 +338,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
:return: return a float value in seconds
"""
if self.verbose:
print('Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
print('{:} Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(time_string(), index, dataset, hp))
cost_dict = self.get_cost_info(index, dataset, hp)
return cost_dict['latency']