fix bugs in RANDOM-NAS and BOHB
This commit is contained in:
@@ -53,43 +53,50 @@ def config2structure_func(max_nodes):
|
||||
|
||||
class MyWorker(Worker):
|
||||
|
||||
def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs):
|
||||
def __init__(self, *args, convert_func=None, nas_bench=None, time_budget=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.convert_func = convert_func
|
||||
self.nas_bench = nas_bench
|
||||
self.time_scale = time_scale
|
||||
self.seen_arch = 0
|
||||
self.time_budget = time_budget
|
||||
self.seen_archs = []
|
||||
self.sim_cost_time = 0
|
||||
self.real_cost_time = 0
|
||||
self.is_end = False
|
||||
|
||||
def get_the_best(self):
|
||||
assert len(self.seen_archs) > 0
|
||||
best_index, best_acc = -1, None
|
||||
for arch_index in self.seen_archs:
|
||||
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
||||
vacc = info['valid-accuracy']
|
||||
if best_acc is None or best_acc < vacc:
|
||||
best_acc = vacc
|
||||
best_index = arch_index
|
||||
assert best_index != -1
|
||||
return best_index
|
||||
|
||||
def compute(self, config, budget, **kwargs):
|
||||
start_time = time.time()
|
||||
structure = self.convert_func( config )
|
||||
arch_index = self.nas_bench.query_index_by_arch( structure )
|
||||
iepoch = 0
|
||||
while iepoch < 12:
|
||||
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True)
|
||||
cur_time = info['train-all-time'] + info['valid-per-time']
|
||||
cur_vacc = info['valid-accuracy']
|
||||
if time.time() - start_time + cur_time / self.time_scale > budget:
|
||||
break
|
||||
else:
|
||||
iepoch += 1
|
||||
self.sim_cost_time += cur_time
|
||||
self.seen_arch += 1
|
||||
remaining_time = cur_time / self.time_scale - (time.time() - start_time)
|
||||
if remaining_time > 0:
|
||||
time.sleep(remaining_time)
|
||||
else:
|
||||
import pdb; pdb.set_trace()
|
||||
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
|
||||
cur_time = info['train-all-time'] + info['valid-per-time']
|
||||
cur_vacc = info['valid-accuracy']
|
||||
self.real_cost_time += (time.time() - start_time)
|
||||
return ({
|
||||
'loss': 100 - float(cur_vacc),
|
||||
'info': {'seen-arch' : self.seen_arch,
|
||||
'sim-test-time' : self.sim_cost_time,
|
||||
'real-test-time': self.real_cost_time,
|
||||
'current-arch' : arch_index,
|
||||
'current-budget': budget}
|
||||
if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end:
|
||||
self.sim_cost_time += cur_time
|
||||
self.seen_archs.append( arch_index )
|
||||
return ({'loss': 100 - float(cur_vacc),
|
||||
'info': {'seen-arch' : len(self.seen_archs),
|
||||
'sim-test-time' : self.sim_cost_time,
|
||||
'current-arch' : arch_index}
|
||||
})
|
||||
else:
|
||||
self.is_end = True
|
||||
return ({'loss': 100,
|
||||
'info': {'seen-arch' : len(self.seen_archs),
|
||||
'sim-test-time' : self.sim_cost_time,
|
||||
'current-arch' : None}
|
||||
})
|
||||
|
||||
|
||||
@@ -139,16 +146,14 @@ def main(xargs, nas_bench):
|
||||
#logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
|
||||
workers = []
|
||||
for i in range(num_workers):
|
||||
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i)
|
||||
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i)
|
||||
w.run(background=True)
|
||||
workers.append(w)
|
||||
|
||||
simulate_time_budge = xargs.time_budget // xargs.time_scale
|
||||
start_time = time.time()
|
||||
logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge))
|
||||
bohb = BOHB(configspace=cs,
|
||||
run_id=hb_run_id,
|
||||
eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge,
|
||||
eta=3, min_budget=12, max_budget=200,
|
||||
nameserver=ns_host,
|
||||
nameserver_port=ns_port,
|
||||
num_samples=xargs.num_samples,
|
||||
@@ -161,11 +166,9 @@ def main(xargs, nas_bench):
|
||||
NS.shutdown()
|
||||
|
||||
real_cost_time = time.time() - start_time
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
id2config = results.get_id2config_mapping()
|
||||
incumbent = results.get_incumbent_id()
|
||||
|
||||
logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config']))
|
||||
best_arch = config2structure( id2config[incumbent]['config'] )
|
||||
|
||||
@@ -174,7 +177,7 @@ def main(xargs, nas_bench):
|
||||
else : logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
|
||||
logger.log('workers : {:}'.format(workers[0].test_time))
|
||||
logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs)))
|
||||
logger.close()
|
||||
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
|
||||
|
||||
@@ -190,14 +193,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||
parser.add_argument('--time_scale' , type=int, help='The time scale to accelerate the time budget.')
|
||||
# BOHB
|
||||
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
|
||||
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
|
||||
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
|
||||
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
|
||||
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
|
||||
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
|
||||
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
|
||||
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
|
||||
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
|
||||
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
|
||||
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
|
Reference in New Issue
Block a user