update NAS-Bench-102 baselines

This commit is contained in:
D-X-Y
2019-12-24 17:36:47 +11:00
parent af4212b4db
commit 44a0d51449
18 changed files with 105 additions and 110 deletions

View File

@@ -52,14 +52,18 @@ def main(xargs, nas_bench):
random_arch = random_architecture_func(xargs.max_nodes, search_space)
#x =random_arch() ; y = mutate_arch(x)
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
best_arch, best_acc = None, -1
for idx in range(xargs.random_num):
best_arch, best_acc, total_time_cost, history = None, -1, 0, []
#for idx in range(xargs.random_num):
while total_time_cost < xargs.time_budget:
arch = random_arch()
accuracy = train_and_eval(arch, nas_bench, extra_info)
accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info)
if total_time_cost + cost_time > xargs.time_budget: break
else: total_time_cost += cost_time
history.append(arch)
if best_arch is None or best_acc < accuracy:
best_acc, best_arch = accuracy, arch
logger.log('[{:03d}/{:03d}] : {:} : accuracy = {:.2f}%'.format(idx, xargs.random_num, arch, accuracy))
logger.log('{:} best arch is {:}, accuracy = {:.2f}%'.format(time_string(), best_arch, best_acc))
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost))
info = nas_bench.query_by_arch( best_arch )
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
@@ -79,7 +83,8 @@ if __name__ == '__main__':
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
#parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')