Add get_torch_home func for NATS-Bench

This commit is contained in:
D-X-Y
2020-12-01 22:25:23 +08:00
parent 8afb62ad2e
commit 46b92e37e2
7 changed files with 294 additions and 10 deletions

View File

@@ -17,6 +17,7 @@ from typing import Dict, Optional, Text, Union, Any
from nats_bench.api_utils import ArchResults
from nats_bench.api_utils import NASBenchMetaAPI
from nats_bench.api_utils import get_torch_home
from nats_bench.api_utils import nats_is_dir
from nats_bench.api_utils import nats_is_file
from nats_bench.api_utils import PICKLE_EXT
@@ -88,10 +89,10 @@ class NATSsize(NASBenchMetaAPI):
if file_path_or_dict is None:
if self._fast_mode:
self._archive_dir = os.path.join(
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else:
file_path_or_dict = os.path.join(
os.environ['TORCH_HOME'], '{:}.{:}'.format(
get_torch_home(), '{:}.{:}'.format(
ALL_BASE_NAMES[-1], PICKLE_EXT))
print('{:} Try to use the default NATS-Bench (size) path from '
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode,

View File

@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Text, Union
from nats_bench.api_utils import ArchResults
from nats_bench.api_utils import NASBenchMetaAPI
from nats_bench.api_utils import get_torch_home
from nats_bench.api_utils import nats_is_dir
from nats_bench.api_utils import nats_is_file
from nats_bench.api_utils import PICKLE_EXT
@@ -88,10 +89,10 @@ class NATStopology(NASBenchMetaAPI):
if file_path_or_dict is None:
if self._fast_mode:
self._archive_dir = os.path.join(
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else:
file_path_or_dict = os.path.join(
os.environ['TORCH_HOME'], '{:}.{:}'.format(
get_torch_home(), '{:}.{:}'.format(
ALL_BASE_NAMES[-1], PICKLE_EXT))
print('{:} Try to use the default NATS-Bench (topology) path from '
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict))

View File

@@ -45,6 +45,17 @@ def get_file_system():
return _FILE_SYSTEM
def get_torch_home():
if 'TORCH_HOME' in os.environ:
return os.environ['TORCH_HOME']
elif 'HOME' in os.environ:
return os.path.join(os.environ['HOME'], '.torch')
else:
raise ValueError('Did not find HOME in os.environ. '
'Please at least setup the path of HOME or TORCH_HOME '
'in the environment.')
def nats_is_dir(file_path):
if _FILE_SYSTEM == 'default':
return os.path.isdir(file_path)