Update NATS-Bench (sss version 1.0)
This commit is contained in:
@@ -10,10 +10,15 @@
|
||||
# History:
|
||||
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
|
||||
#
|
||||
import os, abc, copy, random, torch, numpy as np
|
||||
from pathlib import Path
|
||||
import abc, copy, random, numpy as np
|
||||
import importlib, warnings
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
USE_TORCH = importlib.find_loader('torch') is not None
|
||||
if USE_TORCH:
|
||||
import torch
|
||||
else:
|
||||
warnings.warn('Can not find PyTorch, and thus some features maybe invalid.')
|
||||
|
||||
|
||||
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
|
||||
@@ -545,6 +550,8 @@ class ArchResults(object):
|
||||
def create_from_state_dict(state_dict_or_file):
|
||||
x = ArchResults(-1, -1)
|
||||
if isinstance(state_dict_or_file, str): # a file path
|
||||
if not USE_TORCH:
|
||||
raise ValueError('Since torch is not imported, this logic can not be used.')
|
||||
state_dict = torch.load(state_dict_or_file, map_location='cpu')
|
||||
elif isinstance(state_dict_or_file, dict):
|
||||
state_dict = state_dict_or_file
|
||||
|
@@ -3,3 +3,4 @@ from .gpu_manager import GPUManager
|
||||
from .flop_benchmark import get_model_infos, count_parameters_in_MB
|
||||
from .affine_utils import normalize_points, denormalize_points
|
||||
from .affine_utils import identity2affine, solve2theta, affine2image
|
||||
from .hash_utils import get_md5_file
|
||||
|
16
lib/utils/hash_utils.py
Normal file
16
lib/utils/hash_utils.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import os, hashlib
|
||||
|
||||
|
||||
def get_md5_file(file_path, post_truncated=5):
|
||||
md5_hash = hashlib.md5()
|
||||
if os.path.exists(file_path):
|
||||
xfile = open(file_path, "rb")
|
||||
content = xfile.read()
|
||||
md5_hash.update(content)
|
||||
digest = md5_hash.hexdigest()
|
||||
else:
|
||||
raise ValueError('[get_md5_file] {:} does not exist'.format(file_path))
|
||||
if post_truncated is None:
|
||||
return digest
|
||||
else:
|
||||
return digest[-post_truncated:]
|
Reference in New Issue
Block a user