update README
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# required to install hpbandster #################
|
||||
# bash ./scripts-search/algos/BOHB.sh -1 #
|
||||
##################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np, collections
|
||||
@@ -19,7 +20,6 @@ from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from nas_102_api import NASBench102API as API
|
||||
from models import CellStructure, get_search_spaces
|
||||
from R_EA import train_and_eval
|
||||
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
|
||||
import ConfigSpace
|
||||
from hpbandster.optimizers.bohb import BOHB
|
||||
@@ -53,21 +53,44 @@ def config2structure_func(max_nodes):
|
||||
|
||||
class MyWorker(Worker):
|
||||
|
||||
def __init__(self, *args, sleep_interval=0, convert_func=None, nas_bench=None, **kwargs):
|
||||
def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sleep_interval = sleep_interval
|
||||
self.convert_func = convert_func
|
||||
self.nas_bench = nas_bench
|
||||
self.test_time = 0
|
||||
self.time_scale = time_scale
|
||||
self.seen_arch = 0
|
||||
self.sim_cost_time = 0
|
||||
self.real_cost_time = 0
|
||||
|
||||
def compute(self, config, budget, **kwargs):
|
||||
structure = self.convert_func( config )
|
||||
reward, time_cost = train_and_eval(structure, self.nas_bench, None)
|
||||
import pdb; pdb.set_trace()
|
||||
self.test_time += 1
|
||||
start_time = time.time()
|
||||
structure = self.convert_func( config )
|
||||
arch_index = self.nas_bench.query_index_by_arch( structure )
|
||||
iepoch = 0
|
||||
while iepoch < 12:
|
||||
info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True)
|
||||
cur_time = info['train-all-time'] + info['valid-per-time']
|
||||
cur_vacc = info['valid-accuracy']
|
||||
if time.time() - start_time + cur_time / self.time_scale > budget:
|
||||
break
|
||||
else:
|
||||
iepoch += 1
|
||||
self.sim_cost_time += cur_time
|
||||
self.seen_arch += 1
|
||||
remaining_time = cur_time / self.time_scale - (time.time() - start_time)
|
||||
if remaining_time > 0:
|
||||
time.sleep(remaining_time)
|
||||
else:
|
||||
import pdb; pdb.set_trace()
|
||||
self.real_cost_time += (time.time() - start_time)
|
||||
return ({
|
||||
'loss': float(100-reward),
|
||||
'info': time_cost})
|
||||
'loss': 100 - float(cur_vacc),
|
||||
'info': {'seen-arch' : self.seen_arch,
|
||||
'sim-test-time' : self.sim_cost_time,
|
||||
'real-test-time': self.real_cost_time,
|
||||
'current-arch' : arch_index,
|
||||
'current-budget': budget}
|
||||
})
|
||||
|
||||
|
||||
def main(xargs, nas_bench):
|
||||
@@ -116,26 +139,30 @@ def main(xargs, nas_bench):
|
||||
#logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
|
||||
workers = []
|
||||
for i in range(num_workers):
|
||||
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, run_id=hb_run_id, id=i)
|
||||
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i)
|
||||
w.run(background=True)
|
||||
workers.append(w)
|
||||
|
||||
simulate_time_budge = xargs.time_budget // xargs.time_scale
|
||||
start_time = time.time()
|
||||
logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge))
|
||||
bohb = BOHB(configspace=cs,
|
||||
run_id=hb_run_id,
|
||||
eta=3, min_budget=3, max_budget=xargs.time_budget,
|
||||
eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge,
|
||||
nameserver=ns_host,
|
||||
nameserver_port=ns_port,
|
||||
num_samples=xargs.num_samples,
|
||||
random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor,
|
||||
ping_interval=10, min_bandwidth=xargs.min_bandwidth)
|
||||
# optimization_strategy=xargs.strategy, num_samples=xargs.num_samples,
|
||||
|
||||
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
bohb.shutdown(shutdown_workers=True)
|
||||
NS.shutdown()
|
||||
|
||||
real_cost_time = time.time() - start_time
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
id2config = results.get_id2config_mapping()
|
||||
incumbent = results.get_incumbent_id()
|
||||
|
||||
@@ -163,6 +190,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||
parser.add_argument('--time_scale' , type=int, help='The time scale to accelerate the time budget.')
|
||||
# BOHB
|
||||
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
|
||||
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
|
||||
|
Reference in New Issue
Block a user