update NAS-Bench-102 baselines
This commit is contained in:
@@ -62,11 +62,12 @@ class MyWorker(Worker):
|
||||
|
||||
def compute(self, config, budget, **kwargs):
|
||||
structure = self.convert_func( config )
|
||||
reward = train_and_eval(structure, self.nas_bench, None)
|
||||
reward, time_cost = train_and_eval(structure, self.nas_bench, None)
|
||||
import pdb; pdb.set_trace()
|
||||
self.test_time += 1
|
||||
return ({
|
||||
'loss': float(100-reward),
|
||||
'info': None})
|
||||
'info': time_cost})
|
||||
|
||||
|
||||
def main(xargs, nas_bench):
|
||||
@@ -121,7 +122,7 @@ def main(xargs, nas_bench):
|
||||
|
||||
bohb = BOHB(configspace=cs,
|
||||
run_id=hb_run_id,
|
||||
eta=3, min_budget=3, max_budget=108,
|
||||
eta=3, min_budget=3, max_budget=xargs.time_budget,
|
||||
nameserver=ns_host,
|
||||
nameserver_port=ns_port,
|
||||
num_samples=xargs.num_samples,
|
||||
@@ -130,6 +131,7 @@ def main(xargs, nas_bench):
|
||||
# optimization_strategy=xargs.strategy, num_samples=xargs.num_samples,
|
||||
|
||||
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
bohb.shutdown(shutdown_workers=True)
|
||||
NS.shutdown()
|
||||
@@ -160,9 +162,10 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||
# BOHB
|
||||
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
|
||||
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
|
||||
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
|
||||
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
|
||||
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
|
||||
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
|
||||
|
Reference in New Issue
Block a user