Update new version of BOHB
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
# pip install hpbandster ##################################
|
||||
###################################################################
|
||||
# OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
||||
# OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
||||
###################################################################
|
||||
import os, sys, time, random, argparse, collections
|
||||
from copy import deepcopy
|
||||
@@ -38,12 +39,9 @@ def get_topology_config_space(search_space, max_nodes=4):
|
||||
|
||||
def get_size_config_space(search_space):
|
||||
cs = ConfigSpace.ConfigurationSpace()
|
||||
import pdb; pdb.set_trace()
|
||||
#edge2index = {}
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space))
|
||||
for ilayer in range(search_space['numbers']):
|
||||
node_str = 'layer-{:}'.format(ilayer)
|
||||
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space['candidates']))
|
||||
return cs
|
||||
|
||||
|
||||
@@ -61,6 +59,16 @@ def config2topology_func(max_nodes=4):
|
||||
return config2structure
|
||||
|
||||
|
||||
def config2size_func(search_space):
|
||||
def config2structure(config):
|
||||
channels = []
|
||||
for ilayer in range(search_space['numbers']):
|
||||
node_str = 'layer-{:}'.format(ilayer)
|
||||
channels.append(str(config[node_str]))
|
||||
return ':'.join(channels)
|
||||
return config2structure
|
||||
|
||||
|
||||
class MyWorker(Worker):
|
||||
|
||||
def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs):
|
||||
@@ -89,11 +97,11 @@ def main(xargs, api):
|
||||
api.reset_time()
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
if xargs.search_space == 'tss':
|
||||
cs = get_topology_config_space(search_space)
|
||||
config2structure = config2topology_func()
|
||||
cs = get_topology_config_space(search_space)
|
||||
config2structure = config2topology_func()
|
||||
else:
|
||||
cs = get_size_config_space(search_space)
|
||||
import pdb; pdb.set_trace()
|
||||
config2structure = config2size_func(search_space)
|
||||
|
||||
hb_run_id = '0'
|
||||
|
||||
|
@@ -17,3 +17,6 @@ do
|
||||
python exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
||||
done
|
||||
done
|
||||
|
||||
python exps/experimental/vis-bench-algos.py --search_space tss
|
||||
python exps/experimental/vis-bench-algos.py --search_space sss
|
||||
|
Reference in New Issue
Block a user