Update NATS-Bench (tss version 0.99)
This commit is contained in:
@@ -10,6 +10,7 @@ import os, copy, random, numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
from .api_utils import time_string
|
||||
from .api_utils import pickle_load
|
||||
from .api_utils import ArchResults
|
||||
from .api_utils import NASBenchMetaAPI
|
||||
@@ -71,7 +72,7 @@ class NATSsize(NASBenchMetaAPI):
|
||||
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
||||
file_path_or_dict = str(file_path_or_dict)
|
||||
if verbose:
|
||||
print('Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(file_path_or_dict, fast_mode))
|
||||
print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
|
||||
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict):
|
||||
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
|
||||
self.filename = Path(file_path_or_dict).name
|
||||
@@ -116,14 +117,15 @@ class NATSsize(NASBenchMetaAPI):
|
||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||
self.archstr2index[arch] = idx
|
||||
if self.verbose:
|
||||
print('Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(len(self.evaluated_indexes), len(self.meta_archs)))
|
||||
print('{:} Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(
|
||||
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
|
||||
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
|
||||
If index is None, overwrite all ckps.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call clear_params with archive_root={:} and index={:}'.format(archive_root, index))
|
||||
print('{:} Call clear_params with archive_root={:} and index={:}'.format(time_string(), archive_root, index))
|
||||
if archive_root is None:
|
||||
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
|
||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||
@@ -155,7 +157,7 @@ class NATSsize(NASBenchMetaAPI):
|
||||
The difference between these three configurations are the number of training epochs.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
|
||||
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp))
|
||||
return self._query_info_str_by_arch(arch, hp, print_information)
|
||||
|
||||
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
|
||||
@@ -177,7 +179,8 @@ class NATSsize(NASBenchMetaAPI):
|
||||
When is_random=False, the performanceo of all trials will be averaged.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
|
||||
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(
|
||||
time_string(), index, dataset, iepoch, hp, is_random))
|
||||
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
|
||||
self._prepare_info(index)
|
||||
if index not in self.arch2infos_dict:
|
||||
|
Reference in New Issue
Block a user