Upgrade NAS-API to v2.0:

we use an abstract class NASBenchMetaAPI to define the spec of an API; it can be inherited to support different kinds of NAS API, while keep the query interface the same.
This commit is contained in:
D-X-Y
2020-06-30 09:05:38 +00:00
parent 91ee265bd2
commit 6effb6f127
23 changed files with 1888 additions and 944 deletions

View File

@@ -65,7 +65,7 @@ class MyWorker(Worker):
assert len(self.seen_archs) > 0
best_index, best_acc = -1, None
for arch_index in self.seen_archs:
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True)
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True)
vacc = info['valid-accuracy']
if best_acc is None or best_acc < vacc:
best_acc = vacc
@@ -77,7 +77,7 @@ class MyWorker(Worker):
start_time = time.time()
structure = self.convert_func( config )
arch_index = self._nas_bench.query_index_by_arch( structure )
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, True, True)
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True)
cur_time = info['train-all-time'] + info['valid-per-time']
cur_vacc = info['valid-accuracy']
self.real_cost_time += (time.time() - start_time)