Update REA, REINFORCE, and RANDOM

This commit is contained in:
D-X-Y
2020-07-13 11:35:13 +00:00
parent 6dc494be08
commit ebad9197f7
5 changed files with 38 additions and 26 deletions

View File

@@ -155,7 +155,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
population = collections.deque()
api.reset_time()
history, total_time_cost = [], [] # Not used by the algorithm, only used to report results.
current_best_index = []
# Initialize the population with random models.
while len(population) < population_size:
model = Model()
@@ -163,8 +163,9 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, '12')
# Append the info
population.append(model)
history.append(model)
history.append((model.accuracy, model.arch))
total_time_cost.append(total_cost)
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
# Carry out evolution in cycles. Each cycle produces a model and removes another.
while total_time_cost[-1] < time_budget:
@@ -183,15 +184,16 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
# Create the child model and store it.
child = Model()
child.arch = mutate_arch(parent.arch)
child.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, '12')
child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, '12')
# Append the info
population.append(child)
history.append(child)
history.append((child.accuracy, child.arch))
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
total_time_cost.append(total_cost)
# Remove the oldest model.
population.popleft()
return history, total_time_cost
return history, current_best_index, total_time_cost
def main(xargs, api):
@@ -210,7 +212,7 @@ def main(xargs, api):
x_start_time = time.time()
logger.log('{:} use api : {:}'.format(time_string(), api))
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
history, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset)
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset)
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time))
best_arch = max(history, key=lambda i: i.accuracy)
best_arch = best_arch.arch
@@ -220,7 +222,7 @@ def main(xargs, api):
logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, [api.query_index_by_arch(x.arch) for x in history], total_times
return logger.log_dir, current_best_index, total_times
if __name__ == '__main__':
@@ -249,7 +251,7 @@ if __name__ == '__main__':
print('save-dir : {:}'.format(args.save_dir))
if args.rand_seed < 0:
save_dir, all_info = None, {}
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)