Update NATS-Bench (sss version 1.2)
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
############################################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
|
||||
############################################################################################
|
||||
import os, copy, random, torch, numpy as np
|
||||
import os, copy, random, numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
@@ -62,7 +62,7 @@ class NATStopology(NASBenchMetaAPI):
|
||||
if verbose: print('try to create the NATS-Bench (topology) api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
self.filename = Path(file_path_or_dict).name
|
||||
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
|
||||
file_path_or_dict = np.load(file_path_or_dict)
|
||||
elif isinstance(file_path_or_dict, dict):
|
||||
file_path_or_dict = copy.deepcopy(file_path_or_dict)
|
||||
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
||||
|
Reference in New Issue
Block a user