Add get_torch_home func for NATS-Bench

This commit is contained in:
D-X-Y
2020-12-01 22:25:23 +08:00
parent 8afb62ad2e
commit 46b92e37e2
7 changed files with 294 additions and 10 deletions

View File

@@ -17,6 +17,7 @@ from typing import Dict, Optional, Text, Union, Any
from nats_bench.api_utils import ArchResults
from nats_bench.api_utils import NASBenchMetaAPI
from nats_bench.api_utils import get_torch_home
from nats_bench.api_utils import nats_is_dir
from nats_bench.api_utils import nats_is_file
from nats_bench.api_utils import PICKLE_EXT
@@ -88,10 +89,10 @@ class NATSsize(NASBenchMetaAPI):
if file_path_or_dict is None:
if self._fast_mode:
self._archive_dir = os.path.join(
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else:
file_path_or_dict = os.path.join(
os.environ['TORCH_HOME'], '{:}.{:}'.format(
get_torch_home(), '{:}.{:}'.format(
ALL_BASE_NAMES[-1], PICKLE_EXT))
print('{:} Try to use the default NATS-Bench (size) path from '
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode,