Support accumulate and reset time function for API
This commit is contained in:
@@ -57,6 +57,7 @@ class NASBench301API(NASBenchMetaAPI):
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
|
||||
self.filename = None
|
||||
self.reset_time()
|
||||
if file_path_or_dict is None:
|
||||
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
|
||||
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
||||
@@ -128,7 +129,7 @@ class NASBench301API(NASBenchMetaAPI):
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
|
||||
self._query_info_str_by_arch(arch, hp, print_information)
|
||||
return self._query_info_str_by_arch(arch, hp, print_information)
|
||||
|
||||
def get_more_info(self, index: int, dataset: Text, iepoch=None, hp='12', is_random=True):
|
||||
"""This function will return the metric for the `index`-th architecture
|
||||
|
Reference in New Issue
Block a user