Add get_torch_home func for NATS-Bench

This commit is contained in:
D-X-Y
2020-12-01 22:25:23 +08:00
parent 8afb62ad2e
commit 46b92e37e2
7 changed files with 294 additions and 10 deletions

View File

@@ -45,6 +45,17 @@ def get_file_system():
return _FILE_SYSTEM
def get_torch_home():
if 'TORCH_HOME' in os.environ:
return os.environ['TORCH_HOME']
elif 'HOME' in os.environ:
return os.path.join(os.environ['HOME'], '.torch')
else:
raise ValueError('Did not find HOME in os.environ. '
'Please at least setup the path of HOME or TORCH_HOME '
'in the environment.')
def nats_is_dir(file_path):
if _FILE_SYSTEM == 'default':
return os.path.isdir(file_path)