Add int search space

This commit is contained in:
D-X-Y
2021-03-18 16:02:55 +08:00
parent ece6ac5f41
commit 63c8bb9bc8
67 changed files with 5150 additions and 1474 deletions

View File

@@ -35,7 +35,9 @@ def get_configuration_space(max_nodes, search_space):
for i in range(1, max_nodes):
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space))
cs.add_hyperparameter(
ConfigSpace.CategoricalHyperparameter(node_str, search_space)
)
return cs
@@ -55,7 +57,15 @@ def config2structure_func(max_nodes):
class MyWorker(Worker):
def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs):
def __init__(
self,
*args,
convert_func=None,
dataname=None,
nas_bench=None,
time_budget=None,
**kwargs
):
super().__init__(*args, **kwargs)
self.convert_func = convert_func
self._dataname = dataname
@@ -70,7 +80,9 @@ class MyWorker(Worker):
assert len(self.seen_archs) > 0
best_index, best_acc = -1, None
for arch_index in self.seen_archs:
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp="200", is_random=True)
info = self._nas_bench.get_more_info(
arch_index, self._dataname, None, hp="200", is_random=True
)
vacc = info["valid-accuracy"]
if best_acc is None or best_acc < vacc:
best_acc = vacc
@@ -82,7 +94,9 @@ class MyWorker(Worker):
start_time = time.time()
structure = self.convert_func(config)
arch_index = self._nas_bench.query_index_by_arch(structure)
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp="200", is_random=True)
info = self._nas_bench.get_more_info(
arch_index, self._dataname, None, hp="200", is_random=True
)
cur_time = info["train-all-time"] + info["valid-per-time"]
cur_vacc = info["valid-accuracy"]
self.real_cost_time += time.time() - start_time
@@ -101,7 +115,11 @@ class MyWorker(Worker):
self.is_end = True
return {
"loss": 100,
"info": {"seen-arch": len(self.seen_archs), "sim-test-time": self.sim_cost_time, "current-arch": None},
"info": {
"seen-arch": len(self.seen_archs),
"sim-test-time": self.sim_cost_time,
"current-arch": None,
},
}
@@ -119,13 +137,17 @@ def main(xargs, nas_bench):
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
split_Fpath = "configs/nas-benchmark/cifar-split.txt"
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log("Load split file from {:}".format(split_Fpath))
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger)
config = load_config(
config_path, {"class_num": class_num, "xshape": xshape}, logger
)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
@@ -152,7 +174,11 @@ def main(xargs, nas_bench):
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader}
extra_info = {
"config": config,
"train_loader": train_loader,
"valid_loader": valid_loader,
}
else:
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, None, logger)
@@ -213,7 +239,11 @@ def main(xargs, nas_bench):
id2config = results.get_id2config_mapping()
incumbent = results.get_incumbent_id()
logger.log("Best found configuration: {:} within {:.3f} s".format(id2config[incumbent]["config"], real_cost_time))
logger.log(
"Best found configuration: {:} within {:.3f} s".format(
id2config[incumbent]["config"], real_cost_time
)
)
best_arch = config2structure(id2config[incumbent]["config"])
info = nas_bench.query_by_arch(best_arch, "200")
@@ -223,13 +253,19 @@ def main(xargs, nas_bench):
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.log("workers : {:.1f}s with {:} archs".format(workers[0].time_budget, len(workers[0].seen_archs)))
logger.log(
"workers : {:.1f}s with {:} archs".format(
workers[0].time_budget, len(workers[0].seen_archs)
)
)
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch(best_arch), real_cost_time
if __name__ == "__main__":
parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale")
parser = argparse.ArgumentParser(
"BOHB: Robust and Efficient Hyperparameter Optimization at Scale"
)
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
@@ -241,28 +277,71 @@ if __name__ == "__main__":
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
parser.add_argument(
"--time_budget",
type=int,
help="The total time cost budge for searching (in seconds).",
)
# BOHB
parser.add_argument(
"--strategy", default="sampling", type=str, nargs="?", help="optimization strategy for the acquisition function"
)
parser.add_argument("--min_bandwidth", default=0.3, type=float, nargs="?", help="minimum bandwidth for KDE")
parser.add_argument(
"--num_samples", default=64, type=int, nargs="?", help="number of samples for the acquisition function"
"--strategy",
default="sampling",
type=str,
nargs="?",
help="optimization strategy for the acquisition function",
)
parser.add_argument(
"--random_fraction", default=0.33, type=float, nargs="?", help="fraction of random configurations"
"--min_bandwidth",
default=0.3,
type=float,
nargs="?",
help="minimum bandwidth for KDE",
)
parser.add_argument("--bandwidth_factor", default=3, type=int, nargs="?", help="factor multiplied to the bandwidth")
parser.add_argument(
"--n_iters", default=100, type=int, nargs="?", help="number of iterations for optimization method"
"--num_samples",
default=64,
type=int,
nargs="?",
help="number of samples for the acquisition function",
)
parser.add_argument(
"--random_fraction",
default=0.33,
type=float,
nargs="?",
help="fraction of random configurations",
)
parser.add_argument(
"--bandwidth_factor",
default=3,
type=int,
nargs="?",
help="factor multiplied to the bandwidth",
)
parser.add_argument(
"--n_iters",
default=100,
type=int,
nargs="?",
help="number of iterations for optimization method",
)
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)."
"--workers",
type=int,
default=2,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument(
"--arch_nas_dataset",
type=str,
help="The path to load the architecture dataset (tiny-nas-benchmark).",
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
@@ -271,7 +350,11 @@ if __name__ == "__main__":
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset))
print(
"{:} build NAS-Benchmark-API from {:}".format(
time_string(), args.arch_nas_dataset
)
)
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num, all_times = None, [], 500, []