Update NATS-Bench API to v1.1
This commit is contained in:
@@ -426,13 +426,13 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
arch_index, hp))
|
||||
self._prepare_info(arch_index)
|
||||
if arch_index in self.arch2infos_dict:
|
||||
if hp not in self.arch2infos_dict[arch_index]:
|
||||
if str(hp) not in self.arch2infos_dict[arch_index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of '
|
||||
'{:} instead of {:}.'.format(
|
||||
arch_index,
|
||||
list(self.arch2infos_dict[arch_index].keys()),
|
||||
hp))
|
||||
info = self.arch2infos_dict[arch_index][hp]
|
||||
info = self.arch2infos_dict[arch_index][str(hp)]
|
||||
else:
|
||||
raise ValueError('arch_index [{:}] does not in arch2infos'.format(
|
||||
arch_index))
|
||||
@@ -472,7 +472,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
if self.verbose:
|
||||
print('{:} Call query_by_index with arch_index={:}, dataname={:}, '
|
||||
'hp={:}'.format(time_string(), arch_index, dataname, hp))
|
||||
info = self.query_meta_info_by_index(arch_index, hp)
|
||||
info = self.query_meta_info_by_index(arch_index, str(hp))
|
||||
if dataname is None:
|
||||
return info
|
||||
else:
|
||||
|
Reference in New Issue
Block a user