fix bugs in RANDOM-NAS and BOHB

This commit is contained in:
D-X-Y
2019-12-29 20:17:26 +11:00
parent 4c144b7437
commit f8f44bfb31
8 changed files with 469 additions and 67 deletions

View File

@@ -53,43 +53,50 @@ def config2structure_func(max_nodes):
class MyWorker(Worker):
def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs):
def __init__(self, *args, convert_func=None, nas_bench=None, time_budget=None, **kwargs):
super().__init__(*args, **kwargs)
self.convert_func = convert_func
self.nas_bench = nas_bench
self.time_scale = time_scale
self.seen_arch = 0
self.time_budget = time_budget
self.seen_archs = []
self.sim_cost_time = 0
self.real_cost_time = 0
self.is_end = False
def get_the_best(self):
assert len(self.seen_archs) > 0
best_index, best_acc = -1, None
for arch_index in self.seen_archs:
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
vacc = info['valid-accuracy']
if best_acc is None or best_acc < vacc:
best_acc = vacc
best_index = arch_index
assert best_index != -1
return best_index
def compute(self, config, budget, **kwargs):
start_time = time.time()
structure = self.convert_func( config )
arch_index = self.nas_bench.query_index_by_arch( structure )
iepoch = 0
while iepoch < 12:
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True)
cur_time = info['train-all-time'] + info['valid-per-time']
cur_vacc = info['valid-accuracy']
if time.time() - start_time + cur_time / self.time_scale > budget:
break
else:
iepoch += 1
self.sim_cost_time += cur_time
self.seen_arch += 1
remaining_time = cur_time / self.time_scale - (time.time() - start_time)
if remaining_time > 0:
time.sleep(remaining_time)
else:
import pdb; pdb.set_trace()
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
cur_time = info['train-all-time'] + info['valid-per-time']
cur_vacc = info['valid-accuracy']
self.real_cost_time += (time.time() - start_time)
return ({
'loss': 100 - float(cur_vacc),
'info': {'seen-arch' : self.seen_arch,
'sim-test-time' : self.sim_cost_time,
'real-test-time': self.real_cost_time,
'current-arch' : arch_index,
'current-budget': budget}
if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end:
self.sim_cost_time += cur_time
self.seen_archs.append( arch_index )
return ({'loss': 100 - float(cur_vacc),
'info': {'seen-arch' : len(self.seen_archs),
'sim-test-time' : self.sim_cost_time,
'current-arch' : arch_index}
})
else:
self.is_end = True
return ({'loss': 100,
'info': {'seen-arch' : len(self.seen_archs),
'sim-test-time' : self.sim_cost_time,
'current-arch' : None}
})
@@ -139,16 +146,14 @@ def main(xargs, nas_bench):
#logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
workers = []
for i in range(num_workers):
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i)
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i)
w.run(background=True)
workers.append(w)
simulate_time_budge = xargs.time_budget // xargs.time_scale
start_time = time.time()
logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge))
bohb = BOHB(configspace=cs,
run_id=hb_run_id,
eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge,
eta=3, min_budget=12, max_budget=200,
nameserver=ns_host,
nameserver_port=ns_port,
num_samples=xargs.num_samples,
@@ -161,11 +166,9 @@ def main(xargs, nas_bench):
NS.shutdown()
real_cost_time = time.time() - start_time
import pdb; pdb.set_trace()
id2config = results.get_id2config_mapping()
incumbent = results.get_incumbent_id()
logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config']))
best_arch = config2structure( id2config[incumbent]['config'] )
@@ -174,7 +177,7 @@ def main(xargs, nas_bench):
else : logger.log('{:}'.format(info))
logger.log('-'*100)
logger.log('workers : {:}'.format(workers[0].test_time))
logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs)))
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
@@ -190,14 +193,13 @@ if __name__ == '__main__':
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--time_scale' , type=int, help='The time scale to accelerate the time budget.')
# BOHB
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')

View File

@@ -82,14 +82,29 @@ def valid_func(xloader, network, criterion):
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def search_find_best(valid_loader, network, criterion, select_num):
best_arch, best_acc = None, -1
for iarch in range(select_num):
arch = network.module.random_genotype( True )
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
if best_arch is None or best_acc < valid_a_top1:
best_arch, best_acc = arch, valid_a_top1
return best_arch
def search_find_best(xloader, network, n_samples):
with torch.no_grad():
network.eval()
archs, valid_accs = [], []
#print ('obtain the top-{:} architectures'.format(n_samples))
loader_iter = iter(xloader)
for i in range(n_samples):
arch = network.module.random_genotype( True )
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
archs.append( arch )
valid_accs.append( val_top1.item() )
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
def main(xargs):
@@ -127,7 +142,7 @@ def main(xargs):
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
@@ -177,7 +192,8 @@ def main(xargs):
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num)
logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc))
genotypes[epoch] = cur_arch
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
@@ -211,13 +227,7 @@ def main(xargs):
logger.log('\n' + '-'*200)
logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
start_time = time.time()
best_arch, best_acc = None, -1
for iarch in range(xargs.select_num):
arch = search_model.random_genotype( True )
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss))
if best_arch is None or best_acc < valid_a_top1:
best_arch, best_acc = arch, valid_a_top1
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
search_time.update(time.time() - start_time)
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))