Update NATS-Bench (sss version 1.0)

This commit is contained in:
D-X-Y
2020-08-28 06:02:35 +00:00
parent 3529b993ff
commit c68458f66c
9 changed files with 413 additions and 28 deletions

View File

@@ -10,10 +10,15 @@
# History:
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
#
import os, abc, copy, random, torch, numpy as np
from pathlib import Path
import abc, copy, random, numpy as np
import importlib, warnings
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
USE_TORCH = importlib.find_loader('torch') is not None
if USE_TORCH:
import torch
else:
warnings.warn('Can not find PyTorch, and thus some features maybe invalid.')
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
@@ -545,6 +550,8 @@ class ArchResults(object):
def create_from_state_dict(state_dict_or_file):
x = ArchResults(-1, -1)
if isinstance(state_dict_or_file, str): # a file path
if not USE_TORCH:
raise ValueError('Since torch is not imported, this logic can not be used.')
state_dict = torch.load(state_dict_or_file, map_location='cpu')
elif isinstance(state_dict_or_file, dict):
state_dict = state_dict_or_file

View File

@@ -3,3 +3,4 @@ from .gpu_manager import GPUManager
from .flop_benchmark import get_model_infos, count_parameters_in_MB
from .affine_utils import normalize_points, denormalize_points
from .affine_utils import identity2affine, solve2theta, affine2image
from .hash_utils import get_md5_file

16
lib/utils/hash_utils.py Normal file
View File

@@ -0,0 +1,16 @@
import os, hashlib
def get_md5_file(file_path, post_truncated=5):
md5_hash = hashlib.md5()
if os.path.exists(file_path):
xfile = open(file_path, "rb")
content = xfile.read()
md5_hash.update(content)
digest = md5_hash.hexdigest()
else:
raise ValueError('[get_md5_file] {:} does not exist'.format(file_path))
if post_truncated is None:
return digest
else:
return digest[-post_truncated:]