update NAS-Bench-102 baselines

This commit is contained in:
D-X-Y
2019-12-25 10:30:50 +11:00
parent 44a0d51449
commit 1d5e8debad
5 changed files with 48 additions and 17 deletions

View File

@@ -62,11 +62,12 @@ class MyWorker(Worker):
def compute(self, config, budget, **kwargs):
structure = self.convert_func( config )
reward = train_and_eval(structure, self.nas_bench, None)
reward, time_cost = train_and_eval(structure, self.nas_bench, None)
import pdb; pdb.set_trace()
self.test_time += 1
return ({
'loss': float(100-reward),
'info': None})
'info': time_cost})
def main(xargs, nas_bench):
@@ -121,7 +122,7 @@ def main(xargs, nas_bench):
bohb = BOHB(configspace=cs,
run_id=hb_run_id,
eta=3, min_budget=3, max_budget=108,
eta=3, min_budget=3, max_budget=xargs.time_budget,
nameserver=ns_host,
nameserver_port=ns_port,
num_samples=xargs.num_samples,
@@ -130,6 +131,7 @@ def main(xargs, nas_bench):
# optimization_strategy=xargs.strategy, num_samples=xargs.num_samples,
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
import pdb; pdb.set_trace()
bohb.shutdown(shutdown_workers=True)
NS.shutdown()
@@ -160,9 +162,10 @@ if __name__ == '__main__':
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
# BOHB
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')