Add int search space
This commit is contained in:
@@ -23,7 +23,13 @@ if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from procedures import (
|
||||
prepare_seed,
|
||||
prepare_logger,
|
||||
save_checkpoint,
|
||||
copy_checkpoint,
|
||||
get_optim_scheduler,
|
||||
)
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import CellStructure, get_search_spaces
|
||||
@@ -103,7 +109,15 @@ def mutate_size_func(info):
|
||||
|
||||
|
||||
def regularized_evolution(
|
||||
cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset
|
||||
cycles,
|
||||
population_size,
|
||||
sample_size,
|
||||
time_budget,
|
||||
random_arch,
|
||||
mutate_arch,
|
||||
api,
|
||||
use_proxy,
|
||||
dataset,
|
||||
):
|
||||
"""Algorithm for regularized evolution (i.e. aging evolution).
|
||||
|
||||
@@ -122,7 +136,10 @@ def regularized_evolution(
|
||||
"""
|
||||
population = collections.deque()
|
||||
api.reset_time()
|
||||
history, total_time_cost = [], [] # Not used by the algorithm, only used to report results.
|
||||
history, total_time_cost = (
|
||||
[],
|
||||
[],
|
||||
) # Not used by the algorithm, only used to report results.
|
||||
current_best_index = []
|
||||
# Initialize the population with random models.
|
||||
while len(population) < population_size:
|
||||
@@ -135,7 +152,9 @@ def regularized_evolution(
|
||||
population.append(model)
|
||||
history.append((model.accuracy, model.arch))
|
||||
total_time_cost.append(total_cost)
|
||||
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
|
||||
current_best_index.append(
|
||||
api.query_index_by_arch(max(history, key=lambda x: x[0])[1])
|
||||
)
|
||||
|
||||
# Carry out evolution in cycles. Each cycle produces a model and removes another.
|
||||
while total_time_cost[-1] < time_budget:
|
||||
@@ -160,7 +179,9 @@ def regularized_evolution(
|
||||
# Append the info
|
||||
population.append(child)
|
||||
history.append((child.accuracy, child.arch))
|
||||
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
|
||||
current_best_index.append(
|
||||
api.query_index_by_arch(max(history, key=lambda x: x[0])[1])
|
||||
)
|
||||
total_time_cost.append(total_cost)
|
||||
|
||||
# Remove the oldest model.
|
||||
@@ -183,7 +204,10 @@ def main(xargs, api):
|
||||
|
||||
x_start_time = time.time()
|
||||
logger.log("{:} use api : {:}".format(time_string(), api))
|
||||
logger.log("-" * 30 + " start searching with the time budget of {:} s".format(xargs.time_budget))
|
||||
logger.log(
|
||||
"-" * 30
|
||||
+ " start searching with the time budget of {:} s".format(xargs.time_budget)
|
||||
)
|
||||
history, current_best_index, total_times = regularized_evolution(
|
||||
xargs.ea_cycles,
|
||||
xargs.ea_population,
|
||||
@@ -203,7 +227,9 @@ def main(xargs, api):
|
||||
best_arch = max(history, key=lambda x: x[0])[1]
|
||||
logger.log("{:} best arch is {:}".format(time_string(), best_arch))
|
||||
|
||||
info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90")
|
||||
info = api.query_info_str_by_arch(
|
||||
best_arch, "200" if xargs.search_space == "tss" else "90"
|
||||
)
|
||||
logger.log("{:}".format(info))
|
||||
logger.log("-" * 100)
|
||||
logger.close()
|
||||
@@ -218,19 +244,39 @@ if __name__ == "__main__":
|
||||
choices=["cifar10", "cifar100", "ImageNet16-120"],
|
||||
help="Choose between Cifar10/100 and ImageNet-16.",
|
||||
)
|
||||
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
|
||||
parser.add_argument(
|
||||
"--search_space",
|
||||
type=str,
|
||||
choices=["tss", "sss"],
|
||||
help="Choose the search space.",
|
||||
)
|
||||
# hyperparameters for REA
|
||||
parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.")
|
||||
parser.add_argument("--ea_population", type=int, help="The population size in EA.")
|
||||
parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.")
|
||||
parser.add_argument(
|
||||
"--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)."
|
||||
"--time_budget",
|
||||
type=int,
|
||||
default=20000,
|
||||
help="The total time cost budge for searching (in seconds).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_proxy",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Whether to use the proxy (H0) task or not.",
|
||||
)
|
||||
parser.add_argument("--use_proxy", type=int, default=1, help="Whether to use the proxy (H0) task or not.")
|
||||
#
|
||||
parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.")
|
||||
parser.add_argument(
|
||||
"--loops_if_rand", type=int, default=500, help="The total runs for evaluation."
|
||||
)
|
||||
# log
|
||||
parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.")
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./output/search",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -238,7 +284,9 @@ if __name__ == "__main__":
|
||||
|
||||
args.save_dir = os.path.join(
|
||||
"{:}-{:}".format(args.save_dir, args.search_space),
|
||||
"{:}-T{:}{:}".format(args.dataset, args.time_budget, "" if args.use_proxy > 0 else "-FULL"),
|
||||
"{:}-T{:}{:}".format(
|
||||
args.dataset, args.time_budget, "" if args.use_proxy > 0 else "-FULL"
|
||||
),
|
||||
"R-EA-SS{:}".format(args.ea_sample_size),
|
||||
)
|
||||
print("save-dir : {:}".format(args.save_dir))
|
||||
|
Reference in New Issue
Block a user