update NAS-Bench-102 baselines
This commit is contained in:
@@ -52,14 +52,18 @@ def main(xargs, nas_bench):
|
||||
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
||||
#x =random_arch() ; y = mutate_arch(x)
|
||||
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
||||
best_arch, best_acc = None, -1
|
||||
for idx in range(xargs.random_num):
|
||||
best_arch, best_acc, total_time_cost, history = None, -1, 0, []
|
||||
#for idx in range(xargs.random_num):
|
||||
while total_time_cost < xargs.time_budget:
|
||||
arch = random_arch()
|
||||
accuracy = train_and_eval(arch, nas_bench, extra_info)
|
||||
accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info)
|
||||
if total_time_cost + cost_time > xargs.time_budget: break
|
||||
else: total_time_cost += cost_time
|
||||
history.append(arch)
|
||||
if best_arch is None or best_acc < accuracy:
|
||||
best_acc, best_arch = accuracy, arch
|
||||
logger.log('[{:03d}/{:03d}] : {:} : accuracy = {:.2f}%'.format(idx, xargs.random_num, arch, accuracy))
|
||||
logger.log('{:} best arch is {:}, accuracy = {:.2f}%'.format(time_string(), best_arch, best_acc))
|
||||
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
|
||||
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost))
|
||||
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
@@ -79,7 +83,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
|
||||
#parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
|
||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
|
Reference in New Issue
Block a user