Update NATS-Bench API to v1.1

This commit is contained in:
D-X-Y
2020-12-20 00:30:14 +08:00
parent c4ef3f6620
commit dae387a97d
4 changed files with 38 additions and 12 deletions

View File

@@ -426,13 +426,13 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
arch_index, hp))
self._prepare_info(arch_index)
if arch_index in self.arch2infos_dict:
if hp not in self.arch2infos_dict[arch_index]:
if str(hp) not in self.arch2infos_dict[arch_index]:
raise ValueError('The {:}-th architecture only has hyper-parameters of '
'{:} instead of {:}.'.format(
arch_index,
list(self.arch2infos_dict[arch_index].keys()),
hp))
info = self.arch2infos_dict[arch_index][hp]
info = self.arch2infos_dict[arch_index][str(hp)]
else:
raise ValueError('arch_index [{:}] does not in arch2infos'.format(
arch_index))
@@ -472,7 +472,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
if self.verbose:
print('{:} Call query_by_index with arch_index={:}, dataname={:}, '
'hp={:}'.format(time_string(), arch_index, dataname, hp))
info = self.query_meta_info_by_index(arch_index, hp)
info = self.query_meta_info_by_index(arch_index, str(hp))
if dataname is None:
return info
else: