Update NATS-Bench (sss version 1.0)

This commit is contained in:
D-X-Y
2020-08-28 06:02:35 +00:00
parent 3529b993ff
commit c68458f66c
9 changed files with 413 additions and 28 deletions

View File

@@ -10,10 +10,15 @@
# History:
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
#
import os, abc, copy, random, torch, numpy as np
from pathlib import Path
import abc, copy, random, numpy as np
import importlib, warnings
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
USE_TORCH = importlib.find_loader('torch') is not None
if USE_TORCH:
import torch
else:
warnings.warn('Can not find PyTorch, and thus some features maybe invalid.')
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
@@ -545,6 +550,8 @@ class ArchResults(object):
def create_from_state_dict(state_dict_or_file):
x = ArchResults(-1, -1)
if isinstance(state_dict_or_file, str): # a file path
if not USE_TORCH:
raise ValueError('Since torch is not imported, this logic can not be used.')
state_dict = torch.load(state_dict_or_file, map_location='cpu')
elif isinstance(state_dict_or_file, dict):
state_dict = state_dict_or_file