Update NATS-Bench (sss version 1.0)
This commit is contained in:
@@ -10,10 +10,15 @@
|
||||
# History:
|
||||
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
|
||||
#
|
||||
import os, abc, copy, random, torch, numpy as np
|
||||
from pathlib import Path
|
||||
import abc, copy, random, numpy as np
|
||||
import importlib, warnings
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
USE_TORCH = importlib.find_loader('torch') is not None
|
||||
if USE_TORCH:
|
||||
import torch
|
||||
else:
|
||||
warnings.warn('Can not find PyTorch, and thus some features maybe invalid.')
|
||||
|
||||
|
||||
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
|
||||
@@ -545,6 +550,8 @@ class ArchResults(object):
|
||||
def create_from_state_dict(state_dict_or_file):
|
||||
x = ArchResults(-1, -1)
|
||||
if isinstance(state_dict_or_file, str): # a file path
|
||||
if not USE_TORCH:
|
||||
raise ValueError('Since torch is not imported, this logic can not be used.')
|
||||
state_dict = torch.load(state_dict_or_file, map_location='cpu')
|
||||
elif isinstance(state_dict_or_file, dict):
|
||||
state_dict = state_dict_or_file
|
||||
|
Reference in New Issue
Block a user