Update NATS-Bench (sss version 1.2)

This commit is contained in:
D-X-Y
2020-08-30 08:04:52 +00:00
parent 469a207945
commit 5f151d1970
15 changed files with 317 additions and 229 deletions

View File

@@ -3,7 +3,7 @@
############################################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
############################################################################################
import os, copy, random, torch, numpy as np
import os, copy, random, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
@@ -62,7 +62,7 @@ class NATStopology(NASBenchMetaAPI):
if verbose: print('try to create the NATS-Bench (topology) api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
self.filename = Path(file_path_or_dict).name
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
file_path_or_dict = np.load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy(file_path_or_dict)
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))