update GDAS

This commit is contained in:
D-X-Y
2019-11-19 11:58:04 +11:00
parent c3672648d7
commit 09d68c6375
20 changed files with 1176 additions and 90 deletions

View File

@@ -69,7 +69,7 @@ class MyWorker(Worker):
'info': None})
def main(xargs):
def main(xargs, nas_bench):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
@@ -111,7 +111,7 @@ def main(xargs):
ns_host, ns_port = NS.start()
num_workers = 1
nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
#nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
logger.log('{:} Create AA-NAS-BENCH-API DONE'.format(time_string()))
workers = []
for i in range(num_workers):
@@ -140,15 +140,14 @@ def main(xargs):
logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config']))
best_arch = config2structure( id2config[incumbent]['config'] )
if nas_bench is not None:
info = nas_bench.query_by_arch( best_arch )
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
info = nas_bench.query_by_arch( best_arch )
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
logger.log('workers : {:}'.format(workers[0].test_time))
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
@@ -175,5 +174,19 @@ if __name__ == '__main__':
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
main(args)
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
nas_bench = AANASBenchAPI(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num = None, [], 500
for i in range(num):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index = main(args, nas_bench)
all_indexes.append( index )
torch.save(all_indexes, save_dir / 'results.pth')
else:
main(args, nas_bench)