Add int search space
This commit is contained in:
@@ -35,7 +35,9 @@ def get_configuration_space(max_nodes, search_space):
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space))
|
||||
cs.add_hyperparameter(
|
||||
ConfigSpace.CategoricalHyperparameter(node_str, search_space)
|
||||
)
|
||||
return cs
|
||||
|
||||
|
||||
@@ -55,7 +57,15 @@ def config2structure_func(max_nodes):
|
||||
|
||||
|
||||
class MyWorker(Worker):
|
||||
def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
convert_func=None,
|
||||
dataname=None,
|
||||
nas_bench=None,
|
||||
time_budget=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.convert_func = convert_func
|
||||
self._dataname = dataname
|
||||
@@ -70,7 +80,9 @@ class MyWorker(Worker):
|
||||
assert len(self.seen_archs) > 0
|
||||
best_index, best_acc = -1, None
|
||||
for arch_index in self.seen_archs:
|
||||
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp="200", is_random=True)
|
||||
info = self._nas_bench.get_more_info(
|
||||
arch_index, self._dataname, None, hp="200", is_random=True
|
||||
)
|
||||
vacc = info["valid-accuracy"]
|
||||
if best_acc is None or best_acc < vacc:
|
||||
best_acc = vacc
|
||||
@@ -82,7 +94,9 @@ class MyWorker(Worker):
|
||||
start_time = time.time()
|
||||
structure = self.convert_func(config)
|
||||
arch_index = self._nas_bench.query_index_by_arch(structure)
|
||||
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp="200", is_random=True)
|
||||
info = self._nas_bench.get_more_info(
|
||||
arch_index, self._dataname, None, hp="200", is_random=True
|
||||
)
|
||||
cur_time = info["train-all-time"] + info["valid-per-time"]
|
||||
cur_vacc = info["valid-accuracy"]
|
||||
self.real_cost_time += time.time() - start_time
|
||||
@@ -101,7 +115,11 @@ class MyWorker(Worker):
|
||||
self.is_end = True
|
||||
return {
|
||||
"loss": 100,
|
||||
"info": {"seen-arch": len(self.seen_archs), "sim-test-time": self.sim_cost_time, "current-arch": None},
|
||||
"info": {
|
||||
"seen-arch": len(self.seen_archs),
|
||||
"sim-test-time": self.sim_cost_time,
|
||||
"current-arch": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -119,13 +137,17 @@ def main(xargs, nas_bench):
|
||||
else:
|
||||
dataname = xargs.dataset
|
||||
if xargs.data_path is not None:
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
train_data, valid_data, xshape, class_num = get_datasets(
|
||||
xargs.dataset, xargs.data_path, -1
|
||||
)
|
||||
split_Fpath = "configs/nas-benchmark/cifar-split.txt"
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log("Load split file from {:}".format(split_Fpath))
|
||||
config_path = "configs/nas-benchmark/algos/R-EA.config"
|
||||
config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger)
|
||||
config = load_config(
|
||||
config_path, {"class_num": class_num, "xshape": xshape}, logger
|
||||
)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
@@ -152,7 +174,11 @@ def main(xargs, nas_bench):
|
||||
)
|
||||
)
|
||||
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
|
||||
extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader}
|
||||
extra_info = {
|
||||
"config": config,
|
||||
"train_loader": train_loader,
|
||||
"valid_loader": valid_loader,
|
||||
}
|
||||
else:
|
||||
config_path = "configs/nas-benchmark/algos/R-EA.config"
|
||||
config = load_config(config_path, None, logger)
|
||||
@@ -213,7 +239,11 @@ def main(xargs, nas_bench):
|
||||
|
||||
id2config = results.get_id2config_mapping()
|
||||
incumbent = results.get_incumbent_id()
|
||||
logger.log("Best found configuration: {:} within {:.3f} s".format(id2config[incumbent]["config"], real_cost_time))
|
||||
logger.log(
|
||||
"Best found configuration: {:} within {:.3f} s".format(
|
||||
id2config[incumbent]["config"], real_cost_time
|
||||
)
|
||||
)
|
||||
best_arch = config2structure(id2config[incumbent]["config"])
|
||||
|
||||
info = nas_bench.query_by_arch(best_arch, "200")
|
||||
@@ -223,13 +253,19 @@ def main(xargs, nas_bench):
|
||||
logger.log("{:}".format(info))
|
||||
logger.log("-" * 100)
|
||||
|
||||
logger.log("workers : {:.1f}s with {:} archs".format(workers[0].time_budget, len(workers[0].seen_archs)))
|
||||
logger.log(
|
||||
"workers : {:.1f}s with {:} archs".format(
|
||||
workers[0].time_budget, len(workers[0].seen_archs)
|
||||
)
|
||||
)
|
||||
logger.close()
|
||||
return logger.log_dir, nas_bench.query_index_by_arch(best_arch), real_cost_time
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale")
|
||||
parser = argparse.ArgumentParser(
|
||||
"BOHB: Robust and Efficient Hyperparameter Optimization at Scale"
|
||||
)
|
||||
parser.add_argument("--data_path", type=str, help="Path to dataset")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
@@ -241,28 +277,71 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--search_space_name", type=str, help="The search space name.")
|
||||
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
|
||||
parser.add_argument("--channel", type=int, help="The number of channels.")
|
||||
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
|
||||
parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).")
|
||||
parser.add_argument(
|
||||
"--num_cells", type=int, help="The number of cells in one stage."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time_budget",
|
||||
type=int,
|
||||
help="The total time cost budge for searching (in seconds).",
|
||||
)
|
||||
# BOHB
|
||||
parser.add_argument(
|
||||
"--strategy", default="sampling", type=str, nargs="?", help="optimization strategy for the acquisition function"
|
||||
)
|
||||
parser.add_argument("--min_bandwidth", default=0.3, type=float, nargs="?", help="minimum bandwidth for KDE")
|
||||
parser.add_argument(
|
||||
"--num_samples", default=64, type=int, nargs="?", help="number of samples for the acquisition function"
|
||||
"--strategy",
|
||||
default="sampling",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="optimization strategy for the acquisition function",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_fraction", default=0.33, type=float, nargs="?", help="fraction of random configurations"
|
||||
"--min_bandwidth",
|
||||
default=0.3,
|
||||
type=float,
|
||||
nargs="?",
|
||||
help="minimum bandwidth for KDE",
|
||||
)
|
||||
parser.add_argument("--bandwidth_factor", default=3, type=int, nargs="?", help="factor multiplied to the bandwidth")
|
||||
parser.add_argument(
|
||||
"--n_iters", default=100, type=int, nargs="?", help="number of iterations for optimization method"
|
||||
"--num_samples",
|
||||
default=64,
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="number of samples for the acquisition function",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_fraction",
|
||||
default=0.33,
|
||||
type=float,
|
||||
nargs="?",
|
||||
help="fraction of random configurations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bandwidth_factor",
|
||||
default=3,
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="factor multiplied to the bandwidth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_iters",
|
||||
default=100,
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="number of iterations for optimization method",
|
||||
)
|
||||
# log
|
||||
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
|
||||
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
|
||||
parser.add_argument(
|
||||
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)."
|
||||
"--workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir", type=str, help="Folder to save checkpoints and log."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch_nas_dataset",
|
||||
type=str,
|
||||
help="The path to load the architecture dataset (tiny-nas-benchmark).",
|
||||
)
|
||||
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
|
||||
parser.add_argument("--rand_seed", type=int, help="manual seed")
|
||||
@@ -271,7 +350,11 @@ if __name__ == "__main__":
|
||||
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
|
||||
nas_bench = None
|
||||
else:
|
||||
print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset))
|
||||
print(
|
||||
"{:} build NAS-Benchmark-API from {:}".format(
|
||||
time_string(), args.arch_nas_dataset
|
||||
)
|
||||
)
|
||||
nas_bench = API(args.arch_nas_dataset)
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_indexes, num, all_times = None, [], 500, []
|
||||
|
Reference in New Issue
Block a user