Add get_torch_home func for NATS-Bench
This commit is contained in:
@@ -17,6 +17,7 @@ from typing import Dict, Optional, Text, Union, Any
|
||||
|
||||
from nats_bench.api_utils import ArchResults
|
||||
from nats_bench.api_utils import NASBenchMetaAPI
|
||||
from nats_bench.api_utils import get_torch_home
|
||||
from nats_bench.api_utils import nats_is_dir
|
||||
from nats_bench.api_utils import nats_is_file
|
||||
from nats_bench.api_utils import PICKLE_EXT
|
||||
@@ -88,10 +89,10 @@ class NATSsize(NASBenchMetaAPI):
|
||||
if file_path_or_dict is None:
|
||||
if self._fast_mode:
|
||||
self._archive_dir = os.path.join(
|
||||
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
else:
|
||||
file_path_or_dict = os.path.join(
|
||||
os.environ['TORCH_HOME'], '{:}.{:}'.format(
|
||||
get_torch_home(), '{:}.{:}'.format(
|
||||
ALL_BASE_NAMES[-1], PICKLE_EXT))
|
||||
print('{:} Try to use the default NATS-Bench (size) path from '
|
||||
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode,
|
||||
|
Reference in New Issue
Block a user