Update for Rebuttal

This commit is contained in:
D-X-Y
2020-12-01 12:34:00 +08:00
parent 29428bf5a3
commit 8afb62ad2e
5 changed files with 96 additions and 6 deletions

View File

@@ -95,7 +95,7 @@ def mutate_size_func(info):
return mutate_size_func
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, dataset):
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset):
"""Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
@@ -119,7 +119,10 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
while len(population) < population_size:
model = Model()
model.arch = random_arch()
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
if use_proxy:
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
else:
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp=api.full_train_epochs)
# Append the info
population.append(model)
history.append((model.accuracy, model.arch))
@@ -171,7 +174,11 @@ def main(xargs, api):
x_start_time = time.time()
logger.log('{:} use api : {:}'.format(time_string(), api))
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset)
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles,
xargs.ea_population,
xargs.ea_sample_size,
xargs.time_budget,
random_arch, mutate_arch, api, xargs.use_proxy > 0, xargs.dataset)
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time))
best_arch = max(history, key=lambda x: x[0])[1]
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
@@ -187,11 +194,13 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
# channels and number-of-cells
# hyperparameters for REA
parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.')
parser.add_argument('--ea_population', type=int, help='The population size in EA.')
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--use_proxy', type=int, default=1, help='Whether to use the proxy (H0) task or not.')
#
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
@@ -201,7 +210,8 @@ if __name__ == '__main__':
api = create(None, args.search_space, fast_mode=True, verbose=False)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
'{:}-T{:}'.format(args.dataset, args.time_budget), 'R-EA-SS{:}'.format(args.ea_sample_size))
'{:}-T{:}{:}'.format(args.dataset, args.time_budget, '' if args.use_proxy > 0 else '-FULL'),
'R-EA-SS{:}'.format(args.ea_sample_size))
print('save-dir : {:}'.format(args.save_dir))
print('xargs : {:}'.format(args))