Add int search space

This commit is contained in:
D-X-Y
2021-03-18 16:02:55 +08:00
parent ece6ac5f41
commit 63c8bb9bc8
67 changed files with 5150 additions and 1474 deletions

View File

@@ -23,7 +23,13 @@ if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
get_optim_scheduler,
)
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import CellStructure, get_search_spaces
@@ -103,7 +109,15 @@ def mutate_size_func(info):
def regularized_evolution(
cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset
cycles,
population_size,
sample_size,
time_budget,
random_arch,
mutate_arch,
api,
use_proxy,
dataset,
):
"""Algorithm for regularized evolution (i.e. aging evolution).
@@ -122,7 +136,10 @@ def regularized_evolution(
"""
population = collections.deque()
api.reset_time()
history, total_time_cost = [], [] # Not used by the algorithm, only used to report results.
history, total_time_cost = (
[],
[],
) # Not used by the algorithm, only used to report results.
current_best_index = []
# Initialize the population with random models.
while len(population) < population_size:
@@ -135,7 +152,9 @@ def regularized_evolution(
population.append(model)
history.append((model.accuracy, model.arch))
total_time_cost.append(total_cost)
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
current_best_index.append(
api.query_index_by_arch(max(history, key=lambda x: x[0])[1])
)
# Carry out evolution in cycles. Each cycle produces a model and removes another.
while total_time_cost[-1] < time_budget:
@@ -160,7 +179,9 @@ def regularized_evolution(
# Append the info
population.append(child)
history.append((child.accuracy, child.arch))
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
current_best_index.append(
api.query_index_by_arch(max(history, key=lambda x: x[0])[1])
)
total_time_cost.append(total_cost)
# Remove the oldest model.
@@ -183,7 +204,10 @@ def main(xargs, api):
x_start_time = time.time()
logger.log("{:} use api : {:}".format(time_string(), api))
logger.log("-" * 30 + " start searching with the time budget of {:} s".format(xargs.time_budget))
logger.log(
"-" * 30
+ " start searching with the time budget of {:} s".format(xargs.time_budget)
)
history, current_best_index, total_times = regularized_evolution(
xargs.ea_cycles,
xargs.ea_population,
@@ -203,7 +227,9 @@ def main(xargs, api):
best_arch = max(history, key=lambda x: x[0])[1]
logger.log("{:} best arch is {:}".format(time_string(), best_arch))
info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90")
info = api.query_info_str_by_arch(
best_arch, "200" if xargs.search_space == "tss" else "90"
)
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.close()
@@ -218,19 +244,39 @@ if __name__ == "__main__":
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
parser.add_argument(
"--search_space",
type=str,
choices=["tss", "sss"],
help="Choose the search space.",
)
# hyperparameters for REA
parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.")
parser.add_argument("--ea_population", type=int, help="The population size in EA.")
parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.")
parser.add_argument(
"--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)."
"--time_budget",
type=int,
default=20000,
help="The total time cost budge for searching (in seconds).",
)
parser.add_argument(
"--use_proxy",
type=int,
default=1,
help="Whether to use the proxy (H0) task or not.",
)
parser.add_argument("--use_proxy", type=int, default=1, help="Whether to use the proxy (H0) task or not.")
#
parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.")
parser.add_argument(
"--loops_if_rand", type=int, default=500, help="The total runs for evaluation."
)
# log
parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.")
parser.add_argument(
"--save_dir",
type=str,
default="./output/search",
help="Folder to save checkpoints and log.",
)
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
@@ -238,7 +284,9 @@ if __name__ == "__main__":
args.save_dir = os.path.join(
"{:}-{:}".format(args.save_dir, args.search_space),
"{:}-T{:}{:}".format(args.dataset, args.time_budget, "" if args.use_proxy > 0 else "-FULL"),
"{:}-T{:}{:}".format(
args.dataset, args.time_budget, "" if args.use_proxy > 0 else "-FULL"
),
"R-EA-SS{:}".format(args.ea_sample_size),
)
print("save-dir : {:}".format(args.save_dir))