Add get_torch_home func for NATS-Bench
This commit is contained in:
@@ -17,6 +17,7 @@ from typing import Dict, Optional, Text, Union, Any
|
||||
|
||||
from nats_bench.api_utils import ArchResults
|
||||
from nats_bench.api_utils import NASBenchMetaAPI
|
||||
from nats_bench.api_utils import get_torch_home
|
||||
from nats_bench.api_utils import nats_is_dir
|
||||
from nats_bench.api_utils import nats_is_file
|
||||
from nats_bench.api_utils import PICKLE_EXT
|
||||
@@ -88,10 +89,10 @@ class NATSsize(NASBenchMetaAPI):
|
||||
if file_path_or_dict is None:
|
||||
if self._fast_mode:
|
||||
self._archive_dir = os.path.join(
|
||||
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
else:
|
||||
file_path_or_dict = os.path.join(
|
||||
os.environ['TORCH_HOME'], '{:}.{:}'.format(
|
||||
get_torch_home(), '{:}.{:}'.format(
|
||||
ALL_BASE_NAMES[-1], PICKLE_EXT))
|
||||
print('{:} Try to use the default NATS-Bench (size) path from '
|
||||
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode,
|
||||
|
@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Text, Union
|
||||
|
||||
from nats_bench.api_utils import ArchResults
|
||||
from nats_bench.api_utils import NASBenchMetaAPI
|
||||
from nats_bench.api_utils import get_torch_home
|
||||
from nats_bench.api_utils import nats_is_dir
|
||||
from nats_bench.api_utils import nats_is_file
|
||||
from nats_bench.api_utils import PICKLE_EXT
|
||||
@@ -88,10 +89,10 @@ class NATStopology(NASBenchMetaAPI):
|
||||
if file_path_or_dict is None:
|
||||
if self._fast_mode:
|
||||
self._archive_dir = os.path.join(
|
||||
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||
else:
|
||||
file_path_or_dict = os.path.join(
|
||||
os.environ['TORCH_HOME'], '{:}.{:}'.format(
|
||||
get_torch_home(), '{:}.{:}'.format(
|
||||
ALL_BASE_NAMES[-1], PICKLE_EXT))
|
||||
print('{:} Try to use the default NATS-Bench (topology) path from '
|
||||
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict))
|
||||
|
@@ -45,6 +45,17 @@ def get_file_system():
|
||||
return _FILE_SYSTEM
|
||||
|
||||
|
||||
def get_torch_home():
|
||||
if 'TORCH_HOME' in os.environ:
|
||||
return os.environ['TORCH_HOME']
|
||||
elif 'HOME' in os.environ:
|
||||
return os.path.join(os.environ['HOME'], '.torch')
|
||||
else:
|
||||
raise ValueError('Did not find HOME in os.environ. '
|
||||
'Please at least setup the path of HOME or TORCH_HOME '
|
||||
'in the environment.')
|
||||
|
||||
|
||||
def nats_is_dir(file_path):
|
||||
if _FILE_SYSTEM == 'default':
|
||||
return os.path.isdir(file_path)
|
||||
|
Reference in New Issue
Block a user