Update NATS-Bench (tss version 1.0) and remove the trace of 301
This commit is contained in:
@@ -68,7 +68,7 @@ class NATSsize(NASBenchMetaAPI):
|
||||
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
else:
|
||||
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
|
||||
print ('Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(self._fast_mode, file_path_or_dict))
|
||||
print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict))
|
||||
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
||||
file_path_or_dict = str(file_path_or_dict)
|
||||
if verbose:
|
||||
@@ -125,10 +125,15 @@ class NATSsize(NASBenchMetaAPI):
|
||||
If index is None, overwrite all ckps.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('{:} Call clear_params with archive_root={:} and index={:}'.format(time_string(), archive_root, index))
|
||||
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
|
||||
time_string(), archive_root, index))
|
||||
if archive_root is None:
|
||||
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
|
||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||
if not os.path.isdir(archive_root):
|
||||
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
|
||||
archive_root = self.archive_dir
|
||||
if archive_root is None or not os.path.isdir(archive_root):
|
||||
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
|
||||
if index is None:
|
||||
indexes = list(range(len(self)))
|
||||
else:
|
||||
|
@@ -4,7 +4,7 @@
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
#####################################################################################
|
||||
# The history of benchmark files (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
|
||||
# [2020.08.31] #
|
||||
# [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 #
|
||||
#####################################################################################
|
||||
import os, copy, random, numpy as np
|
||||
from pathlib import Path
|
||||
@@ -19,14 +19,14 @@ from .api_utils import remap_dataset_set_names
|
||||
|
||||
|
||||
PICKLE_EXT = 'pickle.pbz2'
|
||||
ALL_BASE_NAMES = ['NATS-tss-v1_0-xxxxx']
|
||||
ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9']
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
dataset_names = information.get_dataset_names()
|
||||
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
|
||||
def metric2str(loss, acc):
|
||||
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
|
||||
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
|
||||
|
||||
for ida, dataset in enumerate(dataset_names):
|
||||
metric = information.get_compute_costs(dataset)
|
||||
@@ -61,12 +61,15 @@ class NATStopology(NASBenchMetaAPI):
|
||||
self._archive_dir = None
|
||||
self.reset_time()
|
||||
if file_path_or_dict is None:
|
||||
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
|
||||
if self._fast_mode:
|
||||
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
else:
|
||||
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
|
||||
print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict))
|
||||
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
||||
file_path_or_dict = str(file_path_or_dict)
|
||||
if verbose:
|
||||
print('{:} Try to create the NATS-Bench (topology) api from {:}'.format(time_string(), file_path_or_dict))
|
||||
print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
|
||||
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict):
|
||||
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
|
||||
self.filename = Path(file_path_or_dict).name
|
||||
@@ -82,7 +85,7 @@ class NATStopology(NASBenchMetaAPI):
|
||||
file_path_or_dict = pickle_load(file_path_or_dict)
|
||||
elif isinstance(file_path_or_dict, dict):
|
||||
file_path_or_dict = copy.deepcopy(file_path_or_dict)
|
||||
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
|
||||
self.verbose = verbose
|
||||
if isinstance(file_path_or_dict, dict):
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
@@ -91,13 +94,13 @@ class NATStopology(NASBenchMetaAPI):
|
||||
self.arch2infos_dict = OrderedDict()
|
||||
self._avaliable_hps = set()
|
||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||
all_info = file_path_or_dict['arch2infos'][xkey]
|
||||
all_infos = file_path_or_dict['arch2infos'][xkey]
|
||||
hp2archres = OrderedDict()
|
||||
for hp_key, results in all_infos.items():
|
||||
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
||||
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
|
||||
self.arch2infos_dict[xkey] = hp2archres
|
||||
self.evaluated_indexes = list(file_path_or_dict['evaluated_indexes'])
|
||||
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
|
||||
elif self.archive_dir is not None:
|
||||
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT))
|
||||
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
|
||||
@@ -116,7 +119,7 @@ class NATStopology(NASBenchMetaAPI):
|
||||
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space.
|
||||
It will load its data from 'archive_root'.
|
||||
If index is None, overwrite all ckps.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
|
||||
|
Reference in New Issue
Block a user