update GDAS
This commit is contained in:
@@ -69,7 +69,7 @@ class MyWorker(Worker):
|
||||
'info': None})
|
||||
|
||||
|
||||
def main(xargs):
|
||||
def main(xargs, nas_bench):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
@@ -111,7 +111,7 @@ def main(xargs):
|
||||
ns_host, ns_port = NS.start()
|
||||
num_workers = 1
|
||||
|
||||
nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
|
||||
#nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
|
||||
logger.log('{:} Create AA-NAS-BENCH-API DONE'.format(time_string()))
|
||||
workers = []
|
||||
for i in range(num_workers):
|
||||
@@ -140,15 +140,14 @@ def main(xargs):
|
||||
logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config']))
|
||||
best_arch = config2structure( id2config[incumbent]['config'] )
|
||||
|
||||
if nas_bench is not None:
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
else : logger.log('{:}'.format(info))
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
else : logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
|
||||
logger.log('workers : {:}'.format(workers[0].test_time))
|
||||
|
||||
logger.close()
|
||||
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
|
||||
|
||||
|
||||
|
||||
@@ -175,5 +174,19 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
main(args)
|
||||
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
|
||||
nas_bench = None
|
||||
else:
|
||||
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
|
||||
nas_bench = AANASBenchAPI(args.arch_nas_dataset)
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_indexes, num = None, [], 500
|
||||
for i in range(num):
|
||||
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, index = main(args, nas_bench)
|
||||
all_indexes.append( index )
|
||||
torch.save(all_indexes, save_dir / 'results.pth')
|
||||
else:
|
||||
main(args, nas_bench)
|
||||
|
@@ -19,7 +19,7 @@ from aa_nas_api import AANASBenchAPI
|
||||
from R_EA import train_and_eval, random_architecture_func
|
||||
|
||||
|
||||
def main(xargs):
|
||||
def main(xargs, nas_bench):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
@@ -51,12 +51,6 @@ def main(xargs):
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
||||
#x =random_arch() ; y = mutate_arch(x)
|
||||
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||
nas_bench = None
|
||||
else:
|
||||
logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset))
|
||||
nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
|
||||
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
||||
best_arch, best_acc = None, -1
|
||||
for idx in range(xargs.random_num):
|
||||
@@ -67,13 +61,12 @@ def main(xargs):
|
||||
logger.log('[{:03d}/{:03d}] : {:} : accuracy = {:.2f}%'.format(idx, xargs.random_num, arch, accuracy))
|
||||
logger.log('{:} best arch is {:}, accuracy = {:.2f}%'.format(time_string(), best_arch, best_acc))
|
||||
|
||||
if nas_bench is not None:
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
else : logger.log('{:}'.format(info))
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
else : logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
|
||||
logger.close()
|
||||
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
|
||||
|
||||
|
||||
|
||||
@@ -94,5 +87,19 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
main(args)
|
||||
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
|
||||
nas_bench = None
|
||||
else:
|
||||
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
|
||||
nas_bench = AANASBenchAPI(args.arch_nas_dataset)
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_indexes, num = None, [], 500
|
||||
for i in range(num):
|
||||
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, index = main(args, nas_bench)
|
||||
all_indexes.append( index )
|
||||
torch.save(all_indexes, save_dir / 'results.pth')
|
||||
else:
|
||||
main(args, nas_bench)
|
||||
|
@@ -60,7 +60,8 @@ def train_and_eval(arch, nas_bench, extra_info):
|
||||
arch_index = nas_bench.query_index_by_arch( arch )
|
||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||
info = nas_bench.arch2infos[ arch_index ]
|
||||
_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25) # use the validation accuracy after 25 training epochs
|
||||
_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
|
||||
#import pdb; pdb.set_trace()
|
||||
else:
|
||||
# train a model from scratch.
|
||||
raise ValueError('NOT IMPLEMENT YET')
|
||||
@@ -153,7 +154,7 @@ def regularized_evolution(cycles, population_size, sample_size, random_arch, mut
|
||||
return history
|
||||
|
||||
|
||||
def main(xargs):
|
||||
def main(xargs, nas_bench):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
@@ -186,12 +187,6 @@ def main(xargs):
|
||||
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
||||
mutate_arch = mutate_arch_func(search_space)
|
||||
#x =random_arch() ; y = mutate_arch(x)
|
||||
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||
nas_bench = None
|
||||
else:
|
||||
logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset))
|
||||
nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
|
||||
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
||||
history = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info)
|
||||
logger.log('{:} regularized_evolution finish with history of {:} arch.'.format(time_string(), len(history)))
|
||||
@@ -199,13 +194,12 @@ def main(xargs):
|
||||
best_arch = best_arch.arch
|
||||
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
|
||||
|
||||
if nas_bench is not None:
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
else : logger.log('{:}'.format(info))
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
else : logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
|
||||
logger.close()
|
||||
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
|
||||
|
||||
|
||||
|
||||
@@ -227,8 +221,23 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
args.ea_fast_by_api = args.ea_fast_by_api > 0
|
||||
main(args)
|
||||
|
||||
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
|
||||
nas_bench = None
|
||||
else:
|
||||
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
|
||||
nas_bench = AANASBenchAPI(args.arch_nas_dataset)
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_indexes, num = None, [], 500
|
||||
for i in range(num):
|
||||
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, index = main(args, nas_bench)
|
||||
all_indexes.append( index )
|
||||
torch.save(all_indexes, save_dir / 'results.pth')
|
||||
else:
|
||||
main(args, nas_bench)
|
||||
|
@@ -89,7 +89,7 @@ def select_action(policy):
|
||||
return m.log_prob(action), action.cpu().tolist()
|
||||
|
||||
|
||||
def main(xargs):
|
||||
def main(xargs, nas_bench):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
@@ -128,12 +128,6 @@ def main(xargs):
|
||||
logger.log('eps : {:}'.format(eps))
|
||||
|
||||
# nas dataset load
|
||||
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||
nas_bench = None
|
||||
else:
|
||||
logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset))
|
||||
nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
|
||||
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
||||
|
||||
# REINFORCE
|
||||
@@ -156,13 +150,12 @@ def main(xargs):
|
||||
|
||||
best_arch = policy.genotype()
|
||||
|
||||
if nas_bench is not None:
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
else : logger.log('{:}'.format(info))
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
else : logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
|
||||
logger.close()
|
||||
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
|
||||
|
||||
|
||||
|
||||
@@ -183,7 +176,21 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
main(args)
|
||||
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
|
||||
nas_bench = None
|
||||
else:
|
||||
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
|
||||
nas_bench = AANASBenchAPI(args.arch_nas_dataset)
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_indexes, num = None, [], 500
|
||||
for i in range(num):
|
||||
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, index = main(args, nas_bench)
|
||||
all_indexes.append( index )
|
||||
torch.save(all_indexes, save_dir / 'results.pth')
|
||||
else:
|
||||
main(args, nas_bench)
|
||||
|
Reference in New Issue
Block a user