update NAS-Bench-102 baselines
This commit is contained in:
@@ -60,12 +60,12 @@ def train_and_eval(arch, nas_bench, extra_info):
|
||||
arch_index = nas_bench.query_index_by_arch( arch )
|
||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||
info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True)
|
||||
import pdb; pdb.set_trace()
|
||||
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
||||
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
|
||||
else:
|
||||
# train a model from scratch.
|
||||
raise ValueError('NOT IMPLEMENT YET')
|
||||
return valid_acc
|
||||
return valid_acc, time_cost
|
||||
|
||||
|
||||
def random_architecture_func(max_nodes, op_names):
|
||||
@@ -101,7 +101,7 @@ def mutate_arch_func(op_names):
|
||||
return mutate_arch_func
|
||||
|
||||
|
||||
def regularized_evolution(cycles, population_size, sample_size, random_arch, mutate_arch, nas_bench, extra_info):
|
||||
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info):
|
||||
"""Algorithm for regularized evolution (i.e. aging evolution).
|
||||
|
||||
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
|
||||
@@ -111,27 +111,30 @@ def regularized_evolution(cycles, population_size, sample_size, random_arch, mut
|
||||
cycles: the number of cycles the algorithm should run for.
|
||||
population_size: the number of individuals to keep in the population.
|
||||
sample_size: the number of individuals that should participate in each tournament.
|
||||
time_budget: the upper bound of searching cost
|
||||
|
||||
Returns:
|
||||
history: a list of `Model` instances, representing all the models computed
|
||||
during the evolution experiment.
|
||||
"""
|
||||
population = collections.deque()
|
||||
history = [] # Not used by the algorithm, only used to report results.
|
||||
history, total_time_cost = [], 0 # Not used by the algorithm, only used to report results.
|
||||
|
||||
# Initialize the population with random models.
|
||||
while len(population) < population_size:
|
||||
model = Model()
|
||||
model.arch = random_arch()
|
||||
model.accuracy = train_and_eval(model.arch, nas_bench, extra_info)
|
||||
model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info)
|
||||
population.append(model)
|
||||
history.append(model)
|
||||
total_time_cost += time_cost
|
||||
|
||||
# Carry out evolution in cycles. Each cycle produces a model and removes
|
||||
# another.
|
||||
while len(history) < cycles:
|
||||
#while len(history) < cycles:
|
||||
while total_time_cost < time_budget:
|
||||
# Sample randomly chosen models from the current population.
|
||||
sample = []
|
||||
start_time, sample = time.time(), []
|
||||
while len(sample) < sample_size:
|
||||
# Inefficient, but written this way for clarity. In the case of neural
|
||||
# nets, the efficiency of this line is irrelevant because training neural
|
||||
@@ -145,13 +148,18 @@ def regularized_evolution(cycles, population_size, sample_size, random_arch, mut
|
||||
# Create the child model and store it.
|
||||
child = Model()
|
||||
child.arch = mutate_arch(parent.arch)
|
||||
child.accuracy = train_and_eval(child.arch, nas_bench, extra_info)
|
||||
total_time_cost += time.time() - start_time
|
||||
child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info)
|
||||
if total_time_cost + time_cost > time_budget: # return
|
||||
return history, total_time_cost
|
||||
else:
|
||||
total_time_cost += time_cost
|
||||
population.append(child)
|
||||
history.append(child)
|
||||
|
||||
# Remove the oldest model.
|
||||
population.popleft()
|
||||
return history
|
||||
return history, total_time_cost
|
||||
|
||||
|
||||
def main(xargs, nas_bench):
|
||||
@@ -188,8 +196,9 @@ def main(xargs, nas_bench):
|
||||
mutate_arch = mutate_arch_func(search_space)
|
||||
#x =random_arch() ; y = mutate_arch(x)
|
||||
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
||||
history = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info)
|
||||
logger.log('{:} regularized_evolution finish with history of {:} arch.'.format(time_string(), len(history)))
|
||||
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
|
||||
history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info)
|
||||
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s.'.format(time_string(), len(history), total_cost))
|
||||
best_arch = max(history, key=lambda i: i.accuracy)
|
||||
best_arch = best_arch.arch
|
||||
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
|
||||
@@ -216,6 +225,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--ea_population', type=int, help='The population size in EA.')
|
||||
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
|
||||
parser.add_argument('--ea_fast_by_api', type=int, help='Use our API to speed up the experiments or not.')
|
||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
|
Reference in New Issue
Block a user