Update NATS-Bench (tss version 0.99)

This commit is contained in:
D-X-Y
2020-09-05 10:40:29 +00:00
parent 8d64afd4a3
commit bd9288f45d
9 changed files with 379 additions and 56 deletions

View File

@@ -10,6 +10,7 @@ import os, copy, random, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
from .api_utils import time_string
from .api_utils import pickle_load
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
@@ -71,7 +72,7 @@ class NATSsize(NASBenchMetaAPI):
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose:
print('Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(file_path_or_dict, fast_mode))
print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
self.filename = Path(file_path_or_dict).name
@@ -116,14 +117,15 @@ class NATSsize(NASBenchMetaAPI):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[arch] = idx
if self.verbose:
print('Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(len(self.evaluated_indexes), len(self.meta_archs)))
print('{:} Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('Call clear_params with archive_root={:} and index={:}'.format(archive_root, index))
print('{:} Call clear_params with archive_root={:} and index={:}'.format(time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
@@ -155,7 +157,7 @@ class NATSsize(NASBenchMetaAPI):
The difference between these three configurations are the number of training epochs.
"""
if self.verbose:
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
@@ -177,7 +179,8 @@ class NATSsize(NASBenchMetaAPI):
When is_random=False, the performanceo of all trials will be averaged.
"""
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(
time_string(), index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index)
if index not in self.arch2infos_dict: