Update for Rebuttal
This commit is contained in:
@@ -95,7 +95,7 @@ def mutate_size_func(info):
|
||||
return mutate_size_func
|
||||
|
||||
|
||||
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, dataset):
|
||||
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset):
|
||||
"""Algorithm for regularized evolution (i.e. aging evolution).
|
||||
|
||||
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
|
||||
@@ -119,7 +119,10 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
|
||||
while len(population) < population_size:
|
||||
model = Model()
|
||||
model.arch = random_arch()
|
||||
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
|
||||
if use_proxy:
|
||||
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
|
||||
else:
|
||||
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp=api.full_train_epochs)
|
||||
# Append the info
|
||||
population.append(model)
|
||||
history.append((model.accuracy, model.arch))
|
||||
@@ -171,7 +174,11 @@ def main(xargs, api):
|
||||
x_start_time = time.time()
|
||||
logger.log('{:} use api : {:}'.format(time_string(), api))
|
||||
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
|
||||
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset)
|
||||
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles,
|
||||
xargs.ea_population,
|
||||
xargs.ea_sample_size,
|
||||
xargs.time_budget,
|
||||
random_arch, mutate_arch, api, xargs.use_proxy > 0, xargs.dataset)
|
||||
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time))
|
||||
best_arch = max(history, key=lambda x: x[0])[1]
|
||||
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
|
||||
@@ -187,11 +194,13 @@ if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
|
||||
# channels and number-of-cells
|
||||
# hyperparameters for REA
|
||||
parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.')
|
||||
parser.add_argument('--ea_population', type=int, help='The population size in EA.')
|
||||
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
|
||||
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
|
||||
parser.add_argument('--use_proxy', type=int, default=1, help='Whether to use the proxy (H0) task or not.')
|
||||
#
|
||||
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
|
||||
# log
|
||||
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
|
||||
@@ -201,7 +210,8 @@ if __name__ == '__main__':
|
||||
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
|
||||
'{:}-T{:}'.format(args.dataset, args.time_budget), 'R-EA-SS{:}'.format(args.ea_sample_size))
|
||||
'{:}-T{:}{:}'.format(args.dataset, args.time_budget, '' if args.use_proxy > 0 else '-FULL'),
|
||||
'R-EA-SS{:}'.format(args.ea_sample_size))
|
||||
print('save-dir : {:}'.format(args.save_dir))
|
||||
print('xargs : {:}'.format(args))
|
||||
|
||||
|
Reference in New Issue
Block a user