Update REA, REINFORCE, and RANDOM
This commit is contained in:
@@ -155,7 +155,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
|
||||
population = collections.deque()
|
||||
api.reset_time()
|
||||
history, total_time_cost = [], [] # Not used by the algorithm, only used to report results.
|
||||
|
||||
current_best_index = []
|
||||
# Initialize the population with random models.
|
||||
while len(population) < population_size:
|
||||
model = Model()
|
||||
@@ -163,8 +163,9 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
|
||||
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, '12')
|
||||
# Append the info
|
||||
population.append(model)
|
||||
history.append(model)
|
||||
history.append((model.accuracy, model.arch))
|
||||
total_time_cost.append(total_cost)
|
||||
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
|
||||
|
||||
# Carry out evolution in cycles. Each cycle produces a model and removes another.
|
||||
while total_time_cost[-1] < time_budget:
|
||||
@@ -183,15 +184,16 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
|
||||
# Create the child model and store it.
|
||||
child = Model()
|
||||
child.arch = mutate_arch(parent.arch)
|
||||
child.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, '12')
|
||||
child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, '12')
|
||||
# Append the info
|
||||
population.append(child)
|
||||
history.append(child)
|
||||
history.append((child.accuracy, child.arch))
|
||||
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
|
||||
total_time_cost.append(total_cost)
|
||||
|
||||
# Remove the oldest model.
|
||||
population.popleft()
|
||||
return history, total_time_cost
|
||||
return history, current_best_index, total_time_cost
|
||||
|
||||
|
||||
def main(xargs, api):
|
||||
@@ -210,7 +212,7 @@ def main(xargs, api):
|
||||
x_start_time = time.time()
|
||||
logger.log('{:} use api : {:}'.format(time_string(), api))
|
||||
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
|
||||
history, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset)
|
||||
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset)
|
||||
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time))
|
||||
best_arch = max(history, key=lambda i: i.accuracy)
|
||||
best_arch = best_arch.arch
|
||||
@@ -220,7 +222,7 @@ def main(xargs, api):
|
||||
logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
logger.close()
|
||||
return logger.log_dir, [api.query_index_by_arch(x.arch) for x in history], total_times
|
||||
return logger.log_dir, current_best_index, total_times
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -249,7 +251,7 @@ if __name__ == '__main__':
|
||||
print('save-dir : {:}'.format(args.save_dir))
|
||||
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_info = None, {}
|
||||
save_dir, all_info = None, collections.OrderedDict()
|
||||
for i in range(args.loops_if_rand):
|
||||
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
|
Reference in New Issue
Block a user