diff --git a/exps/KD-main.py b/exps/KD-main.py index 0670222..d130b91 100644 --- a/exps/KD-main.py +++ b/exps/KD-main.py @@ -2,161 +2,224 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # ##################################################### import sys, time, torch, random, argparse -from PIL import ImageFile +from PIL import ImageFile + ImageFile.LOAD_TRUNCATED_IMAGES = True -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, obtain_cls_kd_args as obtain_args -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from procedures import get_optim_scheduler, get_procedures -from datasets import get_datasets -from models import obtain_model, load_net_from_checkpoint -from utils import get_model_infos -from log_utils import AverageMeter, time_string, convert_secs2time +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint +from procedures import get_optim_scheduler, get_procedures +from datasets import get_datasets +from models import obtain_model, load_net_from_checkpoint +from utils import get_model_infos +from log_utils import AverageMeter, time_string, convert_secs2time def main(args): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = True - #torch.backends.cudnn.deterministic = True - torch.set_num_threads( args.workers ) - - prepare_seed(args.rand_seed) - logger = prepare_logger(args) - - train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) - train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) - # get configures - model_config = load_config(args.model_config, {'class_num': class_num}, logger) - optim_config = load_config(args.optim_config, - {'class_num': class_num, 'KD_alpha': args.KD_alpha, 'KD_temperature': args.KD_temperature}, - logger) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + torch.set_num_threads(args.workers) - # load checkpoint - teacher_base = load_net_from_checkpoint(args.KD_checkpoint) - teacher = torch.nn.DataParallel(teacher_base).cuda() + prepare_seed(args.rand_seed) + logger = prepare_logger(args) - base_model = obtain_model(model_config) - flop, param = get_model_infos(base_model, xshape) - logger.log('Student ====>>>>:\n{:}'.format(base_model)) - logger.log('Teacher ====>>>>:\n{:}'.format(teacher_base)) - logger.log('model information : {:}'.format(base_model.get_message())) - logger.log('-'*50) - logger.log('Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G'.format(param, flop, flop/1e3)) - logger.log('-'*50) - logger.log('train_data : {:}'.format(train_data)) - logger.log('valid_data : {:}'.format(valid_data)) - optimizer, scheduler, criterion = get_optim_scheduler(base_model.parameters(), optim_config) - logger.log('optimizer : {:}'.format(optimizer)) - logger.log('scheduler : {:}'.format(scheduler)) - logger.log('criterion : {:}'.format(criterion)) - - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() + train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) + train_loader = torch.utils.data.DataLoader( + train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True + ) + # get configures + model_config = load_config(args.model_config, {"class_num": class_num}, logger) + optim_config = load_config( + args.optim_config, + {"class_num": class_num, "KD_alpha": args.KD_alpha, "KD_temperature": args.KD_temperature}, + logger, + ) - if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) - last_info = torch.load(last_info) - start_epoch = last_info['epoch'] + 1 - checkpoint = torch.load(last_info['last_checkpoint']) - base_model.load_state_dict( checkpoint['base-model'] ) - scheduler.load_state_dict ( checkpoint['scheduler'] ) - optimizer.load_state_dict ( checkpoint['optimizer'] ) - valid_accuracies = checkpoint['valid_accuracies'] - max_bytes = checkpoint['max_bytes'] - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) - elif args.resume is not None: - assert Path(args.resume).exists(), 'Can not find the resume file : {:}'.format(args.resume) - checkpoint = torch.load( args.resume ) - start_epoch = checkpoint['epoch'] + 1 - base_model.load_state_dict( checkpoint['base-model'] ) - scheduler.load_state_dict ( checkpoint['scheduler'] ) - optimizer.load_state_dict ( checkpoint['optimizer'] ) - valid_accuracies = checkpoint['valid_accuracies'] - max_bytes = checkpoint['max_bytes'] - logger.log("=> loading checkpoint from '{:}' start with {:}-th epoch.".format(args.resume, start_epoch)) - elif args.init_model is not None: - assert Path(args.init_model).exists(), 'Can not find the initialization file : {:}'.format(args.init_model) - checkpoint = torch.load( args.init_model ) - base_model.load_state_dict( checkpoint['base-model'] ) - start_epoch, valid_accuracies, max_bytes = 0, {'best': -1}, {} - logger.log('=> initialize the model from {:}'.format( args.init_model )) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, max_bytes = 0, {'best': -1}, {} + # load checkpoint + teacher_base = load_net_from_checkpoint(args.KD_checkpoint) + teacher = torch.nn.DataParallel(teacher_base).cuda() - train_func, valid_func = get_procedures(args.procedure) - - total_epoch = optim_config.epochs + optim_config.warmup - # Main Training and Evaluation Loop - start_time = time.time() - epoch_time = AverageMeter() - for epoch in range(start_epoch, total_epoch): - scheduler.update(epoch, 0.0) - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch), True) ) - epoch_str = 'epoch={:03d}/{:03d}'.format(epoch, total_epoch) - LRs = scheduler.get_lr() - find_best = False + base_model = obtain_model(model_config) + flop, param = get_model_infos(base_model, xshape) + logger.log("Student ====>>>>:\n{:}".format(base_model)) + logger.log("Teacher ====>>>>:\n{:}".format(teacher_base)) + logger.log("model information : {:}".format(base_model.get_message())) + logger.log("-" * 50) + logger.log("Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format(param, flop, flop / 1e3)) + logger.log("-" * 50) + logger.log("train_data : {:}".format(train_data)) + logger.log("valid_data : {:}".format(valid_data)) + optimizer, scheduler, criterion = get_optim_scheduler(base_model.parameters(), optim_config) + logger.log("optimizer : {:}".format(optimizer)) + logger.log("scheduler : {:}".format(scheduler)) + logger.log("criterion : {:}".format(criterion)) - logger.log('\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler)) - - # train for one epoch - train_loss, train_acc1, train_acc5 = train_func(train_loader, teacher, network, criterion, scheduler, optimizer, optim_config, epoch_str, args.print_freq, logger) - # log the results - logger.log('***{:s}*** TRAIN [{:}] loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}'.format(time_string(), epoch_str, train_loss, train_acc1, train_acc5)) + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() - # evaluate the performance - if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): - logger.log('-'*150) - valid_loss, valid_acc1, valid_acc5 = valid_func(valid_loader, teacher, network, criterion, optim_config, epoch_str, args.print_freq_eval, logger) - valid_accuracies[epoch] = valid_acc1 - logger.log('***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}'.format(time_string(), epoch_str, valid_loss, valid_acc1, valid_acc5, valid_accuracies['best'], 100-valid_accuracies['best'])) - if valid_acc1 > valid_accuracies['best']: - valid_accuracies['best'] = valid_acc1 - find_best = True - logger.log('Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.'.format(epoch, valid_acc1, valid_acc5, 100-valid_acc1, 100-valid_acc5, model_best_path)) - num_bytes = torch.cuda.max_memory_cached( next(network.parameters()).device ) * 1.0 - logger.log('[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]'.format(next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9)) - max_bytes[epoch] = num_bytes - if epoch % 10 == 0: torch.cuda.empty_cache() + if last_info.exists(): # automatically resume from previous checkpoint + logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + last_info = torch.load(last_info) + start_epoch = last_info["epoch"] + 1 + checkpoint = torch.load(last_info["last_checkpoint"]) + base_model.load_state_dict(checkpoint["base-model"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + valid_accuracies = checkpoint["valid_accuracies"] + max_bytes = checkpoint["max_bytes"] + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + ) + elif args.resume is not None: + assert Path(args.resume).exists(), "Can not find the resume file : {:}".format(args.resume) + checkpoint = torch.load(args.resume) + start_epoch = checkpoint["epoch"] + 1 + base_model.load_state_dict(checkpoint["base-model"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + valid_accuracies = checkpoint["valid_accuracies"] + max_bytes = checkpoint["max_bytes"] + logger.log("=> loading checkpoint from '{:}' start with {:}-th epoch.".format(args.resume, start_epoch)) + elif args.init_model is not None: + assert Path(args.init_model).exists(), "Can not find the initialization file : {:}".format(args.init_model) + checkpoint = torch.load(args.init_model) + base_model.load_state_dict(checkpoint["base-model"]) + start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} + logger.log("=> initialize the model from {:}".format(args.init_model)) + else: + logger.log("=> do not find the last-info file : {:}".format(last_info)) + start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} - # save checkpoint - save_path = save_checkpoint({ - 'epoch' : epoch, - 'args' : deepcopy(args), - 'max_bytes' : deepcopy(max_bytes), - 'FLOP' : flop, - 'PARAM' : param, - 'valid_accuracies': deepcopy(valid_accuracies), - 'model-config' : model_config._asdict(), - 'optim-config' : optim_config._asdict(), - 'base-model' : base_model.state_dict(), - 'scheduler' : scheduler.state_dict(), - 'optimizer' : optimizer.state_dict(), - }, model_base_path, logger) - if find_best: copy_checkpoint(model_base_path, model_best_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) + train_func, valid_func = get_procedures(args.procedure) - # measure elapsed time - epoch_time.update(time.time() - start_time) + total_epoch = optim_config.epochs + optim_config.warmup + # Main Training and Evaluation Loop start_time = time.time() + epoch_time = AverageMeter() + for epoch in range(start_epoch, total_epoch): + scheduler.update(epoch, 0.0) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch), True)) + epoch_str = "epoch={:03d}/{:03d}".format(epoch, total_epoch) + LRs = scheduler.get_lr() + find_best = False - logger.log('\n' + '-'*200) - logger.log('||| Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G'.format(param, flop, flop/1e3)) - logger.log('Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}'.format(convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e6, logger.path('info'))) - logger.log('-'*200 + '\n') - logger.close() + logger.log( + "\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}".format( + time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler + ) + ) + + # train for one epoch + train_loss, train_acc1, train_acc5 = train_func( + train_loader, + teacher, + network, + criterion, + scheduler, + optimizer, + optim_config, + epoch_str, + args.print_freq, + logger, + ) + # log the results + logger.log( + "***{:s}*** TRAIN [{:}] loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}".format( + time_string(), epoch_str, train_loss, train_acc1, train_acc5 + ) + ) + + # evaluate the performance + if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): + logger.log("-" * 150) + valid_loss, valid_acc1, valid_acc5 = valid_func( + valid_loader, teacher, network, criterion, optim_config, epoch_str, args.print_freq_eval, logger + ) + valid_accuracies[epoch] = valid_acc1 + logger.log( + "***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}".format( + time_string(), + epoch_str, + valid_loss, + valid_acc1, + valid_acc5, + valid_accuracies["best"], + 100 - valid_accuracies["best"], + ) + ) + if valid_acc1 > valid_accuracies["best"]: + valid_accuracies["best"] = valid_acc1 + find_best = True + logger.log( + "Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.".format( + epoch, valid_acc1, valid_acc5, 100 - valid_acc1, 100 - valid_acc5, model_best_path + ) + ) + num_bytes = torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 + logger.log( + "[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]".format( + next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9 + ) + ) + max_bytes[epoch] = num_bytes + if epoch % 10 == 0: + torch.cuda.empty_cache() + + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch, + "args": deepcopy(args), + "max_bytes": deepcopy(max_bytes), + "FLOP": flop, + "PARAM": param, + "valid_accuracies": deepcopy(valid_accuracies), + "model-config": model_config._asdict(), + "optim-config": optim_config._asdict(), + "base-model": base_model.state_dict(), + "scheduler": scheduler.state_dict(), + "optimizer": optimizer.state_dict(), + }, + model_base_path, + logger, + ) + if find_best: + copy_checkpoint(model_base_path, model_best_path, logger) + last_info = save_checkpoint( + { + "epoch": epoch, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("\n" + "-" * 200) + logger.log("||| Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format(param, flop, flop / 1e3)) + logger.log( + "Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}".format( + convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e6, logger.path("info") + ) + ) + logger.log("-" * 200 + "\n") + logger.close() -if __name__ == '__main__': - args = obtain_args() - main(args) +if __name__ == "__main__": + args = obtain_args() + main(args) diff --git a/exps/NAS-Bench-201/check.py b/exps/NAS-Bench-201/check.py index afe8529..4ef4153 100644 --- a/exps/NAS-Bench-201/check.py +++ b/exps/NAS-Bench-201/check.py @@ -7,78 +7,112 @@ import sys, time, argparse, collections import torch from pathlib import Path from collections import defaultdict -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from log_utils import AverageMeter, time_string, convert_secs2time def check_files(save_dir, meta_file, basestr): - meta_infos = torch.load(meta_file, map_location='cpu') - meta_archs = meta_infos['archs'] - meta_num_archs = meta_infos['total'] - assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) + meta_infos = torch.load(meta_file, map_location="cpu") + meta_archs = meta_infos["archs"] + meta_num_archs = meta_infos["total"] + assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format( + meta_num_archs, len(meta_archs) + ) - sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) - print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs))) - - subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 - num_seeds = defaultdict(lambda: 0) - for index, sub_dir in enumerate(sub_model_dirs): - xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth')) - #xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth')) - arch_indexes = set() - for checkpoint in xcheckpoints: - temp_names = checkpoint.name.split('-') - assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name) - arch_indexes.add( temp_names[1] ) - subdir2archs[sub_dir] = sorted(list(arch_indexes)) - num_evaluated_arch += len(arch_indexes) - # count number of seeds for each architecture - for arch_index in arch_indexes: - num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1 - print('There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).'.format(num_evaluated_arch, meta_num_archs, sum(k*v for k, v in num_seeds.items()))) - for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key)) + sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) + print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs))) - dir2ckps, dir2ckp_exists = dict(), dict() - start_time, epoch_time = time.time(), AverageMeter() - for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()): - if basestr == 'C16-N5': - seeds = [777, 888, 999] - elif basestr == 'C16-N5-LESS': - seeds = [111, 777] - else: - raise ValueError('Invalid base str : {:}'.format(basestr)) - numrs = defaultdict(lambda: 0) - all_checkpoints, all_ckp_exists = [], [] - for arch_index in arch_indexes: - checkpoints = ['arch-{:}-seed-{:04d}.pth'.format(arch_index, seed) for seed in seeds] - ckp_exists = [(sub_dir/x).exists() for x in checkpoints] - arch_index = int(arch_index) - assert 0 <= arch_index < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index) - all_checkpoints += checkpoints - all_ckp_exists += ckp_exists - numrs[sum(ckp_exists)] += 1 - dir2ckps[ str(sub_dir) ] = all_checkpoints - dir2ckp_exists[ str(sub_dir) ] = all_ckp_exists - # measure time - epoch_time.update(time.time() - start_time) - start_time = time.time() - numrstr = ', '.join( ['{:}: {:03d}'.format(x, numrs[x]) for x in sorted(numrs.keys())] ) - print('{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}'.format(time_string(), IDX+1, len(subdir2archs), len(arch_indexes), len(all_checkpoints), sum(all_ckp_exists), sub_dir, convert_secs2time(epoch_time.avg * (len(subdir2archs)-IDX-1), True), numrstr)) + subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 + num_seeds = defaultdict(lambda: 0) + for index, sub_dir in enumerate(sub_model_dirs): + xcheckpoints = list(sub_dir.glob("arch-*-seed-*.pth")) + # xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth')) + arch_indexes = set() + for checkpoint in xcheckpoints: + temp_names = checkpoint.name.split("-") + assert ( + len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed" + ), "invalid checkpoint name : {:}".format(checkpoint.name) + arch_indexes.add(temp_names[1]) + subdir2archs[sub_dir] = sorted(list(arch_indexes)) + num_evaluated_arch += len(arch_indexes) + # count number of seeds for each architecture + for arch_index in arch_indexes: + num_seeds[len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))] += 1 + print( + "There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).".format( + num_evaluated_arch, meta_num_archs, sum(k * v for k, v in num_seeds.items()) + ) + ) + for key in sorted(list(num_seeds.keys())): + print("There are {:5d} architectures that are evaluated {:} times.".format(num_seeds[key], key)) + + dir2ckps, dir2ckp_exists = dict(), dict() + start_time, epoch_time = time.time(), AverageMeter() + for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()): + if basestr == "C16-N5": + seeds = [777, 888, 999] + elif basestr == "C16-N5-LESS": + seeds = [111, 777] + else: + raise ValueError("Invalid base str : {:}".format(basestr)) + numrs = defaultdict(lambda: 0) + all_checkpoints, all_ckp_exists = [], [] + for arch_index in arch_indexes: + checkpoints = ["arch-{:}-seed-{:04d}.pth".format(arch_index, seed) for seed in seeds] + ckp_exists = [(sub_dir / x).exists() for x in checkpoints] + arch_index = int(arch_index) + assert 0 <= arch_index < len(meta_archs), "invalid arch-index {:} (not found in meta_archs)".format( + arch_index + ) + all_checkpoints += checkpoints + all_ckp_exists += ckp_exists + numrs[sum(ckp_exists)] += 1 + dir2ckps[str(sub_dir)] = all_checkpoints + dir2ckp_exists[str(sub_dir)] = all_ckp_exists + # measure time + epoch_time.update(time.time() - start_time) + start_time = time.time() + numrstr = ", ".join(["{:}: {:03d}".format(x, numrs[x]) for x in sorted(numrs.keys())]) + print( + "{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}".format( + time_string(), + IDX + 1, + len(subdir2archs), + len(arch_indexes), + len(all_checkpoints), + sum(all_ckp_exists), + sub_dir, + convert_secs2time(epoch_time.avg * (len(subdir2archs) - IDX - 1), True), + numrstr, + ) + ) -if __name__ == '__main__': +if __name__ == "__main__": - parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.') - parser.add_argument('--meta_path', type=str, default='./output/NAS-BENCH-201-4/meta-node-4.pth', help='The meta file path.') - parser.add_argument('--base_str', type=str, default='C16-N5', help='The basic string.') - args = parser.parse_args() + parser = argparse.ArgumentParser( + description="NAS Benchmark 201", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--base_save_dir", + type=str, + default="./output/NAS-BENCH-201-4", + help="The base-name of folder to save checkpoints and log.", + ) + parser.add_argument( + "--meta_path", type=str, default="./output/NAS-BENCH-201-4/meta-node-4.pth", help="The meta file path." + ) + parser.add_argument("--base_str", type=str, default="C16-N5", help="The basic string.") + args = parser.parse_args() - save_dir = Path(args.base_save_dir) - meta_path = Path(args.meta_path) - assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir) - assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) - print ('check NAS-Bench-201 in {:}'.format(save_dir)) + save_dir = Path(args.base_save_dir) + meta_path = Path(args.meta_path) + assert save_dir.exists(), "invalid save dir path : {:}".format(save_dir) + assert meta_path.exists(), "invalid saved meta path : {:}".format(meta_path) + print("check NAS-Bench-201 in {:}".format(save_dir)) - check_files(save_dir, meta_path, args.base_str) + check_files(save_dir, meta_path, args.base_str) diff --git a/exps/NAS-Bench-201/dist-setup.py b/exps/NAS-Bench-201/dist-setup.py index a271ae8..0103f3f 100644 --- a/exps/NAS-Bench-201/dist-setup.py +++ b/exps/NAS-Bench-201/dist-setup.py @@ -9,23 +9,23 @@ import os from setuptools import setup -def read(fname='README.md'): - with open(os.path.join(os.path.dirname(__file__), fname), encoding='utf-8') as cfile: - return cfile.read() +def read(fname="README.md"): + with open(os.path.join(os.path.dirname(__file__), fname), encoding="utf-8") as cfile: + return cfile.read() setup( - name = "nas_bench_201", - version = "2.0", - author = "Xuanyi Dong", - author_email = "dongxuanyi888@gmail.com", - description = "API for NAS-Bench-201 (a benchmark for neural architecture search).", - license = "MIT", - keywords = "NAS Dataset API DeepLearning", - url = "https://github.com/D-X-Y/NAS-Bench-201", - packages=['nas_201_api'], - long_description=read('README.md'), - long_description_content_type='text/markdown', + name="nas_bench_201", + version="2.0", + author="Xuanyi Dong", + author_email="dongxuanyi888@gmail.com", + description="API for NAS-Bench-201 (a benchmark for neural architecture search).", + license="MIT", + keywords="NAS Dataset API DeepLearning", + url="https://github.com/D-X-Y/NAS-Bench-201", + packages=["nas_201_api"], + long_description=read("README.md"), + long_description_content_type="text/markdown", classifiers=[ "Programming Language :: Python", "Topic :: Database", diff --git a/exps/NAS-Bench-201/functions.py b/exps/NAS-Bench-201/functions.py index a1dabae..d6eed59 100644 --- a/exps/NAS-Bench-201/functions.py +++ b/exps/NAS-Bench-201/functions.py @@ -2,133 +2,162 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # ##################################################### import time, torch -from procedures import prepare_seed, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy +from procedures import prepare_seed, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy from config_utils import dict2config -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_cell_based_tiny_net -__all__ = ['evaluate_for_seed', 'pure_evaluate'] +__all__ = ["evaluate_for_seed", "pure_evaluate"] def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): - data_time, batch_time, batch = AverageMeter(), AverageMeter(), None - losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() - latencies = [] - network.eval() - with torch.no_grad(): - end = time.time() - for i, (inputs, targets) in enumerate(xloader): - targets = targets.cuda(non_blocking=True) - inputs = inputs.cuda(non_blocking=True) - data_time.update(time.time() - end) - # forward - features, logits = network(inputs) - loss = criterion(logits, targets) - batch_time.update(time.time() - end) - if batch is None or batch == inputs.size(0): - batch = inputs.size(0) - latencies.append( batch_time.val - data_time.val ) - # record loss and accuracy - prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) - losses.update(loss.item(), inputs.size(0)) - top1.update (prec1.item(), inputs.size(0)) - top5.update (prec5.item(), inputs.size(0)) - end = time.time() - if len(latencies) > 2: latencies = latencies[1:] - return losses.avg, top1.avg, top5.avg, latencies - + data_time, batch_time, batch = AverageMeter(), AverageMeter(), None + losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() + latencies = [] + network.eval() + with torch.no_grad(): + end = time.time() + for i, (inputs, targets) in enumerate(xloader): + targets = targets.cuda(non_blocking=True) + inputs = inputs.cuda(non_blocking=True) + data_time.update(time.time() - end) + # forward + features, logits = network(inputs) + loss = criterion(logits, targets) + batch_time.update(time.time() - end) + if batch is None or batch == inputs.size(0): + batch = inputs.size(0) + latencies.append(batch_time.val - data_time.val) + # record loss and accuracy + prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) + losses.update(loss.item(), inputs.size(0)) + top1.update(prec1.item(), inputs.size(0)) + top5.update(prec5.item(), inputs.size(0)) + end = time.time() + if len(latencies) > 2: + latencies = latencies[1:] + return losses.avg, top1.avg, top5.avg, latencies def procedure(xloader, network, criterion, scheduler, optimizer, mode): - losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() - if mode == 'train' : network.train() - elif mode == 'valid': network.eval() - else: raise ValueError("The mode is not right : {:}".format(mode)) + losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() + if mode == "train": + network.train() + elif mode == "valid": + network.eval() + else: + raise ValueError("The mode is not right : {:}".format(mode)) - data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() - for i, (inputs, targets) in enumerate(xloader): - if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) - - targets = targets.cuda(non_blocking=True) - if mode == 'train': optimizer.zero_grad() - # forward - features, logits = network(inputs) - loss = criterion(logits, targets) - # backward - if mode == 'train': - loss.backward() - optimizer.step() - # record loss and accuracy - prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) - losses.update(loss.item(), inputs.size(0)) - top1.update (prec1.item(), inputs.size(0)) - top5.update (prec5.item(), inputs.size(0)) - # count time - batch_time.update(time.time() - end) - end = time.time() - return losses.avg, top1.avg, top5.avg, batch_time.sum + data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() + for i, (inputs, targets) in enumerate(xloader): + if mode == "train": + scheduler.update(None, 1.0 * i / len(xloader)) + targets = targets.cuda(non_blocking=True) + if mode == "train": + optimizer.zero_grad() + # forward + features, logits = network(inputs) + loss = criterion(logits, targets) + # backward + if mode == "train": + loss.backward() + optimizer.step() + # record loss and accuracy + prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) + losses.update(loss.item(), inputs.size(0)) + top1.update(prec1.item(), inputs.size(0)) + top5.update(prec5.item(), inputs.size(0)) + # count time + batch_time.update(time.time() - end) + end = time.time() + return losses.avg, top1.avg, top5.avg, batch_time.sum def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger): - prepare_seed(seed) # random seed - net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny', - 'C': arch_config['channel'], 'N': arch_config['num_cells'], - 'genotype': arch, 'num_classes': config.class_num} - , None) - ) - #net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) - flop, param = get_model_infos(net, config.xshape) - logger.log('Network : {:}'.format(net.get_message()), False) - logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed)) - logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param)) - # train and valid - optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config) - network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() - # start training - start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup - train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} - train_times , valid_times = {}, {} - for epoch in range(total_epoch): - scheduler.update(epoch, 0.0) + prepare_seed(seed) # random seed + net = get_cell_based_tiny_net( + dict2config( + { + "name": "infer.tiny", + "C": arch_config["channel"], + "N": arch_config["num_cells"], + "genotype": arch, + "num_classes": config.class_num, + }, + None, + ) + ) + # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) + flop, param = get_model_infos(net, config.xshape) + logger.log("Network : {:}".format(net.get_message()), False) + logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed)) + logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) + # train and valid + optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config) + network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() + # start training + start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup + train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} + train_times, valid_times = {}, {} + for epoch in range(total_epoch): + scheduler.update(epoch, 0.0) - train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') - train_losses[epoch] = train_loss - train_acc1es[epoch] = train_acc1 - train_acc5es[epoch] = train_acc5 - train_times [epoch] = train_tm - with torch.no_grad(): - for key, xloder in valid_loaders.items(): - valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder , network, criterion, None, None, 'valid') - valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss - valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1 - valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5 - valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm + train_loss, train_acc1, train_acc5, train_tm = procedure( + train_loader, network, criterion, scheduler, optimizer, "train" + ) + train_losses[epoch] = train_loss + train_acc1es[epoch] = train_acc1 + train_acc5es[epoch] = train_acc5 + train_times[epoch] = train_tm + with torch.no_grad(): + for key, xloder in valid_loaders.items(): + valid_loss, valid_acc1, valid_acc5, valid_tm = procedure( + xloder, network, criterion, None, None, "valid" + ) + valid_losses["{:}@{:}".format(key, epoch)] = valid_loss + valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1 + valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5 + valid_times["{:}@{:}".format(key, epoch)] = valid_tm - # measure elapsed time - epoch_time.update(time.time() - start_time) - start_time = time.time() - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) ) - logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5)) - info_seed = {'flop' : flop, - 'param': param, - 'channel' : arch_config['channel'], - 'num_cells' : arch_config['num_cells'], - 'config' : config._asdict(), - 'total_epoch' : total_epoch , - 'train_losses': train_losses, - 'train_acc1es': train_acc1es, - 'train_acc5es': train_acc5es, - 'train_times' : train_times, - 'valid_losses': valid_losses, - 'valid_acc1es': valid_acc1es, - 'valid_acc5es': valid_acc5es, - 'valid_times' : valid_times, - 'net_state_dict': net.state_dict(), - 'net_string' : '{:}'.format(net), - 'finish-train': True - } - return info_seed + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)) + logger.log( + "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]".format( + time_string(), + need_time, + epoch, + total_epoch, + train_loss, + train_acc1, + train_acc5, + valid_loss, + valid_acc1, + valid_acc5, + ) + ) + info_seed = { + "flop": flop, + "param": param, + "channel": arch_config["channel"], + "num_cells": arch_config["num_cells"], + "config": config._asdict(), + "total_epoch": total_epoch, + "train_losses": train_losses, + "train_acc1es": train_acc1es, + "train_acc5es": train_acc5es, + "train_times": train_times, + "valid_losses": valid_losses, + "valid_acc1es": valid_acc1es, + "valid_acc5es": valid_acc5es, + "valid_times": valid_times, + "net_state_dict": net.state_dict(), + "net_string": "{:}".format(net), + "finish-train": True, + } + return info_seed diff --git a/exps/NAS-Bench-201/main.py b/exps/NAS-Bench-201/main.py index 652e938..b0dbf46 100644 --- a/exps/NAS-Bench-201/main.py +++ b/exps/NAS-Bench-201/main.py @@ -4,313 +4,492 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # ############################################################### import os, sys, time, torch, random, argparse -from PIL import ImageFile +from PIL import ImageFile + ImageFile.LOAD_TRUNCATED_IMAGES = True -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config -from procedures import save_checkpoint, copy_checkpoint -from procedures import get_machine_info -from datasets import get_datasets -from log_utils import Logger, AverageMeter, time_string, convert_secs2time -from models import CellStructure, CellArchitectures, get_search_spaces -from functions import evaluate_for_seed +from procedures import save_checkpoint, copy_checkpoint +from procedures import get_machine_info +from datasets import get_datasets +from log_utils import Logger, AverageMeter, time_string, convert_secs2time +from models import CellStructure, CellArchitectures, get_search_spaces +from functions import evaluate_for_seed def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger): - machine_info, arch_config = get_machine_info(), deepcopy(arch_config) - all_infos = {'info': machine_info} - all_dataset_keys = [] - # look all the datasets - for dataset, xpath, split in zip(datasets, xpaths, splits): - # train valid data - train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) - # load the configuration - if dataset == 'cifar10' or dataset == 'cifar100': - if use_less: config_path = 'configs/nas-benchmark/LESS.config' - else : config_path = 'configs/nas-benchmark/CIFAR.config' - split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None) - elif dataset.startswith('ImageNet16'): - if use_less: config_path = 'configs/nas-benchmark/LESS.config' - else : config_path = 'configs/nas-benchmark/ImageNet-16.config' - split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None) + machine_info, arch_config = get_machine_info(), deepcopy(arch_config) + all_infos = {"info": machine_info} + all_dataset_keys = [] + # look all the datasets + for dataset, xpath, split in zip(datasets, xpaths, splits): + # train valid data + train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) + # load the configuration + if dataset == "cifar10" or dataset == "cifar100": + if use_less: + config_path = "configs/nas-benchmark/LESS.config" + else: + config_path = "configs/nas-benchmark/CIFAR.config" + split_info = load_config("configs/nas-benchmark/cifar-split.txt", None, None) + elif dataset.startswith("ImageNet16"): + if use_less: + config_path = "configs/nas-benchmark/LESS.config" + else: + config_path = "configs/nas-benchmark/ImageNet-16.config" + split_info = load_config("configs/nas-benchmark/{:}-split.txt".format(dataset), None, None) + else: + raise ValueError("invalid dataset : {:}".format(dataset)) + config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger) + # check whether use splited validation set + if bool(split): + assert dataset == "cifar10" + ValLoaders = { + "ori-test": torch.utils.data.DataLoader( + valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + ) + } + assert len(train_data) == len(split_info.train) + len( + split_info.valid + ), "invalid length : {:} vs {:} + {:}".format(len(train_data), len(split_info.train), len(split_info.valid)) + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + # data loader + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), + num_workers=workers, + pin_memory=True, + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), + num_workers=workers, + pin_memory=True, + ) + ValLoaders["x-valid"] = valid_loader + else: + # data loader + train_loader = torch.utils.data.DataLoader( + train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + ) + if dataset == "cifar10": + ValLoaders = {"ori-test": valid_loader} + elif dataset == "cifar100": + cifar100_splits = load_config("configs/nas-benchmark/cifar100-test-split.txt", None, None) + ValLoaders = { + "ori-test": valid_loader, + "x-valid": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), + num_workers=workers, + pin_memory=True, + ), + "x-test": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest), + num_workers=workers, + pin_memory=True, + ), + } + elif dataset == "ImageNet16-120": + imagenet16_splits = load_config("configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None) + ValLoaders = { + "ori-test": valid_loader, + "x-valid": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), + num_workers=workers, + pin_memory=True, + ), + "x-test": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest), + num_workers=workers, + pin_memory=True, + ), + } + else: + raise ValueError("invalid dataset : {:}".format(dataset)) + + dataset_key = "{:}".format(dataset) + if bool(split): + dataset_key = dataset_key + "-valid" + logger.log( + "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)) + for key, value in ValLoaders.items(): + logger.log("Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))) + results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger) + all_infos[dataset_key] = results + all_dataset_keys.append(dataset_key) + all_infos["all_dataset_keys"] = all_dataset_keys + return all_infos + + +def main( + save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config +): + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + # torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + torch.set_num_threads(workers) + + assert len(srange) == 2 and 0 <= srange[0] <= srange[1], "invalid srange : {:}".format(srange) + + if use_less: + sub_dir = Path(save_dir) / "{:06d}-{:06d}-C{:}-N{:}-LESS".format( + srange[0], srange[1], arch_config["channel"], arch_config["num_cells"] + ) else: - raise ValueError('invalid dataset : {:}'.format(dataset)) - config = load_config(config_path, \ - {'class_num': class_num, - 'xshape' : xshape}, \ - logger) - # check whether use splited validation set - if bool(split): - assert dataset == 'cifar10' - ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} - assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid)) - train_data_v2 = deepcopy(train_data) - train_data_v2.transform = valid_data.transform - valid_data = train_data_v2 - # data loader - train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True) - ValLoaders['x-valid'] = valid_loader + sub_dir = Path(save_dir) / "{:06d}-{:06d}-C{:}-N{:}".format( + srange[0], srange[1], arch_config["channel"], arch_config["num_cells"] + ) + logger = Logger(str(sub_dir), 0, False) + + all_archs = meta_info["archs"] + assert srange[1] < meta_info["total"], "invalid range : {:}-{:} vs. {:}".format( + srange[0], srange[1], meta_info["total"] + ) + assert arch_index == -1 or srange[0] <= arch_index <= srange[1], "invalid range : {:} vs. {:} vs. {:}".format( + srange[0], arch_index, srange[1] + ) + if arch_index == -1: + to_evaluate_indexes = list(range(srange[0], srange[1] + 1)) else: - # data loader - train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) - if dataset == 'cifar10': - ValLoaders = {'ori-test': valid_loader} - elif dataset == 'cifar100': - cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None) - ValLoaders = {'ori-test': valid_loader, - 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True), - 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True) - } - elif dataset == 'ImageNet16-120': - imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None) - ValLoaders = {'ori-test': valid_loader, - 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True), - 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True) - } - else: - raise ValueError('invalid dataset : {:}'.format(dataset)) + to_evaluate_indexes = [arch_index] + logger.log("xargs : seeds = {:}".format(seeds)) + logger.log("xargs : arch_index = {:}".format(arch_index)) + logger.log("xargs : cover_mode = {:}".format(cover_mode)) + logger.log("-" * 100) - dataset_key = '{:}'.format(dataset) - if bool(split): dataset_key = dataset_key + '-valid' - logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) - logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config)) - for key, value in ValLoaders.items(): - logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) - results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger) - all_infos[dataset_key] = results - all_dataset_keys.append( dataset_key ) - all_infos['all_dataset_keys'] = all_dataset_keys - return all_infos + logger.log( + "Start evaluating range =: {:06d} vs. {:06d} vs. {:06d} / {:06d} with cover-mode={:}".format( + srange[0], arch_index, srange[1], meta_info["total"], cover_mode + ) + ) + for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): + logger.log( + "--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}".format( + i, len(datasets), dataset, xpath, split + ) + ) + logger.log("--->>> architecture config : {:}".format(arch_config)) + start_time, epoch_time = time.time(), AverageMeter() + for i, index in enumerate(to_evaluate_indexes): + arch = all_archs[index] + logger.log( + "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}".format( + "-" * 15, i, len(to_evaluate_indexes), index, meta_info["total"], seeds, "-" * 15 + ) + ) + # logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15)) + logger.log("{:} {:} {:}".format("-" * 15, arch, "-" * 15)) -def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - #torch.backends.cudnn.benchmark = True - torch.backends.cudnn.deterministic = True - torch.set_num_threads( workers ) + # test this arch on different datasets with different seeds + has_continue = False + for seed in seeds: + to_save_name = sub_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) + if to_save_name.exists(): + if cover_mode: + logger.log("Find existing file : {:}, remove it before evaluation".format(to_save_name)) + os.remove(str(to_save_name)) + else: + logger.log("Find existing file : {:}, skip this evaluation".format(to_save_name)) + has_continue = True + continue + results = evaluate_all_datasets( + CellStructure.str2structure(arch), + datasets, + xpaths, + splits, + use_less, + seed, + arch_config, + workers, + logger, + ) + torch.save(results, to_save_name) + logger.log( + "{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}".format( + "-" * 15, i, len(to_evaluate_indexes), index, meta_info["total"], seed, to_save_name + ) + ) + # measure elapsed time + if not has_continue: + epoch_time.update(time.time() - start_time) + start_time = time.time() + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True) + ) + logger.log("This arch costs : {:}".format(convert_secs2time(epoch_time.val, True))) + logger.log("{:}".format("*" * 100)) + logger.log( + "{:} {:74s} {:}".format( + "*" * 10, + "{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}".format( + i, len(to_evaluate_indexes), index, meta_info["total"], need_time + ), + "*" * 10, + ) + ) + logger.log("{:}".format("*" * 100)) - assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange) - - if use_less: - sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) - else: - sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) - logger = Logger(str(sub_dir), 0, False) - - all_archs = meta_info['archs'] - assert srange[1] < meta_info['total'], 'invalid range : {:}-{:} vs. {:}'.format(srange[0], srange[1], meta_info['total']) - assert arch_index == -1 or srange[0] <= arch_index <= srange[1], 'invalid range : {:} vs. {:} vs. {:}'.format(srange[0], arch_index, srange[1]) - if arch_index == -1: - to_evaluate_indexes = list(range(srange[0], srange[1]+1)) - else: - to_evaluate_indexes = [arch_index] - logger.log('xargs : seeds = {:}'.format(seeds)) - logger.log('xargs : arch_index = {:}'.format(arch_index)) - logger.log('xargs : cover_mode = {:}'.format(cover_mode)) - logger.log('-'*100) - - logger.log('Start evaluating range =: {:06d} vs. {:06d} vs. {:06d} / {:06d} with cover-mode={:}'.format(srange[0], arch_index, srange[1], meta_info['total'], cover_mode)) - for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): - logger.log('--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) - logger.log('--->>> architecture config : {:}'.format(arch_config)) - - - start_time, epoch_time = time.time(), AverageMeter() - for i, index in enumerate(to_evaluate_indexes): - arch = all_archs[index] - logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seeds, '-'*15)) - #logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15)) - logger.log('{:} {:} {:}'.format('-'*15, arch, '-'*15)) - - # test this arch on different datasets with different seeds - has_continue = False - for seed in seeds: - to_save_name = sub_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) - if to_save_name.exists(): - if cover_mode: - logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name)) - os.remove(str(to_save_name)) - else : - logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name)) - has_continue = True - continue - results = evaluate_all_datasets(CellStructure.str2structure(arch), \ - datasets, xpaths, splits, use_less, seed, \ - arch_config, workers, logger) - torch.save(results, to_save_name) - logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name)) - # measure elapsed time - if not has_continue: epoch_time.update(time.time() - start_time) - start_time = time.time() - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) ) - logger.log('This arch costs : {:}'.format( convert_secs2time(epoch_time.val, True) )) - logger.log('{:}'.format('*'*100)) - logger.log('{:} {:74s} {:}'.format('*'*10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(to_evaluate_indexes), index, meta_info['total'], need_time), '*'*10)) - logger.log('{:}'.format('*'*100)) - - logger.close() + logger.close() def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.deterministic = True - #torch.backends.cudnn.benchmark = True - torch.set_num_threads( workers ) - - save_dir = Path(save_dir) / 'specifics' / '{:}-{:}-{:}-{:}'.format('LESS' if use_less else 'FULL', model_str, arch_config['channel'], arch_config['num_cells']) - logger = Logger(str(save_dir), 0, False) - if model_str in CellArchitectures: - arch = CellArchitectures[model_str] - logger.log('The model string is found in pre-defined architecture dict : {:}'.format(model_str)) - else: - try: - arch = CellStructure.str2structure(model_str) - except: - raise ValueError('Invalid model string : {:}. It can not be found or parsed.'.format(model_str)) - assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch) - logger.log('Start train-evaluate {:}'.format(arch.tostr())) - logger.log('arch_config : {:}'.format(arch_config)) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = True + torch.set_num_threads(workers) - start_time, seed_time = time.time(), AverageMeter() - for _is, seed in enumerate(seeds): - logger.log('\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds), seed)) - to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed) - if to_save_name.exists(): - logger.log('Find the existing file {:}, directly load!'.format(to_save_name)) - checkpoint = torch.load(to_save_name) + save_dir = ( + Path(save_dir) + / "specifics" + / "{:}-{:}-{:}-{:}".format( + "LESS" if use_less else "FULL", model_str, arch_config["channel"], arch_config["num_cells"] + ) + ) + logger = Logger(str(save_dir), 0, False) + if model_str in CellArchitectures: + arch = CellArchitectures[model_str] + logger.log("The model string is found in pre-defined architecture dict : {:}".format(model_str)) else: - logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name)) - checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger) - torch.save(checkpoint, to_save_name) - # log information - logger.log('{:}'.format(checkpoint['info'])) - all_dataset_keys = checkpoint['all_dataset_keys'] - for dataset_key in all_dataset_keys: - logger.log('\n{:} dataset : {:} {:}'.format('-'*15, dataset_key, '-'*15)) - dataset_info = checkpoint[dataset_key] - #logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] )) - logger.log('Flops = {:} MB, Params = {:} MB'.format(dataset_info['flop'], dataset_info['param'])) - logger.log('config : {:}'.format(dataset_info['config'])) - logger.log('Training State (finish) = {:}'.format(dataset_info['finish-train'])) - last_epoch = dataset_info['total_epoch'] - 1 - train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es'] - valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es'] - logger.log('Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%'.format(train_acc1es[last_epoch], train_acc5es[last_epoch], 100-train_acc1es[last_epoch], valid_acc1es[last_epoch], valid_acc5es[last_epoch], 100-valid_acc1es[last_epoch])) - # measure elapsed time - seed_time.update(time.time() - start_time) - start_time = time.time() - need_time = 'Time Left: {:}'.format( convert_secs2time(seed_time.avg * (len(seeds)-_is-1), True) ) - logger.log('\n<<<***>>> The {:02d}/{:02d}-th seed is {:} other procedures need {:}'.format(_is, len(seeds), seed, need_time)) - logger.close() + try: + arch = CellStructure.str2structure(model_str) + except: + raise ValueError("Invalid model string : {:}. It can not be found or parsed.".format(model_str)) + assert arch.check_valid_op(get_search_spaces("cell", "full")), "{:} has the invalid op.".format(arch) + logger.log("Start train-evaluate {:}".format(arch.tostr())) + logger.log("arch_config : {:}".format(arch_config)) + + start_time, seed_time = time.time(), AverageMeter() + for _is, seed in enumerate(seeds): + logger.log( + "\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------".format( + _is, len(seeds), seed + ) + ) + to_save_name = save_dir / "seed-{:04d}.pth".format(seed) + if to_save_name.exists(): + logger.log("Find the existing file {:}, directly load!".format(to_save_name)) + checkpoint = torch.load(to_save_name) + else: + logger.log("Does not find the existing file {:}, train and evaluate!".format(to_save_name)) + checkpoint = evaluate_all_datasets( + arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger + ) + torch.save(checkpoint, to_save_name) + # log information + logger.log("{:}".format(checkpoint["info"])) + all_dataset_keys = checkpoint["all_dataset_keys"] + for dataset_key in all_dataset_keys: + logger.log("\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15)) + dataset_info = checkpoint[dataset_key] + # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] )) + logger.log("Flops = {:} MB, Params = {:} MB".format(dataset_info["flop"], dataset_info["param"])) + logger.log("config : {:}".format(dataset_info["config"])) + logger.log("Training State (finish) = {:}".format(dataset_info["finish-train"])) + last_epoch = dataset_info["total_epoch"] - 1 + train_acc1es, train_acc5es = dataset_info["train_acc1es"], dataset_info["train_acc5es"] + valid_acc1es, valid_acc5es = dataset_info["valid_acc1es"], dataset_info["valid_acc5es"] + logger.log( + "Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format( + train_acc1es[last_epoch], + train_acc5es[last_epoch], + 100 - train_acc1es[last_epoch], + valid_acc1es[last_epoch], + valid_acc5es[last_epoch], + 100 - valid_acc1es[last_epoch], + ) + ) + # measure elapsed time + seed_time.update(time.time() - start_time) + start_time = time.time() + need_time = "Time Left: {:}".format(convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)) + logger.log( + "\n<<<***>>> The {:02d}/{:02d}-th seed is {:} other procedures need {:}".format( + _is, len(seeds), seed, need_time + ) + ) + logger.close() def generate_meta_info(save_dir, max_node, divide=40): - aa_nas_bench_ss = get_search_spaces('cell', 'nas-bench-201') - archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) - print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2))) + aa_nas_bench_ss = get_search_spaces("cell", "nas-bench-201") + archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) + print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2))) - random.seed( 88 ) # please do not change this line for reproducibility - random.shuffle( archs ) - # to test fixed-random shuffle - #print ('arch [0] : {:}\n---->>>> {:}'.format( archs[0], archs[0].tostr() )) - #print ('arch [9] : {:}\n---->>>> {:}'.format( archs[9], archs[9].tostr() )) - assert archs[0 ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0]) - assert archs[9 ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9]) - assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123]) - total_arch = len(archs) - - num = 50000 - indexes_5W = list(range(num)) - random.seed( 1021 ) - random.shuffle( indexes_5W ) - train_split = sorted( list(set(indexes_5W[:num//2])) ) - valid_split = sorted( list(set(indexes_5W[num//2:])) ) - assert len(train_split) + len(valid_split) == num - assert train_split[0] == 0 and train_split[10] == 26 and train_split[111] == 203 and valid_split[0] == 1 and valid_split[10] == 18 and valid_split[111] == 242, '{:} {:} {:} - {:} {:} {:}'.format(train_split[0], train_split[10], train_split[111], valid_split[0], valid_split[10], valid_split[111]) - splits = {num: {'train': train_split, 'valid': valid_split} } + random.seed(88) # please do not change this line for reproducibility + random.shuffle(archs) + # to test fixed-random shuffle + # print ('arch [0] : {:}\n---->>>> {:}'.format( archs[0], archs[0].tostr() )) + # print ('arch [9] : {:}\n---->>>> {:}'.format( archs[9], archs[9].tostr() )) + assert ( + archs[0].tostr() + == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" + ), "please check the 0-th architecture : {:}".format(archs[0]) + assert ( + archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" + ), "please check the 9-th architecture : {:}".format(archs[9]) + assert ( + archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" + ), "please check the 123-th architecture : {:}".format(archs[123]) + total_arch = len(archs) - info = {'archs' : [x.tostr() for x in archs], - 'total' : total_arch, - 'max_node' : max_node, - 'splits': splits} + num = 50000 + indexes_5W = list(range(num)) + random.seed(1021) + random.shuffle(indexes_5W) + train_split = sorted(list(set(indexes_5W[: num // 2]))) + valid_split = sorted(list(set(indexes_5W[num // 2 :]))) + assert len(train_split) + len(valid_split) == num + assert ( + train_split[0] == 0 + and train_split[10] == 26 + and train_split[111] == 203 + and valid_split[0] == 1 + and valid_split[10] == 18 + and valid_split[111] == 242 + ), "{:} {:} {:} - {:} {:} {:}".format( + train_split[0], train_split[10], train_split[111], valid_split[0], valid_split[10], valid_split[111] + ) + splits = {num: {"train": train_split, "valid": valid_split}} - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - save_name = save_dir / 'meta-node-{:}.pth'.format(max_node) - assert not save_name.exists(), '{:} already exist'.format(save_name) - torch.save(info, save_name) - print ('save the meta file into {:}'.format(save_name)) + info = {"archs": [x.tostr() for x in archs], "total": total_arch, "max_node": max_node, "splits": splits} - script_name_full = save_dir / 'BENCH-201-N{:}.opt-full.script'.format(max_node) - script_name_less = save_dir / 'BENCH-201-N{:}.opt-less.script'.format(max_node) - full_file = open(str(script_name_full), 'w') - less_file = open(str(script_name_less), 'w') - gaps = total_arch // divide - for start in range(0, total_arch, gaps): - xend = min(start+gaps, total_arch) - full_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 0 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) - less_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 1 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) - print ('save the training script into {:} and {:}'.format(script_name_full, script_name_less)) - full_file.close() - less_file.close() + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + save_name = save_dir / "meta-node-{:}.pth".format(max_node) + assert not save_name.exists(), "{:} already exist".format(save_name) + torch.save(info, save_name) + print("save the meta file into {:}".format(save_name)) - script_name = save_dir / 'meta-node-{:}.cal-script.txt'.format(max_node) - macro = 'OMP_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0' - with open(str(script_name), 'w') as cfile: + script_name_full = save_dir / "BENCH-201-N{:}.opt-full.script".format(max_node) + script_name_less = save_dir / "BENCH-201-N{:}.opt-less.script".format(max_node) + full_file = open(str(script_name_full), "w") + less_file = open(str(script_name_less), "w") + gaps = total_arch // divide for start in range(0, total_arch, gaps): - xend = min(start+gaps, total_arch) - cfile.write('{:} python exps/NAS-Bench-201/statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n'.format(macro, start, xend-1)) - print ('save the post-processing script into {:}'.format(script_name)) + xend = min(start + gaps, total_arch) + full_file.write( + "bash ./scripts-search/NAS-Bench-201/train-models.sh 0 {:5d} {:5d} -1 '777 888 999'\n".format( + start, xend - 1 + ) + ) + less_file.write( + "bash ./scripts-search/NAS-Bench-201/train-models.sh 1 {:5d} {:5d} -1 '777 888 999'\n".format( + start, xend - 1 + ) + ) + print("save the training script into {:} and {:}".format(script_name_full, script_name_less)) + full_file.close() + less_file.close() + + script_name = save_dir / "meta-node-{:}.cal-script.txt".format(max_node) + macro = "OMP_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0" + with open(str(script_name), "w") as cfile: + for start in range(0, total_arch, gaps): + xend = min(start + gaps, total_arch) + cfile.write( + "{:} python exps/NAS-Bench-201/statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n".format( + macro, start, xend - 1 + ) + ) + print("save the post-processing script into {:}".format(script_name)) -if __name__ == '__main__': - #mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()] - #parser = argparse.ArgumentParser(description='Algorithm-Agnostic NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--mode' , type=str, required=True, help='The script mode.') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - parser.add_argument('--max_node', type=int, help='The maximum node in a cell.') - # use for train the model - parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 2)') - parser.add_argument('--srange' , type=int, nargs='+', help='The range of models to be evaluated') - parser.add_argument('--arch_index', type=int, default=-1, help='The architecture index to be evaluated (cover mode).') - parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.') - parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.') - parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.') - parser.add_argument('--use_less', type=int, default=0, choices=[0,1], help='Using the less-training-epoch config.') - parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated') - parser.add_argument('--channel', type=int, help='The number of channels.') - parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') - args = parser.parse_args() +if __name__ == "__main__": + # mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()] + # parser = argparse.ArgumentParser(description='Algorithm-Agnostic NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + description="NAS-Bench-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--mode", type=str, required=True, help="The script mode.") + parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument("--max_node", type=int, help="The maximum node in a cell.") + # use for train the model + parser.add_argument("--workers", type=int, default=8, help="number of data loading workers (default: 2)") + parser.add_argument("--srange", type=int, nargs="+", help="The range of models to be evaluated") + parser.add_argument( + "--arch_index", type=int, default=-1, help="The architecture index to be evaluated (cover mode)." + ) + parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.") + parser.add_argument("--xpaths", type=str, nargs="+", help="The root path for this dataset.") + parser.add_argument("--splits", type=int, nargs="+", help="The root path for this dataset.") + parser.add_argument("--use_less", type=int, default=0, choices=[0, 1], help="Using the less-training-epoch config.") + parser.add_argument("--seeds", type=int, nargs="+", help="The range of models to be evaluated") + parser.add_argument("--channel", type=int, help="The number of channels.") + parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + args = parser.parse_args() - assert args.mode in ['meta', 'new', 'cover'] or args.mode.startswith('specific-'), 'invalid mode : {:}'.format(args.mode) + assert args.mode in ["meta", "new", "cover"] or args.mode.startswith("specific-"), "invalid mode : {:}".format( + args.mode + ) - if args.mode == 'meta': - generate_meta_info(args.save_dir, args.max_node) - elif args.mode.startswith('specific'): - assert len(args.mode.split('-')) == 2, 'invalid mode : {:}'.format(args.mode) - model_str = args.mode.split('-')[1] - train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \ - tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells}) - else: - meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node) - assert meta_path.exists(), '{:} does not exist.'.format(meta_path) - meta_info = torch.load( meta_path ) - # check whether args is ok - assert len(args.srange) == 2 and args.srange[0] <= args.srange[1], 'invalid length of srange args: {:}'.format(args.srange) - assert len(args.seeds) > 0, 'invalid length of seeds args: {:}'.format(args.seeds) - assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits)) - assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers) - - main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \ - tuple(args.srange), args.arch_index, tuple(args.seeds), \ - args.mode == 'cover', meta_info, \ - {'channel': args.channel, 'num_cells': args.num_cells}) + if args.mode == "meta": + generate_meta_info(args.save_dir, args.max_node) + elif args.mode.startswith("specific"): + assert len(args.mode.split("-")) == 2, "invalid mode : {:}".format(args.mode) + model_str = args.mode.split("-")[1] + train_single_model( + args.save_dir, + args.workers, + args.datasets, + args.xpaths, + args.splits, + args.use_less > 0, + tuple(args.seeds), + model_str, + {"channel": args.channel, "num_cells": args.num_cells}, + ) + else: + meta_path = Path(args.save_dir) / "meta-node-{:}.pth".format(args.max_node) + assert meta_path.exists(), "{:} does not exist.".format(meta_path) + meta_info = torch.load(meta_path) + # check whether args is ok + assert len(args.srange) == 2 and args.srange[0] <= args.srange[1], "invalid length of srange args: {:}".format( + args.srange + ) + assert len(args.seeds) > 0, "invalid length of seeds args: {:}".format(args.seeds) + assert len(args.datasets) == len(args.xpaths) == len(args.splits), "invalid infos : {:} vs {:} vs {:}".format( + len(args.datasets), len(args.xpaths), len(args.splits) + ) + assert args.workers > 0, "invalid number of workers : {:}".format(args.workers) + + main( + args.save_dir, + args.workers, + args.datasets, + args.xpaths, + args.splits, + args.use_less > 0, + tuple(args.srange), + args.arch_index, + tuple(args.seeds), + args.mode == "cover", + meta_info, + {"channel": args.channel, "num_cells": args.num_cells}, + ) diff --git a/exps/NAS-Bench-201/show-best.py b/exps/NAS-Bench-201/show-best.py index fe3fe3d..153d197 100644 --- a/exps/NAS-Bench-201/show-best.py +++ b/exps/NAS-Bench-201/show-best.py @@ -5,35 +5,37 @@ ################################################################################################ import sys, argparse from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from nas_201_api import NASBench201API as API -if __name__ == '__main__': - parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") - parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.') - args = parser.parse_args() +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from nas_201_api import NASBench201API as API - meta_file = Path(args.api_path) - assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) +if __name__ == "__main__": + parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") + parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.") + args = parser.parse_args() - api = API(str(meta_file)) + meta_file = Path(args.api_path) + assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) - # This will show the results of the best architecture based on the validation set of each dataset. - arch_index, accuracy = api.find_best('cifar10-valid', 'x-valid', None, None, False) - print('FOR CIFAR-010, using the hyper-parameters with 200 training epochs :::') - print('arch-index={:5d}, arch={:}'.format(arch_index, api.arch(arch_index))) - api.show(arch_index) - print('') + api = API(str(meta_file)) - arch_index, accuracy = api.find_best('cifar100', 'x-valid', None, None, False) - print('FOR CIFAR-100, using the hyper-parameters with 200 training epochs :::') - print('arch-index={:5d}, arch={:}'.format(arch_index, api.arch(arch_index))) - api.show(arch_index) - print('') + # This will show the results of the best architecture based on the validation set of each dataset. + arch_index, accuracy = api.find_best("cifar10-valid", "x-valid", None, None, False) + print("FOR CIFAR-010, using the hyper-parameters with 200 training epochs :::") + print("arch-index={:5d}, arch={:}".format(arch_index, api.arch(arch_index))) + api.show(arch_index) + print("") - arch_index, accuracy = api.find_best('ImageNet16-120', 'x-valid', None, None, False) - print('FOR ImageNet16-120, using the hyper-parameters with 200 training epochs :::') - print('arch-index={:5d}, arch={:}'.format(arch_index, api.arch(arch_index))) - api.show(arch_index) - print('') + arch_index, accuracy = api.find_best("cifar100", "x-valid", None, None, False) + print("FOR CIFAR-100, using the hyper-parameters with 200 training epochs :::") + print("arch-index={:5d}, arch={:}".format(arch_index, api.arch(arch_index))) + api.show(arch_index) + print("") + + arch_index, accuracy = api.find_best("ImageNet16-120", "x-valid", None, None, False) + print("FOR ImageNet16-120, using the hyper-parameters with 200 training epochs :::") + print("arch-index={:5d}, arch={:}".format(arch_index, api.arch(arch_index))) + api.show(arch_index) + print("") diff --git a/exps/NAS-Bench-201/statistics-v2.py b/exps/NAS-Bench-201/statistics-v2.py index 79eaee7..dfeaa0b 100644 --- a/exps/NAS-Bench-201/statistics-v2.py +++ b/exps/NAS-Bench-201/statistics-v2.py @@ -7,276 +7,396 @@ import torch from pathlib import Path from collections import defaultdict, OrderedDict from typing import Dict, Any, Text, List -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from log_utils import AverageMeter, time_string, convert_secs2time from config_utils import dict2config + # NAS-Bench-201 related module or function -from models import CellStructure, get_cell_based_tiny_net -from nas_201_api import NASBench201API, ArchResults, ResultsCount -from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders +from models import CellStructure, get_cell_based_tiny_net +from nas_201_api import NASBench201API, ArchResults, ResultsCount +from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders -api = NASBench201API('{:}/.torch/NAS-Bench-201-v1_0-e61699.pth'.format(os.environ['HOME'])) +api = NASBench201API("{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"])) -def create_result_count(used_seed: int, dataset: Text, arch_config: Dict[Text, Any], - results: Dict[Text, Any], dataloader_dict: Dict[Text, Any]) -> ResultsCount: - xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'], - results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None) - net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes': arch_config['class_num']}, None) - network = get_cell_based_tiny_net(net_config) - network.load_state_dict(xresult.get_net_param()) - if 'train_times' in results: # new version - xresult.update_train_info(results['train_acc1es'], results['train_acc5es'], results['train_losses'], results['train_times']) - xresult.update_eval(results['valid_acc1es'], results['valid_losses'], results['valid_times']) - else: - if dataset == 'cifar10-valid': - xresult.update_OLD_eval('x-valid' , results['valid_acc1es'], results['valid_losses']) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format('cifar10', 'test')], network.cuda()) - xresult.update_OLD_eval('ori-test', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) - xresult.update_latency(latencies) - elif dataset == 'cifar10': - xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses']) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda()) - xresult.update_latency(latencies) - elif dataset == 'cifar100' or dataset == 'ImageNet16-120': - xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses']) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'valid')], network.cuda()) - xresult.update_OLD_eval('x-valid', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda()) - xresult.update_OLD_eval('x-test' , {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) - xresult.update_latency(latencies) + +def create_result_count( + used_seed: int, + dataset: Text, + arch_config: Dict[Text, Any], + results: Dict[Text, Any], + dataloader_dict: Dict[Text, Any], +) -> ResultsCount: + xresult = ResultsCount( + dataset, + results["net_state_dict"], + results["train_acc1es"], + results["train_losses"], + results["param"], + results["flop"], + arch_config, + used_seed, + results["total_epoch"], + None, + ) + net_config = dict2config( + { + "name": "infer.tiny", + "C": arch_config["channel"], + "N": arch_config["num_cells"], + "genotype": CellStructure.str2structure(arch_config["arch_str"]), + "num_classes": arch_config["class_num"], + }, + None, + ) + network = get_cell_based_tiny_net(net_config) + network.load_state_dict(xresult.get_net_param()) + if "train_times" in results: # new version + xresult.update_train_info( + results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"] + ) + xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"]) else: - raise ValueError('invalid dataset name : {:}'.format(dataset)) - return xresult - + if dataset == "cifar10-valid": + xresult.update_OLD_eval("x-valid", results["valid_acc1es"], results["valid_losses"]) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda() + ) + xresult.update_OLD_eval("ori-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_latency(latencies) + elif dataset == "cifar10": + xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() + ) + xresult.update_latency(latencies) + elif dataset == "cifar100" or dataset == "ImageNet16-120": + xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda() + ) + xresult.update_OLD_eval("x-valid", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() + ) + xresult.update_OLD_eval("x-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_latency(latencies) + else: + raise ValueError("invalid dataset name : {:}".format(dataset)) + return xresult -def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], - datasets: List[Text], dataloader_dict: Dict[Text, Any]) -> ArchResults: - information = ArchResults(arch_index, arch_str) +def account_one_arch( + arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text], dataloader_dict: Dict[Text, Any] +) -> ArchResults: + information = ArchResults(arch_index, arch_str) - for checkpoint_path in checkpoints: - checkpoint = torch.load(checkpoint_path, map_location='cpu') - used_seed = checkpoint_path.name.split('-')[-1].split('.')[0] - ok_dataset = 0 - for dataset in datasets: - if dataset not in checkpoint: - print('Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path)) - continue - else: - ok_dataset += 1 - results = checkpoint[dataset] - assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path) - arch_config = {'channel': results['channel'], 'num_cells': results['num_cells'], 'arch_str': arch_str, 'class_num': results['config']['class_num']} - - xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict) - information.update(dataset, int(used_seed), xresult) - if ok_dataset == 0: raise ValueError('{:} does not find any data'.format(checkpoint_path)) - return information + for checkpoint_path in checkpoints: + checkpoint = torch.load(checkpoint_path, map_location="cpu") + used_seed = checkpoint_path.name.split("-")[-1].split(".")[0] + ok_dataset = 0 + for dataset in datasets: + if dataset not in checkpoint: + print("Can not find {:} in arch-{:} from {:}".format(dataset, arch_index, checkpoint_path)) + continue + else: + ok_dataset += 1 + results = checkpoint[dataset] + assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( + arch_index, used_seed, dataset, checkpoint_path + ) + arch_config = { + "channel": results["channel"], + "num_cells": results["num_cells"], + "arch_str": arch_str, + "class_num": results["config"]["class_num"], + } + + xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict) + information.update(dataset, int(used_seed), xresult) + if ok_dataset == 0: + raise ValueError("{:} does not find any data".format(checkpoint_path)) + return information def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch_info_less: ArchResults): - # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth - cifar010_latency = (api.get_latency(arch_index, 'cifar10-valid', hp='200') + api.get_latency(arch_index, 'cifar10', hp='200')) / 2 - arch_info_full.reset_latency('cifar10-valid', None, cifar010_latency) - arch_info_full.reset_latency('cifar10', None, cifar010_latency) - arch_info_less.reset_latency('cifar10-valid', None, cifar010_latency) - arch_info_less.reset_latency('cifar10', None, cifar010_latency) + # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth + cifar010_latency = ( + api.get_latency(arch_index, "cifar10-valid", hp="200") + api.get_latency(arch_index, "cifar10", hp="200") + ) / 2 + arch_info_full.reset_latency("cifar10-valid", None, cifar010_latency) + arch_info_full.reset_latency("cifar10", None, cifar010_latency) + arch_info_less.reset_latency("cifar10-valid", None, cifar010_latency) + arch_info_less.reset_latency("cifar10", None, cifar010_latency) - cifar100_latency = api.get_latency(arch_index, 'cifar100', hp='200') - arch_info_full.reset_latency('cifar100', None, cifar100_latency) - arch_info_less.reset_latency('cifar100', None, cifar100_latency) + cifar100_latency = api.get_latency(arch_index, "cifar100", hp="200") + arch_info_full.reset_latency("cifar100", None, cifar100_latency) + arch_info_less.reset_latency("cifar100", None, cifar100_latency) - image_latency = api.get_latency(arch_index, 'ImageNet16-120', hp='200') - arch_info_full.reset_latency('ImageNet16-120', None, image_latency) - arch_info_less.reset_latency('ImageNet16-120', None, image_latency) + image_latency = api.get_latency(arch_index, "ImageNet16-120", hp="200") + arch_info_full.reset_latency("ImageNet16-120", None, image_latency) + arch_info_less.reset_latency("ImageNet16-120", None, image_latency) - train_per_epoch_time = list(arch_info_less.query('cifar10-valid', 777).train_times.values()) - train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) - eval_ori_test_time, eval_x_valid_time = [], [] - for key, value in arch_info_less.query('cifar10-valid', 777).eval_times.items(): - if key.startswith('ori-test@'): - eval_ori_test_time.append(value) - elif key.startswith('x-valid@'): - eval_x_valid_time.append(value) - else: raise ValueError('-- {:} --'.format(key)) - eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time)) - nums = {'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000, 'ImageNet16-120-test': 6000, - 'cifar10-valid-train': 25000, 'cifar10-valid-valid': 25000, - 'cifar10-train': 50000, 'cifar10-test': 10000, - 'cifar100-train': 50000, 'cifar100-test': 10000, 'cifar100-valid': 5000} - eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (nums['cifar10-valid-valid'] + nums['cifar10-test']) - for arch_info in [arch_info_less, arch_info_full]: - arch_info.reset_pseudo_train_times('cifar10-valid', None, - train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar10-valid-train']) - arch_info.reset_pseudo_train_times('cifar10', None, - train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar10-train']) - arch_info.reset_pseudo_train_times('cifar100', None, - train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar100-train']) - arch_info.reset_pseudo_train_times('ImageNet16-120', None, - train_per_epoch_time / nums['cifar10-valid-train'] * nums['ImageNet16-120-train']) - arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'x-valid', eval_per_sample*nums['cifar10-valid-valid']) - arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'ori-test', eval_per_sample * nums['cifar10-test']) - arch_info.reset_pseudo_eval_times('cifar10', None, 'ori-test', eval_per_sample * nums['cifar10-test']) - arch_info.reset_pseudo_eval_times('cifar100', None, 'x-valid', eval_per_sample * nums['cifar100-valid']) - arch_info.reset_pseudo_eval_times('cifar100', None, 'x-test', eval_per_sample * nums['cifar100-valid']) - arch_info.reset_pseudo_eval_times('cifar100', None, 'ori-test', eval_per_sample * nums['cifar100-test']) - arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-valid', eval_per_sample * nums['ImageNet16-120-valid']) - arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-test', eval_per_sample * nums['ImageNet16-120-valid']) - arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'ori-test', eval_per_sample * nums['ImageNet16-120-test']) - # arch_info_full.debug_test() - # arch_info_less.debug_test() - return arch_info_full, arch_info_less + train_per_epoch_time = list(arch_info_less.query("cifar10-valid", 777).train_times.values()) + train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) + eval_ori_test_time, eval_x_valid_time = [], [] + for key, value in arch_info_less.query("cifar10-valid", 777).eval_times.items(): + if key.startswith("ori-test@"): + eval_ori_test_time.append(value) + elif key.startswith("x-valid@"): + eval_x_valid_time.append(value) + else: + raise ValueError("-- {:} --".format(key)) + eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time)) + nums = { + "ImageNet16-120-train": 151700, + "ImageNet16-120-valid": 3000, + "ImageNet16-120-test": 6000, + "cifar10-valid-train": 25000, + "cifar10-valid-valid": 25000, + "cifar10-train": 50000, + "cifar10-test": 10000, + "cifar100-train": 50000, + "cifar100-test": 10000, + "cifar100-valid": 5000, + } + eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (nums["cifar10-valid-valid"] + nums["cifar10-test"]) + for arch_info in [arch_info_less, arch_info_full]: + arch_info.reset_pseudo_train_times( + "cifar10-valid", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-valid-train"] + ) + arch_info.reset_pseudo_train_times( + "cifar10", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"] + ) + arch_info.reset_pseudo_train_times( + "cifar100", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"] + ) + arch_info.reset_pseudo_train_times( + "ImageNet16-120", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["ImageNet16-120-train"] + ) + arch_info.reset_pseudo_eval_times( + "cifar10-valid", None, "x-valid", eval_per_sample * nums["cifar10-valid-valid"] + ) + arch_info.reset_pseudo_eval_times("cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"]) + arch_info.reset_pseudo_eval_times("cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"]) + arch_info.reset_pseudo_eval_times("cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"]) + arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"]) + arch_info.reset_pseudo_eval_times("cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"]) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", None, "x-valid", eval_per_sample * nums["ImageNet16-120-valid"] + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", None, "x-test", eval_per_sample * nums["ImageNet16-120-valid"] + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", None, "ori-test", eval_per_sample * nums["ImageNet16-120-test"] + ) + # arch_info_full.debug_test() + # arch_info_less.debug_test() + return arch_info_full, arch_info_less def simplify(save_dir, meta_file, basestr, target_dir): - meta_infos = torch.load(meta_file, map_location='cpu') - meta_archs = meta_infos['archs'] # a list of architecture strings - meta_num_archs = meta_infos['total'] - assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) + meta_infos = torch.load(meta_file, map_location="cpu") + meta_archs = meta_infos["archs"] # a list of architecture strings + meta_num_archs = meta_infos["total"] + assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format( + meta_num_archs, len(meta_archs) + ) - sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) - print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs))) - - subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 - num_seeds = defaultdict(lambda: 0) - for index, sub_dir in enumerate(sub_model_dirs): - xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth')) - arch_indexes = set() - for checkpoint in xcheckpoints: - temp_names = checkpoint.name.split('-') - assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name) - arch_indexes.add( temp_names[1] ) - subdir2archs[sub_dir] = sorted(list(arch_indexes)) - num_evaluated_arch += len(arch_indexes) - # count number of seeds for each architecture - for arch_index in arch_indexes: - num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1 - print('{:} There are {:5d} architectures that have been evaluated ({:} in total).'.format(time_string(), num_evaluated_arch, meta_num_archs)) - for key in sorted( list( num_seeds.keys() ) ): print ('{:} There are {:5d} architectures that are evaluated {:} times.'.format(time_string(), num_seeds[key], key)) + sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) + print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs))) - dataloader_dict = get_nas_bench_loaders( 6 ) - to_save_simply = save_dir / 'simplifies' - to_save_allarc = save_dir / 'simplifies' / 'architectures' - if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True) - if not to_save_allarc.exists(): to_save_allarc.mkdir(parents=True, exist_ok=True) + subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 + num_seeds = defaultdict(lambda: 0) + for index, sub_dir in enumerate(sub_model_dirs): + xcheckpoints = list(sub_dir.glob("arch-*-seed-*.pth")) + arch_indexes = set() + for checkpoint in xcheckpoints: + temp_names = checkpoint.name.split("-") + assert ( + len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed" + ), "invalid checkpoint name : {:}".format(checkpoint.name) + arch_indexes.add(temp_names[1]) + subdir2archs[sub_dir] = sorted(list(arch_indexes)) + num_evaluated_arch += len(arch_indexes) + # count number of seeds for each architecture + for arch_index in arch_indexes: + num_seeds[len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))] += 1 + print( + "{:} There are {:5d} architectures that have been evaluated ({:} in total).".format( + time_string(), num_evaluated_arch, meta_num_archs + ) + ) + for key in sorted(list(num_seeds.keys())): + print( + "{:} There are {:5d} architectures that are evaluated {:} times.".format(time_string(), num_seeds[key], key) + ) - assert (save_dir / target_dir) in subdir2archs, 'can not find {:}'.format(target_dir) - arch2infos, datasets = {}, ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120') - evaluated_indexes = set() - target_full_dir = save_dir / target_dir - target_less_dir = save_dir / '{:}-LESS'.format(target_dir) - arch_indexes = subdir2archs[ target_full_dir ] - num_seeds = defaultdict(lambda: 0) - end_time = time.time() - arch_time = AverageMeter() - for idx, arch_index in enumerate(arch_indexes): - checkpoints = list(target_full_dir.glob('arch-{:}-seed-*.pth'.format(arch_index))) - ckps_less = list(target_less_dir.glob('arch-{:}-seed-*.pth'.format(arch_index))) - # create the arch info for each architecture - try: - arch_info_full = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict) - arch_info_less = account_one_arch(arch_index, meta_archs[int(arch_index)], ckps_less, datasets, dataloader_dict) - num_seeds[ len(checkpoints) ] += 1 - except: - print('Loading {:} failed, : {:}'.format(arch_index, checkpoints)) - continue - assert int(arch_index) not in evaluated_indexes, 'conflict arch-index : {:}'.format(arch_index) - assert 0 <= int(arch_index) < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index) - arch_info = {'full': arch_info_full, 'less': arch_info_less} - evaluated_indexes.add(int(arch_index)) - arch2infos[int(arch_index)] = arch_info - # to correct the latency and training_time info. - arch_info_full, arch_info_less = correct_time_related_info(int(arch_index), arch_info_full, arch_info_less) - to_save_data = OrderedDict(full=arch_info_full.state_dict(), less=arch_info_less.state_dict()) - torch.save(to_save_data, to_save_allarc / '{:}-FULL.pth'.format(arch_index)) - arch_info['full'].clear_params() - arch_info['less'].clear_params() - torch.save(to_save_data, to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index)) - # measure elapsed time - arch_time.update(time.time() - end_time) - end_time = time.time() - need_time = '{:}'.format( convert_secs2time(arch_time.avg * (len(arch_indexes)-idx-1), True) ) - print('{:} {:} [{:03d}/{:03d}] : {:} still need {:}'.format(time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time)) - # measure time - xstrs = ['{:}:{:03d}'.format(key, num_seeds[key]) for key in sorted( list( num_seeds.keys() ) ) ] - print('{:} {:} done : {:}'.format(time_string(), target_dir, xstrs)) - final_infos = {'meta_archs' : meta_archs, - 'total_archs': meta_num_archs, - 'basestr' : basestr, - 'arch2infos' : arch2infos, - 'evaluated_indexes': evaluated_indexes} - save_file_name = to_save_simply / '{:}.pth'.format(target_dir) - torch.save(final_infos, save_file_name) - print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name)) + dataloader_dict = get_nas_bench_loaders(6) + to_save_simply = save_dir / "simplifies" + to_save_allarc = save_dir / "simplifies" / "architectures" + if not to_save_simply.exists(): + to_save_simply.mkdir(parents=True, exist_ok=True) + if not to_save_allarc.exists(): + to_save_allarc.mkdir(parents=True, exist_ok=True) + + assert (save_dir / target_dir) in subdir2archs, "can not find {:}".format(target_dir) + arch2infos, datasets = {}, ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") + evaluated_indexes = set() + target_full_dir = save_dir / target_dir + target_less_dir = save_dir / "{:}-LESS".format(target_dir) + arch_indexes = subdir2archs[target_full_dir] + num_seeds = defaultdict(lambda: 0) + end_time = time.time() + arch_time = AverageMeter() + for idx, arch_index in enumerate(arch_indexes): + checkpoints = list(target_full_dir.glob("arch-{:}-seed-*.pth".format(arch_index))) + ckps_less = list(target_less_dir.glob("arch-{:}-seed-*.pth".format(arch_index))) + # create the arch info for each architecture + try: + arch_info_full = account_one_arch( + arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict + ) + arch_info_less = account_one_arch( + arch_index, meta_archs[int(arch_index)], ckps_less, datasets, dataloader_dict + ) + num_seeds[len(checkpoints)] += 1 + except: + print("Loading {:} failed, : {:}".format(arch_index, checkpoints)) + continue + assert int(arch_index) not in evaluated_indexes, "conflict arch-index : {:}".format(arch_index) + assert 0 <= int(arch_index) < len(meta_archs), "invalid arch-index {:} (not found in meta_archs)".format( + arch_index + ) + arch_info = {"full": arch_info_full, "less": arch_info_less} + evaluated_indexes.add(int(arch_index)) + arch2infos[int(arch_index)] = arch_info + # to correct the latency and training_time info. + arch_info_full, arch_info_less = correct_time_related_info(int(arch_index), arch_info_full, arch_info_less) + to_save_data = OrderedDict(full=arch_info_full.state_dict(), less=arch_info_less.state_dict()) + torch.save(to_save_data, to_save_allarc / "{:}-FULL.pth".format(arch_index)) + arch_info["full"].clear_params() + arch_info["less"].clear_params() + torch.save(to_save_data, to_save_allarc / "{:}-SIMPLE.pth".format(arch_index)) + # measure elapsed time + arch_time.update(time.time() - end_time) + end_time = time.time() + need_time = "{:}".format(convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1), True)) + print( + "{:} {:} [{:03d}/{:03d}] : {:} still need {:}".format( + time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time + ) + ) + # measure time + xstrs = ["{:}:{:03d}".format(key, num_seeds[key]) for key in sorted(list(num_seeds.keys()))] + print("{:} {:} done : {:}".format(time_string(), target_dir, xstrs)) + final_infos = { + "meta_archs": meta_archs, + "total_archs": meta_num_archs, + "basestr": basestr, + "arch2infos": arch2infos, + "evaluated_indexes": evaluated_indexes, + } + save_file_name = to_save_simply / "{:}.pth".format(target_dir) + torch.save(final_infos, save_file_name) + print( + "Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name) + ) def merge_all(save_dir, meta_file, basestr): - meta_infos = torch.load(meta_file, map_location='cpu') - meta_archs = meta_infos['archs'] - meta_num_archs = meta_infos['total'] - assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) + meta_infos = torch.load(meta_file, map_location="cpu") + meta_archs = meta_infos["archs"] + meta_num_archs = meta_infos["total"] + assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format( + meta_num_archs, len(meta_archs) + ) - sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) - print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs))) - for index, sub_dir in enumerate(sub_model_dirs): - arch_info_files = sorted( list(sub_dir.glob('arch-*-seed-*.pth') ) ) - print ('The {:02d}/{:02d}-th directory : {:} : {:} runs.'.format(index, len(sub_model_dirs), sub_dir, len(arch_info_files))) - - arch2infos, evaluated_indexes = dict(), set() - for IDX, sub_dir in enumerate(sub_model_dirs): - ckp_path = sub_dir.parent / 'simplifies' / '{:}.pth'.format(sub_dir.name) - if ckp_path.exists(): - sub_ckps = torch.load(ckp_path, map_location='cpu') - assert sub_ckps['total_archs'] == meta_num_archs and sub_ckps['basestr'] == basestr - xarch2infos = sub_ckps['arch2infos'] - xevalindexs = sub_ckps['evaluated_indexes'] - for eval_index in xevalindexs: - assert eval_index not in evaluated_indexes and eval_index not in arch2infos - #arch2infos[eval_index] = xarch2infos[eval_index].state_dict() - arch2infos[eval_index] = {'full': xarch2infos[eval_index]['full'].state_dict(), - 'less': xarch2infos[eval_index]['less'].state_dict()} - evaluated_indexes.add( eval_index ) - print ('{:} [{:03d}/{:03d}] merge data from {:} with {:} models.'.format(time_string(), IDX, len(sub_model_dirs), ckp_path, len(xevalindexs))) + sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) + print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs))) + for index, sub_dir in enumerate(sub_model_dirs): + arch_info_files = sorted(list(sub_dir.glob("arch-*-seed-*.pth"))) + print( + "The {:02d}/{:02d}-th directory : {:} : {:} runs.".format( + index, len(sub_model_dirs), sub_dir, len(arch_info_files) + ) + ) + + arch2infos, evaluated_indexes = dict(), set() + for IDX, sub_dir in enumerate(sub_model_dirs): + ckp_path = sub_dir.parent / "simplifies" / "{:}.pth".format(sub_dir.name) + if ckp_path.exists(): + sub_ckps = torch.load(ckp_path, map_location="cpu") + assert sub_ckps["total_archs"] == meta_num_archs and sub_ckps["basestr"] == basestr + xarch2infos = sub_ckps["arch2infos"] + xevalindexs = sub_ckps["evaluated_indexes"] + for eval_index in xevalindexs: + assert eval_index not in evaluated_indexes and eval_index not in arch2infos + # arch2infos[eval_index] = xarch2infos[eval_index].state_dict() + arch2infos[eval_index] = { + "full": xarch2infos[eval_index]["full"].state_dict(), + "less": xarch2infos[eval_index]["less"].state_dict(), + } + evaluated_indexes.add(eval_index) + print( + "{:} [{:03d}/{:03d}] merge data from {:} with {:} models.".format( + time_string(), IDX, len(sub_model_dirs), ckp_path, len(xevalindexs) + ) + ) + else: + raise ValueError("Can not find {:}".format(ckp_path)) + # print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path)) + + evaluated_indexes = sorted(list(evaluated_indexes)) + print("Finally, there are {:} architectures that have been trained and evaluated.".format(len(evaluated_indexes))) + + to_save_simply = save_dir / "simplifies" + if not to_save_simply.exists(): + to_save_simply.mkdir(parents=True, exist_ok=True) + final_infos = { + "meta_archs": meta_archs, + "total_archs": meta_num_archs, + "arch2infos": arch2infos, + "evaluated_indexes": evaluated_indexes, + } + save_file_name = to_save_simply / "{:}-final-infos.pth".format(basestr) + torch.save(final_infos, save_file_name) + print( + "Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name) + ) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="NAS-BENCH-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--mode", type=str, choices=["cal", "merge"], help="The running mode for this script.") + parser.add_argument( + "--base_save_dir", + type=str, + default="./output/NAS-BENCH-201-4", + help="The base-name of folder to save checkpoints and log.", + ) + parser.add_argument("--target_dir", type=str, help="The target directory.") + parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.") + parser.add_argument("--channel", type=int, default=16, help="The number of channels.") + parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + args = parser.parse_args() + + save_dir = Path(args.base_save_dir) + meta_path = save_dir / "meta-node-{:}.pth".format(args.max_node) + assert save_dir.exists(), "invalid save dir path : {:}".format(save_dir) + assert meta_path.exists(), "invalid saved meta path : {:}".format(meta_path) + print("start the statistics of our nas-benchmark from {:} using {:}.".format(save_dir, args.target_dir)) + basestr = "C{:}-N{:}".format(args.channel, args.num_cells) + + if args.mode == "cal": + simplify(save_dir, meta_path, basestr, args.target_dir) + elif args.mode == "merge": + merge_all(save_dir, meta_path, basestr) else: - raise ValueError('Can not find {:}'.format(ckp_path)) - #print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path)) - - evaluated_indexes = sorted( list( evaluated_indexes ) ) - print ('Finally, there are {:} architectures that have been trained and evaluated.'.format(len(evaluated_indexes))) - - to_save_simply = save_dir / 'simplifies' - if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True) - final_infos = {'meta_archs' : meta_archs, - 'total_archs': meta_num_archs, - 'arch2infos' : arch2infos, - 'evaluated_indexes': evaluated_indexes} - save_file_name = to_save_simply / '{:}-final-infos.pth'.format(basestr) - torch.save(final_infos, save_file_name) - print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name)) - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--mode' , type=str, choices=['cal', 'merge'], help='The running mode for this script.') - parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.') - parser.add_argument('--target_dir' , type=str, help='The target directory.') - parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.') - parser.add_argument('--channel' , type=int, default=16, help='The number of channels.') - parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.') - args = parser.parse_args() - - save_dir = Path(args.base_save_dir) - meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node) - assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir) - assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) - print ('start the statistics of our nas-benchmark from {:} using {:}.'.format(save_dir, args.target_dir)) - basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells) - - if args.mode == 'cal': - simplify(save_dir, meta_path, basestr, args.target_dir) - elif args.mode == 'merge': - merge_all(save_dir, meta_path, basestr) - else: - raise ValueError('invalid mode : {:}'.format(args.mode)) + raise ValueError("invalid mode : {:}".format(args.mode)) diff --git a/exps/NAS-Bench-201/statistics.py b/exps/NAS-Bench-201/statistics.py index 19b9c90..80985b9 100644 --- a/exps/NAS-Bench-201/statistics.py +++ b/exps/NAS-Bench-201/statistics.py @@ -6,284 +6,504 @@ from copy import deepcopy import torch from pathlib import Path from collections import defaultdict -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from log_utils import AverageMeter, time_string, convert_secs2time from config_utils import load_config, dict2config -from datasets import get_datasets +from datasets import get_datasets + # NAS-Bench-201 related module or function -from models import CellStructure, get_cell_based_tiny_net -from nas_201_api import ArchResults, ResultsCount -from procedures import bench_pure_evaluate as pure_evaluate +from models import CellStructure, get_cell_based_tiny_net +from nas_201_api import ArchResults, ResultsCount +from procedures import bench_pure_evaluate as pure_evaluate def create_result_count(used_seed, dataset, arch_config, results, dataloader_dict): - xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'], \ - results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None) + xresult = ResultsCount( + dataset, + results["net_state_dict"], + results["train_acc1es"], + results["train_losses"], + results["param"], + results["flop"], + arch_config, + used_seed, + results["total_epoch"], + None, + ) - net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes':arch_config['class_num']}, None) - network = get_cell_based_tiny_net(net_config) - network.load_state_dict(xresult.get_net_param()) - if 'train_times' in results: # new version - xresult.update_train_info(results['train_acc1es'], results['train_acc5es'], results['train_losses'], results['train_times']) - xresult.update_eval(results['valid_acc1es'], results['valid_losses'], results['valid_times']) - else: - if dataset == 'cifar10-valid': - xresult.update_OLD_eval('x-valid' , results['valid_acc1es'], results['valid_losses']) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format('cifar10', 'test')], network.cuda()) - xresult.update_OLD_eval('ori-test', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) - xresult.update_latency(latencies) - elif dataset == 'cifar10': - xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses']) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda()) - xresult.update_latency(latencies) - elif dataset == 'cifar100' or dataset == 'ImageNet16-120': - xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses']) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'valid')], network.cuda()) - xresult.update_OLD_eval('x-valid', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda()) - xresult.update_OLD_eval('x-test' , {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) - xresult.update_latency(latencies) + net_config = dict2config( + { + "name": "infer.tiny", + "C": arch_config["channel"], + "N": arch_config["num_cells"], + "genotype": CellStructure.str2structure(arch_config["arch_str"]), + "num_classes": arch_config["class_num"], + }, + None, + ) + network = get_cell_based_tiny_net(net_config) + network.load_state_dict(xresult.get_net_param()) + if "train_times" in results: # new version + xresult.update_train_info( + results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"] + ) + xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"]) else: - raise ValueError('invalid dataset name : {:}'.format(dataset)) - return xresult - + if dataset == "cifar10-valid": + xresult.update_OLD_eval("x-valid", results["valid_acc1es"], results["valid_losses"]) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda() + ) + xresult.update_OLD_eval("ori-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_latency(latencies) + elif dataset == "cifar10": + xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() + ) + xresult.update_latency(latencies) + elif dataset == "cifar100" or dataset == "ImageNet16-120": + xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda() + ) + xresult.update_OLD_eval("x-valid", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() + ) + xresult.update_OLD_eval("x-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_latency(latencies) + else: + raise ValueError("invalid dataset name : {:}".format(dataset)) + return xresult def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dict): - information = ArchResults(arch_index, arch_str) + information = ArchResults(arch_index, arch_str) - for checkpoint_path in checkpoints: - checkpoint = torch.load(checkpoint_path, map_location='cpu') - used_seed = checkpoint_path.name.split('-')[-1].split('.')[0] - for dataset in datasets: - assert dataset in checkpoint, 'Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path) - results = checkpoint[dataset] - assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path) - arch_config = {'channel': results['channel'], 'num_cells': results['num_cells'], 'arch_str': arch_str, 'class_num': results['config']['class_num']} - - xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict) - information.update(dataset, int(used_seed), xresult) - return information + for checkpoint_path in checkpoints: + checkpoint = torch.load(checkpoint_path, map_location="cpu") + used_seed = checkpoint_path.name.split("-")[-1].split(".")[0] + for dataset in datasets: + assert dataset in checkpoint, "Can not find {:} in arch-{:} from {:}".format( + dataset, arch_index, checkpoint_path + ) + results = checkpoint[dataset] + assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( + arch_index, used_seed, dataset, checkpoint_path + ) + arch_config = { + "channel": results["channel"], + "num_cells": results["num_cells"], + "arch_str": arch_str, + "class_num": results["config"]["class_num"], + } + + xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict) + information.update(dataset, int(used_seed), xresult) + return information def GET_DataLoaders(workers): - torch.set_num_threads(workers) + torch.set_num_threads(workers) - root_dir = (Path(__file__).parent / '..' / '..').resolve() - torch_dir = Path(os.environ['TORCH_HOME']) - # cifar - cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config' - cifar_config = load_config(cifar_config_path, None, None) - print ('{:} Create data-loader for all datasets'.format(time_string())) - print ('-'*200) - TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1) - print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num)) - cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None) - assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14] - temp_dataset = deepcopy(TRAIN_CIFAR10) - temp_dataset.transform = VALID_CIFAR10.transform - # data loader - trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) - train_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True) - valid_cifar10_loader = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True) - test__cifar10_loader = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) - print ('CIFAR-10 : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size)) - print ('CIFAR-10 : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size)) - print ('CIFAR-10 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size)) - print ('CIFAR-10 : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size)) - print ('-'*200) - # CIFAR-100 - TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1) - print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num)) - cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None) - assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24] - train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True) - valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True) - test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True) - print ('CIFAR-100 : train-loader has {:3d} batch'.format(len(train_cifar100_loader))) - print ('CIFAR-100 : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader))) - print ('CIFAR-100 : test--loader has {:3d} batch'.format(len(test__cifar100_loader))) - print ('-'*200) + root_dir = (Path(__file__).parent / ".." / "..").resolve() + torch_dir = Path(os.environ["TORCH_HOME"]) + # cifar + cifar_config_path = root_dir / "configs" / "nas-benchmark" / "CIFAR.config" + cifar_config = load_config(cifar_config_path, None, None) + print("{:} Create data-loader for all datasets".format(time_string())) + print("-" * 200) + TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets("cifar10", str(torch_dir / "cifar.python"), -1) + print( + "original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format( + len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num + ) + ) + cifar10_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None) + assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [ + 1, + 2, + 3, + 4, + 6, + 8, + 9, + 10, + 12, + 14, + ] + temp_dataset = deepcopy(TRAIN_CIFAR10) + temp_dataset.transform = VALID_CIFAR10.transform + # data loader + trainval_cifar10_loader = torch.utils.data.DataLoader( + TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + ) + train_cifar10_loader = torch.utils.data.DataLoader( + TRAIN_CIFAR10, + batch_size=cifar_config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), + num_workers=workers, + pin_memory=True, + ) + valid_cifar10_loader = torch.utils.data.DataLoader( + temp_dataset, + batch_size=cifar_config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), + num_workers=workers, + pin_memory=True, + ) + test__cifar10_loader = torch.utils.data.DataLoader( + VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + ) + print( + "CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format( + len(trainval_cifar10_loader), cifar_config.batch_size + ) + ) + print( + "CIFAR-10 : train-loader has {:3d} batch with {:} per batch".format( + len(train_cifar10_loader), cifar_config.batch_size + ) + ) + print( + "CIFAR-10 : valid-loader has {:3d} batch with {:} per batch".format( + len(valid_cifar10_loader), cifar_config.batch_size + ) + ) + print( + "CIFAR-10 : test--loader has {:3d} batch with {:} per batch".format( + len(test__cifar10_loader), cifar_config.batch_size + ) + ) + print("-" * 200) + # CIFAR-100 + TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets("cifar100", str(torch_dir / "cifar.python"), -1) + print( + "original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( + len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num + ) + ) + cifar100_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None) + assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [ + 0, + 2, + 6, + 7, + 9, + 11, + 12, + 17, + 20, + 24, + ] + train_cifar100_loader = torch.utils.data.DataLoader( + TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + ) + valid_cifar100_loader = torch.utils.data.DataLoader( + VALID_CIFAR100, + batch_size=cifar_config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), + num_workers=workers, + pin_memory=True, + ) + test__cifar100_loader = torch.utils.data.DataLoader( + VALID_CIFAR100, + batch_size=cifar_config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest), + num_workers=workers, + pin_memory=True, + ) + print("CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader))) + print("CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))) + print("CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader))) + print("-" * 200) - imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config' - imagenet16_config = load_config(imagenet16_config_path, None, None) - TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1) - print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num)) - imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None) - assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20] - train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True) - valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True) - test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True) - print ('ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size)) - print ('ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size)) - print ('ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size)) + imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config" + imagenet16_config = load_config(imagenet16_config_path, None, None) + TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets( + "ImageNet16-120", str(torch_dir / "cifar.python" / "ImageNet16"), -1 + ) + print( + "original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( + len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num + ) + ) + imagenet_splits = load_config(root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", None, None) + assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [ + 0, + 4, + 5, + 10, + 11, + 13, + 14, + 15, + 17, + 20, + ] + train_imagenet_loader = torch.utils.data.DataLoader( + TRAIN_ImageNet16_120, + batch_size=imagenet16_config.batch_size, + shuffle=True, + num_workers=workers, + pin_memory=True, + ) + valid_imagenet_loader = torch.utils.data.DataLoader( + VALID_ImageNet16_120, + batch_size=imagenet16_config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), + num_workers=workers, + pin_memory=True, + ) + test__imagenet_loader = torch.utils.data.DataLoader( + VALID_ImageNet16_120, + batch_size=imagenet16_config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest), + num_workers=workers, + pin_memory=True, + ) + print( + "ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch".format( + len(train_imagenet_loader), imagenet16_config.batch_size + ) + ) + print( + "ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch".format( + len(valid_imagenet_loader), imagenet16_config.batch_size + ) + ) + print( + "ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch".format( + len(test__imagenet_loader), imagenet16_config.batch_size + ) + ) - # 'cifar10', 'cifar100', 'ImageNet16-120' - loaders = {'cifar10@trainval': trainval_cifar10_loader, - 'cifar10@train' : train_cifar10_loader, - 'cifar10@valid' : valid_cifar10_loader, - 'cifar10@test' : test__cifar10_loader, - 'cifar100@train' : train_cifar100_loader, - 'cifar100@valid' : valid_cifar100_loader, - 'cifar100@test' : test__cifar100_loader, - 'ImageNet16-120@train': train_imagenet_loader, - 'ImageNet16-120@valid': valid_imagenet_loader, - 'ImageNet16-120@test' : test__imagenet_loader} - return loaders + # 'cifar10', 'cifar100', 'ImageNet16-120' + loaders = { + "cifar10@trainval": trainval_cifar10_loader, + "cifar10@train": train_cifar10_loader, + "cifar10@valid": valid_cifar10_loader, + "cifar10@test": test__cifar10_loader, + "cifar100@train": train_cifar100_loader, + "cifar100@valid": valid_cifar100_loader, + "cifar100@test": test__cifar100_loader, + "ImageNet16-120@train": train_imagenet_loader, + "ImageNet16-120@valid": valid_imagenet_loader, + "ImageNet16-120@test": test__imagenet_loader, + } + return loaders def simplify(save_dir, meta_file, basestr, target_dir): - meta_infos = torch.load(meta_file, map_location='cpu') - meta_archs = meta_infos['archs'] # a list of architecture strings - meta_num_archs = meta_infos['total'] - meta_max_node = meta_infos['max_node'] - assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) + meta_infos = torch.load(meta_file, map_location="cpu") + meta_archs = meta_infos["archs"] # a list of architecture strings + meta_num_archs = meta_infos["total"] + meta_max_node = meta_infos["max_node"] + assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format( + meta_num_archs, len(meta_archs) + ) - sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) - print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs))) - - subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 - num_seeds = defaultdict(lambda: 0) - for index, sub_dir in enumerate(sub_model_dirs): - xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth')) - arch_indexes = set() - for checkpoint in xcheckpoints: - temp_names = checkpoint.name.split('-') - assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name) - arch_indexes.add( temp_names[1] ) - subdir2archs[sub_dir] = sorted(list(arch_indexes)) - num_evaluated_arch += len(arch_indexes) - # count number of seeds for each architecture - for arch_index in arch_indexes: - num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1 - print('{:} There are {:5d} architectures that have been evaluated ({:} in total).'.format(time_string(), num_evaluated_arch, meta_num_archs)) - for key in sorted( list( num_seeds.keys() ) ): print ('{:} There are {:5d} architectures that are evaluated {:} times.'.format(time_string(), num_seeds[key], key)) + sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) + print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs))) - dataloader_dict = GET_DataLoaders( 6 ) + subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 + num_seeds = defaultdict(lambda: 0) + for index, sub_dir in enumerate(sub_model_dirs): + xcheckpoints = list(sub_dir.glob("arch-*-seed-*.pth")) + arch_indexes = set() + for checkpoint in xcheckpoints: + temp_names = checkpoint.name.split("-") + assert ( + len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed" + ), "invalid checkpoint name : {:}".format(checkpoint.name) + arch_indexes.add(temp_names[1]) + subdir2archs[sub_dir] = sorted(list(arch_indexes)) + num_evaluated_arch += len(arch_indexes) + # count number of seeds for each architecture + for arch_index in arch_indexes: + num_seeds[len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))] += 1 + print( + "{:} There are {:5d} architectures that have been evaluated ({:} in total).".format( + time_string(), num_evaluated_arch, meta_num_archs + ) + ) + for key in sorted(list(num_seeds.keys())): + print( + "{:} There are {:5d} architectures that are evaluated {:} times.".format(time_string(), num_seeds[key], key) + ) - to_save_simply = save_dir / 'simplifies' - to_save_allarc = save_dir / 'simplifies' / 'architectures' - if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True) - if not to_save_allarc.exists(): to_save_allarc.mkdir(parents=True, exist_ok=True) + dataloader_dict = GET_DataLoaders(6) - assert (save_dir / target_dir) in subdir2archs, 'can not find {:}'.format(target_dir) - arch2infos, datasets = {}, ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120') - evaluated_indexes = set() - target_directory = save_dir / target_dir - target_less_dir = save_dir / '{:}-LESS'.format(target_dir) - arch_indexes = subdir2archs[ target_directory ] - num_seeds = defaultdict(lambda: 0) - end_time = time.time() - arch_time = AverageMeter() - for idx, arch_index in enumerate(arch_indexes): - checkpoints = list(target_directory.glob('arch-{:}-seed-*.pth'.format(arch_index))) - ckps_less = list(target_less_dir.glob('arch-{:}-seed-*.pth'.format(arch_index))) - # create the arch info for each architecture - try: - arch_info_full = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict) - arch_info_less = account_one_arch(arch_index, meta_archs[int(arch_index)], ckps_less, ['cifar10-valid'], dataloader_dict) - num_seeds[ len(checkpoints) ] += 1 - except: - print('Loading {:} failed, : {:}'.format(arch_index, checkpoints)) - continue - assert int(arch_index) not in evaluated_indexes, 'conflict arch-index : {:}'.format(arch_index) - assert 0 <= int(arch_index) < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index) - arch_info = {'full': arch_info_full, 'less': arch_info_less} - evaluated_indexes.add( int(arch_index) ) - arch2infos[int(arch_index)] = arch_info - torch.save({'full': arch_info_full.state_dict(), - 'less': arch_info_less.state_dict()}, to_save_allarc / '{:}-FULL.pth'.format(arch_index)) - arch_info['full'].clear_params() - arch_info['less'].clear_params() - torch.save({'full': arch_info_full.state_dict(), - 'less': arch_info_less.state_dict()}, to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index)) - # measure elapsed time - arch_time.update(time.time() - end_time) - end_time = time.time() - need_time = '{:}'.format( convert_secs2time(arch_time.avg * (len(arch_indexes)-idx-1), True) ) - print('{:} {:} [{:03d}/{:03d}] : {:} still need {:}'.format(time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time)) - # measure time - xstrs = ['{:}:{:03d}'.format(key, num_seeds[key]) for key in sorted( list( num_seeds.keys() ) ) ] - print('{:} {:} done : {:}'.format(time_string(), target_dir, xstrs)) - final_infos = {'meta_archs' : meta_archs, - 'total_archs': meta_num_archs, - 'basestr' : basestr, - 'arch2infos' : arch2infos, - 'evaluated_indexes': evaluated_indexes} - save_file_name = to_save_simply / '{:}.pth'.format(target_dir) - torch.save(final_infos, save_file_name) - print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name)) + to_save_simply = save_dir / "simplifies" + to_save_allarc = save_dir / "simplifies" / "architectures" + if not to_save_simply.exists(): + to_save_simply.mkdir(parents=True, exist_ok=True) + if not to_save_allarc.exists(): + to_save_allarc.mkdir(parents=True, exist_ok=True) + + assert (save_dir / target_dir) in subdir2archs, "can not find {:}".format(target_dir) + arch2infos, datasets = {}, ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") + evaluated_indexes = set() + target_directory = save_dir / target_dir + target_less_dir = save_dir / "{:}-LESS".format(target_dir) + arch_indexes = subdir2archs[target_directory] + num_seeds = defaultdict(lambda: 0) + end_time = time.time() + arch_time = AverageMeter() + for idx, arch_index in enumerate(arch_indexes): + checkpoints = list(target_directory.glob("arch-{:}-seed-*.pth".format(arch_index))) + ckps_less = list(target_less_dir.glob("arch-{:}-seed-*.pth".format(arch_index))) + # create the arch info for each architecture + try: + arch_info_full = account_one_arch( + arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict + ) + arch_info_less = account_one_arch( + arch_index, meta_archs[int(arch_index)], ckps_less, ["cifar10-valid"], dataloader_dict + ) + num_seeds[len(checkpoints)] += 1 + except: + print("Loading {:} failed, : {:}".format(arch_index, checkpoints)) + continue + assert int(arch_index) not in evaluated_indexes, "conflict arch-index : {:}".format(arch_index) + assert 0 <= int(arch_index) < len(meta_archs), "invalid arch-index {:} (not found in meta_archs)".format( + arch_index + ) + arch_info = {"full": arch_info_full, "less": arch_info_less} + evaluated_indexes.add(int(arch_index)) + arch2infos[int(arch_index)] = arch_info + torch.save( + {"full": arch_info_full.state_dict(), "less": arch_info_less.state_dict()}, + to_save_allarc / "{:}-FULL.pth".format(arch_index), + ) + arch_info["full"].clear_params() + arch_info["less"].clear_params() + torch.save( + {"full": arch_info_full.state_dict(), "less": arch_info_less.state_dict()}, + to_save_allarc / "{:}-SIMPLE.pth".format(arch_index), + ) + # measure elapsed time + arch_time.update(time.time() - end_time) + end_time = time.time() + need_time = "{:}".format(convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1), True)) + print( + "{:} {:} [{:03d}/{:03d}] : {:} still need {:}".format( + time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time + ) + ) + # measure time + xstrs = ["{:}:{:03d}".format(key, num_seeds[key]) for key in sorted(list(num_seeds.keys()))] + print("{:} {:} done : {:}".format(time_string(), target_dir, xstrs)) + final_infos = { + "meta_archs": meta_archs, + "total_archs": meta_num_archs, + "basestr": basestr, + "arch2infos": arch2infos, + "evaluated_indexes": evaluated_indexes, + } + save_file_name = to_save_simply / "{:}.pth".format(target_dir) + torch.save(final_infos, save_file_name) + print( + "Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name) + ) def merge_all(save_dir, meta_file, basestr): - meta_infos = torch.load(meta_file, map_location='cpu') - meta_archs = meta_infos['archs'] - meta_num_archs = meta_infos['total'] - meta_max_node = meta_infos['max_node'] - assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) + meta_infos = torch.load(meta_file, map_location="cpu") + meta_archs = meta_infos["archs"] + meta_num_archs = meta_infos["total"] + meta_max_node = meta_infos["max_node"] + assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format( + meta_num_archs, len(meta_archs) + ) - sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) - print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs))) - for index, sub_dir in enumerate(sub_model_dirs): - arch_info_files = sorted( list(sub_dir.glob('arch-*-seed-*.pth') ) ) - print ('The {:02d}/{:02d}-th directory : {:} : {:} runs.'.format(index, len(sub_model_dirs), sub_dir, len(arch_info_files))) - - arch2infos, evaluated_indexes = dict(), set() - for IDX, sub_dir in enumerate(sub_model_dirs): - ckp_path = sub_dir.parent / 'simplifies' / '{:}.pth'.format(sub_dir.name) - if ckp_path.exists(): - sub_ckps = torch.load(ckp_path, map_location='cpu') - assert sub_ckps['total_archs'] == meta_num_archs and sub_ckps['basestr'] == basestr - xarch2infos = sub_ckps['arch2infos'] - xevalindexs = sub_ckps['evaluated_indexes'] - for eval_index in xevalindexs: - assert eval_index not in evaluated_indexes and eval_index not in arch2infos - #arch2infos[eval_index] = xarch2infos[eval_index].state_dict() - arch2infos[eval_index] = {'full': xarch2infos[eval_index]['full'].state_dict(), - 'less': xarch2infos[eval_index]['less'].state_dict()} - evaluated_indexes.add( eval_index ) - print ('{:} [{:03d}/{:03d}] merge data from {:} with {:} models.'.format(time_string(), IDX, len(sub_model_dirs), ckp_path, len(xevalindexs))) + sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) + print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs))) + for index, sub_dir in enumerate(sub_model_dirs): + arch_info_files = sorted(list(sub_dir.glob("arch-*-seed-*.pth"))) + print( + "The {:02d}/{:02d}-th directory : {:} : {:} runs.".format( + index, len(sub_model_dirs), sub_dir, len(arch_info_files) + ) + ) + + arch2infos, evaluated_indexes = dict(), set() + for IDX, sub_dir in enumerate(sub_model_dirs): + ckp_path = sub_dir.parent / "simplifies" / "{:}.pth".format(sub_dir.name) + if ckp_path.exists(): + sub_ckps = torch.load(ckp_path, map_location="cpu") + assert sub_ckps["total_archs"] == meta_num_archs and sub_ckps["basestr"] == basestr + xarch2infos = sub_ckps["arch2infos"] + xevalindexs = sub_ckps["evaluated_indexes"] + for eval_index in xevalindexs: + assert eval_index not in evaluated_indexes and eval_index not in arch2infos + # arch2infos[eval_index] = xarch2infos[eval_index].state_dict() + arch2infos[eval_index] = { + "full": xarch2infos[eval_index]["full"].state_dict(), + "less": xarch2infos[eval_index]["less"].state_dict(), + } + evaluated_indexes.add(eval_index) + print( + "{:} [{:03d}/{:03d}] merge data from {:} with {:} models.".format( + time_string(), IDX, len(sub_model_dirs), ckp_path, len(xevalindexs) + ) + ) + else: + raise ValueError("Can not find {:}".format(ckp_path)) + # print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path)) + + evaluated_indexes = sorted(list(evaluated_indexes)) + print("Finally, there are {:} architectures that have been trained and evaluated.".format(len(evaluated_indexes))) + + to_save_simply = save_dir / "simplifies" + if not to_save_simply.exists(): + to_save_simply.mkdir(parents=True, exist_ok=True) + final_infos = { + "meta_archs": meta_archs, + "total_archs": meta_num_archs, + "arch2infos": arch2infos, + "evaluated_indexes": evaluated_indexes, + } + save_file_name = to_save_simply / "{:}-final-infos.pth".format(basestr) + torch.save(final_infos, save_file_name) + print( + "Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name) + ) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="NAS-BENCH-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--mode", type=str, choices=["cal", "merge"], help="The running mode for this script.") + parser.add_argument( + "--base_save_dir", + type=str, + default="./output/NAS-BENCH-201-4", + help="The base-name of folder to save checkpoints and log.", + ) + parser.add_argument("--target_dir", type=str, help="The target directory.") + parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.") + parser.add_argument("--channel", type=int, default=16, help="The number of channels.") + parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + args = parser.parse_args() + + save_dir = Path(args.base_save_dir) + meta_path = save_dir / "meta-node-{:}.pth".format(args.max_node) + assert save_dir.exists(), "invalid save dir path : {:}".format(save_dir) + assert meta_path.exists(), "invalid saved meta path : {:}".format(meta_path) + print("start the statistics of our nas-benchmark from {:} using {:}.".format(save_dir, args.target_dir)) + basestr = "C{:}-N{:}".format(args.channel, args.num_cells) + + if args.mode == "cal": + simplify(save_dir, meta_path, basestr, args.target_dir) + elif args.mode == "merge": + merge_all(save_dir, meta_path, basestr) else: - raise ValueError('Can not find {:}'.format(ckp_path)) - #print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path)) - - evaluated_indexes = sorted( list( evaluated_indexes ) ) - print ('Finally, there are {:} architectures that have been trained and evaluated.'.format(len(evaluated_indexes))) - - to_save_simply = save_dir / 'simplifies' - if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True) - final_infos = {'meta_archs' : meta_archs, - 'total_archs': meta_num_archs, - 'arch2infos' : arch2infos, - 'evaluated_indexes': evaluated_indexes} - save_file_name = to_save_simply / '{:}-final-infos.pth'.format(basestr) - torch.save(final_infos, save_file_name) - print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name)) - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--mode' , type=str, choices=['cal', 'merge'], help='The running mode for this script.') - parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.') - parser.add_argument('--target_dir' , type=str, help='The target directory.') - parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.') - parser.add_argument('--channel' , type=int, default=16, help='The number of channels.') - parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.') - args = parser.parse_args() - - save_dir = Path(args.base_save_dir) - meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node) - assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir) - assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) - print ('start the statistics of our nas-benchmark from {:} using {:}.'.format(save_dir, args.target_dir)) - basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells) - - if args.mode == 'cal': - simplify(save_dir, meta_path, basestr, args.target_dir) - elif args.mode == 'merge': - merge_all(save_dir, meta_path, basestr) - else: - raise ValueError('invalid mode : {:}'.format(args.mode)) \ No newline at end of file + raise ValueError("invalid mode : {:}".format(args.mode)) diff --git a/exps/NAS-Bench-201/test-correlation.py b/exps/NAS-Bench-201/test-correlation.py index 0a49634..aaf2e14 100644 --- a/exps/NAS-Bench-201/test-correlation.py +++ b/exps/NAS-Bench-201/test-correlation.py @@ -9,123 +9,151 @@ from copy import deepcopy from tqdm import tqdm import torch from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from log_utils import time_string -from models import CellStructure -from nas_201_api import NASBench201API as API + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from log_utils import time_string +from models import CellStructure +from nas_201_api import NASBench201API as API def check_unique_arch(meta_file): - api = API(str(meta_file)) - arch_strs = deepcopy(api.meta_archs) - xarchs = [CellStructure.str2structure(x) for x in arch_strs] - def get_unique_matrix(archs, consider_zero): - UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] - print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs))) - Unique2Index = dict() - for index, xstr in enumerate(UniquStrs): - if xstr not in Unique2Index: Unique2Index[xstr] = list() - Unique2Index[xstr].append( index ) - sm_matrix = torch.eye(len(archs)).bool() - for _, xlist in Unique2Index.items(): - for i in xlist: - for j in xlist: - sm_matrix[i,j] = True - unique_ids, unique_num = [-1 for _ in archs], 0 - for i in range(len(unique_ids)): - if unique_ids[i] > -1: continue - neighbours = sm_matrix[i].nonzero().view(-1).tolist() - for nghb in neighbours: - assert unique_ids[nghb] == -1, 'impossible' - unique_ids[nghb] = unique_num - unique_num += 1 - return sm_matrix, unique_ids, unique_num + api = API(str(meta_file)) + arch_strs = deepcopy(api.meta_archs) + xarchs = [CellStructure.str2structure(x) for x in arch_strs] - print ('There are {:} valid-archs'.format( sum(arch.check_valid() for arch in xarchs) )) - sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None) - print ('{:} There are {:} unique architectures (considering nothing).'.format(time_string(), unique_num)) - sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False) - print ('{:} There are {:} unique architectures (not considering zero).'.format(time_string(), unique_num)) - sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True) - print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num)) + def get_unique_matrix(archs, consider_zero): + UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] + print("{:} create unique-string ({:}/{:}) done".format(time_string(), len(set(UniquStrs)), len(UniquStrs))) + Unique2Index = dict() + for index, xstr in enumerate(UniquStrs): + if xstr not in Unique2Index: + Unique2Index[xstr] = list() + Unique2Index[xstr].append(index) + sm_matrix = torch.eye(len(archs)).bool() + for _, xlist in Unique2Index.items(): + for i in xlist: + for j in xlist: + sm_matrix[i, j] = True + unique_ids, unique_num = [-1 for _ in archs], 0 + for i in range(len(unique_ids)): + if unique_ids[i] > -1: + continue + neighbours = sm_matrix[i].nonzero().view(-1).tolist() + for nghb in neighbours: + assert unique_ids[nghb] == -1, "impossible" + unique_ids[nghb] = unique_num + unique_num += 1 + return sm_matrix, unique_ids, unique_num + + print("There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs))) + sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None) + print("{:} There are {:} unique architectures (considering nothing).".format(time_string(), unique_num)) + sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False) + print("{:} There are {:} unique architectures (not considering zero).".format(time_string(), unique_num)) + sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True) + print("{:} There are {:} unique architectures (considering zero).".format(time_string(), unique_num)) def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False): - if isinstance(meta_file, API): - api = meta_file - else: - api = API(str(meta_file)) - cifar10_currs = [] - cifar10_valid = [] - cifar10_test = [] - cifar100_valid = [] - cifar100_test = [] - imagenet_test = [] - imagenet_valid = [] - for idx, arch in enumerate(api): - results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) - cifar10_currs.append( results['valid-accuracy'] ) - # --->>>>> - results = api.get_more_info(idx, 'cifar10-valid' , None, False, is_rand) - cifar10_valid.append( results['valid-accuracy'] ) - results = api.get_more_info(idx, 'cifar10' , None, False, is_rand) - cifar10_test.append( results['test-accuracy'] ) - results = api.get_more_info(idx, 'cifar100' , None, False, is_rand) - cifar100_test.append( results['test-accuracy'] ) - cifar100_valid.append( results['valid-accuracy'] ) - results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand) - imagenet_test.append( results['test-accuracy'] ) - imagenet_valid.append( results['valid-accuracy'] ) - def get_cor(A, B): - return float(np.corrcoef(A, B)[0,1]) - cors = [] - for basestr, xlist in zip(['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]): - correlation = get_cor(cifar10_currs, xlist) - if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation)) - cors.append( correlation ) - #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) - #print('-'*200) - #print('*'*230) - return cors + if isinstance(meta_file, API): + api = meta_file + else: + api = API(str(meta_file)) + cifar10_currs = [] + cifar10_valid = [] + cifar10_test = [] + cifar100_valid = [] + cifar100_test = [] + imagenet_test = [] + imagenet_valid = [] + for idx, arch in enumerate(api): + results = api.get_more_info(idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand) + cifar10_currs.append(results["valid-accuracy"]) + # --->>>>> + results = api.get_more_info(idx, "cifar10-valid", None, False, is_rand) + cifar10_valid.append(results["valid-accuracy"]) + results = api.get_more_info(idx, "cifar10", None, False, is_rand) + cifar10_test.append(results["test-accuracy"]) + results = api.get_more_info(idx, "cifar100", None, False, is_rand) + cifar100_test.append(results["test-accuracy"]) + cifar100_valid.append(results["valid-accuracy"]) + results = api.get_more_info(idx, "ImageNet16-120", None, False, is_rand) + imagenet_test.append(results["test-accuracy"]) + imagenet_valid.append(results["valid-accuracy"]) + + def get_cor(A, B): + return float(np.corrcoef(A, B)[0, 1]) + + cors = [] + for basestr, xlist in zip( + ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"], + [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test], + ): + correlation = get_cor(cifar10_currs, xlist) + if need_print: + print( + "With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}".format( + test_epoch, "012" if use_less_or_not else "200", basestr, correlation + ) + ) + cors.append(correlation) + # print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) + # print('-'*200) + # print('*'*230) + return cors def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand): - corrs = [] - for i in tqdm(range(100)): - x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) - corrs.append( x ) - #xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] - xstrs = ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] - correlations = np.array(corrs) - print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200')) - for idx, xstr in enumerate(xstrs): - print ('{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}'.format(xstr, correlations[:,idx].mean(), correlations[:,idx].std(), correlations[:,idx].mean(), correlations[:,idx].std())) - print('') + corrs = [] + for i in tqdm(range(100)): + x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) + corrs.append(x) + # xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] + xstrs = ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"] + correlations = np.array(corrs) + print("------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(test_epoch, "012" if use_less_or_not else "200")) + for idx, xstr in enumerate(xstrs): + print( + "{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}".format( + xstr, + correlations[:, idx].mean(), + correlations[:, idx].std(), + correlations[:, idx].mean(), + correlations[:, idx].std(), + ) + ) + print("") -if __name__ == '__main__': - parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") - parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.') - parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") + parser.add_argument( + "--save_dir", + type=str, + default="./output/search-cell-nas-bench-201/visuals", + help="The base-name of folder to save checkpoints and log.", + ) + parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.") + args = parser.parse_args() - vis_save_dir = Path(args.save_dir) - vis_save_dir.mkdir(parents=True, exist_ok=True) - meta_file = Path(args.api_path) - assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) + vis_save_dir = Path(args.save_dir) + vis_save_dir.mkdir(parents=True, exist_ok=True) + meta_file = Path(args.api_path) + assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) - #check_unique_arch(meta_file) - api = API(str(meta_file)) - #for iepoch in [11, 25, 50, 100, 150, 175, 200]: - # check_cor_for_bandit(api, 6, iepoch) - # check_cor_for_bandit(api, 12, iepoch) - check_cor_for_bandit_v2(api, 6, True, True) - check_cor_for_bandit_v2(api, 12, True, True) - check_cor_for_bandit_v2(api, 12, False, True) - check_cor_for_bandit_v2(api, 24, False, True) - check_cor_for_bandit_v2(api, 100, False, True) - check_cor_for_bandit_v2(api, 150, False, True) - check_cor_for_bandit_v2(api, 175, False, True) - check_cor_for_bandit_v2(api, 200, False, True) - print('----') + # check_unique_arch(meta_file) + api = API(str(meta_file)) + # for iepoch in [11, 25, 50, 100, 150, 175, 200]: + # check_cor_for_bandit(api, 6, iepoch) + # check_cor_for_bandit(api, 12, iepoch) + check_cor_for_bandit_v2(api, 6, True, True) + check_cor_for_bandit_v2(api, 12, True, True) + check_cor_for_bandit_v2(api, 12, False, True) + check_cor_for_bandit_v2(api, 24, False, True) + check_cor_for_bandit_v2(api, 100, False, True) + check_cor_for_bandit_v2(api, 150, False, True) + check_cor_for_bandit_v2(api, 175, False, True) + check_cor_for_bandit_v2(api, 200, False, True) + print("----") diff --git a/exps/NAS-Bench-201/visualize.py b/exps/NAS-Bench-201/visualize.py index 451e614..8e91a92 100644 --- a/exps/NAS-Bench-201/visualize.py +++ b/exps/NAS-Bench-201/visualize.py @@ -13,534 +13,686 @@ from collections import defaultdict import matplotlib import seaborn as sns from mpl_toolkits.mplot3d import Axes3D -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from log_utils import time_string -from nas_201_api import NASBench201API as API - +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from log_utils import time_string +from nas_201_api import NASBench201API as API def calculate_correlation(*vectors): - matrix = [] - for i, vectori in enumerate(vectors): - x = [] - for j, vectorj in enumerate(vectors): - x.append( np.corrcoef(vectori, vectorj)[0,1] ) - matrix.append( x ) - return np.array(matrix) - + matrix = [] + for i, vectori in enumerate(vectors): + x = [] + for j, vectorj in enumerate(vectors): + x.append(np.corrcoef(vectori, vectorj)[0, 1]) + matrix.append(x) + return np.array(matrix) def visualize_relative_ranking(vis_save_dir): - print ('\n' + '-'*100) - cifar010_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar10') - cifar100_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar100') - imagenet_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('ImageNet16-120') - cifar010_info = torch.load(cifar010_cache_path) - cifar100_info = torch.load(cifar100_cache_path) - imagenet_info = torch.load(imagenet_cache_path) - indexes = list(range(len(cifar010_info['params']))) + print("\n" + "-" * 100) + cifar010_cache_path = vis_save_dir / "{:}-cache-info.pth".format("cifar10") + cifar100_cache_path = vis_save_dir / "{:}-cache-info.pth".format("cifar100") + imagenet_cache_path = vis_save_dir / "{:}-cache-info.pth".format("ImageNet16-120") + cifar010_info = torch.load(cifar010_cache_path) + cifar100_info = torch.load(cifar100_cache_path) + imagenet_info = torch.load(imagenet_cache_path) + indexes = list(range(len(cifar010_info["params"]))) - print ('{:} start to visualize relative ranking'.format(time_string())) - # maximum accuracy with ResNet-level params 11472 - x_010_accs = [ cifar010_info['test_accs'][i] if cifar010_info['params'][i] <= cifar010_info['params'][11472] else -1 for i in indexes] - x_100_accs = [ cifar100_info['test_accs'][i] if cifar100_info['params'][i] <= cifar100_info['params'][11472] else -1 for i in indexes] - x_img_accs = [ imagenet_info['test_accs'][i] if imagenet_info['params'][i] <= imagenet_info['params'][11472] else -1 for i in indexes] - - cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i]) - cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i]) - imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i]) + print("{:} start to visualize relative ranking".format(time_string())) + # maximum accuracy with ResNet-level params 11472 + x_010_accs = [ + cifar010_info["test_accs"][i] if cifar010_info["params"][i] <= cifar010_info["params"][11472] else -1 + for i in indexes + ] + x_100_accs = [ + cifar100_info["test_accs"][i] if cifar100_info["params"][i] <= cifar100_info["params"][11472] else -1 + for i in indexes + ] + x_img_accs = [ + imagenet_info["test_accs"][i] if imagenet_info["params"][i] <= imagenet_info["params"][11472] else -1 + for i in indexes + ] - cifar100_labels, imagenet_labels = [], [] - for idx in cifar010_ord_indexes: - cifar100_labels.append( cifar100_ord_indexes.index(idx) ) - imagenet_labels.append( imagenet_ord_indexes.index(idx) ) - print ('{:} prepare data done.'.format(time_string())) + cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i]) + cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i]) + imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i]) - dpi, width, height = 300, 2600, 2600 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 18, 18 - resnet_scale, resnet_alpha = 120, 0.5 + cifar100_labels, imagenet_labels = [], [] + for idx in cifar010_ord_indexes: + cifar100_labels.append(cifar100_ord_indexes.index(idx)) + imagenet_labels.append(imagenet_ord_indexes.index(idx)) + print("{:} prepare data done.".format(time_string())) - fig = plt.figure(figsize=figsize) - ax = fig.add_subplot(111) - plt.xlim(min(indexes), max(indexes)) - plt.ylim(min(indexes), max(indexes)) - #plt.ylabel('y').set_rotation(0) - plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical') - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize) - #ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100') - #ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8, label='ImageNet-16-120') - #ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10') - ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) - ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8) - ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8) - ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10') - ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100') - ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120') - plt.grid(zorder=0) - ax.set_axisbelow(True) - plt.legend(loc=0, fontsize=LegendFontsize) - ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize) - ax.set_ylabel('architecture ranking', fontsize=LabelSize) - save_path = (vis_save_dir / 'relative-rank.pdf').resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / 'relative-rank.png').resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - - # calculate correlation - sns_size = 15 - CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs']) - fig = plt.figure(figsize=figsize) - plt.axis('off') - h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5) - save_path = (vis_save_dir / 'co-relation-all.pdf').resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - print ('{:} save into {:}'.format(time_string(), save_path)) - - # calculate correlation - acc_bars = [92, 93] - for acc_bar in acc_bars: - selected_indexes = [] - for i, acc in enumerate(cifar010_info['test_accs']): - if acc > acc_bar: selected_indexes.append( i ) - print ('select {:} architectures'.format(len(selected_indexes))) - cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ] - cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ] - cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ] - cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ] - imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ] - imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ] - CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs) - fig = plt.figure(figsize=figsize) - plt.axis('off') - h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5) - save_path = (vis_save_dir / 'co-relation-top-{:}.pdf'.format(len(selected_indexes))).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') - - - -def visualize_info(meta_file, dataset, vis_save_dir): - print ('{:} start to visualize {:} information'.format(time_string(), dataset)) - cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset) - if not cache_file_path.exists(): - print ('Do not find cache file : {:}'.format(cache_file_path)) - nas_bench = API(str(meta_file)) - params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], [] - for index in range( len(nas_bench) ): - info = nas_bench.query_by_index(index, use_12epochs_result=False) - resx = info.get_comput_costs(dataset) ; flop, param = resx['flops'], resx['params'] - if dataset == 'cifar10': - res = info.get_metrics('cifar10', 'train') ; train_acc = res['accuracy'] - res = info.get_metrics('cifar10-valid', 'x-valid') ; valid_acc = res['accuracy'] - res = info.get_metrics('cifar10', 'ori-test') ; test_acc = res['accuracy'] - res = info.get_metrics('cifar10', 'ori-test') ; otest_acc = res['accuracy'] - else: - res = info.get_metrics(dataset, 'train') ; train_acc = res['accuracy'] - res = info.get_metrics(dataset, 'x-valid') ; valid_acc = res['accuracy'] - res = info.get_metrics(dataset, 'x-test') ; test_acc = res['accuracy'] - res = info.get_metrics(dataset, 'ori-test') ; otest_acc = res['accuracy'] - if index == 11472: # resnet - resnet = {'params':param, 'flops': flop, 'index': 11472, 'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': test_acc, 'otest_acc': otest_acc} - flops.append( flop ) - params.append( param ) - train_accs.append( train_acc ) - valid_accs.append( valid_acc ) - test_accs.append( test_acc ) - otest_accs.append( otest_acc ) - #resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97} - info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs} - info['resnet'] = resnet - torch.save(info, cache_file_path) - else: - print ('Find cache file : {:}'.format(cache_file_path)) - info = torch.load(cache_file_path) - params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs'] - resnet = info['resnet'] - print ('{:} collect data done.'.format(time_string())) - - indexes = list(range(len(params))) - dpi, width, height = 300, 2600, 2600 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 22, 22 - resnet_scale, resnet_alpha = 120, 0.5 - - fig = plt.figure(figsize=figsize) - ax = fig.add_subplot(111) - plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) - if dataset == 'cifar10': - plt.ylim(50, 100) - plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) - elif dataset == 'cifar100': - plt.ylim(25, 75) - plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) - else: - plt.ylim(0, 50) - plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) - ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue') - ax.scatter([resnet['params']], [resnet['valid_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=0.4) - plt.grid(zorder=0) - ax.set_axisbelow(True) - plt.legend(loc=4, fontsize=LegendFontsize) - ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize) - save_path = (vis_save_dir / '{:}-param-vs-valid.pdf'.format(dataset)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / '{:}-param-vs-valid.png'.format(dataset)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - - fig = plt.figure(figsize=figsize) - ax = fig.add_subplot(111) - plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) - if dataset == 'cifar10': - plt.ylim(50, 100) - plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) - elif dataset == 'cifar100': - plt.ylim(25, 75) - plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) - else: - plt.ylim(0, 50) - plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) - ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue') - ax.scatter([resnet['params']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) - plt.grid() - ax.set_axisbelow(True) - plt.legend(loc=4, fontsize=LegendFontsize) - ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize) - save_path = (vis_save_dir / '{:}-param-vs-test.pdf'.format(dataset)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / '{:}-param-vs-test.png'.format(dataset)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - - fig = plt.figure(figsize=figsize) - ax = fig.add_subplot(111) - plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) - if dataset == 'cifar10': - plt.ylim(50, 100) - plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) - elif dataset == 'cifar100': - plt.ylim(20, 100) - plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize) - else: - plt.ylim(25, 76) - plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) - ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') - ax.scatter([resnet['params']], [resnet['train_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) - plt.grid() - ax.set_axisbelow(True) - plt.legend(loc=4, fontsize=LegendFontsize) - ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize) - save_path = (vis_save_dir / '{:}-param-vs-train.pdf'.format(dataset)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / '{:}-param-vs-train.png'.format(dataset)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - - fig = plt.figure(figsize=figsize) - ax = fig.add_subplot(111) - plt.xlim(0, max(indexes)) - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize) - if dataset == 'cifar10': - plt.ylim(50, 100) - plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) - elif dataset == 'cifar100': - plt.ylim(25, 75) - plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) - else: - plt.ylim(0, 50) - plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) - ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') - ax.scatter([resnet['index']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) - plt.grid() - ax.set_axisbelow(True) - plt.legend(loc=4, fontsize=LegendFontsize) - ax.set_xlabel('architecture ID', fontsize=LabelSize) - ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize) - save_path = (vis_save_dir / '{:}-test-over-ID.pdf'.format(dataset)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / '{:}-test-over-ID.png'.format(dataset)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') - - - -def visualize_rank_over_time(meta_file, vis_save_dir): - print ('\n' + '-'*150) - vis_save_dir.mkdir(parents=True, exist_ok=True) - print ('{:} start to visualize rank-over-time into {:}'.format(time_string(), vis_save_dir)) - cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth' - if not cache_file_path.exists(): - print ('Do not find cache file : {:}'.format(cache_file_path)) - nas_bench = API(str(meta_file)) - print ('{:} load nas_bench done'.format(time_string())) - params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list) - #for iepoch in range(200): for index in range( len(nas_bench) ): - for index in tqdm(range(len(nas_bench))): - info = nas_bench.query_by_index(index, use_12epochs_result=False) - for iepoch in range(200): - res = info.get_metrics('cifar10' , 'train' , iepoch) ; train_acc = res['accuracy'] - res = info.get_metrics('cifar10-valid', 'x-valid' , iepoch) ; valid_acc = res['accuracy'] - res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; test_acc = res['accuracy'] - res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; otest_acc = res['accuracy'] - train_accs[iepoch].append( train_acc ) - valid_accs[iepoch].append( valid_acc ) - test_accs [iepoch].append( test_acc ) - otest_accs[iepoch].append( otest_acc ) - if iepoch == 0: - res = info.get_comput_costs('cifar10') ; flop, param = res['flops'], res['params'] - flops.append( flop ) - params.append( param ) - info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs} - torch.save(info, cache_file_path) - else: - print ('Find cache file : {:}'.format(cache_file_path)) - info = torch.load(cache_file_path) - params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs'] - print ('{:} collect data done.'.format(time_string())) - #selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199] - selected_epochs = list( range(200) ) - x_xtests = test_accs[199] - indexes = list(range(len(x_xtests))) - ord_idxs = sorted(indexes, key=lambda i: x_xtests[i]) - for sepoch in selected_epochs: - x_valids = valid_accs[sepoch] - valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i]) - valid_ord_lbls = [] - for idx in ord_idxs: - valid_ord_lbls.append( valid_ord_idxs.index(idx) ) - # labeled data dpi, width, height = 300, 2600, 2600 figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize = 18, 18 + resnet_scale, resnet_alpha = 120, 0.5 fig = plt.figure(figsize=figsize) - ax = fig.add_subplot(111) + ax = fig.add_subplot(111) plt.xlim(min(indexes), max(indexes)) plt.ylim(min(indexes), max(indexes)) - plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical') - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize) - ax.scatter(indexes, valid_ord_lbls, marker='^', s=0.5, c='tab:green', alpha=0.8) - ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8) - ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-10 validation') - ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10 test') + # plt.ylabel('y').set_rotation(0) + plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize, rotation="vertical") + plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize) + # ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100') + # ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8, label='ImageNet-16-120') + # ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10') + ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) + ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8) + ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) + ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10") + ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100") + ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120") plt.grid(zorder=0) ax.set_axisbelow(True) - plt.legend(loc='upper left', fontsize=LegendFontsize) - ax.set_xlabel('architecture ranking in the final test accuracy', fontsize=LabelSize) - ax.set_ylabel('architecture ranking in the validation set', fontsize=LabelSize) - save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + plt.legend(loc=0, fontsize=LegendFontsize) + ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize) + ax.set_ylabel("architecture ranking", fontsize=LabelSize) + save_path = (vis_save_dir / "relative-rank.pdf").resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (vis_save_dir / "relative-rank.png").resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + # calculate correlation + sns_size = 15 + CoRelMatrix = calculate_correlation( + cifar010_info["valid_accs"], + cifar010_info["test_accs"], + cifar100_info["valid_accs"], + cifar100_info["test_accs"], + imagenet_info["valid_accs"], + imagenet_info["test_accs"], + ) + fig = plt.figure(figsize=figsize) + plt.axis("off") + h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={"size": sns_size}, fmt=".3f", linewidths=0.5) + save_path = (vis_save_dir / "co-relation-all.pdf").resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + print("{:} save into {:}".format(time_string(), save_path)) + + # calculate correlation + acc_bars = [92, 93] + for acc_bar in acc_bars: + selected_indexes = [] + for i, acc in enumerate(cifar010_info["test_accs"]): + if acc > acc_bar: + selected_indexes.append(i) + print("select {:} architectures".format(len(selected_indexes))) + cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes] + cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes] + cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes] + cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes] + imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes] + imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes] + CoRelMatrix = calculate_correlation( + cifar010_valid_accs, + cifar010_test_accs, + cifar100_valid_accs, + cifar100_test_accs, + imagenet_valid_accs, + imagenet_test_accs, + ) + fig = plt.figure(figsize=figsize) + plt.axis("off") + h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={"size": sns_size}, fmt=".3f", linewidths=0.5) + save_path = (vis_save_dir / "co-relation-top-{:}.pdf".format(len(selected_indexes))).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") + + +def visualize_info(meta_file, dataset, vis_save_dir): + print("{:} start to visualize {:} information".format(time_string(), dataset)) + cache_file_path = vis_save_dir / "{:}-cache-info.pth".format(dataset) + if not cache_file_path.exists(): + print("Do not find cache file : {:}".format(cache_file_path)) + nas_bench = API(str(meta_file)) + params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], [] + for index in range(len(nas_bench)): + info = nas_bench.query_by_index(index, use_12epochs_result=False) + resx = info.get_comput_costs(dataset) + flop, param = resx["flops"], resx["params"] + if dataset == "cifar10": + res = info.get_metrics("cifar10", "train") + train_acc = res["accuracy"] + res = info.get_metrics("cifar10-valid", "x-valid") + valid_acc = res["accuracy"] + res = info.get_metrics("cifar10", "ori-test") + test_acc = res["accuracy"] + res = info.get_metrics("cifar10", "ori-test") + otest_acc = res["accuracy"] + else: + res = info.get_metrics(dataset, "train") + train_acc = res["accuracy"] + res = info.get_metrics(dataset, "x-valid") + valid_acc = res["accuracy"] + res = info.get_metrics(dataset, "x-test") + test_acc = res["accuracy"] + res = info.get_metrics(dataset, "ori-test") + otest_acc = res["accuracy"] + if index == 11472: # resnet + resnet = { + "params": param, + "flops": flop, + "index": 11472, + "train_acc": train_acc, + "valid_acc": valid_acc, + "test_acc": test_acc, + "otest_acc": otest_acc, + } + flops.append(flop) + params.append(param) + train_accs.append(train_acc) + valid_accs.append(valid_acc) + test_accs.append(test_acc) + otest_accs.append(otest_acc) + # resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97} + info = { + "params": params, + "flops": flops, + "train_accs": train_accs, + "valid_accs": valid_accs, + "test_accs": test_accs, + "otest_accs": otest_accs, + } + info["resnet"] = resnet + torch.save(info, cache_file_path) + else: + print("Find cache file : {:}".format(cache_file_path)) + info = torch.load(cache_file_path) + params, flops, train_accs, valid_accs, test_accs, otest_accs = ( + info["params"], + info["flops"], + info["train_accs"], + info["valid_accs"], + info["test_accs"], + info["otest_accs"], + ) + resnet = info["resnet"] + print("{:} collect data done.".format(time_string())) + + indexes = list(range(len(params))) + dpi, width, height = 300, 2600, 2600 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 22, 22 + resnet_scale, resnet_alpha = 120, 0.5 + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) + if dataset == "cifar10": + plt.ylim(50, 100) + plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) + elif dataset == "cifar100": + plt.ylim(25, 75) + plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) + else: + plt.ylim(0, 50) + plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) + ax.scatter(params, valid_accs, marker="o", s=0.5, c="tab:blue") + ax.scatter( + [resnet["params"]], [resnet["valid_acc"]], marker="*", s=resnet_scale, c="tab:orange", label="resnet", alpha=0.4 + ) + plt.grid(zorder=0) + ax.set_axisbelow(True) + plt.legend(loc=4, fontsize=LegendFontsize) + ax.set_xlabel("#parameters (MB)", fontsize=LabelSize) + ax.set_ylabel("the validation accuracy (%)", fontsize=LabelSize) + save_path = (vis_save_dir / "{:}-param-vs-valid.pdf".format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (vis_save_dir / "{:}-param-vs-valid.png".format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) + if dataset == "cifar10": + plt.ylim(50, 100) + plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) + elif dataset == "cifar100": + plt.ylim(25, 75) + plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) + else: + plt.ylim(0, 50) + plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) + ax.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") + ax.scatter( + [resnet["params"]], + [resnet["test_acc"]], + marker="*", + s=resnet_scale, + c="tab:orange", + label="resnet", + alpha=resnet_alpha, + ) + plt.grid() + ax.set_axisbelow(True) + plt.legend(loc=4, fontsize=LegendFontsize) + ax.set_xlabel("#parameters (MB)", fontsize=LabelSize) + ax.set_ylabel("the test accuracy (%)", fontsize=LabelSize) + save_path = (vis_save_dir / "{:}-param-vs-test.pdf".format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (vis_save_dir / "{:}-param-vs-test.png".format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) + if dataset == "cifar10": + plt.ylim(50, 100) + plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) + elif dataset == "cifar100": + plt.ylim(20, 100) + plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize) + else: + plt.ylim(25, 76) + plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) + ax.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") + ax.scatter( + [resnet["params"]], + [resnet["train_acc"]], + marker="*", + s=resnet_scale, + c="tab:orange", + label="resnet", + alpha=resnet_alpha, + ) + plt.grid() + ax.set_axisbelow(True) + plt.legend(loc=4, fontsize=LegendFontsize) + ax.set_xlabel("#parameters (MB)", fontsize=LabelSize) + ax.set_ylabel("the trarining accuracy (%)", fontsize=LabelSize) + save_path = (vis_save_dir / "{:}-param-vs-train.pdf".format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (vis_save_dir / "{:}-param-vs-train.png".format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xlim(0, max(indexes)) + plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize) + if dataset == "cifar10": + plt.ylim(50, 100) + plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) + elif dataset == "cifar100": + plt.ylim(25, 75) + plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) + else: + plt.ylim(0, 50) + plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) + ax.scatter(indexes, test_accs, marker="o", s=0.5, c="tab:blue") + ax.scatter( + [resnet["index"]], + [resnet["test_acc"]], + marker="*", + s=resnet_scale, + c="tab:orange", + label="resnet", + alpha=resnet_alpha, + ) + plt.grid() + ax.set_axisbelow(True) + plt.legend(loc=4, fontsize=LegendFontsize) + ax.set_xlabel("architecture ID", fontsize=LabelSize) + ax.set_ylabel("the test accuracy (%)", fontsize=LabelSize) + save_path = (vis_save_dir / "{:}-test-over-ID.pdf".format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (vis_save_dir / "{:}-test-over-ID.png".format(dataset)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") + + +def visualize_rank_over_time(meta_file, vis_save_dir): + print("\n" + "-" * 150) + vis_save_dir.mkdir(parents=True, exist_ok=True) + print("{:} start to visualize rank-over-time into {:}".format(time_string(), vis_save_dir)) + cache_file_path = vis_save_dir / "rank-over-time-cache-info.pth" + if not cache_file_path.exists(): + print("Do not find cache file : {:}".format(cache_file_path)) + nas_bench = API(str(meta_file)) + print("{:} load nas_bench done".format(time_string())) + params, flops, train_accs, valid_accs, test_accs, otest_accs = ( + [], + [], + defaultdict(list), + defaultdict(list), + defaultdict(list), + defaultdict(list), + ) + # for iepoch in range(200): for index in range( len(nas_bench) ): + for index in tqdm(range(len(nas_bench))): + info = nas_bench.query_by_index(index, use_12epochs_result=False) + for iepoch in range(200): + res = info.get_metrics("cifar10", "train", iepoch) + train_acc = res["accuracy"] + res = info.get_metrics("cifar10-valid", "x-valid", iepoch) + valid_acc = res["accuracy"] + res = info.get_metrics("cifar10", "ori-test", iepoch) + test_acc = res["accuracy"] + res = info.get_metrics("cifar10", "ori-test", iepoch) + otest_acc = res["accuracy"] + train_accs[iepoch].append(train_acc) + valid_accs[iepoch].append(valid_acc) + test_accs[iepoch].append(test_acc) + otest_accs[iepoch].append(otest_acc) + if iepoch == 0: + res = info.get_comput_costs("cifar10") + flop, param = res["flops"], res["params"] + flops.append(flop) + params.append(param) + info = { + "params": params, + "flops": flops, + "train_accs": train_accs, + "valid_accs": valid_accs, + "test_accs": test_accs, + "otest_accs": otest_accs, + } + torch.save(info, cache_file_path) + else: + print("Find cache file : {:}".format(cache_file_path)) + info = torch.load(cache_file_path) + params, flops, train_accs, valid_accs, test_accs, otest_accs = ( + info["params"], + info["flops"], + info["train_accs"], + info["valid_accs"], + info["test_accs"], + info["otest_accs"], + ) + print("{:} collect data done.".format(time_string())) + # selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199] + selected_epochs = list(range(200)) + x_xtests = test_accs[199] + indexes = list(range(len(x_xtests))) + ord_idxs = sorted(indexes, key=lambda i: x_xtests[i]) + for sepoch in selected_epochs: + x_valids = valid_accs[sepoch] + valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i]) + valid_ord_lbls = [] + for idx in ord_idxs: + valid_ord_lbls.append(valid_ord_idxs.index(idx)) + # labeled data + dpi, width, height = 300, 2600, 2600 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 18, 18 + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xlim(min(indexes), max(indexes)) + plt.ylim(min(indexes), max(indexes)) + plt.yticks( + np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize, rotation="vertical" + ) + plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize) + ax.scatter(indexes, valid_ord_lbls, marker="^", s=0.5, c="tab:green", alpha=0.8) + ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) + ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-10 validation") + ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10 test") + plt.grid(zorder=0) + ax.set_axisbelow(True) + plt.legend(loc="upper left", fontsize=LegendFontsize) + ax.set_xlabel("architecture ranking in the final test accuracy", fontsize=LabelSize) + ax.set_ylabel("architecture ranking in the validation set", fontsize=LabelSize) + save_path = (vis_save_dir / "time-{:03d}.pdf".format(sepoch)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (vis_save_dir / "time-{:03d}.png".format(sepoch)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") def write_video(save_dir): - import cv2 - video_save_path = save_dir / 'time.avi' - print ('{:} start create video for {:}'.format(time_string(), video_save_path)) - images = sorted( list( save_dir.glob('time-*.png') ) ) - ximage = cv2.imread(str(images[0])) - #shape = (ximage.shape[1], ximage.shape[0]) - shape = (1000, 1000) - #writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape) - writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape) - for idx, image in enumerate(images): - ximage = cv2.imread(str(image)) - _image = cv2.resize(ximage, shape) - writer.write(_image) - writer.release() - print ('write video [{:} frames] into {:}'.format(len(images), video_save_path)) + import cv2 + video_save_path = save_dir / "time.avi" + print("{:} start create video for {:}".format(time_string(), video_save_path)) + images = sorted(list(save_dir.glob("time-*.png"))) + ximage = cv2.imread(str(images[0])) + # shape = (ximage.shape[1], ximage.shape[0]) + shape = (1000, 1000) + # writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape) + writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape) + for idx, image in enumerate(images): + ximage = cv2.imread(str(image)) + _image = cv2.resize(ximage, shape) + writer.write(_image) + writer.release() + print("write video [{:} frames] into {:}".format(len(images), video_save_path)) def plot_results_nas_v2(api, dataset_xset_a, dataset_xset_b, root, file_name, y_lims): - #print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) - print ('root-path : {:} and {:}'.format(dataset_xset_a, dataset_xset_b)) - checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth', - './output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth', - './output/search-cell-nas-bench-201/RAND-cifar10/results.pth', - './output/search-cell-nas-bench-201/BOHB-cifar10/results.pth' - ] - legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None - All_Accs_A, All_Accs_B = OrderedDict(), OrderedDict() - for legend, checkpoint in zip(legends, checkpoints): - all_indexes = torch.load(checkpoint, map_location='cpu') - accuracies_A, accuracies_B = [], [] - accuracies = [] - for x in all_indexes: - info = api.arch2infos_full[ x ] - metrics = info.get_metrics(dataset_xset_a[0], dataset_xset_a[1], None, False) - accuracies_A.append( metrics['accuracy'] ) - metrics = info.get_metrics(dataset_xset_b[0], dataset_xset_b[1], None, False) - accuracies_B.append( metrics['accuracy'] ) - accuracies.append( (accuracies_A[-1], accuracies_B[-1]) ) - if indexes is None: indexes = list(range(len(all_indexes))) - accuracies = sorted(accuracies) - All_Accs_A[legend] = [x[0] for x in accuracies] - All_Accs_B[legend] = [x[1] for x in accuracies] - - color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] - dpi, width, height = 300, 3400, 2600 - LabelSize, LegendFontsize = 28, 28 - figsize = width / float(dpi), height / float(dpi) - fig = plt.figure(figsize=figsize) - x_axis = np.arange(0, 600) - plt.xlim(0, max(indexes)) - plt.ylim(y_lims[0], y_lims[1]) - interval_x, interval_y = 100, y_lims[2] - plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) - plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) - plt.grid() - plt.xlabel('The index of runs', fontsize=LabelSize) - plt.ylabel('The accuracy (%)', fontsize=LabelSize) - - for idx, legend in enumerate(legends): - plt.plot(indexes, All_Accs_B[legend], color=color_set[idx], linestyle='--', label='{:}'.format(legend), lw=1, alpha=0.5) - plt.plot(indexes, All_Accs_A[legend], color=color_set[idx], linestyle='-', lw=1) - for All_Accs in [All_Accs_A, All_Accs_B]: - print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend]))) - plt.legend(loc=4, fontsize=LegendFontsize) - save_path = root / '{:}'.format(file_name) - print('save figure into {:}\n'.format(save_path)) - fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') + # print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) + print("root-path : {:} and {:}".format(dataset_xset_a, dataset_xset_b)) + checkpoints = [ + "./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth", + "./output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth", + "./output/search-cell-nas-bench-201/RAND-cifar10/results.pth", + "./output/search-cell-nas-bench-201/BOHB-cifar10/results.pth", + ] + legends, indexes = ["REA", "REINFORCE", "RANDOM", "BOHB"], None + All_Accs_A, All_Accs_B = OrderedDict(), OrderedDict() + for legend, checkpoint in zip(legends, checkpoints): + all_indexes = torch.load(checkpoint, map_location="cpu") + accuracies_A, accuracies_B = [], [] + accuracies = [] + for x in all_indexes: + info = api.arch2infos_full[x] + metrics = info.get_metrics(dataset_xset_a[0], dataset_xset_a[1], None, False) + accuracies_A.append(metrics["accuracy"]) + metrics = info.get_metrics(dataset_xset_b[0], dataset_xset_b[1], None, False) + accuracies_B.append(metrics["accuracy"]) + accuracies.append((accuracies_A[-1], accuracies_B[-1])) + if indexes is None: + indexes = list(range(len(all_indexes))) + accuracies = sorted(accuracies) + All_Accs_A[legend] = [x[0] for x in accuracies] + All_Accs_B[legend] = [x[1] for x in accuracies] + color_set = ["r", "b", "g", "c", "m", "y", "k"] + dpi, width, height = 300, 3400, 2600 + LabelSize, LegendFontsize = 28, 28 + figsize = width / float(dpi), height / float(dpi) + fig = plt.figure(figsize=figsize) + x_axis = np.arange(0, 600) + plt.xlim(0, max(indexes)) + plt.ylim(y_lims[0], y_lims[1]) + interval_x, interval_y = 100, y_lims[2] + plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) + plt.yticks(np.arange(y_lims[0], y_lims[1], interval_y), fontsize=LegendFontsize) + plt.grid() + plt.xlabel("The index of runs", fontsize=LabelSize) + plt.ylabel("The accuracy (%)", fontsize=LabelSize) + for idx, legend in enumerate(legends): + plt.plot( + indexes, + All_Accs_B[legend], + color=color_set[idx], + linestyle="--", + label="{:}".format(legend), + lw=1, + alpha=0.5, + ) + plt.plot(indexes, All_Accs_A[legend], color=color_set[idx], linestyle="-", lw=1) + for All_Accs in [All_Accs_A, All_Accs_B]: + print( + "{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}".format( + legend, + np.mean(All_Accs[legend]), + np.std(All_Accs[legend]), + np.mean(All_Accs[legend]), + np.std(All_Accs[legend]), + ) + ) + plt.legend(loc=4, fontsize=LegendFontsize) + save_path = root / "{:}".format(file_name) + print("save figure into {:}\n".format(save_path)) + fig.savefig(str(save_path), dpi=dpi, bbox_inches="tight", format="pdf") def plot_results_nas(api, dataset, xset, root, file_name, y_lims): - print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) - checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth', - './output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth', - './output/search-cell-nas-bench-201/RAND-cifar10/results.pth', - './output/search-cell-nas-bench-201/BOHB-cifar10/results.pth' - ] - legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None - All_Accs = OrderedDict() - for legend, checkpoint in zip(legends, checkpoints): - all_indexes = torch.load(checkpoint, map_location='cpu') - accuracies = [] - for x in all_indexes: - info = api.arch2infos_full[ x ] - metrics = info.get_metrics(dataset, xset, None, False) - accuracies.append( metrics['accuracy'] ) - if indexes is None: indexes = list(range(len(all_indexes))) - All_Accs[legend] = sorted(accuracies) - - color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] - dpi, width, height = 300, 3400, 2600 - LabelSize, LegendFontsize = 28, 28 - figsize = width / float(dpi), height / float(dpi) - fig = plt.figure(figsize=figsize) - x_axis = np.arange(0, 600) - plt.xlim(0, max(indexes)) - plt.ylim(y_lims[0], y_lims[1]) - interval_x, interval_y = 100, y_lims[2] - plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) - plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) - plt.grid() - plt.xlabel('The index of runs', fontsize=LabelSize) - plt.ylabel('The accuracy (%)', fontsize=LabelSize) + print("root-path={:}, dataset={:}, xset={:}".format(root, dataset, xset)) + checkpoints = [ + "./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth", + "./output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth", + "./output/search-cell-nas-bench-201/RAND-cifar10/results.pth", + "./output/search-cell-nas-bench-201/BOHB-cifar10/results.pth", + ] + legends, indexes = ["REA", "REINFORCE", "RANDOM", "BOHB"], None + All_Accs = OrderedDict() + for legend, checkpoint in zip(legends, checkpoints): + all_indexes = torch.load(checkpoint, map_location="cpu") + accuracies = [] + for x in all_indexes: + info = api.arch2infos_full[x] + metrics = info.get_metrics(dataset, xset, None, False) + accuracies.append(metrics["accuracy"]) + if indexes is None: + indexes = list(range(len(all_indexes))) + All_Accs[legend] = sorted(accuracies) - for idx, legend in enumerate(legends): - plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2) - print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend]))) - plt.legend(loc=4, fontsize=LegendFontsize) - save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name) - print('save figure into {:}\n'.format(save_path)) - fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') + color_set = ["r", "b", "g", "c", "m", "y", "k"] + dpi, width, height = 300, 3400, 2600 + LabelSize, LegendFontsize = 28, 28 + figsize = width / float(dpi), height / float(dpi) + fig = plt.figure(figsize=figsize) + x_axis = np.arange(0, 600) + plt.xlim(0, max(indexes)) + plt.ylim(y_lims[0], y_lims[1]) + interval_x, interval_y = 100, y_lims[2] + plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) + plt.yticks(np.arange(y_lims[0], y_lims[1], interval_y), fontsize=LegendFontsize) + plt.grid() + plt.xlabel("The index of runs", fontsize=LabelSize) + plt.ylabel("The accuracy (%)", fontsize=LabelSize) + + for idx, legend in enumerate(legends): + plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle="-", label="{:}".format(legend), lw=2) + print( + "{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}".format( + legend, + np.mean(All_Accs[legend]), + np.std(All_Accs[legend]), + np.mean(All_Accs[legend]), + np.std(All_Accs[legend]), + ) + ) + plt.legend(loc=4, fontsize=LegendFontsize) + save_path = root / "{:}-{:}-{:}".format(dataset, xset, file_name) + print("save figure into {:}\n".format(save_path)) + fig.savefig(str(save_path), dpi=dpi, bbox_inches="tight", format="pdf") def just_show(api): - xtimes = {'RSPS' : [8082.5, 7794.2, 8144.7], - 'DARTS-V1': [11582.1, 11347.0, 11948.2], - 'DARTS-V2': [35694.7, 36132.7, 35518.0], - 'GDAS' : [31334.1, 31478.6, 32016.7], - 'SETN' : [33528.8, 33831.5, 35058.3], - 'ENAS' : [14340.2, 13817.3, 14018.9]} - for xkey, xlist in xtimes.items(): - xlist = np.array(xlist) - print ('{:4s} : mean-time={:.2f} s'.format(xkey, xlist.mean())) + xtimes = { + "RSPS": [8082.5, 7794.2, 8144.7], + "DARTS-V1": [11582.1, 11347.0, 11948.2], + "DARTS-V2": [35694.7, 36132.7, 35518.0], + "GDAS": [31334.1, 31478.6, 32016.7], + "SETN": [33528.8, 33831.5, 35058.3], + "ENAS": [14340.2, 13817.3, 14018.9], + } + for xkey, xlist in xtimes.items(): + xlist = np.array(xlist) + print("{:4s} : mean-time={:.2f} s".format(xkey, xlist.mean())) - xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/', - 'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/', - 'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/', - 'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/', - 'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/', - 'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/', - } - xseeds = {'RSPS' : [5349, 59613, 5983], - 'DARTS-V1': [11416, 72873, 81184], - 'DARTS-V2': [43330, 79405, 79423], - 'GDAS' : [19677, 884, 95950], - 'SETN' : [20518, 61817, 89144], - 'ENAS' : [3231, 34238, 96929], - } + xpaths = { + "RSPS": "output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/", + "DARTS-V1": "output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/", + "DARTS-V2": "output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/", + "GDAS": "output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/", + "SETN": "output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/", + "ENAS": "output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/", + } + xseeds = { + "RSPS": [5349, 59613, 5983], + "DARTS-V1": [11416, 72873, 81184], + "DARTS-V2": [43330, 79405, 79423], + "GDAS": [19677, 884, 95950], + "SETN": [20518, 61817, 89144], + "ENAS": [3231, 34238, 96929], + } - def get_accs(xdata, index=-1): - if index == -1: - epochs = xdata['epoch'] - genotype = xdata['genotypes'][epochs-1] - index = api.query_index_by_arch(genotype) - pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')] - xresults = [] - for dataset, xset in pairs: - metrics = api.arch2infos_full[index].get_metrics(dataset, xset, None, False) - xresults.append( metrics['accuracy'] ) - return xresults + def get_accs(xdata, index=-1): + if index == -1: + epochs = xdata["epoch"] + genotype = xdata["genotypes"][epochs - 1] + index = api.query_index_by_arch(genotype) + pairs = [ + ("cifar10-valid", "x-valid"), + ("cifar10", "ori-test"), + ("cifar100", "x-valid"), + ("cifar100", "x-test"), + ("ImageNet16-120", "x-valid"), + ("ImageNet16-120", "x-test"), + ] + xresults = [] + for dataset, xset in pairs: + metrics = api.arch2infos_full[index].get_metrics(dataset, xset, None, False) + xresults.append(metrics["accuracy"]) + return xresults - for xkey in xpaths.keys(): - all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ] - all_datas = [torch.load(xpath) for xpath in all_paths] - accyss = [get_accs(xdatas) for xdatas in all_datas] - accyss = np.array( accyss ) - print('\nxkey = {:}'.format(xkey)) - for i in range(accyss.shape[1]): print('---->>>> {:.2f}$\\pm${:.2f}'.format(accyss[:,i].mean(), accyss[:,i].std())) + for xkey in xpaths.keys(): + all_paths = ["{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]] + all_datas = [torch.load(xpath) for xpath in all_paths] + accyss = [get_accs(xdatas) for xdatas in all_datas] + accyss = np.array(accyss) + print("\nxkey = {:}".format(xkey)) + for i in range(accyss.shape[1]): + print("---->>>> {:.2f}$\\pm${:.2f}".format(accyss[:, i].mean(), accyss[:, i].std())) - print('\n{:}'.format(get_accs(None, 11472))) # resnet - pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')] - for dataset, metric_on_set in pairs: - arch_index, highest_acc = api.find_best(dataset, metric_on_set) - print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc)) + print("\n{:}".format(get_accs(None, 11472))) # resnet + pairs = [ + ("cifar10-valid", "x-valid"), + ("cifar10", "ori-test"), + ("cifar100", "x-valid"), + ("cifar100", "x-test"), + ("ImageNet16-120", "x-valid"), + ("ImageNet16-120", "x-test"), + ] + for dataset, metric_on_set in pairs: + arch_index, highest_acc = api.find_best(dataset, metric_on_set) + print("[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}".format(dataset, metric_on_set, arch_index, highest_acc)) def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_lims, x_maxs): - color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] - dpi, width, height = 300, 3400, 2600 - LabelSize, LegendFontsize = 28, 28 - figsize = width / float(dpi), height / float(dpi) - fig = plt.figure(figsize=figsize) - #x_maxs = 250 - plt.xlim(0, x_maxs+1) - plt.ylim(y_lims[0], y_lims[1]) - interval_x, interval_y = x_maxs // 5, y_lims[2] - plt.xticks(np.arange(0, x_maxs+1, interval_x), fontsize=LegendFontsize) - plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) - plt.grid() - plt.xlabel('The searching epoch', fontsize=LabelSize) - plt.ylabel('The accuracy (%)', fontsize=LabelSize) + color_set = ["r", "b", "g", "c", "m", "y", "k"] + dpi, width, height = 300, 3400, 2600 + LabelSize, LegendFontsize = 28, 28 + figsize = width / float(dpi), height / float(dpi) + fig = plt.figure(figsize=figsize) + # x_maxs = 250 + plt.xlim(0, x_maxs + 1) + plt.ylim(y_lims[0], y_lims[1]) + interval_x, interval_y = x_maxs // 5, y_lims[2] + plt.xticks(np.arange(0, x_maxs + 1, interval_x), fontsize=LegendFontsize) + plt.yticks(np.arange(y_lims[0], y_lims[1], interval_y), fontsize=LegendFontsize) + plt.grid() + plt.xlabel("The searching epoch", fontsize=LabelSize) + plt.ylabel("The accuracy (%)", fontsize=LabelSize) - xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/'.format(sufix), - 'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/'.format(sufix), - 'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/'.format(sufix), - 'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/'.format(sufix), - 'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/'.format(sufix), - 'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/'.format(sufix), - } - """ + xpaths = { + "RSPS": "output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/".format(sufix), + "DARTS-V1": "output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/".format(sufix), + "DARTS-V2": "output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/".format(sufix), + "GDAS": "output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/".format(sufix), + "SETN": "output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/".format(sufix), + "ENAS": "output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/".format(sufix), + } + """ xseeds = {'RSPS' : [5349, 59613, 5983], 'DARTS-V1': [11416, 72873, 81184, 28640], 'DARTS-V2': [43330, 79405, 79423], @@ -549,76 +701,92 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_l 'ENAS' : [3231, 34238, 96929], } """ - xseeds = {'RSPS' : [23814, 28015, 95809], - 'DARTS-V1': [48349, 80877, 81920], - 'DARTS-V2': [61712, 7941 , 87041] , - 'GDAS' : [72818, 72996, 78877], - 'SETN' : [26985, 55206, 95404], - 'ENAS' : [21792, 36605, 45029] - } + xseeds = { + "RSPS": [23814, 28015, 95809], + "DARTS-V1": [48349, 80877, 81920], + "DARTS-V2": [61712, 7941, 87041], + "GDAS": [72818, 72996, 78877], + "SETN": [26985, 55206, 95404], + "ENAS": [21792, 36605, 45029], + } + def get_accs(xdata): + epochs, xresults = xdata["epoch"], [] + if -1 in xdata["genotypes"]: + metrics = api.arch2infos_full[api.query_index_by_arch(xdata["genotypes"][-1])].get_metrics( + dataset, subset, None, False + ) + else: + metrics = api.arch2infos_full[api.random()].get_metrics(dataset, subset, None, False) + xresults.append(metrics["accuracy"]) + for iepoch in range(epochs): + genotype = xdata["genotypes"][iepoch] + index = api.query_index_by_arch(genotype) + metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False) + xresults.append(metrics["accuracy"]) + return xresults - def get_accs(xdata): - epochs, xresults = xdata['epoch'], [] - if -1 in xdata['genotypes']: - metrics = api.arch2infos_full[ api.query_index_by_arch(xdata['genotypes'][-1]) ].get_metrics(dataset, subset, None, False) + if x_maxs == 50: + xox, xxxstrs = "v2", ["DARTS-V1", "DARTS-V2"] + elif x_maxs == 250: + xox, xxxstrs = "v1", ["RSPS", "GDAS", "SETN", "ENAS"] else: - metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False) - xresults.append( metrics['accuracy'] ) - for iepoch in range(epochs): - genotype = xdata['genotypes'][iepoch] - index = api.query_index_by_arch(genotype) - metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False) - xresults.append( metrics['accuracy'] ) - return xresults - - if x_maxs == 50: - xox, xxxstrs = 'v2', ['DARTS-V1', 'DARTS-V2'] - elif x_maxs == 250: - xox, xxxstrs = 'v1', ['RSPS', 'GDAS', 'SETN', 'ENAS'] - else: raise ValueError('invalid x_maxs={:}'.format(x_maxs)) - - for idx, method in enumerate(xxxstrs): - xkey = method - all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ] - all_datas = [torch.load(xpath, map_location='cpu') for xpath in all_paths] - accyss = [get_accs(xdatas) for xdatas in all_datas] - accyss = np.array( accyss ) - epochs = list(range(accyss.shape[1])) - plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx], linestyle='-', label='{:}'.format(method), lw=2) - plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx]) - #plt.legend(loc=4, fontsize=LegendFontsize) - plt.legend(loc=0, fontsize=LegendFontsize) - save_path = vis_save_dir / '{:}.pdf'.format(file_name) - print('save figure into {:}\n'.format(save_path)) - fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') + raise ValueError("invalid x_maxs={:}".format(x_maxs)) + for idx, method in enumerate(xxxstrs): + xkey = method + all_paths = ["{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]] + all_datas = [torch.load(xpath, map_location="cpu") for xpath in all_paths] + accyss = [get_accs(xdatas) for xdatas in all_datas] + accyss = np.array(accyss) + epochs = list(range(accyss.shape[1])) + plt.plot( + epochs, + [accyss[:, i].mean() for i in epochs], + color=color_set[idx], + linestyle="-", + label="{:}".format(method), + lw=2, + ) + plt.fill_between( + epochs, + [accyss[:, i].mean() - accyss[:, i].std() for i in epochs], + [accyss[:, i].mean() + accyss[:, i].std() for i in epochs], + alpha=0.2, + color=color_set[idx], + ) + # plt.legend(loc=4, fontsize=LegendFontsize) + plt.legend(loc=0, fontsize=LegendFontsize) + save_path = vis_save_dir / "{:}.pdf".format(file_name) + print("save figure into {:}\n".format(save_path)) + fig.savefig(str(save_path), dpi=dpi, bbox_inches="tight", format="pdf") def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file_name, y_lims, x_maxs): - color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] - dpi, width, height = 300, 3400, 2600 - LabelSize, LegendFontsize = 28, 28 - figsize = width / float(dpi), height / float(dpi) - fig = plt.figure(figsize=figsize) - #x_maxs = 250 - plt.xlim(0, x_maxs+1) - plt.ylim(y_lims[0], y_lims[1]) - interval_x, interval_y = x_maxs // 5, y_lims[2] - plt.xticks(np.arange(0, x_maxs+1, interval_x), fontsize=LegendFontsize) - plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) - plt.grid() - plt.xlabel('The searching epoch', fontsize=LabelSize) - plt.ylabel('The accuracy (%)', fontsize=LabelSize) + color_set = ["r", "b", "g", "c", "m", "y", "k"] + dpi, width, height = 300, 3400, 2600 + LabelSize, LegendFontsize = 28, 28 + figsize = width / float(dpi), height / float(dpi) + fig = plt.figure(figsize=figsize) + # x_maxs = 250 + plt.xlim(0, x_maxs + 1) + plt.ylim(y_lims[0], y_lims[1]) + interval_x, interval_y = x_maxs // 5, y_lims[2] + plt.xticks(np.arange(0, x_maxs + 1, interval_x), fontsize=LegendFontsize) + plt.yticks(np.arange(y_lims[0], y_lims[1], interval_y), fontsize=LegendFontsize) + plt.grid() + plt.xlabel("The searching epoch", fontsize=LabelSize) + plt.ylabel("The accuracy (%)", fontsize=LabelSize) - xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/'.format(sufix), - 'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/'.format(sufix), - 'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/'.format(sufix), - 'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/'.format(sufix), - 'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/'.format(sufix), - 'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/'.format(sufix), - } - """ + xpaths = { + "RSPS": "output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/".format(sufix), + "DARTS-V1": "output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/".format(sufix), + "DARTS-V2": "output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/".format(sufix), + "GDAS": "output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/".format(sufix), + "SETN": "output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/".format(sufix), + "ENAS": "output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/".format(sufix), + } + """ xseeds = {'RSPS' : [5349, 59613, 5983], 'DARTS-V1': [11416, 72873, 81184, 28640], 'DARTS-V2': [43330, 79405, 79423], @@ -627,181 +795,281 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file 'ENAS' : [3231, 34238, 96929], } """ - xseeds = {'RSPS' : [23814, 28015, 95809], - 'DARTS-V1': [48349, 80877, 81920], - 'DARTS-V2': [61712, 7941 , 87041] , - 'GDAS' : [72818, 72996, 78877], - 'SETN' : [26985, 55206, 95404], - 'ENAS' : [21792, 36605, 45029] - } + xseeds = { + "RSPS": [23814, 28015, 95809], + "DARTS-V1": [48349, 80877, 81920], + "DARTS-V2": [61712, 7941, 87041], + "GDAS": [72818, 72996, 78877], + "SETN": [26985, 55206, 95404], + "ENAS": [21792, 36605, 45029], + } + def get_accs(xdata, dataset, subset): + epochs, xresults = xdata["epoch"], [] + if -1 in xdata["genotypes"]: + metrics = api.arch2infos_full[api.query_index_by_arch(xdata["genotypes"][-1])].get_metrics( + dataset, subset, None, False + ) + else: + metrics = api.arch2infos_full[api.random()].get_metrics(dataset, subset, None, False) + xresults.append(metrics["accuracy"]) + for iepoch in range(epochs): + genotype = xdata["genotypes"][iepoch] + index = api.query_index_by_arch(genotype) + metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False) + xresults.append(metrics["accuracy"]) + return xresults - def get_accs(xdata, dataset, subset): - epochs, xresults = xdata['epoch'], [] - if -1 in xdata['genotypes']: - metrics = api.arch2infos_full[ api.query_index_by_arch(xdata['genotypes'][-1]) ].get_metrics(dataset, subset, None, False) + if x_maxs == 50: + xox, xxxstrs = "v2", ["DARTS-V1", "DARTS-V2"] + elif x_maxs == 250: + xox, xxxstrs = "v1", ["RSPS", "GDAS", "SETN", "ENAS"] else: - metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False) - xresults.append( metrics['accuracy'] ) - for iepoch in range(epochs): - genotype = xdata['genotypes'][iepoch] - index = api.query_index_by_arch(genotype) - metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False) - xresults.append( metrics['accuracy'] ) - return xresults + raise ValueError("invalid x_maxs={:}".format(x_maxs)) - if x_maxs == 50: - xox, xxxstrs = 'v2', ['DARTS-V1', 'DARTS-V2'] - elif x_maxs == 250: - xox, xxxstrs = 'v1', ['RSPS', 'GDAS', 'SETN', 'ENAS'] - else: raise ValueError('invalid x_maxs={:}'.format(x_maxs)) - - for idx, method in enumerate(xxxstrs): - xkey = method - all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ] - all_datas = [torch.load(xpath, map_location='cpu') for xpath in all_paths] - accyss_A = np.array( [get_accs(xdatas, data_sub_a[0], data_sub_a[1]) for xdatas in all_datas] ) - accyss_B = np.array( [get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas] ) - epochs = list(range(accyss_A.shape[1])) - for j, accyss in enumerate([accyss_A, accyss_B]): - if x_maxs == 50: - color, line = color_set[idx*2+j], '-' if j==0 else '--' - elif x_maxs == 250: - color, line = color_set[idx], '-' if j==0 else '--' - else: raise ValueError('invalid x-maxs={:}'.format(x_maxs)) - plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color, linestyle=line, label='{:} ({:})'.format(method, 'VALID' if j == 0 else 'TEST'), lw=2, alpha=0.9) - plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color) - setname = data_sub_a if j == 0 else data_sub_b - print('{:} -- {:} ---- {:.2f}$\\pm${:.2f}'.format(method, setname, accyss[:,-1].mean(), accyss[:,-1].std())) - #plt.legend(loc=4, fontsize=LegendFontsize) - plt.legend(loc=0, fontsize=LegendFontsize) - save_path = vis_save_dir / '{:}-{:}'.format(xox, file_name) - print('save figure into {:}\n'.format(save_path)) - fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') + for idx, method in enumerate(xxxstrs): + xkey = method + all_paths = ["{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]] + all_datas = [torch.load(xpath, map_location="cpu") for xpath in all_paths] + accyss_A = np.array([get_accs(xdatas, data_sub_a[0], data_sub_a[1]) for xdatas in all_datas]) + accyss_B = np.array([get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas]) + epochs = list(range(accyss_A.shape[1])) + for j, accyss in enumerate([accyss_A, accyss_B]): + if x_maxs == 50: + color, line = color_set[idx * 2 + j], "-" if j == 0 else "--" + elif x_maxs == 250: + color, line = color_set[idx], "-" if j == 0 else "--" + else: + raise ValueError("invalid x-maxs={:}".format(x_maxs)) + plt.plot( + epochs, + [accyss[:, i].mean() for i in epochs], + color=color, + linestyle=line, + label="{:} ({:})".format(method, "VALID" if j == 0 else "TEST"), + lw=2, + alpha=0.9, + ) + plt.fill_between( + epochs, + [accyss[:, i].mean() - accyss[:, i].std() for i in epochs], + [accyss[:, i].mean() + accyss[:, i].std() for i in epochs], + alpha=0.2, + color=color, + ) + setname = data_sub_a if j == 0 else data_sub_b + print( + "{:} -- {:} ---- {:.2f}$\\pm${:.2f}".format(method, setname, accyss[:, -1].mean(), accyss[:, -1].std()) + ) + # plt.legend(loc=4, fontsize=LegendFontsize) + plt.legend(loc=0, fontsize=LegendFontsize) + save_path = vis_save_dir / "{:}-{:}".format(xox, file_name) + print("save figure into {:}\n".format(save_path)) + fig.savefig(str(save_path), dpi=dpi, bbox_inches="tight", format="pdf") def show_reinforce(api, root, dataset, xset, file_name, y_lims): - print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) - LRs = ['0.01', '0.02', '0.1', '0.2', '0.5'] - checkpoints = ['./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth'.format(x) for x in LRs] - acc_lr_dict, indexes = {}, None - for lr, checkpoint in zip(LRs, checkpoints): - all_indexes, accuracies = torch.load(checkpoint, map_location='cpu'), [] - for x in all_indexes: - info = api.arch2infos_full[ x ] - metrics = info.get_metrics(dataset, xset, None, False) - accuracies.append( metrics['accuracy'] ) - if indexes is None: indexes = list(range(len(accuracies))) - acc_lr_dict[lr] = np.array( sorted(accuracies) ) - print ('LR={:.3f}, mean={:}, std={:}'.format(float(lr), acc_lr_dict[lr].mean(), acc_lr_dict[lr].std())) - - color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] - dpi, width, height = 300, 3400, 2600 - LabelSize, LegendFontsize = 28, 22 - figsize = width / float(dpi), height / float(dpi) - fig = plt.figure(figsize=figsize) - x_axis = np.arange(0, 600) - plt.xlim(0, max(indexes)) - plt.ylim(y_lims[0], y_lims[1]) - interval_x, interval_y = 100, y_lims[2] - plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) - plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) - plt.grid() - plt.xlabel('The index of runs', fontsize=LabelSize) - plt.ylabel('The accuracy (%)', fontsize=LabelSize) + print("root-path={:}, dataset={:}, xset={:}".format(root, dataset, xset)) + LRs = ["0.01", "0.02", "0.1", "0.2", "0.5"] + checkpoints = ["./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth".format(x) for x in LRs] + acc_lr_dict, indexes = {}, None + for lr, checkpoint in zip(LRs, checkpoints): + all_indexes, accuracies = torch.load(checkpoint, map_location="cpu"), [] + for x in all_indexes: + info = api.arch2infos_full[x] + metrics = info.get_metrics(dataset, xset, None, False) + accuracies.append(metrics["accuracy"]) + if indexes is None: + indexes = list(range(len(accuracies))) + acc_lr_dict[lr] = np.array(sorted(accuracies)) + print("LR={:.3f}, mean={:}, std={:}".format(float(lr), acc_lr_dict[lr].mean(), acc_lr_dict[lr].std())) - for idx, LR in enumerate(LRs): - legend = 'LR={:.2f}'.format(float(LR)) - #color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.' - color, linestyle = color_set[idx], '-' - plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8) - print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]), np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]))) - plt.legend(loc=4, fontsize=LegendFontsize) - save_path = root / '{:}-{:}-{:}.pdf'.format(dataset, xset, file_name) - print('save figure into {:}\n'.format(save_path)) - fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') + color_set = ["r", "b", "g", "c", "m", "y", "k"] + dpi, width, height = 300, 3400, 2600 + LabelSize, LegendFontsize = 28, 22 + figsize = width / float(dpi), height / float(dpi) + fig = plt.figure(figsize=figsize) + x_axis = np.arange(0, 600) + plt.xlim(0, max(indexes)) + plt.ylim(y_lims[0], y_lims[1]) + interval_x, interval_y = 100, y_lims[2] + plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) + plt.yticks(np.arange(y_lims[0], y_lims[1], interval_y), fontsize=LegendFontsize) + plt.grid() + plt.xlabel("The index of runs", fontsize=LabelSize) + plt.ylabel("The accuracy (%)", fontsize=LabelSize) + for idx, LR in enumerate(LRs): + legend = "LR={:.2f}".format(float(LR)) + # color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.' + color, linestyle = color_set[idx], "-" + plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8) + print( + "{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}".format( + legend, + np.mean(acc_lr_dict[LR]), + np.std(acc_lr_dict[LR]), + np.mean(acc_lr_dict[LR]), + np.std(acc_lr_dict[LR]), + ) + ) + plt.legend(loc=4, fontsize=LegendFontsize) + save_path = root / "{:}-{:}-{:}.pdf".format(dataset, xset, file_name) + print("save figure into {:}\n".format(save_path)) + fig.savefig(str(save_path), dpi=dpi, bbox_inches="tight", format="pdf") def show_rea(api, root, dataset, xset, file_name, y_lims): - print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) - SSs = [3, 5, 10] - checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10-SS{:}/results.pth'.format(x) for x in SSs] - acc_ss_dict, indexes = {}, None - for ss, checkpoint in zip(SSs, checkpoints): - all_indexes, accuracies = torch.load(checkpoint, map_location='cpu'), [] - for x in all_indexes: - info = api.arch2infos_full[ x ] - metrics = info.get_metrics(dataset, xset, None, False) - accuracies.append( metrics['accuracy'] ) - if indexes is None: indexes = list(range(len(accuracies))) - acc_ss_dict[ss] = np.array( sorted(accuracies) ) - print ('Sample-Size={:2d}, mean={:}, std={:}'.format(ss, acc_ss_dict[ss].mean(), acc_ss_dict[ss].std())) - - color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] - dpi, width, height = 300, 3400, 2600 - LabelSize, LegendFontsize = 28, 22 - figsize = width / float(dpi), height / float(dpi) - fig = plt.figure(figsize=figsize) - x_axis = np.arange(0, 600) - plt.xlim(0, max(indexes)) - plt.ylim(y_lims[0], y_lims[1]) - interval_x, interval_y = 100, y_lims[2] - plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) - plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) - plt.grid() - plt.xlabel('The index of runs', fontsize=LabelSize) - plt.ylabel('The accuracy (%)', fontsize=LabelSize) + print("root-path={:}, dataset={:}, xset={:}".format(root, dataset, xset)) + SSs = [3, 5, 10] + checkpoints = ["./output/search-cell-nas-bench-201/R-EA-cifar10-SS{:}/results.pth".format(x) for x in SSs] + acc_ss_dict, indexes = {}, None + for ss, checkpoint in zip(SSs, checkpoints): + all_indexes, accuracies = torch.load(checkpoint, map_location="cpu"), [] + for x in all_indexes: + info = api.arch2infos_full[x] + metrics = info.get_metrics(dataset, xset, None, False) + accuracies.append(metrics["accuracy"]) + if indexes is None: + indexes = list(range(len(accuracies))) + acc_ss_dict[ss] = np.array(sorted(accuracies)) + print("Sample-Size={:2d}, mean={:}, std={:}".format(ss, acc_ss_dict[ss].mean(), acc_ss_dict[ss].std())) - for idx, ss in enumerate(SSs): - legend = 'sample-size={:2d}'.format(ss) - #color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.' - color, linestyle = color_set[idx], '-' - plt.plot(indexes, acc_ss_dict[ss], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8) - print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_ss_dict[ss]), np.std(acc_ss_dict[ss]), np.mean(acc_ss_dict[ss]), np.std(acc_ss_dict[ss]))) - plt.legend(loc=4, fontsize=LegendFontsize) - save_path = root / '{:}-{:}-{:}.pdf'.format(dataset, xset, file_name) - print('save figure into {:}\n'.format(save_path)) - fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') + color_set = ["r", "b", "g", "c", "m", "y", "k"] + dpi, width, height = 300, 3400, 2600 + LabelSize, LegendFontsize = 28, 22 + figsize = width / float(dpi), height / float(dpi) + fig = plt.figure(figsize=figsize) + x_axis = np.arange(0, 600) + plt.xlim(0, max(indexes)) + plt.ylim(y_lims[0], y_lims[1]) + interval_x, interval_y = 100, y_lims[2] + plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) + plt.yticks(np.arange(y_lims[0], y_lims[1], interval_y), fontsize=LegendFontsize) + plt.grid() + plt.xlabel("The index of runs", fontsize=LabelSize) + plt.ylabel("The accuracy (%)", fontsize=LabelSize) + + for idx, ss in enumerate(SSs): + legend = "sample-size={:2d}".format(ss) + # color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.' + color, linestyle = color_set[idx], "-" + plt.plot(indexes, acc_ss_dict[ss], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8) + print( + "{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}".format( + legend, + np.mean(acc_ss_dict[ss]), + np.std(acc_ss_dict[ss]), + np.mean(acc_ss_dict[ss]), + np.std(acc_ss_dict[ss]), + ) + ) + plt.legend(loc=4, fontsize=LegendFontsize) + save_path = root / "{:}-{:}-{:}.pdf".format(dataset, xset, file_name) + print("save figure into {:}\n".format(save_path)) + fig.savefig(str(save_path), dpi=dpi, bbox_inches="tight", format="pdf") -if __name__ == '__main__': +if __name__ == "__main__": - parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.') - parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.') - args = parser.parse_args() - - vis_save_dir = Path(args.save_dir) - vis_save_dir.mkdir(parents=True, exist_ok=True) - meta_file = Path(args.api_path) - assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) - #visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time') - #write_video(vis_save_dir / 'over-time') - #visualize_info(str(meta_file), 'cifar10' , vis_save_dir) - #visualize_info(str(meta_file), 'cifar100', vis_save_dir) - #visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) - #visualize_relative_ranking(vis_save_dir) + parser = argparse.ArgumentParser( + description="NAS-Bench-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--save_dir", + type=str, + default="./output/search-cell-nas-bench-201/visuals", + help="The base-name of folder to save checkpoints and log.", + ) + parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.") + args = parser.parse_args() - api = API(args.api_path) - #show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (85, 92, 2)) - #show_rea (api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REA-CIFAR-10', (88, 92, 1)) + vis_save_dir = Path(args.save_dir) + vis_save_dir.mkdir(parents=True, exist_ok=True) + meta_file = Path(args.api_path) + assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) + # visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time') + # write_video(vis_save_dir / 'over-time') + # visualize_info(str(meta_file), 'cifar10' , vis_save_dir) + # visualize_info(str(meta_file), 'cifar100', vis_save_dir) + # visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) + # visualize_relative_ranking(vis_save_dir) - #plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1)) - #plot_results_nas_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3)) - #plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2)) + api = API(args.api_path) + # show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (85, 92, 2)) + # show_rea (api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REA-CIFAR-10', (88, 92, 1)) - show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'BN0', 'BN0-DARTS-CIFAR010.pdf', (0, 100,10), 50) - show_nas_sharing_w_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ) , vis_save_dir, 'BN0', 'BN0-DARTS-CIFAR100.pdf', (0, 100,10), 50) - show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ) , vis_save_dir, 'BN0', 'BN0-DARTS-ImageNet.pdf', (0, 100,10), 50) + # plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1)) + # plot_results_nas_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3)) + # plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2)) - show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'BN0', 'BN0-OTHER-CIFAR010.pdf', (0, 100,10), 250) - show_nas_sharing_w_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ) , vis_save_dir, 'BN0', 'BN0-OTHER-CIFAR100.pdf', (0, 100,10), 250) - show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ) , vis_save_dir, 'BN0', 'BN0-OTHER-ImageNet.pdf', (0, 100,10), 250) + show_nas_sharing_w_v2( + api, + ("cifar10-valid", "x-valid"), + ("cifar10", "ori-test"), + vis_save_dir, + "BN0", + "BN0-DARTS-CIFAR010.pdf", + (0, 100, 10), + 50, + ) + show_nas_sharing_w_v2( + api, + ("cifar100", "x-valid"), + ("cifar100", "x-test"), + vis_save_dir, + "BN0", + "BN0-DARTS-CIFAR100.pdf", + (0, 100, 10), + 50, + ) + show_nas_sharing_w_v2( + api, + ("ImageNet16-120", "x-valid"), + ("ImageNet16-120", "x-test"), + vis_save_dir, + "BN0", + "BN0-DARTS-ImageNet.pdf", + (0, 100, 10), + 50, + ) - show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-VALID.pdf', (0, 100,10), 250) - show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-TEST.pdf' , (0, 100,10), 250) - """ + show_nas_sharing_w_v2( + api, + ("cifar10-valid", "x-valid"), + ("cifar10", "ori-test"), + vis_save_dir, + "BN0", + "BN0-OTHER-CIFAR010.pdf", + (0, 100, 10), + 250, + ) + show_nas_sharing_w_v2( + api, + ("cifar100", "x-valid"), + ("cifar100", "x-test"), + vis_save_dir, + "BN0", + "BN0-OTHER-CIFAR100.pdf", + (0, 100, 10), + 250, + ) + show_nas_sharing_w_v2( + api, + ("ImageNet16-120", "x-valid"), + ("ImageNet16-120", "x-test"), + vis_save_dir, + "BN0", + "BN0-OTHER-ImageNet.pdf", + (0, 100, 10), + 250, + ) + + show_nas_sharing_w( + api, "cifar10-valid", "x-valid", vis_save_dir, "BN0", "BN0-XX-CIFAR010-VALID.pdf", (0, 100, 10), 250 + ) + show_nas_sharing_w(api, "cifar10", "ori-test", vis_save_dir, "BN0", "BN0-XX-CIFAR010-TEST.pdf", (0, 100, 10), 250) + """ for x_maxs in [50, 250]: show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) diff --git a/exps/NATS-Bench/Analyze-time.py b/exps/NATS-Bench/Analyze-time.py index b6174cc..4db9e05 100644 --- a/exps/NATS-Bench/Analyze-time.py +++ b/exps/NATS-Bench/Analyze-time.py @@ -8,37 +8,45 @@ import os, sys, time, tqdm, argparse from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from datasets import get_datasets from nats_bench import create def show_time(api, epoch=12): - print('Show the time for {:} with {:}-epoch-training'.format(api, epoch)) - all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0 - for index in tqdm.tqdm(range(len(api))): - info = api.get_more_info(index, 'ImageNet16-120', hp=epoch) - imagenet_time = info['train-all-time'] - info = api.get_more_info(index, 'cifar10-valid', hp=epoch) - cifar10_time = info['train-all-time'] - info = api.get_more_info(index, 'cifar100', hp=epoch) - cifar100_time = info['train-all-time'] - # accumulate the time - all_cifar10_time += cifar10_time - all_cifar100_time += cifar100_time - all_imagenet_time += imagenet_time - print('The total training time for CIFAR-10 (held-out train set) is {:} seconds'.format(all_cifar10_time)) - print('The total training time for CIFAR-100 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10'.format(all_cifar100_time, all_cifar100_time / all_cifar10_time)) - print('The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10'.format(all_imagenet_time, all_imagenet_time / all_cifar10_time)) + print("Show the time for {:} with {:}-epoch-training".format(api, epoch)) + all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0 + for index in tqdm.tqdm(range(len(api))): + info = api.get_more_info(index, "ImageNet16-120", hp=epoch) + imagenet_time = info["train-all-time"] + info = api.get_more_info(index, "cifar10-valid", hp=epoch) + cifar10_time = info["train-all-time"] + info = api.get_more_info(index, "cifar100", hp=epoch) + cifar100_time = info["train-all-time"] + # accumulate the time + all_cifar10_time += cifar10_time + all_cifar100_time += cifar100_time + all_imagenet_time += imagenet_time + print("The total training time for CIFAR-10 (held-out train set) is {:} seconds".format(all_cifar10_time)) + print( + "The total training time for CIFAR-100 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format( + all_cifar100_time, all_cifar100_time / all_cifar10_time + ) + ) + print( + "The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format( + all_imagenet_time, all_imagenet_time / all_cifar10_time + ) + ) -if __name__ == '__main__': +if __name__ == "__main__": - api_nats_tss = create(None, 'tss', fast_mode=True, verbose=False) - show_time(api_nats_tss, 12) - - api_nats_sss = create(None, 'sss', fast_mode=True, verbose=False) - show_time(api_nats_sss, 12) + api_nats_tss = create(None, "tss", fast_mode=True, verbose=False) + show_time(api_nats_tss, 12) + api_nats_sss = create(None, "sss", fast_mode=True, verbose=False) + show_time(api_nats_sss, 12) diff --git a/exps/NATS-Bench/draw-correlations.py b/exps/NATS-Bench/draw-correlations.py index 6afac3b..500b088 100644 --- a/exps/NATS-Bench/draw-correlations.py +++ b/exps/NATS-Bench/draw-correlations.py @@ -10,81 +10,88 @@ import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict, OrderedDict -from copy import deepcopy +from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from nats_bench import create from log_utils import time_string def get_valid_test_acc(api, arch, dataset): - is_size_space = api.search_space_name == 'size' - if dataset == 'cifar10': - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - test_acc = xinfo['test-accuracy'] - xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False) - valid_acc = xinfo['valid-accuracy'] - else: - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - valid_acc = xinfo['valid-accuracy'] - test_acc = xinfo['test-accuracy'] - return valid_acc, test_acc, 'validation = {:.2f}, test = {:.2f}\n'.format(valid_acc, test_acc) + is_size_space = api.search_space_name == "size" + if dataset == "cifar10": + xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + test_acc = xinfo["test-accuracy"] + xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False) + valid_acc = xinfo["valid-accuracy"] + else: + xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + valid_acc = xinfo["valid-accuracy"] + test_acc = xinfo["test-accuracy"] + return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc) def compute_kendalltau(vectori, vectorj): - # indexes = list(range(len(vectori))) - # rank_1 = sorted(indexes, key=lambda i: vectori[i]) - # rank_2 = sorted(indexes, key=lambda i: vectorj[i]) - # import pdb; pdb.set_trace() - coef, p = scipy.stats.kendalltau(vectori, vectorj) - return coef + # indexes = list(range(len(vectori))) + # rank_1 = sorted(indexes, key=lambda i: vectori[i]) + # rank_2 = sorted(indexes, key=lambda i: vectorj[i]) + # import pdb; pdb.set_trace() + coef, p = scipy.stats.kendalltau(vectori, vectorj) + return coef def compute_spearmanr(vectori, vectorj): - coef, p = scipy.stats.spearmanr(vectori, vectorj) - return coef + coef, p = scipy.stats.spearmanr(vectori, vectorj) + return coef -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') - parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log." + ) + parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") + args = parser.parse_args() - save_dir = Path(args.save_dir) + save_dir = Path(args.save_dir) - api = create(None, 'tss', fast_mode=True, verbose=False) - indexes = list(range(1, 10000, 300)) - scores_1 = [] - scores_2 = [] - for index in indexes: - valid_acc, test_acc, _ = get_valid_test_acc(api, index, 'cifar10') - scores_1.append(valid_acc) - scores_2.append(test_acc) - correlation = compute_kendalltau(scores_1, scores_2) - print('The kendall tau correlation of {:} samples : {:}'.format(len(indexes), correlation)) - correlation = compute_spearmanr(scores_1, scores_2) - print('The spearmanr correlation of {:} samples : {:}'.format(len(indexes), correlation)) - # scores_1 = ['{:.2f}'.format(x) for x in scores_1] - # scores_2 = ['{:.2f}'.format(x) for x in scores_2] - # print(', '.join(scores_1)) - # print(', '.join(scores_2)) + api = create(None, "tss", fast_mode=True, verbose=False) + indexes = list(range(1, 10000, 300)) + scores_1 = [] + scores_2 = [] + for index in indexes: + valid_acc, test_acc, _ = get_valid_test_acc(api, index, "cifar10") + scores_1.append(valid_acc) + scores_2.append(test_acc) + correlation = compute_kendalltau(scores_1, scores_2) + print("The kendall tau correlation of {:} samples : {:}".format(len(indexes), correlation)) + correlation = compute_spearmanr(scores_1, scores_2) + print("The spearmanr correlation of {:} samples : {:}".format(len(indexes), correlation)) + # scores_1 = ['{:.2f}'.format(x) for x in scores_1] + # scores_2 = ['{:.2f}'.format(x) for x in scores_2] + # print(', '.join(scores_1)) + # print(', '.join(scores_2)) - dpi, width, height = 250, 1000, 1000 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 14, 14 + dpi, width, height = 250, 1000, 1000 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 14, 14 - fig, ax = plt.subplots(1, 1, figsize=figsize) - ax.scatter(scores_1, scores_2 , marker='^', s=0.5, c='tab:green', alpha=0.8) + fig, ax = plt.subplots(1, 1, figsize=figsize) + ax.scatter(scores_1, scores_2, marker="^", s=0.5, c="tab:green", alpha=0.8) - save_path = '/Users/xuanyidong/Desktop/test-temp-rank.png' - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - plt.close('all') + save_path = "/Users/xuanyidong/Desktop/test-temp-rank.png" + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + plt.close("all") diff --git a/exps/NATS-Bench/draw-fig2_5.py b/exps/NATS-Bench/draw-fig2_5.py index f7cf0b2..771d2e0 100644 --- a/exps/NATS-Bench/draw-fig2_5.py +++ b/exps/NATS-Bench/draw-fig2_5.py @@ -12,16 +12,18 @@ import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict -from copy import deepcopy +from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from log_utils import time_string from models import get_cell_based_tiny_net @@ -29,387 +31,574 @@ from nats_bench import create def visualize_relative_info(api, vis_save_dir, indicator): - vis_save_dir = vis_save_dir.resolve() - # print ('{:} start to visualize {:} information'.format(time_string(), api)) - vis_save_dir.mkdir(parents=True, exist_ok=True) + vis_save_dir = vis_save_dir.resolve() + # print ('{:} start to visualize {:} information'.format(time_string(), api)) + vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) - cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) - imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) - cifar010_info = torch.load(cifar010_cache_path) - cifar100_info = torch.load(cifar100_cache_path) - imagenet_info = torch.load(imagenet_cache_path) - indexes = list(range(len(cifar010_info['params']))) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_info = torch.load(cifar010_cache_path) + cifar100_info = torch.load(cifar100_cache_path) + imagenet_info = torch.load(imagenet_cache_path) + indexes = list(range(len(cifar010_info["params"]))) - print ('{:} start to visualize relative ranking'.format(time_string())) + print("{:} start to visualize relative ranking".format(time_string())) - cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i]) - cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i]) - imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i]) + cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i]) + cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i]) + imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i]) - cifar100_labels, imagenet_labels = [], [] - for idx in cifar010_ord_indexes: - cifar100_labels.append( cifar100_ord_indexes.index(idx) ) - imagenet_labels.append( imagenet_ord_indexes.index(idx) ) - print ('{:} prepare data done.'.format(time_string())) + cifar100_labels, imagenet_labels = [], [] + for idx in cifar010_ord_indexes: + cifar100_labels.append(cifar100_ord_indexes.index(idx)) + imagenet_labels.append(imagenet_ord_indexes.index(idx)) + print("{:} prepare data done.".format(time_string())) - dpi, width, height = 200, 1400, 800 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 18, 12 - resnet_scale, resnet_alpha = 120, 0.5 + dpi, width, height = 200, 1400, 800 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 18, 12 + resnet_scale, resnet_alpha = 120, 0.5 - fig = plt.figure(figsize=figsize) - ax = fig.add_subplot(111) - plt.xlim(min(indexes), max(indexes)) - plt.ylim(min(indexes), max(indexes)) - # plt.ylabel('y').set_rotation(30) - plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical') - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize) - ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) - ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8) - ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8) - ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10') - ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100') - ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120') - plt.grid(zorder=0) - ax.set_axisbelow(True) - plt.legend(loc=0, fontsize=LegendFontsize) - ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize) - ax.set_ylabel('architecture ranking', fontsize=LabelSize) - save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xlim(min(indexes), max(indexes)) + plt.ylim(min(indexes), max(indexes)) + # plt.ylabel('y').set_rotation(30) + plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical") + plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize) + ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) + ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8) + ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) + ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10") + ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100") + ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120") + plt.grid(zorder=0) + ax.set_axisbelow(True) + plt.legend(loc=0, fontsize=LegendFontsize) + ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize) + ax.set_ylabel("architecture ranking", fontsize=LabelSize) + save_path = (vis_save_dir / "{:}-relative-rank.pdf".format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (vis_save_dir / "{:}-relative-rank.png".format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) def visualize_sss_info(api, dataset, vis_save_dir): - vis_save_dir = vis_save_dir.resolve() - print ('{:} start to visualize {:} information'.format(time_string(), dataset)) - vis_save_dir.mkdir(parents=True, exist_ok=True) - cache_file_path = vis_save_dir / '{:}-cache-sss-info.pth'.format(dataset) - if not cache_file_path.exists(): - print ('Do not find cache file : {:}'.format(cache_file_path)) - params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] - for index in range(len(api)): - cost_info = api.get_cost_info(index, dataset, hp='90') - params.append(cost_info['params']) - flops.append(cost_info['flops']) - # accuracy - info = api.get_more_info(index, dataset, hp='90', is_random=False) - train_accs.append(info['train-accuracy']) - test_accs.append(info['test-accuracy']) - if dataset == 'cifar10': - info = api.get_more_info(index, 'cifar10-valid', hp='90', is_random=False) - valid_accs.append(info['valid-accuracy']) - else: - valid_accs.append(info['valid-accuracy']) - info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs} - torch.save(info, cache_file_path) - else: - print ('Find cache file : {:}'.format(cache_file_path)) - info = torch.load(cache_file_path) - params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'] - print ('{:} collect data done.'.format(time_string())) + vis_save_dir = vis_save_dir.resolve() + print("{:} start to visualize {:} information".format(time_string(), dataset)) + vis_save_dir.mkdir(parents=True, exist_ok=True) + cache_file_path = vis_save_dir / "{:}-cache-sss-info.pth".format(dataset) + if not cache_file_path.exists(): + print("Do not find cache file : {:}".format(cache_file_path)) + params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] + for index in range(len(api)): + cost_info = api.get_cost_info(index, dataset, hp="90") + params.append(cost_info["params"]) + flops.append(cost_info["flops"]) + # accuracy + info = api.get_more_info(index, dataset, hp="90", is_random=False) + train_accs.append(info["train-accuracy"]) + test_accs.append(info["test-accuracy"]) + if dataset == "cifar10": + info = api.get_more_info(index, "cifar10-valid", hp="90", is_random=False) + valid_accs.append(info["valid-accuracy"]) + else: + valid_accs.append(info["valid-accuracy"]) + info = { + "params": params, + "flops": flops, + "train_accs": train_accs, + "valid_accs": valid_accs, + "test_accs": test_accs, + } + torch.save(info, cache_file_path) + else: + print("Find cache file : {:}".format(cache_file_path)) + info = torch.load(cache_file_path) + params, flops, train_accs, valid_accs, test_accs = ( + info["params"], + info["flops"], + info["train_accs"], + info["valid_accs"], + info["test_accs"], + ) + print("{:} collect data done.".format(time_string())) - # pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64'] - pyramid = ['8:16:24:32:40', '8:16:32:48:64', '32:40:48:56:64'] - pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid] - largest_indexes = [api.query_index_by_arch('64:64:64:64:64')] + # pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64'] + pyramid = ["8:16:24:32:40", "8:16:32:48:64", "32:40:48:56:64"] + pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid] + largest_indexes = [api.query_index_by_arch("64:64:64:64:64")] - indexes = list(range(len(params))) - dpi, width, height = 250, 8500, 1300 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 24, 24 - # resnet_scale, resnet_alpha = 120, 0.5 - xscale, xalpha = 120, 0.8 + indexes = list(range(len(params))) + dpi, width, height = 250, 8500, 1300 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 24, 24 + # resnet_scale, resnet_alpha = 120, 0.5 + xscale, xalpha = 120, 0.8 - fig, axs = plt.subplots(1, 4, figsize=figsize) - # ax1, ax2, ax3, ax4, ax5 = axs - for ax in axs: - for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f')) - for tick in ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax1, ax2, ax3, ax4 = axs + fig, axs = plt.subplots(1, 4, figsize=figsize) + # ax1, ax2, ax3, ax4, ax5 = axs + for ax in axs: + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax1, ax2, ax3, ax4 = axs - ax1.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') - ax1.scatter([params[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax1.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax1.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax1.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax1.legend(loc=4, fontsize=LegendFontsize) + ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") + ax1.scatter( + [params[x] for x in pyramid_indexes], + [train_accs[x] for x in pyramid_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="Pyramid Structure", + alpha=xalpha, + ) + ax1.scatter( + [params[x] for x in largest_indexes], + [train_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize) + ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize) + ax1.legend(loc=4, fontsize=LegendFontsize) - ax2.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue') - ax2.scatter([flops[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax2.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax2.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - # ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax2.legend(loc=4, fontsize=LegendFontsize) + ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") + ax2.scatter( + [flops[x] for x in pyramid_indexes], + [train_accs[x] for x in pyramid_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="Pyramid Structure", + alpha=xalpha, + ) + ax2.scatter( + [flops[x] for x in largest_indexes], + [train_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize) + # ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) + ax2.legend(loc=4, fontsize=LegendFontsize) - ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue') - ax3.scatter([params[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax3.legend(loc=4, fontsize=LegendFontsize) + ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") + ax3.scatter( + [params[x] for x in pyramid_indexes], + [test_accs[x] for x in pyramid_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="Pyramid Structure", + alpha=xalpha, + ) + ax3.scatter( + [params[x] for x in largest_indexes], + [test_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) + ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) + ax3.legend(loc=4, fontsize=LegendFontsize) - ax4.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue') - ax4.scatter([flops[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax4.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - # ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax4.legend(loc=4, fontsize=LegendFontsize) + ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") + ax4.scatter( + [flops[x] for x in pyramid_indexes], + [test_accs[x] for x in pyramid_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="Pyramid Structure", + alpha=xalpha, + ) + ax4.scatter( + [flops[x] for x in largest_indexes], + [test_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) + # ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize) + ax4.legend(loc=4, fontsize=LegendFontsize) - save_path = vis_save_dir / 'sss-{:}.png'.format(dataset.lower()) - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + save_path = vis_save_dir / "sss-{:}.png".format(dataset.lower()) + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") def visualize_tss_info(api, dataset, vis_save_dir): - vis_save_dir = vis_save_dir.resolve() - print ('{:} start to visualize {:} information'.format(time_string(), dataset)) - vis_save_dir.mkdir(parents=True, exist_ok=True) - cache_file_path = vis_save_dir / '{:}-cache-tss-info.pth'.format(dataset) - if not cache_file_path.exists(): - print ('Do not find cache file : {:}'.format(cache_file_path)) - params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] - for index in range(len(api)): - cost_info = api.get_cost_info(index, dataset, hp='12') - params.append(cost_info['params']) - flops.append(cost_info['flops']) - # accuracy - info = api.get_more_info(index, dataset, hp='200', is_random=False) - train_accs.append(info['train-accuracy']) - test_accs.append(info['test-accuracy']) - if dataset == 'cifar10': - info = api.get_more_info(index, 'cifar10-valid', hp='200', is_random=False) - valid_accs.append(info['valid-accuracy']) - else: - valid_accs.append(info['valid-accuracy']) - print('') - info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs} - torch.save(info, cache_file_path) - else: - print ('Find cache file : {:}'.format(cache_file_path)) - info = torch.load(cache_file_path) - params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'] - print ('{:} collect data done.'.format(time_string())) + vis_save_dir = vis_save_dir.resolve() + print("{:} start to visualize {:} information".format(time_string(), dataset)) + vis_save_dir.mkdir(parents=True, exist_ok=True) + cache_file_path = vis_save_dir / "{:}-cache-tss-info.pth".format(dataset) + if not cache_file_path.exists(): + print("Do not find cache file : {:}".format(cache_file_path)) + params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] + for index in range(len(api)): + cost_info = api.get_cost_info(index, dataset, hp="12") + params.append(cost_info["params"]) + flops.append(cost_info["flops"]) + # accuracy + info = api.get_more_info(index, dataset, hp="200", is_random=False) + train_accs.append(info["train-accuracy"]) + test_accs.append(info["test-accuracy"]) + if dataset == "cifar10": + info = api.get_more_info(index, "cifar10-valid", hp="200", is_random=False) + valid_accs.append(info["valid-accuracy"]) + else: + valid_accs.append(info["valid-accuracy"]) + print("") + info = { + "params": params, + "flops": flops, + "train_accs": train_accs, + "valid_accs": valid_accs, + "test_accs": test_accs, + } + torch.save(info, cache_file_path) + else: + print("Find cache file : {:}".format(cache_file_path)) + info = torch.load(cache_file_path) + params, flops, train_accs, valid_accs, test_accs = ( + info["params"], + info["flops"], + info["train_accs"], + info["valid_accs"], + info["test_accs"], + ) + print("{:} collect data done.".format(time_string())) - resnet = ['|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|'] - resnet_indexes = [api.query_index_by_arch(x) for x in resnet] - largest_indexes = [api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|')] + resnet = ["|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"] + resnet_indexes = [api.query_index_by_arch(x) for x in resnet] + largest_indexes = [ + api.query_index_by_arch( + "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|" + ) + ] - indexes = list(range(len(params))) - dpi, width, height = 250, 8500, 1300 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 24, 24 - # resnet_scale, resnet_alpha = 120, 0.5 - xscale, xalpha = 120, 0.8 + indexes = list(range(len(params))) + dpi, width, height = 250, 8500, 1300 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 24, 24 + # resnet_scale, resnet_alpha = 120, 0.5 + xscale, xalpha = 120, 0.8 - fig, axs = plt.subplots(1, 4, figsize=figsize) - for ax in axs: - for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f')) - for tick in ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax1, ax2, ax3, ax4 = axs + fig, axs = plt.subplots(1, 4, figsize=figsize) + for ax in axs: + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax1, ax2, ax3, ax4 = axs - ax1.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') - ax1.scatter([params[x] for x in resnet_indexes] , [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax1.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax1.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax1.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax1.legend(loc=4, fontsize=LegendFontsize) + ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") + ax1.scatter( + [params[x] for x in resnet_indexes], + [train_accs[x] for x in resnet_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="ResNet", + alpha=xalpha, + ) + ax1.scatter( + [params[x] for x in largest_indexes], + [train_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize) + ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize) + ax1.legend(loc=4, fontsize=LegendFontsize) - ax2.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue') - ax2.scatter([flops[x] for x in resnet_indexes], [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax2.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax2.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - # ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax2.legend(loc=4, fontsize=LegendFontsize) + ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") + ax2.scatter( + [flops[x] for x in resnet_indexes], + [train_accs[x] for x in resnet_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="ResNet", + alpha=xalpha, + ) + ax2.scatter( + [flops[x] for x in largest_indexes], + [train_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize) + # ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) + ax2.legend(loc=4, fontsize=LegendFontsize) - ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue') - ax3.scatter([params[x] for x in resnet_indexes] , [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax3.legend(loc=4, fontsize=LegendFontsize) + ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") + ax3.scatter( + [params[x] for x in resnet_indexes], + [test_accs[x] for x in resnet_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="ResNet", + alpha=xalpha, + ) + ax3.scatter( + [params[x] for x in largest_indexes], + [test_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) + ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) + ax3.legend(loc=4, fontsize=LegendFontsize) - ax4.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue') - ax4.scatter([flops[x] for x in resnet_indexes], [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax4.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - # ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax4.legend(loc=4, fontsize=LegendFontsize) + ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") + ax4.scatter( + [flops[x] for x in resnet_indexes], + [test_accs[x] for x in resnet_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="ResNet", + alpha=xalpha, + ) + ax4.scatter( + [flops[x] for x in largest_indexes], + [test_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) + # ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize) + ax4.legend(loc=4, fontsize=LegendFontsize) - save_path = vis_save_dir / 'tss-{:}.png'.format(dataset.lower()) - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + save_path = vis_save_dir / "tss-{:}.png".format(dataset.lower()) + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") def visualize_rank_info(api, vis_save_dir, indicator): - vis_save_dir = vis_save_dir.resolve() - # print ('{:} start to visualize {:} information'.format(time_string(), api)) - vis_save_dir.mkdir(parents=True, exist_ok=True) + vis_save_dir = vis_save_dir.resolve() + # print ('{:} start to visualize {:} information'.format(time_string(), api)) + vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) - cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) - imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) - cifar010_info = torch.load(cifar010_cache_path) - cifar100_info = torch.load(cifar100_cache_path) - imagenet_info = torch.load(imagenet_cache_path) - indexes = list(range(len(cifar010_info['params']))) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_info = torch.load(cifar010_cache_path) + cifar100_info = torch.load(cifar100_cache_path) + imagenet_info = torch.load(imagenet_cache_path) + indexes = list(range(len(cifar010_info["params"]))) - print ('{:} start to visualize relative ranking'.format(time_string())) + print("{:} start to visualize relative ranking".format(time_string())) - dpi, width, height = 250, 3800, 1200 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 14, 14 + dpi, width, height = 250, 3800, 1200 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 14, 14 - fig, axs = plt.subplots(1, 3, figsize=figsize) - ax1, ax2, ax3 = axs + fig, axs = plt.subplots(1, 3, figsize=figsize) + ax1, ax2, ax3 = axs - def get_labels(info): - ord_test_indexes = sorted(indexes, key=lambda i: info['test_accs'][i]) - ord_valid_indexes = sorted(indexes, key=lambda i: info['valid_accs'][i]) - labels = [] - for idx in ord_test_indexes: - labels.append(ord_valid_indexes.index(idx)) - return labels + def get_labels(info): + ord_test_indexes = sorted(indexes, key=lambda i: info["test_accs"][i]) + ord_valid_indexes = sorted(indexes, key=lambda i: info["valid_accs"][i]) + labels = [] + for idx in ord_test_indexes: + labels.append(ord_valid_indexes.index(idx)) + return labels - def plot_ax(labels, ax, name): - for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - for tick in ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - tick.label.set_rotation(90) - ax.set_xlim(min(indexes), max(indexes)) - ax.set_ylim(min(indexes), max(indexes)) - ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes)//3)) - ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes)//5)) - ax.scatter(indexes, labels , marker='^', s=0.5, c='tab:green', alpha=0.8) - ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue' , alpha=0.8) - ax.scatter([-1], [-1], marker='^', s=100, c='tab:green' , label='{:} test'.format(name)) - ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='{:} validation'.format(name)) - ax.legend(loc=4, fontsize=LegendFontsize) - ax.set_xlabel('ranking on the {:} validation'.format(name), fontsize=LabelSize) - ax.set_ylabel('architecture ranking', fontsize=LabelSize) - labels = get_labels(cifar010_info) - plot_ax(labels, ax1, 'CIFAR-10') - labels = get_labels(cifar100_info) - plot_ax(labels, ax2, 'CIFAR-100') - labels = get_labels(imagenet_info) - plot_ax(labels, ax3, 'ImageNet-16-120') + def plot_ax(labels, ax, name): + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + tick.label.set_rotation(90) + ax.set_xlim(min(indexes), max(indexes)) + ax.set_ylim(min(indexes), max(indexes)) + ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 3)) + ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) + ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8) + ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) + ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)) + ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name)) + ax.legend(loc=4, fontsize=LegendFontsize) + ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize) + ax.set_ylabel("architecture ranking", fontsize=LabelSize) - save_path = (vis_save_dir / '{:}-same-relative-rank.pdf'.format(indicator)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / '{:}-same-relative-rank.png'.format(indicator)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + labels = get_labels(cifar010_info) + plot_ax(labels, ax1, "CIFAR-10") + labels = get_labels(cifar100_info) + plot_ax(labels, ax2, "CIFAR-100") + labels = get_labels(imagenet_info) + plot_ax(labels, ax3, "ImageNet-16-120") + + save_path = (vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (vis_save_dir / "{:}-same-relative-rank.png".format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") def compute_kendalltau(vectori, vectorj): - # indexes = list(range(len(vectori))) - # rank_1 = sorted(indexes, key=lambda i: vectori[i]) - # rank_2 = sorted(indexes, key=lambda i: vectorj[i]) - return scipy.stats.kendalltau(vectori, vectorj).correlation + # indexes = list(range(len(vectori))) + # rank_1 = sorted(indexes, key=lambda i: vectori[i]) + # rank_2 = sorted(indexes, key=lambda i: vectorj[i]) + return scipy.stats.kendalltau(vectori, vectorj).correlation def calculate_correlation(*vectors): - matrix = [] - for i, vectori in enumerate(vectors): - x = [] - for j, vectorj in enumerate(vectors): - # x.append(np.corrcoef(vectori, vectorj)[0,1]) - x.append(compute_kendalltau(vectori, vectorj)) - matrix.append( x ) - return np.array(matrix) + matrix = [] + for i, vectori in enumerate(vectors): + x = [] + for j, vectorj in enumerate(vectors): + # x.append(np.corrcoef(vectori, vectorj)[0,1]) + x.append(compute_kendalltau(vectori, vectorj)) + matrix.append(x) + return np.array(matrix) def visualize_all_rank_info(api, vis_save_dir, indicator): - vis_save_dir = vis_save_dir.resolve() - # print ('{:} start to visualize {:} information'.format(time_string(), api)) - vis_save_dir.mkdir(parents=True, exist_ok=True) + vis_save_dir = vis_save_dir.resolve() + # print ('{:} start to visualize {:} information'.format(time_string(), api)) + vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) - cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) - imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) - cifar010_info = torch.load(cifar010_cache_path) - cifar100_info = torch.load(cifar100_cache_path) - imagenet_info = torch.load(imagenet_cache_path) - indexes = list(range(len(cifar010_info['params']))) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_info = torch.load(cifar010_cache_path) + cifar100_info = torch.load(cifar100_cache_path) + imagenet_info = torch.load(imagenet_cache_path) + indexes = list(range(len(cifar010_info["params"]))) - print ('{:} start to visualize relative ranking'.format(time_string())) - + print("{:} start to visualize relative ranking".format(time_string())) - dpi, width, height = 250, 3200, 1400 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 14, 14 + dpi, width, height = 250, 3200, 1400 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 14, 14 - fig, axs = plt.subplots(1, 2, figsize=figsize) - ax1, ax2 = axs + fig, axs = plt.subplots(1, 2, figsize=figsize) + ax1, ax2 = axs - sns_size, xformat = 15, '.2f' - CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs']) - - sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt=xformat, linewidths=0.5, ax=ax1, - xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'], - yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T']) - - selected_indexes, acc_bar = [], 92 - for i, acc in enumerate(cifar010_info['test_accs']): - if acc > acc_bar: selected_indexes.append( i ) - cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ] - cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ] - cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ] - cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ] - imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ] - imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ] - CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs) - - sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt=xformat, linewidths=0.5, ax=ax2, - xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'], - yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T']) - ax1.set_title('Correlation coefficient over ALL candidates') - ax2.set_title('Correlation coefficient over candidates with accuracy > {:}%'.format(acc_bar)) - save_path = (vis_save_dir / '{:}-all-relative-rank.png'.format(indicator)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + sns_size, xformat = 15, ".2f" + CoRelMatrix = calculate_correlation( + cifar010_info["valid_accs"], + cifar010_info["test_accs"], + cifar100_info["valid_accs"], + cifar100_info["test_accs"], + imagenet_info["valid_accs"], + imagenet_info["test_accs"], + ) + + sns.heatmap( + CoRelMatrix, + annot=True, + annot_kws={"size": sns_size}, + fmt=xformat, + linewidths=0.5, + ax=ax1, + xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], + yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], + ) + + selected_indexes, acc_bar = [], 92 + for i, acc in enumerate(cifar010_info["test_accs"]): + if acc > acc_bar: + selected_indexes.append(i) + cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes] + cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes] + cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes] + cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes] + imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes] + imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes] + CoRelMatrix = calculate_correlation( + cifar010_valid_accs, + cifar010_test_accs, + cifar100_valid_accs, + cifar100_test_accs, + imagenet_valid_accs, + imagenet_test_accs, + ) + + sns.heatmap( + CoRelMatrix, + annot=True, + annot_kws={"size": sns_size}, + fmt=xformat, + linewidths=0.5, + ax=ax2, + xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], + yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], + ) + ax1.set_title("Correlation coefficient over ALL candidates") + ax2.set_title("Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)) + save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench', help='Folder to save checkpoints and log.') - # use for train the model - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log." + ) + # use for train the model + args = parser.parse_args() - to_save_dir = Path(args.save_dir) + to_save_dir = Path(args.save_dir) - datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] - # Figure 3 (a-c) - api_tss = create(None, 'tss', verbose=True) - for xdata in datasets: - visualize_tss_info(api_tss, xdata, to_save_dir) - # Figure 3 (d-f) - api_sss = create(None, 'size', verbose=True) - for xdata in datasets: - visualize_sss_info(api_sss, xdata, to_save_dir) + datasets = ["cifar10", "cifar100", "ImageNet16-120"] + # Figure 3 (a-c) + api_tss = create(None, "tss", verbose=True) + for xdata in datasets: + visualize_tss_info(api_tss, xdata, to_save_dir) + # Figure 3 (d-f) + api_sss = create(None, "size", verbose=True) + for xdata in datasets: + visualize_sss_info(api_sss, xdata, to_save_dir) - # Figure 2 - visualize_relative_info(None, to_save_dir, 'tss') - visualize_relative_info(None, to_save_dir, 'sss') + # Figure 2 + visualize_relative_info(None, to_save_dir, "tss") + visualize_relative_info(None, to_save_dir, "sss") - # Figure 4 - visualize_rank_info(None, to_save_dir, 'tss') - visualize_rank_info(None, to_save_dir, 'sss') + # Figure 4 + visualize_rank_info(None, to_save_dir, "tss") + visualize_rank_info(None, to_save_dir, "sss") - # Figure 5 - visualize_all_rank_info(None, to_save_dir, 'tss') - visualize_all_rank_info(None, to_save_dir, 'sss') + # Figure 5 + visualize_all_rank_info(None, to_save_dir, "tss") + visualize_all_rank_info(None, to_save_dir, "sss") diff --git a/exps/NATS-Bench/draw-fig6.py b/exps/NATS-Bench/draw-fig6.py index fc7a152..12ca998 100644 --- a/exps/NATS-Bench/draw-fig6.py +++ b/exps/NATS-Bench/draw-fig6.py @@ -12,158 +12,174 @@ import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict, OrderedDict -from copy import deepcopy +from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from nats_bench import create from log_utils import time_string -def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): - ss_dir = '{:}-{:}'.format(root_dir, search_space) - alg2name, alg2path = OrderedDict(), OrderedDict() - alg2name['REA'] = 'R-EA-SS3' - alg2name['REINFORCE'] = 'REINFORCE-0.01' - alg2name['RANDOM'] = 'RANDOM' - alg2name['BOHB'] = 'BOHB' - for alg, name in alg2name.items(): - alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth') - assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg]) - alg2data = OrderedDict() - for alg, path in alg2path.items(): - data = torch.load(path) - for index, info in data.items(): - info['time_w_arch'] = [(x, y) for x, y in zip(info['all_total_times'], info['all_archs'])] - for j, arch in enumerate(info['all_archs']): - assert arch != -1, 'invalid arch from {:} {:} {:} ({:}, {:})'.format(alg, search_space, dataset, index, j) - alg2data[alg] = data - return alg2data +def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): + ss_dir = "{:}-{:}".format(root_dir, search_space) + alg2name, alg2path = OrderedDict(), OrderedDict() + alg2name["REA"] = "R-EA-SS3" + alg2name["REINFORCE"] = "REINFORCE-0.01" + alg2name["RANDOM"] = "RANDOM" + alg2name["BOHB"] = "BOHB" + for alg, name in alg2name.items(): + alg2path[alg] = os.path.join(ss_dir, dataset, name, "results.pth") + assert os.path.isfile(alg2path[alg]), "invalid path : {:}".format(alg2path[alg]) + alg2data = OrderedDict() + for alg, path in alg2path.items(): + data = torch.load(path) + for index, info in data.items(): + info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])] + for j, arch in enumerate(info["all_archs"]): + assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( + alg, search_space, dataset, index, j + ) + alg2data[alg] = data + return alg2data def query_performance(api, data, dataset, ticket): - results, is_size_space = [], api.search_space_name == 'size' - for i, info in data.items(): - time_w_arch = sorted(info['time_w_arch'], key=lambda x: abs(x[0]-ticket)) - time_a, arch_a = time_w_arch[0] - time_b, arch_b = time_w_arch[1] - info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - accuracy_a, accuracy_b = info_a['test-accuracy'], info_b['test-accuracy'] - interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b - results.append(interplate) - # return sum(results) / len(results) - return np.mean(results), np.std(results) + results, is_size_space = [], api.search_space_name == "size" + for i, info in data.items(): + time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket)) + time_a, arch_a = time_w_arch[0] + time_b, arch_b = time_w_arch[1] + info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"] + interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (ticket - time_a) / ( + time_b - time_a + ) * accuracy_b + results.append(interplate) + # return sum(results) / len(results) + return np.mean(results), np.std(results) def show_valid_test(api, data, dataset): - valid_accs, test_accs, is_size_space = [], [], api.search_space_name == 'size' - for i, info in data.items(): - time, arch = info['time_w_arch'][-1] - if dataset == 'cifar10': - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - test_accs.append(xinfo['test-accuracy']) - xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False) - valid_accs.append(xinfo['valid-accuracy']) - else: - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - valid_accs.append(xinfo['valid-accuracy']) - test_accs.append(xinfo['test-accuracy']) - valid_str = '{:.2f}$\pm${:.2f}'.format(np.mean(valid_accs), np.std(valid_accs)) - test_str = '{:.2f}$\pm${:.2f}'.format(np.mean(test_accs), np.std(test_accs)) - return valid_str, test_str + valid_accs, test_accs, is_size_space = [], [], api.search_space_name == "size" + for i, info in data.items(): + time, arch = info["time_w_arch"][-1] + if dataset == "cifar10": + xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + test_accs.append(xinfo["test-accuracy"]) + xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False) + valid_accs.append(xinfo["valid-accuracy"]) + else: + xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + valid_accs.append(xinfo["valid-accuracy"]) + test_accs.append(xinfo["test-accuracy"]) + valid_str = "{:.2f}$\pm${:.2f}".format(np.mean(valid_accs), np.std(valid_accs)) + test_str = "{:.2f}$\pm${:.2f}".format(np.mean(test_accs), np.std(test_accs)) + return valid_str, test_str -y_min_s = {('cifar10', 'tss'): 90, - ('cifar10', 'sss'): 92, - ('cifar100', 'tss'): 65, - ('cifar100', 'sss'): 65, - ('ImageNet16-120', 'tss'): 36, - ('ImageNet16-120', 'sss'): 40} +y_min_s = { + ("cifar10", "tss"): 90, + ("cifar10", "sss"): 92, + ("cifar100", "tss"): 65, + ("cifar100", "sss"): 65, + ("ImageNet16-120", "tss"): 36, + ("ImageNet16-120", "sss"): 40, +} -y_max_s = {('cifar10', 'tss'): 94.3, - ('cifar10', 'sss'): 93.3, - ('cifar100', 'tss'): 72.5, - ('cifar100', 'sss'): 70.5, - ('ImageNet16-120', 'tss'): 46, - ('ImageNet16-120', 'sss'): 46} +y_max_s = { + ("cifar10", "tss"): 94.3, + ("cifar10", "sss"): 93.3, + ("cifar100", "tss"): 72.5, + ("cifar100", "sss"): 70.5, + ("ImageNet16-120", "tss"): 46, + ("ImageNet16-120", "sss"): 46, +} -x_axis_s = {('cifar10', 'tss'): 200, - ('cifar10', 'sss'): 200, - ('cifar100', 'tss'): 400, - ('cifar100', 'sss'): 400, - ('ImageNet16-120', 'tss'): 1200, - ('ImageNet16-120', 'sss'): 600} +x_axis_s = { + ("cifar10", "tss"): 200, + ("cifar10", "sss"): 200, + ("cifar100", "tss"): 400, + ("cifar100", "sss"): 400, + ("ImageNet16-120", "tss"): 1200, + ("ImageNet16-120", "sss"): 600, +} -name2label = {'cifar10': 'CIFAR-10', - 'cifar100': 'CIFAR-100', - 'ImageNet16-120': 'ImageNet-16-120'} +name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} def visualize_curve(api, vis_save_dir, search_space): - vis_save_dir = vis_save_dir.resolve() - vis_save_dir.mkdir(parents=True, exist_ok=True) + vis_save_dir = vis_save_dir.resolve() + vis_save_dir.mkdir(parents=True, exist_ok=True) - dpi, width, height = 250, 5200, 1400 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 16, 16 + dpi, width, height = 250, 5200, 1400 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 16, 16 - def sub_plot_fn(ax, dataset): - xdataset, max_time = dataset.split('-T') - alg2data = fetch_data(search_space=search_space, dataset=dataset) - alg2accuracies = OrderedDict() - total_tickets = 150 - time_tickets = [float(i) / total_tickets * int(max_time) for i in range(total_tickets)] - colors = ['b', 'g', 'c', 'm', 'y'] - ax.set_xlim(0, x_axis_s[(xdataset, search_space)]) - ax.set_ylim(y_min_s[(xdataset, search_space)], - y_max_s[(xdataset, search_space)]) - for idx, (alg, data) in enumerate(alg2data.items()): - accuracies = [] - for ticket in time_tickets: - accuracy, accuracy_std = query_performance(api, data, xdataset, ticket) - accuracies.append(accuracy) - valid_str, test_str = show_valid_test(api, data, xdataset) - # print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std)) - print('{:} plot alg : {:10s} | validation = {:} | test = {:}'.format(time_string(), alg, valid_str, test_str)) - alg2accuracies[alg] = accuracies - ax.plot([x/100 for x in time_tickets], accuracies, c=colors[idx], label='{:}'.format(alg)) - ax.set_xlabel('Estimated wall-clock time (1e2 seconds)', fontsize=LabelSize) - ax.set_ylabel('Test accuracy on {:}'.format(name2label[xdataset]), fontsize=LabelSize) - ax.set_title('Searching results on {:}'.format(name2label[xdataset]), fontsize=LabelSize+4) - ax.legend(loc=4, fontsize=LegendFontsize) + def sub_plot_fn(ax, dataset): + xdataset, max_time = dataset.split("-T") + alg2data = fetch_data(search_space=search_space, dataset=dataset) + alg2accuracies = OrderedDict() + total_tickets = 150 + time_tickets = [float(i) / total_tickets * int(max_time) for i in range(total_tickets)] + colors = ["b", "g", "c", "m", "y"] + ax.set_xlim(0, x_axis_s[(xdataset, search_space)]) + ax.set_ylim(y_min_s[(xdataset, search_space)], y_max_s[(xdataset, search_space)]) + for idx, (alg, data) in enumerate(alg2data.items()): + accuracies = [] + for ticket in time_tickets: + accuracy, accuracy_std = query_performance(api, data, xdataset, ticket) + accuracies.append(accuracy) + valid_str, test_str = show_valid_test(api, data, xdataset) + # print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std)) + print( + "{:} plot alg : {:10s} | validation = {:} | test = {:}".format(time_string(), alg, valid_str, test_str) + ) + alg2accuracies[alg] = accuracies + ax.plot([x / 100 for x in time_tickets], accuracies, c=colors[idx], label="{:}".format(alg)) + ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize) + ax.set_ylabel("Test accuracy on {:}".format(name2label[xdataset]), fontsize=LabelSize) + ax.set_title("Searching results on {:}".format(name2label[xdataset]), fontsize=LabelSize + 4) + ax.legend(loc=4, fontsize=LegendFontsize) - fig, axs = plt.subplots(1, 3, figsize=figsize) - # datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] - if search_space == 'tss': - datasets = ['cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T120000'] - elif search_space == 'sss': - datasets = ['cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T60000'] - else: - raise ValueError('Unknown search space: {:}'.format(search_space)) - for dataset, ax in zip(datasets, axs): - sub_plot_fn(ax, dataset) - print('sub-plot {:} on {:} done.'.format(dataset, search_space)) - save_path = (vis_save_dir / '{:}-curve.png'.format(search_space)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + fig, axs = plt.subplots(1, 3, figsize=figsize) + # datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] + if search_space == "tss": + datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T120000"] + elif search_space == "sss": + datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T60000"] + else: + raise ValueError("Unknown search space: {:}".format(search_space)) + for dataset, ax in zip(datasets, axs): + sub_plot_fn(ax, dataset) + print("sub-plot {:} on {:} done.".format(dataset, search_space)) + save_path = (vis_save_dir / "{:}-curve.png".format(search_space)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') - parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log." + ) + parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") + args = parser.parse_args() - save_dir = Path(args.save_dir) + save_dir = Path(args.save_dir) - api = create(None, args.search_space, fast_mode=True, verbose=False) - visualize_curve(api, save_dir, args.search_space) + api = create(None, args.search_space, fast_mode=True, verbose=False) + visualize_curve(api, save_dir, args.search_space) diff --git a/exps/NATS-Bench/draw-fig7.py b/exps/NATS-Bench/draw-fig7.py index 26a66aa..77c01fe 100644 --- a/exps/NATS-Bench/draw-fig7.py +++ b/exps/NATS-Bench/draw-fig7.py @@ -11,170 +11,179 @@ import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict, OrderedDict -from copy import deepcopy +from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from nats_bench import create from log_utils import time_string def get_valid_test_acc(api, arch, dataset): - is_size_space = api.search_space_name == 'size' - if dataset == 'cifar10': - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - test_acc = xinfo['test-accuracy'] - xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False) - valid_acc = xinfo['valid-accuracy'] - else: - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - valid_acc = xinfo['valid-accuracy'] - test_acc = xinfo['test-accuracy'] - return valid_acc, test_acc, 'validation = {:.2f}, test = {:.2f}\n'.format(valid_acc, test_acc) + is_size_space = api.search_space_name == "size" + if dataset == "cifar10": + xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + test_acc = xinfo["test-accuracy"] + xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False) + valid_acc = xinfo["valid-accuracy"] + else: + xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + valid_acc = xinfo["valid-accuracy"] + test_acc = xinfo["test-accuracy"] + return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc) -def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARM0.3'): - ss_dir = '{:}-{:}'.format(root_dir, search_space) - alg2name, alg2path = OrderedDict(), OrderedDict() - seeds = [777, 888, 999] - print('\n[fetch data] from {:} on {:}'.format(search_space, dataset)) - if search_space == 'tss': - alg2name['GDAS'] = 'gdas-affine0_BN0-None' - alg2name['RSPS'] = 'random-affine0_BN0-None' - alg2name['DARTS (1st)'] = 'darts-v1-affine0_BN0-None' - alg2name['DARTS (2nd)'] = 'darts-v2-affine0_BN0-None' - alg2name['ENAS'] = 'enas-affine0_BN0-None' - alg2name['SETN'] = 'setn-affine0_BN0-None' - else: - alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix) - alg2name['masking + Gumbel-Softmax'] = 'mask_gumbel-affine0_BN0-AWD0.001{:}'.format(suffix) - alg2name['masking + sampling'] = 'mask_rl-affine0_BN0-AWD0.0{:}'.format(suffix) - for alg, name in alg2name.items(): - alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') - alg2data = OrderedDict() - for alg, path in alg2path.items(): - alg2data[alg], ok_num = [], 0 - for seed in seeds: - xpath = path.format(seed) - if os.path.isfile(xpath): - ok_num += 1 - else: - print('This is an invalid path : {:}'.format(xpath)) - continue - data = torch.load(xpath, map_location=torch.device('cpu')) - try: - data = torch.load(data['last_checkpoint'], map_location=torch.device('cpu')) - except: - xpath = str(data['last_checkpoint']).split('E100-') - if len(xpath) == 2 and os.path.isfile(xpath[0] + xpath[1]): - xpath = xpath[0] + xpath[1] - elif 'fbv2' in str(data['last_checkpoint']): - xpath = str(data['last_checkpoint']).replace('fbv2', 'mask_gumbel') - elif 'tunas' in str(data['last_checkpoint']): - xpath = str(data['last_checkpoint']).replace('tunas', 'mask_rl') - else: - raise ValueError('Invalid path: {:}'.format(data['last_checkpoint'])) - data = torch.load(xpath, map_location=torch.device('cpu')) - alg2data[alg].append(data['genotypes']) - print('This algorithm : {:} has {:} valid ckps.'.format(alg, ok_num)) - assert ok_num > 0, 'Must have at least 1 valid ckps.' - return alg2data +def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"): + ss_dir = "{:}-{:}".format(root_dir, search_space) + alg2name, alg2path = OrderedDict(), OrderedDict() + seeds = [777, 888, 999] + print("\n[fetch data] from {:} on {:}".format(search_space, dataset)) + if search_space == "tss": + alg2name["GDAS"] = "gdas-affine0_BN0-None" + alg2name["RSPS"] = "random-affine0_BN0-None" + alg2name["DARTS (1st)"] = "darts-v1-affine0_BN0-None" + alg2name["DARTS (2nd)"] = "darts-v2-affine0_BN0-None" + alg2name["ENAS"] = "enas-affine0_BN0-None" + alg2name["SETN"] = "setn-affine0_BN0-None" + else: + alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(suffix) + alg2name["masking + Gumbel-Softmax"] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix) + alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix) + for alg, name in alg2name.items(): + alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth") + alg2data = OrderedDict() + for alg, path in alg2path.items(): + alg2data[alg], ok_num = [], 0 + for seed in seeds: + xpath = path.format(seed) + if os.path.isfile(xpath): + ok_num += 1 + else: + print("This is an invalid path : {:}".format(xpath)) + continue + data = torch.load(xpath, map_location=torch.device("cpu")) + try: + data = torch.load(data["last_checkpoint"], map_location=torch.device("cpu")) + except: + xpath = str(data["last_checkpoint"]).split("E100-") + if len(xpath) == 2 and os.path.isfile(xpath[0] + xpath[1]): + xpath = xpath[0] + xpath[1] + elif "fbv2" in str(data["last_checkpoint"]): + xpath = str(data["last_checkpoint"]).replace("fbv2", "mask_gumbel") + elif "tunas" in str(data["last_checkpoint"]): + xpath = str(data["last_checkpoint"]).replace("tunas", "mask_rl") + else: + raise ValueError("Invalid path: {:}".format(data["last_checkpoint"])) + data = torch.load(xpath, map_location=torch.device("cpu")) + alg2data[alg].append(data["genotypes"]) + print("This algorithm : {:} has {:} valid ckps.".format(alg, ok_num)) + assert ok_num > 0, "Must have at least 1 valid ckps." + return alg2data -y_min_s = {('cifar10', 'tss'): 90, - ('cifar10', 'sss'): 92, - ('cifar100', 'tss'): 65, - ('cifar100', 'sss'): 65, - ('ImageNet16-120', 'tss'): 36, - ('ImageNet16-120', 'sss'): 40} +y_min_s = { + ("cifar10", "tss"): 90, + ("cifar10", "sss"): 92, + ("cifar100", "tss"): 65, + ("cifar100", "sss"): 65, + ("ImageNet16-120", "tss"): 36, + ("ImageNet16-120", "sss"): 40, +} -y_max_s = {('cifar10', 'tss'): 94.5, - ('cifar10', 'sss'): 93.3, - ('cifar100', 'tss'): 72, - ('cifar100', 'sss'): 70, - ('ImageNet16-120', 'tss'): 44, - ('ImageNet16-120', 'sss'): 46} +y_max_s = { + ("cifar10", "tss"): 94.5, + ("cifar10", "sss"): 93.3, + ("cifar100", "tss"): 72, + ("cifar100", "sss"): 70, + ("ImageNet16-120", "tss"): 44, + ("ImageNet16-120", "sss"): 46, +} -name2label = {'cifar10': 'CIFAR-10', - 'cifar100': 'CIFAR-100', - 'ImageNet16-120': 'ImageNet-16-120'} +name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} + +name2suffix = {("sss", "warm"): "-WARM0.3", ("sss", "none"): "-WARMNone", ("tss", "none"): None, ("tss", None): None} -name2suffix = {('sss', 'warm'): '-WARM0.3', - ('sss', 'none'): '-WARMNone', - ('tss', 'none') : None, - ('tss', None) : None} def visualize_curve(api, vis_save_dir, search_space, suffix): - vis_save_dir = vis_save_dir.resolve() - vis_save_dir.mkdir(parents=True, exist_ok=True) + vis_save_dir = vis_save_dir.resolve() + vis_save_dir.mkdir(parents=True, exist_ok=True) - dpi, width, height = 250, 5200, 1400 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 16, 16 + dpi, width, height = 250, 5200, 1400 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 16, 16 - def sub_plot_fn(ax, dataset): - print('{:} plot {:10s}'.format(time_string(), dataset)) - alg2data = fetch_data(search_space=search_space, dataset=dataset, suffix=name2suffix[(search_space, suffix)]) - alg2accuracies = OrderedDict() - epochs = 100 - colors = ['b', 'g', 'c', 'm', 'y', 'r'] - ax.set_xlim(0, epochs) - # ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) - for idx, (alg, data) in enumerate(alg2data.items()): - xs, accuracies = [], [] - for iepoch in range(epochs + 1): - try: - structures, accs = [_[iepoch-1] for _ in data], [] - except: - raise ValueError('This alg {:} on {:} has invalid checkpoints.'.format(alg, dataset)) - for structure in structures: - info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False) - accs.append(info['test-accuracy']) - accuracies.append(sum(accs)/len(accs)) - xs.append(iepoch) - alg2accuracies[alg] = accuracies - ax.plot(xs, accuracies, c=colors[idx], label='{:}'.format(alg)) - ax.set_xlabel('The searching epoch', fontsize=LabelSize) - ax.set_ylabel('Test accuracy on {:}'.format(name2label[dataset]), fontsize=LabelSize) - ax.set_title('Searching results on {:}'.format(name2label[dataset]), fontsize=LabelSize+4) - structures, valid_accs, test_accs = [_[epochs-1] for _ in data], [], [] - print('{:} plot alg : {:} -- final {:} architectures.'.format(time_string(), alg, len(structures))) - for arch in structures: - valid_acc, test_acc, _ = get_valid_test_acc(api, arch, dataset) - test_accs.append(test_acc) - valid_accs.append(valid_acc) - print('{:} plot alg : {:} -- validation: {:.2f}$\pm${:.2f} -- test: {:.2f}$\pm${:.2f}'.format( - time_string(), alg, np.mean(valid_accs), np.std(valid_accs), np.mean(test_accs), np.std(test_accs))) - ax.legend(loc=4, fontsize=LegendFontsize) + def sub_plot_fn(ax, dataset): + print("{:} plot {:10s}".format(time_string(), dataset)) + alg2data = fetch_data(search_space=search_space, dataset=dataset, suffix=name2suffix[(search_space, suffix)]) + alg2accuracies = OrderedDict() + epochs = 100 + colors = ["b", "g", "c", "m", "y", "r"] + ax.set_xlim(0, epochs) + # ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) + for idx, (alg, data) in enumerate(alg2data.items()): + xs, accuracies = [], [] + for iepoch in range(epochs + 1): + try: + structures, accs = [_[iepoch - 1] for _ in data], [] + except: + raise ValueError("This alg {:} on {:} has invalid checkpoints.".format(alg, dataset)) + for structure in structures: + info = api.get_more_info( + structure, dataset=dataset, hp=90 if api.search_space_name == "size" else 200, is_random=False + ) + accs.append(info["test-accuracy"]) + accuracies.append(sum(accs) / len(accs)) + xs.append(iepoch) + alg2accuracies[alg] = accuracies + ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg)) + ax.set_xlabel("The searching epoch", fontsize=LabelSize) + ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize) + ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4) + structures, valid_accs, test_accs = [_[epochs - 1] for _ in data], [], [] + print("{:} plot alg : {:} -- final {:} architectures.".format(time_string(), alg, len(structures))) + for arch in structures: + valid_acc, test_acc, _ = get_valid_test_acc(api, arch, dataset) + test_accs.append(test_acc) + valid_accs.append(valid_acc) + print( + "{:} plot alg : {:} -- validation: {:.2f}$\pm${:.2f} -- test: {:.2f}$\pm${:.2f}".format( + time_string(), alg, np.mean(valid_accs), np.std(valid_accs), np.mean(test_accs), np.std(test_accs) + ) + ) + ax.legend(loc=4, fontsize=LegendFontsize) - fig, axs = plt.subplots(1, 3, figsize=figsize) - datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] - for dataset, ax in zip(datasets, axs): - sub_plot_fn(ax, dataset) - print('sub-plot {:} on {:} done.'.format(dataset, search_space)) - save_path = (vis_save_dir / '{:}-ws-{:}-curve.png'.format(search_space, suffix)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + fig, axs = plt.subplots(1, 3, figsize=figsize) + datasets = ["cifar10", "cifar100", "ImageNet16-120"] + for dataset, ax in zip(datasets, axs): + sub_plot_fn(ax, dataset) + print("sub-plot {:} on {:} done.".format(dataset, search_space)) + save_path = (vis_save_dir / "{:}-ws-{:}-curve.png".format(search_space, suffix)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log." + ) + args = parser.parse_args() - save_dir = Path(args.save_dir) + save_dir = Path(args.save_dir) - api_tss = create(None, 'tss', fast_mode=True, verbose=False) - visualize_curve(api_tss, save_dir, 'tss', None) + api_tss = create(None, "tss", fast_mode=True, verbose=False) + visualize_curve(api_tss, save_dir, "tss", None) - api_sss = create(None, 'sss', fast_mode=True, verbose=False) - visualize_curve(api_sss, save_dir, 'sss', 'warm') - visualize_curve(api_sss, save_dir, 'sss', 'none') + api_sss = create(None, "sss", fast_mode=True, verbose=False) + visualize_curve(api_sss, save_dir, "sss", "warm") + visualize_curve(api_sss, save_dir, "sss", "none") diff --git a/exps/NATS-Bench/draw-fig8.py b/exps/NATS-Bench/draw-fig8.py index e650ed7..bc81e19 100644 --- a/exps/NATS-Bench/draw-fig8.py +++ b/exps/NATS-Bench/draw-fig8.py @@ -11,176 +11,194 @@ import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict, OrderedDict -from copy import deepcopy +from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from nats_bench import create from log_utils import time_string -plt.rcParams.update({ - "text.usetex": True, - "font.family": "sans-serif", - "font.sans-serif": ["Helvetica"]}) +plt.rcParams.update({"text.usetex": True, "font.family": "sans-serif", "font.sans-serif": ["Helvetica"]}) ## for Palatino and other serif fonts use: -plt.rcParams.update({ - "text.usetex": True, - "font.family": "serif", - "font.serif": ["Palatino"], -}) +plt.rcParams.update( + { + "text.usetex": True, + "font.family": "serif", + "font.serif": ["Palatino"], + } +) -def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): - ss_dir = '{:}-{:}'.format(root_dir, search_space) - alg2all = OrderedDict() - # alg2name['REINFORCE'] = 'REINFORCE-0.01' - # alg2name['RANDOM'] = 'RANDOM' - # alg2name['BOHB'] = 'BOHB' - if search_space == 'tss': - hp = '$\mathcal{H}^{1}$' - if dataset == 'cifar10': - suffixes = ['-T1200000', '-T1200000-FULL'] - elif search_space == 'sss': - hp = '$\mathcal{H}^{2}$' - if dataset == 'cifar10': - suffixes = ['-T200000', '-T200000-FULL'] - else: - raise ValueError('Unkonwn search space: {:}'.format(search_space)) +def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): + ss_dir = "{:}-{:}".format(root_dir, search_space) + alg2all = OrderedDict() + # alg2name['REINFORCE'] = 'REINFORCE-0.01' + # alg2name['RANDOM'] = 'RANDOM' + # alg2name['BOHB'] = 'BOHB' + if search_space == "tss": + hp = "$\mathcal{H}^{1}$" + if dataset == "cifar10": + suffixes = ["-T1200000", "-T1200000-FULL"] + elif search_space == "sss": + hp = "$\mathcal{H}^{2}$" + if dataset == "cifar10": + suffixes = ["-T200000", "-T200000-FULL"] + else: + raise ValueError("Unkonwn search space: {:}".format(search_space)) - alg2all[r'REA ($\mathcal{H}^{0}$)'] = dict( - path=os.path.join(ss_dir, dataset + suffixes[0], 'R-EA-SS3', 'results.pth'), - color='b', linestyle='-') - alg2all[r'REA ({:})'.format(hp)] = dict( - path=os.path.join(ss_dir, dataset + suffixes[1], 'R-EA-SS3', 'results.pth'), - color='b', linestyle='--') + alg2all[r"REA ($\mathcal{H}^{0}$)"] = dict( + path=os.path.join(ss_dir, dataset + suffixes[0], "R-EA-SS3", "results.pth"), color="b", linestyle="-" + ) + alg2all[r"REA ({:})".format(hp)] = dict( + path=os.path.join(ss_dir, dataset + suffixes[1], "R-EA-SS3", "results.pth"), color="b", linestyle="--" + ) - for alg, xdata in alg2all.items(): - data = torch.load(xdata['path']) - for index, info in data.items(): - info['time_w_arch'] = [(x, y) for x, y in zip(info['all_total_times'], info['all_archs'])] - for j, arch in enumerate(info['all_archs']): - assert arch != -1, 'invalid arch from {:} {:} {:} ({:}, {:})'.format(alg, search_space, dataset, index, j) - xdata['data'] = data - return alg2all + for alg, xdata in alg2all.items(): + data = torch.load(xdata["path"]) + for index, info in data.items(): + info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])] + for j, arch in enumerate(info["all_archs"]): + assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( + alg, search_space, dataset, index, j + ) + xdata["data"] = data + return alg2all def query_performance(api, data, dataset, ticket): - results, is_size_space = [], api.search_space_name == 'size' - for i, info in data.items(): - time_w_arch = sorted(info['time_w_arch'], key=lambda x: abs(x[0]-ticket)) - time_a, arch_a = time_w_arch[0] - time_b, arch_b = time_w_arch[1] - info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - accuracy_a, accuracy_b = info_a['test-accuracy'], info_b['test-accuracy'] - interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b - results.append(interplate) - # return sum(results) / len(results) - return np.mean(results), np.std(results) + results, is_size_space = [], api.search_space_name == "size" + for i, info in data.items(): + time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket)) + time_a, arch_a = time_w_arch[0] + time_b, arch_b = time_w_arch[1] + info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"] + interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (ticket - time_a) / ( + time_b - time_a + ) * accuracy_b + results.append(interplate) + # return sum(results) / len(results) + return np.mean(results), np.std(results) -y_min_s = {('cifar10', 'tss'): 91, - ('cifar10', 'sss'): 91, - ('cifar100', 'tss'): 65, - ('cifar100', 'sss'): 65, - ('ImageNet16-120', 'tss'): 36, - ('ImageNet16-120', 'sss'): 40} +y_min_s = { + ("cifar10", "tss"): 91, + ("cifar10", "sss"): 91, + ("cifar100", "tss"): 65, + ("cifar100", "sss"): 65, + ("ImageNet16-120", "tss"): 36, + ("ImageNet16-120", "sss"): 40, +} -y_max_s = {('cifar10', 'tss'): 94.5, - ('cifar10', 'sss'): 93.5, - ('cifar100', 'tss'): 72.5, - ('cifar100', 'sss'): 70.5, - ('ImageNet16-120', 'tss'): 46, - ('ImageNet16-120', 'sss'): 46} +y_max_s = { + ("cifar10", "tss"): 94.5, + ("cifar10", "sss"): 93.5, + ("cifar100", "tss"): 72.5, + ("cifar100", "sss"): 70.5, + ("ImageNet16-120", "tss"): 46, + ("ImageNet16-120", "sss"): 46, +} -x_axis_s = {('cifar10', 'tss'): 1200000, - ('cifar10', 'sss'): 200000, - ('cifar100', 'tss'): 400, - ('cifar100', 'sss'): 400, - ('ImageNet16-120', 'tss'): 1200, - ('ImageNet16-120', 'sss'): 600} +x_axis_s = { + ("cifar10", "tss"): 1200000, + ("cifar10", "sss"): 200000, + ("cifar100", "tss"): 400, + ("cifar100", "sss"): 400, + ("ImageNet16-120", "tss"): 1200, + ("ImageNet16-120", "sss"): 600, +} -name2label = {'cifar10': 'CIFAR-10', - 'cifar100': 'CIFAR-100', - 'ImageNet16-120': 'ImageNet-16-120'} +name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} -spaces2latex = {'tss': r'$\mathcal{S}_{t}$', - 'sss': r'$\mathcal{S}_{s}$',} +spaces2latex = { + "tss": r"$\mathcal{S}_{t}$", + "sss": r"$\mathcal{S}_{s}$", +} # FuncFormatter can be used as a decorator @ticker.FuncFormatter def major_formatter(x, pos): - if x == 0: - return '0' - else: - return "{:.2f}e5".format(x/1e5) + if x == 0: + return "0" + else: + return "{:.2f}e5".format(x / 1e5) def visualize_curve(api_dict, vis_save_dir): - vis_save_dir = vis_save_dir.resolve() - vis_save_dir.mkdir(parents=True, exist_ok=True) + vis_save_dir = vis_save_dir.resolve() + vis_save_dir.mkdir(parents=True, exist_ok=True) - dpi, width, height = 250, 5000, 2000 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 28, 28 + dpi, width, height = 250, 5000, 2000 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 28, 28 - def sub_plot_fn(ax, search_space, dataset): - max_time = x_axis_s[(dataset, search_space)] - alg2data = fetch_data(search_space=search_space, dataset=dataset) - alg2accuracies = OrderedDict() - total_tickets = 200 - time_tickets = [float(i) / total_tickets * int(max_time) for i in range(total_tickets)] - ax.set_xlim(0, x_axis_s[(dataset, search_space)]) - ax.set_ylim(y_min_s[(dataset, search_space)], - y_max_s[(dataset, search_space)]) - for tick in ax.get_xticklabels(): - tick.set_rotation(25) - tick.set_fontsize(LabelSize - 6) - for tick in ax.get_yticklabels(): - tick.set_fontsize(LabelSize - 6) - ax.xaxis.set_major_formatter(major_formatter) - for idx, (alg, xdata) in enumerate(alg2data.items()): - accuracies = [] - for ticket in time_tickets: - # import pdb; pdb.set_trace() - accuracy, accuracy_std = query_performance( - api_dict[search_space], xdata['data'], dataset, ticket) - accuracies.append(accuracy) - # print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std)) - print('{:} plot alg : {:10s} on {:}'.format(time_string(), alg, search_space)) - alg2accuracies[alg] = accuracies - ax.plot(time_tickets, accuracies, c=xdata['color'], linestyle=xdata['linestyle'], label='{:}'.format(alg)) - ax.set_xlabel('Estimated wall-clock time', fontsize=LabelSize) - ax.set_ylabel('Test accuracy', fontsize=LabelSize) - ax.set_title(r'Results on {:} over {:}'.format(name2label[dataset], spaces2latex[search_space]), - fontsize=LabelSize) - ax.legend(loc=4, fontsize=LegendFontsize) + def sub_plot_fn(ax, search_space, dataset): + max_time = x_axis_s[(dataset, search_space)] + alg2data = fetch_data(search_space=search_space, dataset=dataset) + alg2accuracies = OrderedDict() + total_tickets = 200 + time_tickets = [float(i) / total_tickets * int(max_time) for i in range(total_tickets)] + ax.set_xlim(0, x_axis_s[(dataset, search_space)]) + ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) + for tick in ax.get_xticklabels(): + tick.set_rotation(25) + tick.set_fontsize(LabelSize - 6) + for tick in ax.get_yticklabels(): + tick.set_fontsize(LabelSize - 6) + ax.xaxis.set_major_formatter(major_formatter) + for idx, (alg, xdata) in enumerate(alg2data.items()): + accuracies = [] + for ticket in time_tickets: + # import pdb; pdb.set_trace() + accuracy, accuracy_std = query_performance(api_dict[search_space], xdata["data"], dataset, ticket) + accuracies.append(accuracy) + # print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std)) + print("{:} plot alg : {:10s} on {:}".format(time_string(), alg, search_space)) + alg2accuracies[alg] = accuracies + ax.plot(time_tickets, accuracies, c=xdata["color"], linestyle=xdata["linestyle"], label="{:}".format(alg)) + ax.set_xlabel("Estimated wall-clock time", fontsize=LabelSize) + ax.set_ylabel("Test accuracy", fontsize=LabelSize) + ax.set_title( + r"Results on {:} over {:}".format(name2label[dataset], spaces2latex[search_space]), fontsize=LabelSize + ) + ax.legend(loc=4, fontsize=LegendFontsize) - fig, axs = plt.subplots(1, 2, figsize=figsize) - sub_plot_fn(axs[0], 'tss', 'cifar10') - sub_plot_fn(axs[1], 'sss', 'cifar10') - save_path = (vis_save_dir / 'full-curve.png').resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + fig, axs = plt.subplots(1, 2, figsize=figsize) + sub_plot_fn(axs[0], "tss", "cifar10") + sub_plot_fn(axs[1], "sss", "cifar10") + save_path = (vis_save_dir / "full-curve.png").resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--save_dir", + type=str, + default="output/vis-nas-bench/nas-algos-vs-h", + help="Folder to save checkpoints and log.", + ) + args = parser.parse_args() - save_dir = Path(args.save_dir) + save_dir = Path(args.save_dir) - api_tss = create(None, 'tss', fast_mode=True, verbose=False) - api_sss = create(None, 'sss', fast_mode=True, verbose=False) - visualize_curve(dict(tss=api_tss, sss=api_sss), save_dir) + api_tss = create(None, "tss", fast_mode=True, verbose=False) + api_sss = create(None, "sss", fast_mode=True, verbose=False) + visualize_curve(dict(tss=api_tss, sss=api_sss), save_dir) diff --git a/exps/NATS-Bench/draw-ranks.py b/exps/NATS-Bench/draw-ranks.py index ad69f4b..d3a8c02 100644 --- a/exps/NATS-Bench/draw-ranks.py +++ b/exps/NATS-Bench/draw-ranks.py @@ -12,119 +12,132 @@ import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict, OrderedDict -from copy import deepcopy +from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from log_utils import time_string from models import get_cell_based_tiny_net from nats_bench import create -name2label = {'cifar10': 'CIFAR-10', - 'cifar100': 'CIFAR-100', - 'ImageNet16-120': 'ImageNet-16-120'} +name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} def visualize_relative_info(vis_save_dir, search_space, indicator, topk): - vis_save_dir = vis_save_dir.resolve() - print ('{:} start to visualize {:} with top-{:} information'.format(time_string(), search_space, topk)) - vis_save_dir.mkdir(parents=True, exist_ok=True) - cache_file_path = vis_save_dir / 'cache-{:}-info.pth'.format(search_space) - datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] - if not cache_file_path.exists(): - api = create(None, search_space, fast_mode=False, verbose=False) - all_infos = OrderedDict() - for index in range(len(api)): - all_info = OrderedDict() - for dataset in datasets: - info_less = api.get_more_info(index, dataset, hp='12', is_random=False) - info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False) - all_info[dataset] = dict(less=info_less['test-accuracy'], - more=info_more['test-accuracy']) - all_infos[index] = all_info - torch.save(all_infos, cache_file_path) - print ('{:} save all cache data into {:}'.format(time_string(), cache_file_path)) - else: - api = create(None, search_space, fast_mode=True, verbose=False) - all_infos = torch.load(cache_file_path) + vis_save_dir = vis_save_dir.resolve() + print("{:} start to visualize {:} with top-{:} information".format(time_string(), search_space, topk)) + vis_save_dir.mkdir(parents=True, exist_ok=True) + cache_file_path = vis_save_dir / "cache-{:}-info.pth".format(search_space) + datasets = ["cifar10", "cifar100", "ImageNet16-120"] + if not cache_file_path.exists(): + api = create(None, search_space, fast_mode=False, verbose=False) + all_infos = OrderedDict() + for index in range(len(api)): + all_info = OrderedDict() + for dataset in datasets: + info_less = api.get_more_info(index, dataset, hp="12", is_random=False) + info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False) + all_info[dataset] = dict(less=info_less["test-accuracy"], more=info_more["test-accuracy"]) + all_infos[index] = all_info + torch.save(all_infos, cache_file_path) + print("{:} save all cache data into {:}".format(time_string(), cache_file_path)) + else: + api = create(None, search_space, fast_mode=True, verbose=False) + all_infos = torch.load(cache_file_path) + + dpi, width, height = 250, 5000, 1300 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 16, 16 + + fig, axs = plt.subplots(1, 3, figsize=figsize) + datasets = ["cifar10", "cifar100", "ImageNet16-120"] + + def sub_plot_fn(ax, dataset, indicator): + performances = [] + # pickup top 10% architectures + for _index in range(len(api)): + performances.append((all_infos[_index][dataset][indicator], _index)) + performances = sorted(performances, reverse=True) + performances = performances[: int(len(api) * topk * 0.01)] + selected_indexes = [x[1] for x in performances] + print( + "{:} plot {:10s} with {:}, {:} architectures".format( + time_string(), dataset, indicator, len(selected_indexes) + ) + ) + standard_scores = [] + random_scores = [] + for idx in selected_indexes: + standard_scores.append( + api.get_more_info( + idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=False + )["test-accuracy"] + ) + random_scores.append( + api.get_more_info( + idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=True + )["test-accuracy"] + ) + indexes = list(range(len(selected_indexes))) + standard_indexes = sorted(indexes, key=lambda i: standard_scores[i]) + random_indexes = sorted(indexes, key=lambda i: random_scores[i]) + random_labels = [] + for idx in standard_indexes: + random_labels.append(random_indexes.index(idx)) + for tick in ax.get_xticklabels(): + tick.set_fontsize(LabelSize - 3) + for tick in ax.get_yticklabels(): + tick.set_rotation(25) + tick.set_fontsize(LabelSize - 3) + ax.set_xlim(0, len(indexes)) + ax.set_ylim(0, len(indexes)) + ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3)) + ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) + ax.scatter(indexes, random_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) + ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) + ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Average Over Multi-Trials") + ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="Randomly Selected Trial") + + coef, p = scipy.stats.kendalltau(standard_scores, random_scores) + ax.set_xlabel("architecture ranking in {:}".format(name2label[dataset]), fontsize=LabelSize) + if dataset == "cifar10": + ax.set_ylabel("architecture ranking", fontsize=LabelSize) + ax.legend(loc=4, fontsize=LegendFontsize) + return coef + + for dataset, ax in zip(datasets, axs): + rank_coef = sub_plot_fn(ax, dataset, indicator) + print("sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format(dataset, search_space, rank_coef)) + + save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("Save into {:}".format(save_path)) - dpi, width, height = 250, 5000, 1300 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 16, 16 +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--save_dir", + type=str, + default="output/vis-nas-bench/rank-stability", + help="Folder to save checkpoints and log.", + ) + args = parser.parse_args() + to_save_dir = Path(args.save_dir) - fig, axs = plt.subplots(1, 3, figsize=figsize) - datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] - - def sub_plot_fn(ax, dataset, indicator): - performances = [] - # pickup top 10% architectures - for _index in range(len(api)): - performances.append((all_infos[_index][dataset][indicator], _index)) - performances = sorted(performances, reverse=True) - performances = performances[: int(len(api) * topk * 0.01)] - selected_indexes = [x[1] for x in performances] - print('{:} plot {:10s} with {:}, {:} architectures'.format(time_string(), dataset, indicator, len(selected_indexes))) - standard_scores = [] - random_scores = [] - for idx in selected_indexes: - standard_scores.append( - api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=False)['test-accuracy']) - random_scores.append( - api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=True)['test-accuracy']) - indexes = list(range(len(selected_indexes))) - standard_indexes = sorted(indexes, key=lambda i: standard_scores[i]) - random_indexes = sorted(indexes, key=lambda i: random_scores[i]) - random_labels = [] - for idx in standard_indexes: - random_labels.append(random_indexes.index(idx)) - for tick in ax.get_xticklabels(): - tick.set_fontsize(LabelSize - 3) - for tick in ax.get_yticklabels(): - tick.set_rotation(25) - tick.set_fontsize(LabelSize - 3) - ax.set_xlim(0, len(indexes)) - ax.set_ylim(0, len(indexes)) - ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes)//3)) - ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes)//5)) - ax.scatter(indexes, random_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) - ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8) - ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Average Over Multi-Trials') - ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='Randomly Selected Trial') - - coef, p = scipy.stats.kendalltau(standard_scores, random_scores) - ax.set_xlabel('architecture ranking in {:}'.format(name2label[dataset]), fontsize=LabelSize) - if dataset == 'cifar10': - ax.set_ylabel('architecture ranking', fontsize=LabelSize) - ax.legend(loc=4, fontsize=LegendFontsize) - return coef - - for dataset, ax in zip(datasets, axs): - rank_coef = sub_plot_fn(ax, dataset, indicator) - print('sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.'.format(dataset, search_space, rank_coef)) - - save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.pdf'.format(search_space, indicator, topk)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.png'.format(search_space, indicator, topk)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print('Save into {:}'.format(save_path)) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/rank-stability', help='Folder to save checkpoints and log.') - args = parser.parse_args() - to_save_dir = Path(args.save_dir) - - for topk in [1, 5, 10, 20]: - visualize_relative_info(to_save_dir, 'tss', 'more', topk) - visualize_relative_info(to_save_dir, 'sss', 'less', topk) - print ('{:} : complete running this file : {:}'.format(time_string(), __file__)) + for topk in [1, 5, 10, 20]: + visualize_relative_info(to_save_dir, "tss", "more", topk) + visualize_relative_info(to_save_dir, "sss", "less", topk) + print("{:} : complete running this file : {:}".format(time_string(), __file__)) diff --git a/exps/NATS-Bench/draw-table.py b/exps/NATS-Bench/draw-table.py index 3d53526..8a52a4f 100644 --- a/exps/NATS-Bench/draw-table.py +++ b/exps/NATS-Bench/draw-table.py @@ -11,149 +11,157 @@ import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict, OrderedDict -from copy import deepcopy +from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from nats_bench import create from log_utils import time_string -def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): - ss_dir = '{:}-{:}'.format(root_dir, search_space) - alg2name, alg2path = OrderedDict(), OrderedDict() - alg2name['REA'] = 'R-EA-SS3' - alg2name['REINFORCE'] = 'REINFORCE-0.01' - alg2name['RANDOM'] = 'RANDOM' - alg2name['BOHB'] = 'BOHB' - for alg, name in alg2name.items(): - alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth') - assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg]) - alg2data = OrderedDict() - for alg, path in alg2path.items(): - data = torch.load(path) - for index, info in data.items(): - info['time_w_arch'] = [(x, y) for x, y in zip(info['all_total_times'], info['all_archs'])] - for j, arch in enumerate(info['all_archs']): - assert arch != -1, 'invalid arch from {:} {:} {:} ({:}, {:})'.format(alg, search_space, dataset, index, j) - alg2data[alg] = data - return alg2data +def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): + ss_dir = "{:}-{:}".format(root_dir, search_space) + alg2name, alg2path = OrderedDict(), OrderedDict() + alg2name["REA"] = "R-EA-SS3" + alg2name["REINFORCE"] = "REINFORCE-0.01" + alg2name["RANDOM"] = "RANDOM" + alg2name["BOHB"] = "BOHB" + for alg, name in alg2name.items(): + alg2path[alg] = os.path.join(ss_dir, dataset, name, "results.pth") + assert os.path.isfile(alg2path[alg]), "invalid path : {:}".format(alg2path[alg]) + alg2data = OrderedDict() + for alg, path in alg2path.items(): + data = torch.load(path) + for index, info in data.items(): + info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])] + for j, arch in enumerate(info["all_archs"]): + assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( + alg, search_space, dataset, index, j + ) + alg2data[alg] = data + return alg2data def get_valid_test_acc(api, arch, dataset): - is_size_space = api.search_space_name == 'size' - if dataset == 'cifar10': - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - test_acc = xinfo['test-accuracy'] - xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False) - valid_acc = xinfo['valid-accuracy'] - else: - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - valid_acc = xinfo['valid-accuracy'] - test_acc = xinfo['test-accuracy'] - return valid_acc, test_acc, 'validation = {:.2f}, test = {:.2f}\n'.format(valid_acc, test_acc) + is_size_space = api.search_space_name == "size" + if dataset == "cifar10": + xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + test_acc = xinfo["test-accuracy"] + xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False) + valid_acc = xinfo["valid-accuracy"] + else: + xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + valid_acc = xinfo["valid-accuracy"] + test_acc = xinfo["test-accuracy"] + return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc) def show_valid_test(api, arch): - is_size_space = api.search_space_name == 'size' - final_str = '' - for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']: - valid_acc, test_acc, perf_str = get_valid_test_acc(api, arch, dataset) - final_str += '{:} : {:}\n'.format(dataset, perf_str) - return final_str + is_size_space = api.search_space_name == "size" + final_str = "" + for dataset in ["cifar10", "cifar100", "ImageNet16-120"]: + valid_acc, test_acc, perf_str = get_valid_test_acc(api, arch, dataset) + final_str += "{:} : {:}\n".format(dataset, perf_str) + return final_str def find_best_valid(api, dataset): - all_valid_accs, all_test_accs = [], [] - for index, arch in enumerate(api): - valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset) - all_valid_accs.append((index, valid_acc)) - all_test_accs.append((index, test_acc)) - best_valid_index = sorted(all_valid_accs, key=lambda x: -x[1])[0][0] - best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0] + all_valid_accs, all_test_accs = [], [] + for index, arch in enumerate(api): + valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset) + all_valid_accs.append((index, valid_acc)) + all_test_accs.append((index, test_acc)) + best_valid_index = sorted(all_valid_accs, key=lambda x: -x[1])[0][0] + best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0] - print('-' * 50 + '{:10s}'.format(dataset) + '-' * 50) - print('Best ({:}) architecture on validation: {:}'.format(best_valid_index, api[best_valid_index])) - print('Best ({:}) architecture on test: {:}'.format(best_test_index, api[best_test_index])) - _, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset) - print('using validation ::: {:}'.format(perf_str)) - _, _, perf_str = get_valid_test_acc(api, best_test_index, dataset) - print('using test ::: {:}'.format(perf_str)) + print("-" * 50 + "{:10s}".format(dataset) + "-" * 50) + print("Best ({:}) architecture on validation: {:}".format(best_valid_index, api[best_valid_index])) + print("Best ({:}) architecture on test: {:}".format(best_test_index, api[best_test_index])) + _, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset) + print("using validation ::: {:}".format(perf_str)) + _, _, perf_str = get_valid_test_acc(api, best_test_index, dataset) + print("using test ::: {:}".format(perf_str)) def interplate_fn(xpair1, xpair2, x): - (x1, y1) = xpair1 - (x2, y2) = xpair2 - return (x2 - x) / (x2 - x1) * y1 + \ - (x - x1) / (x2 - x1) * y2 + (x1, y1) = xpair1 + (x2, y2) = xpair2 + return (x2 - x) / (x2 - x1) * y1 + (x - x1) / (x2 - x1) * y2 + def query_performance(api, info, dataset, ticket): - info = deepcopy(info) - results, is_size_space = [], api.search_space_name == 'size' - time_w_arch = sorted(info['time_w_arch'], key=lambda x: abs(x[0]-ticket)) - time_a, arch_a = time_w_arch[0] - time_b, arch_b = time_w_arch[1] + info = deepcopy(info) + results, is_size_space = [], api.search_space_name == "size" + time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket)) + time_a, arch_a = time_w_arch[0] + time_b, arch_b = time_w_arch[1] - v_acc_a, t_acc_a, _ = get_valid_test_acc(api, arch_a, dataset) - v_acc_b, t_acc_b, _ = get_valid_test_acc(api, arch_b, dataset) - v_acc = interplate_fn((time_a, v_acc_a), (time_b, v_acc_b), ticket) - t_acc = interplate_fn((time_a, t_acc_a), (time_b, t_acc_b), ticket) - # if True: - # interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b - # results.append(interplate) - # return sum(results) / len(results) - return v_acc, t_acc + v_acc_a, t_acc_a, _ = get_valid_test_acc(api, arch_a, dataset) + v_acc_b, t_acc_b, _ = get_valid_test_acc(api, arch_b, dataset) + v_acc = interplate_fn((time_a, v_acc_a), (time_b, v_acc_b), ticket) + t_acc = interplate_fn((time_a, t_acc_a), (time_b, t_acc_b), ticket) + # if True: + # interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b + # results.append(interplate) + # return sum(results) / len(results) + return v_acc, t_acc def show_multi_trial(search_space): - api = create(None, search_space, fast_mode=True, verbose=False) - def show(dataset): - print('show {:} on {:} done.'.format(dataset, search_space)) - xdataset, max_time = dataset.split('-T') - alg2data = fetch_data(search_space=search_space, dataset=dataset) - for idx, (alg, data) in enumerate(alg2data.items()): + api = create(None, search_space, fast_mode=True, verbose=False) - valid_accs, test_accs = [], [] - for _, x in data.items(): - v_acc, t_acc = query_performance(api, x, xdataset, float(max_time)) - valid_accs.append(v_acc) - test_accs.append(t_acc) - valid_str = '{:.2f}$\pm${:.2f}'.format(np.mean(valid_accs), np.std(valid_accs)) - test_str = '{:.2f}$\pm${:.2f}'.format(np.mean(test_accs), np.std(test_accs)) - print('{:} plot alg : {:10s} | validation = {:} | test = {:}'.format(time_string(), alg, valid_str, test_str)) - if search_space == 'tss': - datasets = ['cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T120000'] - elif search_space == 'sss': - datasets = ['cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T60000'] - else: - raise ValueError('Unknown search space: {:}'.format(search_space)) - for dataset in datasets: - show(dataset) - print('{:} complete show multi-trial results.\n'.format(time_string())) + def show(dataset): + print("show {:} on {:} done.".format(dataset, search_space)) + xdataset, max_time = dataset.split("-T") + alg2data = fetch_data(search_space=search_space, dataset=dataset) + for idx, (alg, data) in enumerate(alg2data.items()): + + valid_accs, test_accs = [], [] + for _, x in data.items(): + v_acc, t_acc = query_performance(api, x, xdataset, float(max_time)) + valid_accs.append(v_acc) + test_accs.append(t_acc) + valid_str = "{:.2f}$\pm${:.2f}".format(np.mean(valid_accs), np.std(valid_accs)) + test_str = "{:.2f}$\pm${:.2f}".format(np.mean(test_accs), np.std(test_accs)) + print( + "{:} plot alg : {:10s} | validation = {:} | test = {:}".format(time_string(), alg, valid_str, test_str) + ) + + if search_space == "tss": + datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T120000"] + elif search_space == "sss": + datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T60000"] + else: + raise ValueError("Unknown search space: {:}".format(search_space)) + for dataset in datasets: + show(dataset) + print("{:} complete show multi-trial results.\n".format(time_string())) -if __name__ == '__main__': - - show_multi_trial('tss') - show_multi_trial('sss') +if __name__ == "__main__": - api_tss = create(None, 'tss', fast_mode=False, verbose=False) - resnet = '|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|' - resnet_index = api_tss.query_index_by_arch(resnet) - print(show_valid_test(api_tss, resnet_index)) + show_multi_trial("tss") + show_multi_trial("sss") - for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']: - find_best_valid(api_tss, dataset) + api_tss = create(None, "tss", fast_mode=False, verbose=False) + resnet = "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|" + resnet_index = api_tss.query_index_by_arch(resnet) + print(show_valid_test(api_tss, resnet_index)) - largest = '64:64:64:64:64' - largest_index = api_sss.query_index_by_arch(largest) - print(show_valid_test(api_sss, largest_index)) - for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']: - find_best_valid(api_sss, dataset) + for dataset in ["cifar10", "cifar100", "ImageNet16-120"]: + find_best_valid(api_tss, dataset) + + largest = "64:64:64:64:64" + largest_index = api_sss.query_index_by_arch(largest) + print(show_valid_test(api_sss, largest_index)) + for dataset in ["cifar10", "cifar100", "ImageNet16-120"]: + find_best_valid(api_sss, dataset) diff --git a/exps/NATS-Bench/main-sss.py b/exps/NATS-Bench/main-sss.py index 062289a..dcbd579 100644 --- a/exps/NATS-Bench/main-sss.py +++ b/exps/NATS-Bench/main-sss.py @@ -15,230 +15,349 @@ ############################################################################## import os, sys, time, torch, argparse from typing import List, Text, Dict, Any -from PIL import ImageFile +from PIL import ImageFile + ImageFile.LOAD_TRUNCATED_IMAGES = True -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config -from procedures import bench_evaluate_for_seed -from procedures import get_machine_info -from datasets import get_datasets -from log_utils import Logger, AverageMeter, time_string, convert_secs2time -from utils import split_str2indexes +from procedures import bench_evaluate_for_seed +from procedures import get_machine_info +from datasets import get_datasets +from log_utils import Logger, AverageMeter, time_string, convert_secs2time +from utils import split_str2indexes -def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Text], - splits: List[Text], config_path: Text, seed: int, workers: int, logger): - machine_info = get_machine_info() - all_infos = {'info': machine_info} - all_dataset_keys = [] - # look all the dataset - for dataset, xpath, split in zip(datasets, xpaths, splits): - # the train and valid data - train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) - # load the configuration - if dataset == 'cifar10' or dataset == 'cifar100': - split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None) - elif dataset.startswith('ImageNet16'): - split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None) - else: - raise ValueError('invalid dataset : {:}'.format(dataset)) - config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) - # check whether use the splitted validation set - if bool(split): - assert dataset == 'cifar10' - ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} - assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid)) - train_data_v2 = deepcopy(train_data) - train_data_v2.transform = valid_data.transform - valid_data = train_data_v2 - # data loader - train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True) - ValLoaders['x-valid'] = valid_loader - else: - # data loader - train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) - if dataset == 'cifar10': - ValLoaders = {'ori-test': valid_loader} - elif dataset == 'cifar100': - cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None) - ValLoaders = {'ori-test': valid_loader, - 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True), - 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True) - } - elif dataset == 'ImageNet16-120': - imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None) - ValLoaders = {'ori-test': valid_loader, - 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True), - 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True) - } - else: - raise ValueError('invalid dataset : {:}'.format(dataset)) - - dataset_key = '{:}'.format(dataset) - if bool(split): dataset_key = dataset_key + '-valid' - logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) - logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config)) - for key, value in ValLoaders.items(): - logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) - # arch-index= 9930, arch=|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2| - # this genotype is the architecture with the highest accuracy on CIFAR-100 validation set - genotype = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|' - arch_config = dict2config(dict(name='infer.shape.tiny', channels=channels, genotype=genotype, num_classes=class_num), None) - results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger) - all_infos[dataset_key] = results - all_dataset_keys.append( dataset_key ) - all_infos['all_dataset_keys'] = all_dataset_keys - return all_infos - - -def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], - splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any], - to_evaluate_indexes: tuple, cover_mode: bool): - - log_dir = save_dir / 'logs' - log_dir.mkdir(parents=True, exist_ok=True) - logger = Logger(str(log_dir), os.getpid(), False) - - logger.log('xargs : seeds = {:}'.format(seeds)) - logger.log('xargs : cover_mode = {:}'.format(cover_mode)) - logger.log('-' * 100) - logger.log( - 'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes)) - +'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode)) - for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): - logger.log( - '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) - logger.log('--->>> optimization config : {:}'.format(opt_config)) - - start_time, epoch_time = time.time(), AverageMeter() - for i, index in enumerate(to_evaluate_indexes): - channelstr = nets[index] - logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}'.format(time_string(), i, - len(to_evaluate_indexes), index, len(nets), seeds, '-' * 15)) - logger.log('{:} {:} {:}'.format('-' * 15, channelstr, '-' * 15)) - - # test this arch on different datasets with different seeds - has_continue = False - for seed in seeds: - to_save_name = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) - if to_save_name.exists(): - if cover_mode: - logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name)) - os.remove(str(to_save_name)) +def evaluate_all_datasets( + channels: Text, + datasets: List[Text], + xpaths: List[Text], + splits: List[Text], + config_path: Text, + seed: int, + workers: int, + logger, +): + machine_info = get_machine_info() + all_infos = {"info": machine_info} + all_dataset_keys = [] + # look all the dataset + for dataset, xpath, split in zip(datasets, xpaths, splits): + # the train and valid data + train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) + # load the configuration + if dataset == "cifar10" or dataset == "cifar100": + split_info = load_config("configs/nas-benchmark/cifar-split.txt", None, None) + elif dataset.startswith("ImageNet16"): + split_info = load_config("configs/nas-benchmark/{:}-split.txt".format(dataset), None, None) else: - logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name)) - has_continue = True - continue - results = evaluate_all_datasets(channelstr, - datasets, xpaths, splits, opt_config, seed, - workers, logger) - torch.save(results, to_save_name) - logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i, - len(to_evaluate_indexes), index, len(nets), seeds, to_save_name)) - # measure elapsed time - if not has_continue: epoch_time.update(time.time() - start_time) - start_time = time.time() - need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True)) - logger.log('This arch costs : {:}'.format(convert_secs2time(epoch_time.val, True))) - logger.log('{:}'.format('*' * 100)) - logger.log('{:} {:74s} {:}'.format('*' * 10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len( - to_evaluate_indexes), index, len(nets), need_time), '*' * 10)) - logger.log('{:}'.format('*' * 100)) + raise ValueError("invalid dataset : {:}".format(dataset)) + config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) + # check whether use the splitted validation set + if bool(split): + assert dataset == "cifar10" + ValLoaders = { + "ori-test": torch.utils.data.DataLoader( + valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + ) + } + assert len(train_data) == len(split_info.train) + len( + split_info.valid + ), "invalid length : {:} vs {:} + {:}".format(len(train_data), len(split_info.train), len(split_info.valid)) + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + # data loader + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), + num_workers=workers, + pin_memory=True, + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), + num_workers=workers, + pin_memory=True, + ) + ValLoaders["x-valid"] = valid_loader + else: + # data loader + train_loader = torch.utils.data.DataLoader( + train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + ) + if dataset == "cifar10": + ValLoaders = {"ori-test": valid_loader} + elif dataset == "cifar100": + cifar100_splits = load_config("configs/nas-benchmark/cifar100-test-split.txt", None, None) + ValLoaders = { + "ori-test": valid_loader, + "x-valid": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), + num_workers=workers, + pin_memory=True, + ), + "x-test": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest), + num_workers=workers, + pin_memory=True, + ), + } + elif dataset == "ImageNet16-120": + imagenet16_splits = load_config("configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None) + ValLoaders = { + "ori-test": valid_loader, + "x-valid": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), + num_workers=workers, + pin_memory=True, + ), + "x-test": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest), + num_workers=workers, + pin_memory=True, + ), + } + else: + raise ValueError("invalid dataset : {:}".format(dataset)) - logger.close() + dataset_key = "{:}".format(dataset) + if bool(split): + dataset_key = dataset_key + "-valid" + logger.log( + "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)) + for key, value in ValLoaders.items(): + logger.log("Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))) + # arch-index= 9930, arch=|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2| + # this genotype is the architecture with the highest accuracy on CIFAR-100 validation set + genotype = "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|" + arch_config = dict2config( + dict(name="infer.shape.tiny", channels=channels, genotype=genotype, num_classes=class_num), None + ) + results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger) + all_infos[dataset_key] = results + all_dataset_keys.append(dataset_key) + all_infos["all_dataset_keys"] = all_dataset_keys + return all_infos + + +def main( + save_dir: Path, + workers: int, + datasets: List[Text], + xpaths: List[Text], + splits: List[int], + seeds: List[int], + nets: List[str], + opt_config: Dict[Text, Any], + to_evaluate_indexes: tuple, + cover_mode: bool, +): + + log_dir = save_dir / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + logger = Logger(str(log_dir), os.getpid(), False) + + logger.log("xargs : seeds = {:}".format(seeds)) + logger.log("xargs : cover_mode = {:}".format(cover_mode)) + logger.log("-" * 100) + logger.log( + "Start evaluating range =: {:06d} - {:06d}".format(min(to_evaluate_indexes), max(to_evaluate_indexes)) + + "({:} in total) / {:06d} with cover-mode={:}".format(len(to_evaluate_indexes), len(nets), cover_mode) + ) + for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): + logger.log( + "--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}".format( + i, len(datasets), dataset, xpath, split + ) + ) + logger.log("--->>> optimization config : {:}".format(opt_config)) + + start_time, epoch_time = time.time(), AverageMeter() + for i, index in enumerate(to_evaluate_indexes): + channelstr = nets[index] + logger.log( + "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}".format( + time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, "-" * 15 + ) + ) + logger.log("{:} {:} {:}".format("-" * 15, channelstr, "-" * 15)) + + # test this arch on different datasets with different seeds + has_continue = False + for seed in seeds: + to_save_name = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) + if to_save_name.exists(): + if cover_mode: + logger.log("Find existing file : {:}, remove it before evaluation".format(to_save_name)) + os.remove(str(to_save_name)) + else: + logger.log("Find existing file : {:}, skip this evaluation".format(to_save_name)) + has_continue = True + continue + results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger) + torch.save(results, to_save_name) + logger.log( + "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}".format( + time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, to_save_name + ) + ) + # measure elapsed time + if not has_continue: + epoch_time.update(time.time() - start_time) + start_time = time.time() + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True) + ) + logger.log("This arch costs : {:}".format(convert_secs2time(epoch_time.val, True))) + logger.log("{:}".format("*" * 100)) + logger.log( + "{:} {:74s} {:}".format( + "*" * 10, + "{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}".format( + i, len(to_evaluate_indexes), index, len(nets), need_time + ), + "*" * 10, + ) + ) + logger.log("{:}".format("*" * 100)) + + logger.close() def traverse_net(candidates: List[int], N: int): - nets = [''] - for i in range(N): - new_nets = [] - for net in nets: - for C in candidates: - new_nets.append(str(C) if net == '' else "{:}:{:}".format(net,C)) - nets = new_nets - return nets + nets = [""] + for i in range(N): + new_nets = [] + for net in nets: + for C in candidates: + new_nets.append(str(C) if net == "" else "{:}:{:}".format(net, C)) + nets = new_nets + return nets def filter_indexes(xlist, mode, save_dir, seeds): - all_indexes = [] - for index in xlist: - if mode == 'cover': - all_indexes.append(index) - else: - for seed in seeds: - temp_path = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) - if not temp_path.exists(): - all_indexes.append(index) - break - print('{:} [FILTER-INDEXES] : there are {:}/{:} architectures in total'.format(time_string(), len(all_indexes), len(xlist))) + all_indexes = [] + for index in xlist: + if mode == "cover": + all_indexes.append(index) + else: + for seed in seeds: + temp_path = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) + if not temp_path.exists(): + all_indexes.append(index) + break + print( + "{:} [FILTER-INDEXES] : there are {:}/{:} architectures in total".format( + time_string(), len(all_indexes), len(xlist) + ) + ) - SLURM_PROCID, SLURM_NTASKS = 'SLURM_PROCID', 'SLURM_NTASKS' - if SLURM_PROCID in os.environ and SLURM_NTASKS in os.environ: # run on the slurm - proc_id, ntasks = int(os.environ[SLURM_PROCID]), int(os.environ[SLURM_NTASKS]) - assert 0 <= proc_id < ntasks, 'invalid proc_id {:} vs ntasks {:}'.format(proc_id, ntasks) - scales = [int(float(i)/ntasks*len(all_indexes)) for i in range(ntasks)] + [len(all_indexes)] - per_job = [] - for i in range(ntasks): - xs, xe = min(max(scales[i],0), len(all_indexes)-1), min(max(scales[i+1]-1,0), len(all_indexes)-1) - per_job.append((xs, xe)) - for i, srange in enumerate(per_job): - print(' -->> {:2d}/{:02d} : {:}'.format(i, ntasks, srange)) - current_range = per_job[proc_id] - all_indexes = [all_indexes[i] for i in range(current_range[0], current_range[1]+1)] - # set the device id - device = proc_id % torch.cuda.device_count() - torch.cuda.set_device(device) - print(' set the device id = {:}'.format(device)) - print('{:} [FILTER-INDEXES] : after filtering there are {:} architectures in total'.format(time_string(), len(all_indexes))) - return all_indexes + SLURM_PROCID, SLURM_NTASKS = "SLURM_PROCID", "SLURM_NTASKS" + if SLURM_PROCID in os.environ and SLURM_NTASKS in os.environ: # run on the slurm + proc_id, ntasks = int(os.environ[SLURM_PROCID]), int(os.environ[SLURM_NTASKS]) + assert 0 <= proc_id < ntasks, "invalid proc_id {:} vs ntasks {:}".format(proc_id, ntasks) + scales = [int(float(i) / ntasks * len(all_indexes)) for i in range(ntasks)] + [len(all_indexes)] + per_job = [] + for i in range(ntasks): + xs, xe = min(max(scales[i], 0), len(all_indexes) - 1), min(max(scales[i + 1] - 1, 0), len(all_indexes) - 1) + per_job.append((xs, xe)) + for i, srange in enumerate(per_job): + print(" -->> {:2d}/{:02d} : {:}".format(i, ntasks, srange)) + current_range = per_job[proc_id] + all_indexes = [all_indexes[i] for i in range(current_range[0], current_range[1] + 1)] + # set the device id + device = proc_id % torch.cuda.device_count() + torch.cuda.set_device(device) + print(" set the device id = {:}".format(device)) + print( + "{:} [FILTER-INDEXES] : after filtering there are {:} architectures in total".format( + time_string(), len(all_indexes) + ) + ) + return all_indexes -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NATS-Bench (size search space)', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--mode', type=str, required=True, choices=['new', 'cover'], help='The script mode.') - parser.add_argument('--save_dir', type=str, default='output/NATS-Bench-size', help='Folder to save checkpoints and log.') - parser.add_argument('--candidateC', type=int, nargs='+', default=[8, 16, 24, 32, 40, 48, 56, 64], help='.') - parser.add_argument('--num_layers', type=int, default=5, help='The number of layers in a network.') - parser.add_argument('--check_N', type=int, default=32768, help='For safety.') - # use for train the model - parser.add_argument('--workers', type=int, default=8, help='The number of data loading workers (default: 2)') - parser.add_argument('--srange' , type=str, required=True, help='The range of models to be evaluated') - parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.') - parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.') - parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.') - parser.add_argument('--hyper', type=str, default='12', choices=['01', '12', '90'], help='The tag for hyper-parameters.') - parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="NATS-Bench (size search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--mode", type=str, required=True, choices=["new", "cover"], help="The script mode.") + parser.add_argument( + "--save_dir", type=str, default="output/NATS-Bench-size", help="Folder to save checkpoints and log." + ) + parser.add_argument("--candidateC", type=int, nargs="+", default=[8, 16, 24, 32, 40, 48, 56, 64], help=".") + parser.add_argument("--num_layers", type=int, default=5, help="The number of layers in a network.") + parser.add_argument("--check_N", type=int, default=32768, help="For safety.") + # use for train the model + parser.add_argument("--workers", type=int, default=8, help="The number of data loading workers (default: 2)") + parser.add_argument("--srange", type=str, required=True, help="The range of models to be evaluated") + parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.") + parser.add_argument("--xpaths", type=str, nargs="+", help="The root path for this dataset.") + parser.add_argument("--splits", type=int, nargs="+", help="The root path for this dataset.") + parser.add_argument( + "--hyper", type=str, default="12", choices=["01", "12", "90"], help="The tag for hyper-parameters." + ) + parser.add_argument("--seeds", type=int, nargs="+", help="The range of models to be evaluated") + args = parser.parse_args() - nets = traverse_net(args.candidateC, args.num_layers) - if len(nets) != args.check_N: - raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N)) + nets = traverse_net(args.candidateC, args.num_layers) + if len(nets) != args.check_N: + raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) - opt_config = './configs/nas-benchmark/hyper-opts/{:}E.config'.format(args.hyper) - if not os.path.isfile(opt_config): - raise ValueError('{:} is not a file.'.format(opt_config)) - save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper) - save_dir.mkdir(parents=True, exist_ok=True) - to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5) + opt_config = "./configs/nas-benchmark/hyper-opts/{:}E.config".format(args.hyper) + if not os.path.isfile(opt_config): + raise ValueError("{:} is not a file.".format(opt_config)) + save_dir = Path(args.save_dir) / "raw-data-{:}".format(args.hyper) + save_dir.mkdir(parents=True, exist_ok=True) + to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5) - if not len(args.seeds): - raise ValueError('invalid length of seeds args: {:}'.format(args.seeds)) - if not (len(args.datasets) == len(args.xpaths) == len(args.splits)): - raise ValueError('invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))) - if args.workers <= 0: - raise ValueError('invalid number of workers : {:}'.format(args.workers)) + if not len(args.seeds): + raise ValueError("invalid length of seeds args: {:}".format(args.seeds)) + if not (len(args.datasets) == len(args.xpaths) == len(args.splits)): + raise ValueError( + "invalid infos : {:} vs {:} vs {:}".format(len(args.datasets), len(args.xpaths), len(args.splits)) + ) + if args.workers <= 0: + raise ValueError("invalid number of workers : {:}".format(args.workers)) - target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds) - - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.deterministic = True - torch.set_num_threads(args.workers) + target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds) - main(save_dir, args.workers, args.datasets, args.xpaths, args.splits, tuple(args.seeds), nets, opt_config, target_indexes, args.mode == 'cover') + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.deterministic = True + torch.set_num_threads(args.workers) + + main( + save_dir, + args.workers, + args.datasets, + args.xpaths, + args.splits, + tuple(args.seeds), + nets, + opt_config, + target_indexes, + args.mode == "cover", + ) diff --git a/exps/NATS-Bench/main-tss.py b/exps/NATS-Bench/main-tss.py index 189c0e1..d6e9231 100644 --- a/exps/NATS-Bench/main-tss.py +++ b/exps/NATS-Bench/main-tss.py @@ -19,316 +19,505 @@ ############################################################################## import os, sys, time, torch, random, argparse from typing import List, Text, Dict, Any -from PIL import ImageFile +from PIL import ImageFile + ImageFile.LOAD_TRUNCATED_IMAGES = True -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config -from procedures import bench_evaluate_for_seed -from procedures import get_machine_info -from datasets import get_datasets -from log_utils import Logger, AverageMeter, time_string, convert_secs2time -from models import CellStructure, CellArchitectures, get_search_spaces -from utils import split_str2indexes +from procedures import bench_evaluate_for_seed +from procedures import get_machine_info +from datasets import get_datasets +from log_utils import Logger, AverageMeter, time_string, convert_secs2time +from models import CellStructure, CellArchitectures, get_search_spaces +from utils import split_str2indexes -def evaluate_all_datasets(arch: Text, datasets: List[Text], xpaths: List[Text], - splits: List[Text], config_path: Text, seed: int, raw_arch_config, workers, logger): - machine_info, raw_arch_config = get_machine_info(), deepcopy(raw_arch_config) - all_infos = {'info': machine_info} - all_dataset_keys = [] - # look all the datasets - for dataset, xpath, split in zip(datasets, xpaths, splits): - # train valid data - train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) - # load the configuration - if dataset == 'cifar10' or dataset == 'cifar100': - split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None) - elif dataset.startswith('ImageNet16'): - split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None) - else: - raise ValueError('invalid dataset : {:}'.format(dataset)) - config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) - # check whether use splited validation set - if bool(split): - assert dataset == 'cifar10' - ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} - assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid)) - train_data_v2 = deepcopy(train_data) - train_data_v2.transform = valid_data.transform - valid_data = train_data_v2 - # data loader - train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True) - ValLoaders['x-valid'] = valid_loader - else: - # data loader - train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) - if dataset == 'cifar10': - ValLoaders = {'ori-test': valid_loader} - elif dataset == 'cifar100': - cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None) - ValLoaders = {'ori-test': valid_loader, - 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True), - 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True) - } - elif dataset == 'ImageNet16-120': - imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None) - ValLoaders = {'ori-test': valid_loader, - 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True), - 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True) - } - else: - raise ValueError('invalid dataset : {:}'.format(dataset)) - - dataset_key = '{:}'.format(dataset) - if bool(split): dataset_key = dataset_key + '-valid' - logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) - logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config)) - for key, value in ValLoaders.items(): - logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) - arch_config = dict2config(dict(name='infer.tiny', C=raw_arch_config['channel'], N=raw_arch_config['num_cells'], - genotype=arch, num_classes=config.class_num), None) - results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger) - all_infos[dataset_key] = results - all_dataset_keys.append( dataset_key ) - all_infos['all_dataset_keys'] = all_dataset_keys - return all_infos - - -def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], - splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any], - to_evaluate_indexes: tuple, cover_mode: bool, arch_config: Dict[Text, Any]): - - log_dir = save_dir / 'logs' - log_dir.mkdir(parents=True, exist_ok=True) - logger = Logger(str(log_dir), os.getpid(), False) - - logger.log('xargs : seeds = {:}'.format(seeds)) - logger.log('xargs : cover_mode = {:}'.format(cover_mode)) - logger.log('-' * 100) - logger.log( - 'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes)) - +'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode)) - for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): - logger.log( - '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) - logger.log('--->>> optimization config : {:}'.format(opt_config)) - - start_time, epoch_time = time.time(), AverageMeter() - for i, index in enumerate(to_evaluate_indexes): - arch = nets[index] - logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}'.format(time_string(), i, - len(to_evaluate_indexes), index, len(nets), seeds, '-' * 15)) - logger.log('{:} {:} {:}'.format('-' * 15, arch, '-' * 15)) - - # test this arch on different datasets with different seeds - has_continue = False - for seed in seeds: - to_save_name = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) - if to_save_name.exists(): - if cover_mode: - logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name)) - os.remove(str(to_save_name)) +def evaluate_all_datasets( + arch: Text, + datasets: List[Text], + xpaths: List[Text], + splits: List[Text], + config_path: Text, + seed: int, + raw_arch_config, + workers, + logger, +): + machine_info, raw_arch_config = get_machine_info(), deepcopy(raw_arch_config) + all_infos = {"info": machine_info} + all_dataset_keys = [] + # look all the datasets + for dataset, xpath, split in zip(datasets, xpaths, splits): + # train valid data + train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) + # load the configuration + if dataset == "cifar10" or dataset == "cifar100": + split_info = load_config("configs/nas-benchmark/cifar-split.txt", None, None) + elif dataset.startswith("ImageNet16"): + split_info = load_config("configs/nas-benchmark/{:}-split.txt".format(dataset), None, None) else: - logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name)) - has_continue = True - continue - results = evaluate_all_datasets(CellStructure.str2structure(arch), - datasets, xpaths, splits, opt_config, seed, - arch_config, workers, logger) - torch.save(results, to_save_name) - logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i, - len(to_evaluate_indexes), index, len(nets), seeds, to_save_name)) - # measure elapsed time - if not has_continue: epoch_time.update(time.time() - start_time) - start_time = time.time() - need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) ) - logger.log('This arch costs : {:}'.format(convert_secs2time(epoch_time.val, True) )) - logger.log('{:}'.format('*' * 100)) - logger.log('{:} {:74s} {:}'.format('*' * 10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len( - to_evaluate_indexes), index, len(nets), need_time), '*' * 10)) - logger.log('{:}'.format('*' * 100)) + raise ValueError("invalid dataset : {:}".format(dataset)) + config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) + # check whether use splited validation set + if bool(split): + assert dataset == "cifar10" + ValLoaders = { + "ori-test": torch.utils.data.DataLoader( + valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + ) + } + assert len(train_data) == len(split_info.train) + len( + split_info.valid + ), "invalid length : {:} vs {:} + {:}".format(len(train_data), len(split_info.train), len(split_info.valid)) + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + # data loader + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), + num_workers=workers, + pin_memory=True, + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), + num_workers=workers, + pin_memory=True, + ) + ValLoaders["x-valid"] = valid_loader + else: + # data loader + train_loader = torch.utils.data.DataLoader( + train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + ) + if dataset == "cifar10": + ValLoaders = {"ori-test": valid_loader} + elif dataset == "cifar100": + cifar100_splits = load_config("configs/nas-benchmark/cifar100-test-split.txt", None, None) + ValLoaders = { + "ori-test": valid_loader, + "x-valid": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), + num_workers=workers, + pin_memory=True, + ), + "x-test": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest), + num_workers=workers, + pin_memory=True, + ), + } + elif dataset == "ImageNet16-120": + imagenet16_splits = load_config("configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None) + ValLoaders = { + "ori-test": valid_loader, + "x-valid": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), + num_workers=workers, + pin_memory=True, + ), + "x-test": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest), + num_workers=workers, + pin_memory=True, + ), + } + else: + raise ValueError("invalid dataset : {:}".format(dataset)) - logger.close() + dataset_key = "{:}".format(dataset) + if bool(split): + dataset_key = dataset_key + "-valid" + logger.log( + "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)) + for key, value in ValLoaders.items(): + logger.log("Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))) + arch_config = dict2config( + dict( + name="infer.tiny", + C=raw_arch_config["channel"], + N=raw_arch_config["num_cells"], + genotype=arch, + num_classes=config.class_num, + ), + None, + ) + results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger) + all_infos[dataset_key] = results + all_dataset_keys.append(dataset_key) + all_infos["all_dataset_keys"] = all_dataset_keys + return all_infos + + +def main( + save_dir: Path, + workers: int, + datasets: List[Text], + xpaths: List[Text], + splits: List[int], + seeds: List[int], + nets: List[str], + opt_config: Dict[Text, Any], + to_evaluate_indexes: tuple, + cover_mode: bool, + arch_config: Dict[Text, Any], +): + + log_dir = save_dir / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + logger = Logger(str(log_dir), os.getpid(), False) + + logger.log("xargs : seeds = {:}".format(seeds)) + logger.log("xargs : cover_mode = {:}".format(cover_mode)) + logger.log("-" * 100) + logger.log( + "Start evaluating range =: {:06d} - {:06d}".format(min(to_evaluate_indexes), max(to_evaluate_indexes)) + + "({:} in total) / {:06d} with cover-mode={:}".format(len(to_evaluate_indexes), len(nets), cover_mode) + ) + for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): + logger.log( + "--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}".format( + i, len(datasets), dataset, xpath, split + ) + ) + logger.log("--->>> optimization config : {:}".format(opt_config)) + + start_time, epoch_time = time.time(), AverageMeter() + for i, index in enumerate(to_evaluate_indexes): + arch = nets[index] + logger.log( + "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}".format( + time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, "-" * 15 + ) + ) + logger.log("{:} {:} {:}".format("-" * 15, arch, "-" * 15)) + + # test this arch on different datasets with different seeds + has_continue = False + for seed in seeds: + to_save_name = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) + if to_save_name.exists(): + if cover_mode: + logger.log("Find existing file : {:}, remove it before evaluation".format(to_save_name)) + os.remove(str(to_save_name)) + else: + logger.log("Find existing file : {:}, skip this evaluation".format(to_save_name)) + has_continue = True + continue + results = evaluate_all_datasets( + CellStructure.str2structure(arch), + datasets, + xpaths, + splits, + opt_config, + seed, + arch_config, + workers, + logger, + ) + torch.save(results, to_save_name) + logger.log( + "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}".format( + time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, to_save_name + ) + ) + # measure elapsed time + if not has_continue: + epoch_time.update(time.time() - start_time) + start_time = time.time() + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True) + ) + logger.log("This arch costs : {:}".format(convert_secs2time(epoch_time.val, True))) + logger.log("{:}".format("*" * 100)) + logger.log( + "{:} {:74s} {:}".format( + "*" * 10, + "{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}".format( + i, len(to_evaluate_indexes), index, len(nets), need_time + ), + "*" * 10, + ) + ) + logger.log("{:}".format("*" * 100)) + + logger.close() def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.deterministic = True - #torch.backends.cudnn.benchmark = True - torch.set_num_threads( workers ) - - save_dir = Path(save_dir) / 'specifics' / '{:}-{:}-{:}-{:}'.format('LESS' if use_less else 'FULL', model_str, arch_config['channel'], arch_config['num_cells']) - logger = Logger(str(save_dir), 0, False) - if model_str in CellArchitectures: - arch = CellArchitectures[model_str] - logger.log('The model string is found in pre-defined architecture dict : {:}'.format(model_str)) - else: - try: - arch = CellStructure.str2structure(model_str) - except: - raise ValueError('Invalid model string : {:}. It can not be found or parsed.'.format(model_str)) - assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch) - logger.log('Start train-evaluate {:}'.format(arch.tostr())) - logger.log('arch_config : {:}'.format(arch_config)) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = True + torch.set_num_threads(workers) - start_time, seed_time = time.time(), AverageMeter() - for _is, seed in enumerate(seeds): - logger.log('\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds), seed)) - to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed) - if to_save_name.exists(): - logger.log('Find the existing file {:}, directly load!'.format(to_save_name)) - checkpoint = torch.load(to_save_name) + save_dir = ( + Path(save_dir) + / "specifics" + / "{:}-{:}-{:}-{:}".format( + "LESS" if use_less else "FULL", model_str, arch_config["channel"], arch_config["num_cells"] + ) + ) + logger = Logger(str(save_dir), 0, False) + if model_str in CellArchitectures: + arch = CellArchitectures[model_str] + logger.log("The model string is found in pre-defined architecture dict : {:}".format(model_str)) else: - logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name)) - checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger) - torch.save(checkpoint, to_save_name) - # log information - logger.log('{:}'.format(checkpoint['info'])) - all_dataset_keys = checkpoint['all_dataset_keys'] - for dataset_key in all_dataset_keys: - logger.log('\n{:} dataset : {:} {:}'.format('-'*15, dataset_key, '-'*15)) - dataset_info = checkpoint[dataset_key] - #logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] )) - logger.log('Flops = {:} MB, Params = {:} MB'.format(dataset_info['flop'], dataset_info['param'])) - logger.log('config : {:}'.format(dataset_info['config'])) - logger.log('Training State (finish) = {:}'.format(dataset_info['finish-train'])) - last_epoch = dataset_info['total_epoch'] - 1 - train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es'] - valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es'] - logger.log('Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%'.format(train_acc1es[last_epoch], train_acc5es[last_epoch], 100-train_acc1es[last_epoch], valid_acc1es[last_epoch], valid_acc5es[last_epoch], 100-valid_acc1es[last_epoch])) - # measure elapsed time - seed_time.update(time.time() - start_time) - start_time = time.time() - need_time = 'Time Left: {:}'.format( convert_secs2time(seed_time.avg * (len(seeds)-_is-1), True) ) - logger.log('\n<<<***>>> The {:02d}/{:02d}-th seed is {:} other procedures need {:}'.format(_is, len(seeds), seed, need_time)) - logger.close() + try: + arch = CellStructure.str2structure(model_str) + except: + raise ValueError("Invalid model string : {:}. It can not be found or parsed.".format(model_str)) + assert arch.check_valid_op(get_search_spaces("cell", "full")), "{:} has the invalid op.".format(arch) + logger.log("Start train-evaluate {:}".format(arch.tostr())) + logger.log("arch_config : {:}".format(arch_config)) + + start_time, seed_time = time.time(), AverageMeter() + for _is, seed in enumerate(seeds): + logger.log( + "\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------".format( + _is, len(seeds), seed + ) + ) + to_save_name = save_dir / "seed-{:04d}.pth".format(seed) + if to_save_name.exists(): + logger.log("Find the existing file {:}, directly load!".format(to_save_name)) + checkpoint = torch.load(to_save_name) + else: + logger.log("Does not find the existing file {:}, train and evaluate!".format(to_save_name)) + checkpoint = evaluate_all_datasets( + arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger + ) + torch.save(checkpoint, to_save_name) + # log information + logger.log("{:}".format(checkpoint["info"])) + all_dataset_keys = checkpoint["all_dataset_keys"] + for dataset_key in all_dataset_keys: + logger.log("\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15)) + dataset_info = checkpoint[dataset_key] + # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] )) + logger.log("Flops = {:} MB, Params = {:} MB".format(dataset_info["flop"], dataset_info["param"])) + logger.log("config : {:}".format(dataset_info["config"])) + logger.log("Training State (finish) = {:}".format(dataset_info["finish-train"])) + last_epoch = dataset_info["total_epoch"] - 1 + train_acc1es, train_acc5es = dataset_info["train_acc1es"], dataset_info["train_acc5es"] + valid_acc1es, valid_acc5es = dataset_info["valid_acc1es"], dataset_info["valid_acc5es"] + logger.log( + "Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format( + train_acc1es[last_epoch], + train_acc5es[last_epoch], + 100 - train_acc1es[last_epoch], + valid_acc1es[last_epoch], + valid_acc5es[last_epoch], + 100 - valid_acc1es[last_epoch], + ) + ) + # measure elapsed time + seed_time.update(time.time() - start_time) + start_time = time.time() + need_time = "Time Left: {:}".format(convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)) + logger.log( + "\n<<<***>>> The {:02d}/{:02d}-th seed is {:} other procedures need {:}".format( + _is, len(seeds), seed, need_time + ) + ) + logger.close() def generate_meta_info(save_dir, max_node, divide=40): - aa_nas_bench_ss = get_search_spaces('cell', 'nas-bench-201') - archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) - print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2))) + aa_nas_bench_ss = get_search_spaces("cell", "nas-bench-201") + archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) + print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2))) - random.seed( 88 ) # please do not change this line for reproducibility - random.shuffle( archs ) - # to test fixed-random shuffle - #print ('arch [0] : {:}\n---->>>> {:}'.format( archs[0], archs[0].tostr() )) - #print ('arch [9] : {:}\n---->>>> {:}'.format( archs[9], archs[9].tostr() )) - assert archs[0 ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0]) - assert archs[9 ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9]) - assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123]) - total_arch = len(archs) - - num = 50000 - indexes_5W = list(range(num)) - random.seed( 1021 ) - random.shuffle( indexes_5W ) - train_split = sorted( list(set(indexes_5W[:num//2])) ) - valid_split = sorted( list(set(indexes_5W[num//2:])) ) - assert len(train_split) + len(valid_split) == num - assert train_split[0] == 0 and train_split[10] == 26 and train_split[111] == 203 and valid_split[0] == 1 and valid_split[10] == 18 and valid_split[111] == 242, '{:} {:} {:} - {:} {:} {:}'.format(train_split[0], train_split[10], train_split[111], valid_split[0], valid_split[10], valid_split[111]) - splits = {num: {'train': train_split, 'valid': valid_split} } + random.seed(88) # please do not change this line for reproducibility + random.shuffle(archs) + # to test fixed-random shuffle + # print ('arch [0] : {:}\n---->>>> {:}'.format( archs[0], archs[0].tostr() )) + # print ('arch [9] : {:}\n---->>>> {:}'.format( archs[9], archs[9].tostr() )) + assert ( + archs[0].tostr() + == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" + ), "please check the 0-th architecture : {:}".format(archs[0]) + assert ( + archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" + ), "please check the 9-th architecture : {:}".format(archs[9]) + assert ( + archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" + ), "please check the 123-th architecture : {:}".format(archs[123]) + total_arch = len(archs) - info = {'archs' : [x.tostr() for x in archs], - 'total' : total_arch, - 'max_node' : max_node, - 'splits': splits} + num = 50000 + indexes_5W = list(range(num)) + random.seed(1021) + random.shuffle(indexes_5W) + train_split = sorted(list(set(indexes_5W[: num // 2]))) + valid_split = sorted(list(set(indexes_5W[num // 2 :]))) + assert len(train_split) + len(valid_split) == num + assert ( + train_split[0] == 0 + and train_split[10] == 26 + and train_split[111] == 203 + and valid_split[0] == 1 + and valid_split[10] == 18 + and valid_split[111] == 242 + ), "{:} {:} {:} - {:} {:} {:}".format( + train_split[0], train_split[10], train_split[111], valid_split[0], valid_split[10], valid_split[111] + ) + splits = {num: {"train": train_split, "valid": valid_split}} - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - save_name = save_dir / 'meta-node-{:}.pth'.format(max_node) - assert not save_name.exists(), '{:} already exist'.format(save_name) - torch.save(info, save_name) - print ('save the meta file into {:}'.format(save_name)) + info = {"archs": [x.tostr() for x in archs], "total": total_arch, "max_node": max_node, "splits": splits} + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + save_name = save_dir / "meta-node-{:}.pth".format(max_node) + assert not save_name.exists(), "{:} already exist".format(save_name) + torch.save(info, save_name) + print("save the meta file into {:}".format(save_name)) def traverse_net(max_node): - aa_nas_bench_ss = get_search_spaces('cell', 'nats-bench') - archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) - print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2))) + aa_nas_bench_ss = get_search_spaces("cell", "nats-bench") + archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) + print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2))) - random.seed( 88 ) # please do not change this line for reproducibility - random.shuffle( archs ) - assert archs[0 ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0]) - assert archs[9 ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9]) - assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123]) - return [x.tostr() for x in archs] + random.seed(88) # please do not change this line for reproducibility + random.shuffle(archs) + assert ( + archs[0].tostr() + == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" + ), "please check the 0-th architecture : {:}".format(archs[0]) + assert ( + archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" + ), "please check the 9-th architecture : {:}".format(archs[9]) + assert ( + archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" + ), "please check the 123-th architecture : {:}".format(archs[123]) + return [x.tostr() for x in archs] def filter_indexes(xlist, mode, save_dir, seeds): - all_indexes = [] - for index in xlist: - if mode == 'cover': - all_indexes.append(index) + all_indexes = [] + for index in xlist: + if mode == "cover": + all_indexes.append(index) + else: + for seed in seeds: + temp_path = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) + if not temp_path.exists(): + all_indexes.append(index) + break + print( + "{:} [FILTER-INDEXES] : there are {:}/{:} architectures in total".format( + time_string(), len(all_indexes), len(xlist) + ) + ) + return all_indexes + + +if __name__ == "__main__": + # mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()] + parser = argparse.ArgumentParser( + description="NATS-Bench (topology search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--mode", type=str, required=True, help="The script mode.") + parser.add_argument( + "--save_dir", type=str, default="output/NATS-Bench-topology", help="Folder to save checkpoints and log." + ) + parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell (please do not change it).") + # use for train the model + parser.add_argument("--workers", type=int, default=8, help="number of data loading workers (default: 2)") + parser.add_argument("--srange", type=str, required=True, help="The range of models to be evaluated") + parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.") + parser.add_argument("--xpaths", type=str, nargs="+", help="The root path for this dataset.") + parser.add_argument("--splits", type=int, nargs="+", help="The root path for this dataset.") + parser.add_argument( + "--hyper", type=str, default="12", choices=["01", "12", "200"], help="The tag for hyper-parameters." + ) + + parser.add_argument("--seeds", type=int, nargs="+", help="The range of models to be evaluated") + parser.add_argument("--channel", type=int, default=16, help="The number of channels.") + parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + parser.add_argument("--check_N", type=int, default=15625, help="For safety.") + args = parser.parse_args() + + assert args.mode in ["meta", "new", "cover"] or args.mode.startswith("specific-"), "invalid mode : {:}".format( + args.mode + ) + + if args.mode == "meta": + generate_meta_info(args.save_dir, args.max_node) + elif args.mode.startswith("specific"): + assert len(args.mode.split("-")) == 2, "invalid mode : {:}".format(args.mode) + model_str = args.mode.split("-")[1] + train_single_model( + args.save_dir, + args.workers, + args.datasets, + args.xpaths, + args.splits, + args.use_less > 0, + tuple(args.seeds), + model_str, + {"channel": args.channel, "num_cells": args.num_cells}, + ) else: - for seed in seeds: - temp_path = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) - if not temp_path.exists(): - all_indexes.append(index) - break - print('{:} [FILTER-INDEXES] : there are {:}/{:} architectures in total'.format(time_string(), len(all_indexes), len(xlist))) - return all_indexes + nets = traverse_net(args.max_node) + if len(nets) != args.check_N: + raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) + opt_config = "./configs/nas-benchmark/hyper-opts/{:}E.config".format(args.hyper) + if not os.path.isfile(opt_config): + raise ValueError("{:} is not a file.".format(opt_config)) + save_dir = Path(args.save_dir) / "raw-data-{:}".format(args.hyper) + save_dir.mkdir(parents=True, exist_ok=True) + to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5) + if not len(args.seeds): + raise ValueError("invalid length of seeds args: {:}".format(args.seeds)) + if not (len(args.datasets) == len(args.xpaths) == len(args.splits)): + raise ValueError( + "invalid infos : {:} vs {:} vs {:}".format(len(args.datasets), len(args.xpaths), len(args.splits)) + ) + if args.workers <= 0: + raise ValueError("invalid number of workers : {:}".format(args.workers)) + target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds) -if __name__ == '__main__': - # mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()] - parser = argparse.ArgumentParser(description='NATS-Bench (topology search space)', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--mode' , type=str, required=True, help='The script mode.') - parser.add_argument('--save_dir', type=str, default='output/NATS-Bench-topology', help='Folder to save checkpoints and log.') - parser.add_argument('--max_node', type=int, default=4, help='The maximum node in a cell (please do not change it).') - # use for train the model - parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 2)') - parser.add_argument('--srange' , type=str, required=True, help='The range of models to be evaluated') - parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.') - parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.') - parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.') - parser.add_argument('--hyper', type=str, default='12', choices=['01', '12', '200'], help='The tag for hyper-parameters.') + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.deterministic = True + torch.set_num_threads(args.workers) - parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated') - parser.add_argument('--channel', type=int, default=16, help='The number of channels.') - parser.add_argument('--num_cells', type=int, default=5, help='The number of cells in one stage.') - parser.add_argument('--check_N', type=int, default=15625, help='For safety.') - args = parser.parse_args() - - assert args.mode in ['meta', 'new', 'cover'] or args.mode.startswith('specific-'), 'invalid mode : {:}'.format(args.mode) - - if args.mode == 'meta': - generate_meta_info(args.save_dir, args.max_node) - elif args.mode.startswith('specific'): - assert len(args.mode.split('-')) == 2, 'invalid mode : {:}'.format(args.mode) - model_str = args.mode.split('-')[1] - train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \ - tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells}) - else: - nets = traverse_net(args.max_node) - if len(nets) != args.check_N: - raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N)) - opt_config = './configs/nas-benchmark/hyper-opts/{:}E.config'.format(args.hyper) - if not os.path.isfile(opt_config): - raise ValueError('{:} is not a file.'.format(opt_config)) - save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper) - save_dir.mkdir(parents=True, exist_ok=True) - to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5) - if not len(args.seeds): - raise ValueError('invalid length of seeds args: {:}'.format(args.seeds)) - if not (len(args.datasets) == len(args.xpaths) == len(args.splits)): - raise ValueError('invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))) - if args.workers <= 0: - raise ValueError('invalid number of workers : {:}'.format(args.workers)) - - target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds) - - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.deterministic = True - torch.set_num_threads(args.workers) - - main(save_dir, args.workers, args.datasets, args.xpaths, args.splits, tuple(args.seeds), nets, opt_config, target_indexes, args.mode == 'cover', \ - {'name': 'infer.tiny', 'channel': args.channel, 'num_cells': args.num_cells}) + main( + save_dir, + args.workers, + args.datasets, + args.xpaths, + args.splits, + tuple(args.seeds), + nets, + opt_config, + target_indexes, + args.mode == "cover", + {"name": "infer.tiny", "channel": args.channel, "num_cells": args.num_cells}, + ) diff --git a/exps/NATS-Bench/sss-collect.py b/exps/NATS-Bench/sss-collect.py index 036a3e7..dc8f650 100644 --- a/exps/NATS-Bench/sss-collect.py +++ b/exps/NATS-Bench/sss-collect.py @@ -16,263 +16,304 @@ from tqdm import tqdm from pathlib import Path from collections import defaultdict, OrderedDict from typing import Dict, Any, Text, List -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from log_utils import AverageMeter, time_string, convert_secs2time from config_utils import dict2config -from models import CellStructure, get_cell_based_tiny_net -from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount -from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders -from utils import get_md5_file +from models import CellStructure, get_cell_based_tiny_net +from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount +from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders +from utils import get_md5_file -NATS_SSS_BASE_NAME = 'NATS-sss-v1_0' # 2020.08.28 +NATS_SSS_BASE_NAME = "NATS-sss-v1_0" # 2020.08.28 def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text]) -> ArchResults: - information = ArchResults(arch_index, arch_str) + information = ArchResults(arch_index, arch_str) - for checkpoint_path in checkpoints: - try: - checkpoint = torch.load(checkpoint_path, map_location='cpu') - except: - raise ValueError('This checkpoint failed to be loaded : {:}'.format(checkpoint_path)) - used_seed = checkpoint_path.name.split('-')[-1].split('.')[0] - ok_dataset = 0 - for dataset in datasets: - if dataset not in checkpoint: - print('Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path)) - continue - else: - ok_dataset += 1 - results = checkpoint[dataset] - assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path) - arch_config = {'name': 'infer.shape.tiny', 'channels': arch_str, 'arch_str': arch_str, - 'genotype': results['arch_config']['genotype'], - 'class_num': results['arch_config']['num_classes']} - xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'], - results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None) - xresult.update_train_info(results['train_acc1es'], results['train_acc5es'], results['train_losses'], results['train_times']) - xresult.update_eval(results['valid_acc1es'], results['valid_losses'], results['valid_times']) - information.update(dataset, int(used_seed), xresult) - if ok_dataset < len(datasets): raise ValueError('{:} does find enought data : {:} vs {:}'.format(checkpoint_path, ok_dataset, len(datasets))) - return information + for checkpoint_path in checkpoints: + try: + checkpoint = torch.load(checkpoint_path, map_location="cpu") + except: + raise ValueError("This checkpoint failed to be loaded : {:}".format(checkpoint_path)) + used_seed = checkpoint_path.name.split("-")[-1].split(".")[0] + ok_dataset = 0 + for dataset in datasets: + if dataset not in checkpoint: + print("Can not find {:} in arch-{:} from {:}".format(dataset, arch_index, checkpoint_path)) + continue + else: + ok_dataset += 1 + results = checkpoint[dataset] + assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( + arch_index, used_seed, dataset, checkpoint_path + ) + arch_config = { + "name": "infer.shape.tiny", + "channels": arch_str, + "arch_str": arch_str, + "genotype": results["arch_config"]["genotype"], + "class_num": results["arch_config"]["num_classes"], + } + xresult = ResultsCount( + dataset, + results["net_state_dict"], + results["train_acc1es"], + results["train_losses"], + results["param"], + results["flop"], + arch_config, + used_seed, + results["total_epoch"], + None, + ) + xresult.update_train_info( + results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"] + ) + xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"]) + information.update(dataset, int(used_seed), xresult) + if ok_dataset < len(datasets): + raise ValueError( + "{:} does find enought data : {:} vs {:}".format(checkpoint_path, ok_dataset, len(datasets)) + ) + return information def correct_time_related_info(hp2info: Dict[Text, ArchResults]): - # calibrate the latency based on the number of epochs = 01, since they are trained on the same machine. - x1 = hp2info['01'].get_metrics('cifar10-valid', 'x-valid')['all_time'] / 98 - x2 = hp2info['01'].get_metrics('cifar10-valid', 'ori-test')['all_time'] / 40 - cifar010_latency = (x1 + x2) / 2 - for hp, arch_info in hp2info.items(): - arch_info.reset_latency('cifar10-valid', None, cifar010_latency) - arch_info.reset_latency('cifar10', None, cifar010_latency) - # hp2info['01'].get_latency('cifar10') + # calibrate the latency based on the number of epochs = 01, since they are trained on the same machine. + x1 = hp2info["01"].get_metrics("cifar10-valid", "x-valid")["all_time"] / 98 + x2 = hp2info["01"].get_metrics("cifar10-valid", "ori-test")["all_time"] / 40 + cifar010_latency = (x1 + x2) / 2 + for hp, arch_info in hp2info.items(): + arch_info.reset_latency("cifar10-valid", None, cifar010_latency) + arch_info.reset_latency("cifar10", None, cifar010_latency) + # hp2info['01'].get_latency('cifar10') - x1 = hp2info['01'].get_metrics('cifar100', 'ori-test')['all_time'] / 40 - x2 = hp2info['01'].get_metrics('cifar100', 'x-test')['all_time'] / 20 - x3 = hp2info['01'].get_metrics('cifar100', 'x-valid')['all_time'] / 20 - cifar100_latency = (x1 + x2 + x3) / 3 - for hp, arch_info in hp2info.items(): - arch_info.reset_latency('cifar100', None, cifar100_latency) + x1 = hp2info["01"].get_metrics("cifar100", "ori-test")["all_time"] / 40 + x2 = hp2info["01"].get_metrics("cifar100", "x-test")["all_time"] / 20 + x3 = hp2info["01"].get_metrics("cifar100", "x-valid")["all_time"] / 20 + cifar100_latency = (x1 + x2 + x3) / 3 + for hp, arch_info in hp2info.items(): + arch_info.reset_latency("cifar100", None, cifar100_latency) - x1 = hp2info['01'].get_metrics('ImageNet16-120', 'ori-test')['all_time'] / 24 - x2 = hp2info['01'].get_metrics('ImageNet16-120', 'x-test')['all_time'] / 12 - x3 = hp2info['01'].get_metrics('ImageNet16-120', 'x-valid')['all_time'] / 12 - image_latency = (x1 + x2 + x3) / 3 - for hp, arch_info in hp2info.items(): - arch_info.reset_latency('ImageNet16-120', None, image_latency) + x1 = hp2info["01"].get_metrics("ImageNet16-120", "ori-test")["all_time"] / 24 + x2 = hp2info["01"].get_metrics("ImageNet16-120", "x-test")["all_time"] / 12 + x3 = hp2info["01"].get_metrics("ImageNet16-120", "x-valid")["all_time"] / 12 + image_latency = (x1 + x2 + x3) / 3 + for hp, arch_info in hp2info.items(): + arch_info.reset_latency("ImageNet16-120", None, image_latency) - # CIFAR10 VALID - train_per_epoch_time = list(hp2info['01'].query('cifar10-valid', 777).train_times.values()) - train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) - eval_ori_test_time, eval_x_valid_time = [], [] - for key, value in hp2info['01'].query('cifar10-valid', 777).eval_times.items(): - if key.startswith('ori-test@'): - eval_ori_test_time.append(value) - elif key.startswith('x-valid@'): - eval_x_valid_time.append(value) - else: raise ValueError('-- {:} --'.format(key)) - eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) - eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time) - for hp, arch_info in hp2info.items(): - arch_info.reset_pseudo_train_times('cifar10-valid', None, train_per_epoch_time) - arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'x-valid', eval_x_valid_time) - arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'ori-test', eval_ori_test_time) + # CIFAR10 VALID + train_per_epoch_time = list(hp2info["01"].query("cifar10-valid", 777).train_times.values()) + train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) + eval_ori_test_time, eval_x_valid_time = [], [] + for key, value in hp2info["01"].query("cifar10-valid", 777).eval_times.items(): + if key.startswith("ori-test@"): + eval_ori_test_time.append(value) + elif key.startswith("x-valid@"): + eval_x_valid_time.append(value) + else: + raise ValueError("-- {:} --".format(key)) + eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) + eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time) + for hp, arch_info in hp2info.items(): + arch_info.reset_pseudo_train_times("cifar10-valid", None, train_per_epoch_time) + arch_info.reset_pseudo_eval_times("cifar10-valid", None, "x-valid", eval_x_valid_time) + arch_info.reset_pseudo_eval_times("cifar10-valid", None, "ori-test", eval_ori_test_time) - # CIFAR10 - train_per_epoch_time = list(hp2info['01'].query('cifar10', 777).train_times.values()) - train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) - eval_ori_test_time = [] - for key, value in hp2info['01'].query('cifar10', 777).eval_times.items(): - if key.startswith('ori-test@'): - eval_ori_test_time.append(value) - else: raise ValueError('-- {:} --'.format(key)) - eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) - for hp, arch_info in hp2info.items(): - arch_info.reset_pseudo_train_times('cifar10', None, train_per_epoch_time) - arch_info.reset_pseudo_eval_times('cifar10', None, 'ori-test', eval_ori_test_time) + # CIFAR10 + train_per_epoch_time = list(hp2info["01"].query("cifar10", 777).train_times.values()) + train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) + eval_ori_test_time = [] + for key, value in hp2info["01"].query("cifar10", 777).eval_times.items(): + if key.startswith("ori-test@"): + eval_ori_test_time.append(value) + else: + raise ValueError("-- {:} --".format(key)) + eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) + for hp, arch_info in hp2info.items(): + arch_info.reset_pseudo_train_times("cifar10", None, train_per_epoch_time) + arch_info.reset_pseudo_eval_times("cifar10", None, "ori-test", eval_ori_test_time) - # CIFAR100 - train_per_epoch_time = list(hp2info['01'].query('cifar100', 777).train_times.values()) - train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) - eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], [] - for key, value in hp2info['01'].query('cifar100', 777).eval_times.items(): - if key.startswith('ori-test@'): - eval_ori_test_time.append(value) - elif key.startswith('x-valid@'): - eval_x_valid_time.append(value) - elif key.startswith('x-test@'): - eval_x_test_time.append(value) - else: raise ValueError('-- {:} --'.format(key)) - eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) - eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time) - eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time) - for hp, arch_info in hp2info.items(): - arch_info.reset_pseudo_train_times('cifar100', None, train_per_epoch_time) - arch_info.reset_pseudo_eval_times('cifar100', None, 'x-valid', eval_x_valid_time) - arch_info.reset_pseudo_eval_times('cifar100', None, 'x-test', eval_x_test_time) - arch_info.reset_pseudo_eval_times('cifar100', None, 'ori-test', eval_ori_test_time) + # CIFAR100 + train_per_epoch_time = list(hp2info["01"].query("cifar100", 777).train_times.values()) + train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) + eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], [] + for key, value in hp2info["01"].query("cifar100", 777).eval_times.items(): + if key.startswith("ori-test@"): + eval_ori_test_time.append(value) + elif key.startswith("x-valid@"): + eval_x_valid_time.append(value) + elif key.startswith("x-test@"): + eval_x_test_time.append(value) + else: + raise ValueError("-- {:} --".format(key)) + eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) + eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time) + eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time) + for hp, arch_info in hp2info.items(): + arch_info.reset_pseudo_train_times("cifar100", None, train_per_epoch_time) + arch_info.reset_pseudo_eval_times("cifar100", None, "x-valid", eval_x_valid_time) + arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_x_test_time) + arch_info.reset_pseudo_eval_times("cifar100", None, "ori-test", eval_ori_test_time) - # ImageNet16-120 - train_per_epoch_time = list(hp2info['01'].query('ImageNet16-120', 777).train_times.values()) - train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) - eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], [] - for key, value in hp2info['01'].query('ImageNet16-120', 777).eval_times.items(): - if key.startswith('ori-test@'): - eval_ori_test_time.append(value) - elif key.startswith('x-valid@'): - eval_x_valid_time.append(value) - elif key.startswith('x-test@'): - eval_x_test_time.append(value) - else: raise ValueError('-- {:} --'.format(key)) - eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) - eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time) - eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time) - for hp, arch_info in hp2info.items(): - arch_info.reset_pseudo_train_times('ImageNet16-120', None, train_per_epoch_time) - arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-valid', eval_x_valid_time) - arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-test', eval_x_test_time) - arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'ori-test', eval_ori_test_time) - return hp2info + # ImageNet16-120 + train_per_epoch_time = list(hp2info["01"].query("ImageNet16-120", 777).train_times.values()) + train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) + eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], [] + for key, value in hp2info["01"].query("ImageNet16-120", 777).eval_times.items(): + if key.startswith("ori-test@"): + eval_ori_test_time.append(value) + elif key.startswith("x-valid@"): + eval_x_valid_time.append(value) + elif key.startswith("x-test@"): + eval_x_test_time.append(value) + else: + raise ValueError("-- {:} --".format(key)) + eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) + eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time) + eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time) + for hp, arch_info in hp2info.items(): + arch_info.reset_pseudo_train_times("ImageNet16-120", None, train_per_epoch_time) + arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "x-valid", eval_x_valid_time) + arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "x-test", eval_x_test_time) + arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "ori-test", eval_ori_test_time) + return hp2info def simplify(save_dir, save_name, nets, total): - - hps, seeds = ['01', '12', '90'], set() - for hp in hps: - sub_save_dir = save_dir / 'raw-data-{:}'.format(hp) - ckps = sorted(list(sub_save_dir.glob('arch-*-seed-*.pth'))) - seed2names = defaultdict(list) - for ckp in ckps: - parts = re.split('-|\.', ckp.name) - seed2names[parts[3]].append(ckp.name) - print('DIR : {:}'.format(sub_save_dir)) - nums = [] - for seed, xlist in seed2names.items(): - seeds.add(seed) - nums.append(len(xlist)) - print(' [seed={:}] there are {:} checkpoints.'.format(seed, len(xlist))) - assert len(nets) == total == max(nums), 'there are some missed files : {:} vs {:}'.format(max(nums), total) - print('{:} start simplify the checkpoint.'.format(time_string())) - - datasets = ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120') - - # Create the directory to save the processed data - # full_save_dir contains all benchmark files with trained weights. - # simplify_save_dir contains all benchmark files without trained weights. - full_save_dir = save_dir / (save_name + '-FULL') - simple_save_dir = save_dir / (save_name + '-SIMPLIFY') - full_save_dir.mkdir(parents=True, exist_ok=True) - simple_save_dir.mkdir(parents=True, exist_ok=True) - # all data in memory - arch2infos, evaluated_indexes = dict(), set() - end_time, arch_time = time.time(), AverageMeter() - - for index in tqdm(range(total)): - arch_str = nets[index] - hp2info = OrderedDict() - - full_save_path = full_save_dir / '{:06d}.pickle'.format(index) - simple_save_path = simple_save_dir / '{:06d}.pickle'.format(index) + hps, seeds = ["01", "12", "90"], set() for hp in hps: - sub_save_dir = save_dir / 'raw-data-{:}'.format(hp) - ckps = [sub_save_dir / 'arch-{:06d}-seed-{:}.pth'.format(index, seed) for seed in seeds] - ckps = [x for x in ckps if x.exists()] - if len(ckps) == 0: - raise ValueError('Invalid data : index={:}, hp={:}'.format(index, hp)) + sub_save_dir = save_dir / "raw-data-{:}".format(hp) + ckps = sorted(list(sub_save_dir.glob("arch-*-seed-*.pth"))) + seed2names = defaultdict(list) + for ckp in ckps: + parts = re.split("-|\.", ckp.name) + seed2names[parts[3]].append(ckp.name) + print("DIR : {:}".format(sub_save_dir)) + nums = [] + for seed, xlist in seed2names.items(): + seeds.add(seed) + nums.append(len(xlist)) + print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist))) + assert len(nets) == total == max(nums), "there are some missed files : {:} vs {:}".format(max(nums), total) + print("{:} start simplify the checkpoint.".format(time_string())) - arch_info = account_one_arch(index, arch_str, ckps, datasets) - hp2info[hp] = arch_info - - hp2info = correct_time_related_info(hp2info) - evaluated_indexes.add(index) + datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") - hp2info['01'].clear_params() # to save some spaces... - to_save_data = OrderedDict({'01': hp2info['01'].state_dict(), - '12': hp2info['12'].state_dict(), - '90': hp2info['90'].state_dict()}) - pickle_save(to_save_data, str(full_save_path)) - - for hp in hps: hp2info[hp].clear_params() - to_save_data = OrderedDict({'01': hp2info['01'].state_dict(), - '12': hp2info['12'].state_dict(), - '90': hp2info['90'].state_dict()}) - pickle_save(to_save_data, str(simple_save_path)) - arch2infos[index] = to_save_data - # measure elapsed time - arch_time.update(time.time() - end_time) - end_time = time.time() - need_time = '{:}'.format(convert_secs2time(arch_time.avg * (total-index-1), True)) - # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) - print('{:} {:} done.'.format(time_string(), save_name)) - final_infos = {'meta_archs' : nets, - 'total_archs': total, - 'arch2infos' : arch2infos, - 'evaluated_indexes': evaluated_indexes} - save_file_name = save_dir / '{:}.pickle'.format(save_name) - pickle_save(final_infos, str(save_file_name)) - # move the benchmark file to a new path - hd5sum = get_md5_file(str(save_file_name) + '.pbz2') - hd5_file_name = save_dir / '{:}-{:}.pickle.pbz2'.format(NATS_SSS_BASE_NAME, hd5sum) - shutil.move(str(save_file_name) + '.pbz2', hd5_file_name) - print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name)) - # move the directory to a new path - hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_SSS_BASE_NAME, hd5sum) - hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_SSS_BASE_NAME, hd5sum) - shutil.move(full_save_dir, hd5_full_save_dir) - shutil.move(simple_save_dir, hd5_simple_save_dir) - # save the meta information for simple and full - final_infos['arch2infos'] = None - final_infos['evaluated_indexes'] = set() - pickle_save(final_infos, str(hd5_full_save_dir / 'meta.pickle')) - pickle_save(final_infos, str(hd5_simple_save_dir / 'meta.pickle')) + # Create the directory to save the processed data + # full_save_dir contains all benchmark files with trained weights. + # simplify_save_dir contains all benchmark files without trained weights. + full_save_dir = save_dir / (save_name + "-FULL") + simple_save_dir = save_dir / (save_name + "-SIMPLIFY") + full_save_dir.mkdir(parents=True, exist_ok=True) + simple_save_dir.mkdir(parents=True, exist_ok=True) + # all data in memory + arch2infos, evaluated_indexes = dict(), set() + end_time, arch_time = time.time(), AverageMeter() + + for index in tqdm(range(total)): + arch_str = nets[index] + hp2info = OrderedDict() + + full_save_path = full_save_dir / "{:06d}.pickle".format(index) + simple_save_path = simple_save_dir / "{:06d}.pickle".format(index) + + for hp in hps: + sub_save_dir = save_dir / "raw-data-{:}".format(hp) + ckps = [sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) for seed in seeds] + ckps = [x for x in ckps if x.exists()] + if len(ckps) == 0: + raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp)) + + arch_info = account_one_arch(index, arch_str, ckps, datasets) + hp2info[hp] = arch_info + + hp2info = correct_time_related_info(hp2info) + evaluated_indexes.add(index) + + hp2info["01"].clear_params() # to save some spaces... + to_save_data = OrderedDict( + {"01": hp2info["01"].state_dict(), "12": hp2info["12"].state_dict(), "90": hp2info["90"].state_dict()} + ) + pickle_save(to_save_data, str(full_save_path)) + + for hp in hps: + hp2info[hp].clear_params() + to_save_data = OrderedDict( + {"01": hp2info["01"].state_dict(), "12": hp2info["12"].state_dict(), "90": hp2info["90"].state_dict()} + ) + pickle_save(to_save_data, str(simple_save_path)) + arch2infos[index] = to_save_data + # measure elapsed time + arch_time.update(time.time() - end_time) + end_time = time.time() + need_time = "{:}".format(convert_secs2time(arch_time.avg * (total - index - 1), True)) + # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) + print("{:} {:} done.".format(time_string(), save_name)) + final_infos = { + "meta_archs": nets, + "total_archs": total, + "arch2infos": arch2infos, + "evaluated_indexes": evaluated_indexes, + } + save_file_name = save_dir / "{:}.pickle".format(save_name) + pickle_save(final_infos, str(save_file_name)) + # move the benchmark file to a new path + hd5sum = get_md5_file(str(save_file_name) + ".pbz2") + hd5_file_name = save_dir / "{:}-{:}.pickle.pbz2".format(NATS_SSS_BASE_NAME, hd5sum) + shutil.move(str(save_file_name) + ".pbz2", hd5_file_name) + print( + "Save {:} / {:} architecture results into {:} -> {:}.".format( + len(evaluated_indexes), total, save_file_name, hd5_file_name + ) + ) + # move the directory to a new path + hd5_full_save_dir = save_dir / "{:}-{:}-full".format(NATS_SSS_BASE_NAME, hd5sum) + hd5_simple_save_dir = save_dir / "{:}-{:}-simple".format(NATS_SSS_BASE_NAME, hd5sum) + shutil.move(full_save_dir, hd5_full_save_dir) + shutil.move(simple_save_dir, hd5_simple_save_dir) + # save the meta information for simple and full + final_infos["arch2infos"] = None + final_infos["evaluated_indexes"] = set() + pickle_save(final_infos, str(hd5_full_save_dir / "meta.pickle")) + pickle_save(final_infos, str(hd5_simple_save_dir / "meta.pickle")) def traverse_net(candidates: List[int], N: int): - nets = [''] - for i in range(N): - new_nets = [] - for net in nets: - for C in candidates: - new_nets.append(str(C) if net == '' else "{:}:{:}".format(net,C)) - nets = new_nets - return nets + nets = [""] + for i in range(N): + new_nets = [] + for net in nets: + for C in candidates: + new_nets.append(str(C) if net == "" else "{:}:{:}".format(net, C)) + nets = new_nets + return nets -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NATS-Bench (size search space)', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--base_save_dir', type=str, default='./output/NATS-Bench-size', help='The base-name of folder to save checkpoints and log.') - parser.add_argument('--candidateC' , type=int, nargs='+', default=[8, 16, 24, 32, 40, 48, 56, 64], help='.') - parser.add_argument('--num_layers' , type=int, default=5, help='The number of layers in a network.') - parser.add_argument('--check_N' , type=int, default=32768, help='For safety.') - parser.add_argument('--save_name' , type=str, default='process', help='The save directory.') - args = parser.parse_args() - - nets = traverse_net(args.candidateC, args.num_layers) - if len(nets) != args.check_N: - raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N)) +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="NATS-Bench (size search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--base_save_dir", + type=str, + default="./output/NATS-Bench-size", + help="The base-name of folder to save checkpoints and log.", + ) + parser.add_argument("--candidateC", type=int, nargs="+", default=[8, 16, 24, 32, 40, 48, 56, 64], help=".") + parser.add_argument("--num_layers", type=int, default=5, help="The number of layers in a network.") + parser.add_argument("--check_N", type=int, default=32768, help="For safety.") + parser.add_argument("--save_name", type=str, default="process", help="The save directory.") + args = parser.parse_args() - save_dir = Path(args.base_save_dir) - simplify(save_dir, args.save_name, nets, args.check_N) + nets = traverse_net(args.candidateC, args.num_layers) + if len(nets) != args.check_N: + raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) + + save_dir = Path(args.base_save_dir) + simplify(save_dir, args.save_name, nets, args.check_N) diff --git a/exps/NATS-Bench/sss-file-manager.py b/exps/NATS-Bench/sss-file-manager.py index 93373ff..292a2ff 100644 --- a/exps/NATS-Bench/sss-file-manager.py +++ b/exps/NATS-Bench/sss-file-manager.py @@ -9,72 +9,82 @@ import os, sys, time, torch, argparse from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config -from procedures import bench_evaluate_for_seed -from procedures import get_machine_info -from datasets import get_datasets -from log_utils import Logger, AverageMeter, time_string, convert_secs2time +from procedures import bench_evaluate_for_seed +from procedures import get_machine_info +from datasets import get_datasets +from log_utils import Logger, AverageMeter, time_string, convert_secs2time def obtain_valid_ckp(save_dir: Text, total: int): - possible_seeds = [777, 888, 999] - seed2ckps = defaultdict(list) - miss2ckps = defaultdict(list) - for i in range(total): - for seed in possible_seeds: - path = os.path.join(save_dir, 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed)) - if os.path.exists(path): - seed2ckps[seed].append(i) - else: - miss2ckps[seed].append(i) - for seed, xlist in seed2ckps.items(): - print('[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}'.format(save_dir, seed, len(xlist), total, total-len(xlist), total)) - return dict(seed2ckps), dict(miss2ckps) - + possible_seeds = [777, 888, 999] + seed2ckps = defaultdict(list) + miss2ckps = defaultdict(list) + for i in range(total): + for seed in possible_seeds: + path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed)) + if os.path.exists(path): + seed2ckps[seed].append(i) + else: + miss2ckps[seed].append(i) + for seed, xlist in seed2ckps.items(): + print( + "[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format( + save_dir, seed, len(xlist), total, total - len(xlist), total + ) + ) + return dict(seed2ckps), dict(miss2ckps) + def copy_data(source_dir, target_dir, meta_path): - target_dir = Path(target_dir) - target_dir.mkdir(parents=True, exist_ok=True) - miss2ckps = torch.load(meta_path)['miss2ckps'] - s2t = {} - for seed, xlist in miss2ckps.items(): - for i in xlist: - file_name = 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed) - source_path = os.path.join(source_dir, file_name) - target_path = os.path.join(target_dir, file_name) - if os.path.exists(source_path): - s2t[source_path] = target_path - print('Map from {:} to {:}, find {:} missed ckps.'.format(source_dir, target_dir, len(s2t))) - for s, t in s2t.items(): - copyfile(s, t) + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + miss2ckps = torch.load(meta_path)["miss2ckps"] + s2t = {} + for seed, xlist in miss2ckps.items(): + for i in xlist: + file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed) + source_path = os.path.join(source_dir, file_name) + target_path = os.path.join(target_dir, file_name) + if os.path.exists(source_path): + s2t[source_path] = target_path + print("Map from {:} to {:}, find {:} missed ckps.".format(source_dir, target_dir, len(s2t))) + for s, t in s2t.items(): + copyfile(s, t) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NATS-Bench (size search space) file manager.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--mode', type=str, required=True, choices=['check', 'copy'], help='The script mode.') - parser.add_argument('--save_dir', type=str, default='output/NATS-Bench-size', help='Folder to save checkpoints and log.') - parser.add_argument('--check_N', type=int, default=32768, help='For safety.') - # use for train the model - args = parser.parse_args() - possible_configs = ['01', '12', '90'] - if args.mode == 'check': - for config in possible_configs: - cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config) - seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N) - torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config)) - elif args.mode == 'copy': - for config in possible_configs: - cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config) - cur_copy_dir = '{:}/copy-{:}'.format(args.save_dir, config) - cur_meta_path = '{:}/meta-{:}.pth'.format(args.save_dir, config) - if os.path.exists(cur_meta_path): - copy_data(cur_save_dir, cur_copy_dir, cur_meta_path) - else: - print('Do not find : {:}'.format(cur_meta_path)) - else: - raise ValueError('invalid mode : {:}'.format(args.mode)) +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="NATS-Bench (size search space) file manager.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--mode", type=str, required=True, choices=["check", "copy"], help="The script mode.") + parser.add_argument( + "--save_dir", type=str, default="output/NATS-Bench-size", help="Folder to save checkpoints and log." + ) + parser.add_argument("--check_N", type=int, default=32768, help="For safety.") + # use for train the model + args = parser.parse_args() + possible_configs = ["01", "12", "90"] + if args.mode == "check": + for config in possible_configs: + cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) + seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N) + torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), "{:}/meta-{:}.pth".format(args.save_dir, config)) + elif args.mode == "copy": + for config in possible_configs: + cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) + cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config) + cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config) + if os.path.exists(cur_meta_path): + copy_data(cur_save_dir, cur_copy_dir, cur_meta_path) + else: + print("Do not find : {:}".format(cur_meta_path)) + else: + raise ValueError("invalid mode : {:}".format(args.mode)) diff --git a/exps/NATS-Bench/test-nats-api.py b/exps/NATS-Bench/test-nats-api.py index a9ca1e4..75ce7fd 100644 --- a/exps/NATS-Bench/test-nats-api.py +++ b/exps/NATS-Bench/test-nats-api.py @@ -10,16 +10,18 @@ import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict -from copy import deepcopy +from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from nats_bench import create from log_utils import time_string @@ -27,78 +29,78 @@ from models import get_cell_based_tiny_net, CellStructure def test_api(api, sss_or_tss=True): - print('{:} start testing the api : {:}'.format(time_string(), api)) - api.clear_params(12) - api.reload(index=12) - - # Query the informations of 1113-th architecture - info_strs = api.query_info_str_by_arch(1113) - print(info_strs) - info = api.query_by_index(113) - print('{:}\n'.format(info)) - info = api.query_by_index(113, 'cifar100') - print('{:}\n'.format(info)) + print("{:} start testing the api : {:}".format(time_string(), api)) + api.clear_params(12) + api.reload(index=12) - info = api.query_meta_info_by_index(115, '90' if sss_or_tss else '200') - print('{:}\n'.format(info)) + # Query the informations of 1113-th architecture + info_strs = api.query_info_str_by_arch(1113) + print(info_strs) + info = api.query_by_index(113) + print("{:}\n".format(info)) + info = api.query_by_index(113, "cifar100") + print("{:}\n".format(info)) - for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']: - for xset in ['train', 'test', 'valid']: - best_index, highest_accuracy = api.find_best(dataset, xset) - print('') - params = api.get_net_param(12, 'cifar10', None) + info = api.query_meta_info_by_index(115, "90" if sss_or_tss else "200") + print("{:}\n".format(info)) - # Obtain the config and create the network - config = api.get_net_config(12, 'cifar10') - print('{:}\n'.format(config)) - network = get_cell_based_tiny_net(config) - network.load_state_dict(next(iter(params.values()))) + for dataset in ["cifar10", "cifar100", "ImageNet16-120"]: + for xset in ["train", "test", "valid"]: + best_index, highest_accuracy = api.find_best(dataset, xset) + print("") + params = api.get_net_param(12, "cifar10", None) - # Obtain the cost information - info = api.get_cost_info(12, 'cifar10') - print('{:}\n'.format(info)) - info = api.get_latency(12, 'cifar10') - print('{:}\n'.format(info)) - for index in [13, 15, 19, 200]: - info = api.get_latency(index, 'cifar10') + # Obtain the config and create the network + config = api.get_net_config(12, "cifar10") + print("{:}\n".format(config)) + network = get_cell_based_tiny_net(config) + network.load_state_dict(next(iter(params.values()))) - # Count the number of architectures - info = api.statistics('cifar100', '12') - print('{:} statistics results : {:}\n'.format(time_string(), info)) + # Obtain the cost information + info = api.get_cost_info(12, "cifar10") + print("{:}\n".format(info)) + info = api.get_latency(12, "cifar10") + print("{:}\n".format(info)) + for index in [13, 15, 19, 200]: + info = api.get_latency(index, "cifar10") - # Show the information of the 123-th architecture - api.show(123) + # Count the number of architectures + info = api.statistics("cifar100", "12") + print("{:} statistics results : {:}\n".format(time_string(), info)) - # Obtain both cost and performance information - info = api.get_more_info(1234, 'cifar10') - print('{:}\n'.format(info)) - print('{:} finish testing the api : {:}'.format(time_string(), api)) + # Show the information of the 123-th architecture + api.show(123) - if not sss_or_tss: - arch_str = '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|' - matrix = api.str2matrix(arch_str) - print('Compute the adjacency matrix of {:}'.format(arch_str)) - print(matrix) - info = api.simulate_train_eval(123, 'cifar10') - print('simulate_train_eval : {:}\n\n'.format(info)) + # Obtain both cost and performance information + info = api.get_more_info(1234, "cifar10") + print("{:}\n".format(info)) + print("{:} finish testing the api : {:}".format(time_string(), api)) + + if not sss_or_tss: + arch_str = "|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|" + matrix = api.str2matrix(arch_str) + print("Compute the adjacency matrix of {:}".format(arch_str)) + print(matrix) + info = api.simulate_train_eval(123, "cifar10") + print("simulate_train_eval : {:}\n\n".format(info)) -if __name__ == '__main__': +if __name__ == "__main__": - # api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True) - for fast_mode in [True, False]: - for verbose in [True, False]: - api_nats_tss = create(None, 'tss', fast_mode=fast_mode, verbose=True) - print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose)) - test_api(api_nats_tss, False) - del api_nats_tss - gc.collect() + # api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True) + for fast_mode in [True, False]: + for verbose in [True, False]: + api_nats_tss = create(None, "tss", fast_mode=fast_mode, verbose=True) + print("{:} create with fast_mode={:} and verbose={:}".format(time_string(), fast_mode, verbose)) + test_api(api_nats_tss, False) + del api_nats_tss + gc.collect() - for fast_mode in [True, False]: - for verbose in [True, False]: - print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose)) - api_nats_sss = create(None, 'size', fast_mode=fast_mode, verbose=True) - print('{:} --->>> {:}'.format(time_string(), api_nats_sss)) - test_api(api_nats_sss, True) - del api_nats_sss - gc.collect() + for fast_mode in [True, False]: + for verbose in [True, False]: + print("{:} create with fast_mode={:} and verbose={:}".format(time_string(), fast_mode, verbose)) + api_nats_sss = create(None, "size", fast_mode=fast_mode, verbose=True) + print("{:} --->>> {:}".format(time_string(), api_nats_sss)) + test_api(api_nats_sss, True) + del api_nats_sss + gc.collect() diff --git a/exps/NATS-Bench/tss-collect-patcher.py b/exps/NATS-Bench/tss-collect-patcher.py index 3a9414e..5b4b13f 100644 --- a/exps/NATS-Bench/tss-collect-patcher.py +++ b/exps/NATS-Bench/tss-collect-patcher.py @@ -18,112 +18,140 @@ from tqdm import tqdm from pathlib import Path from collections import defaultdict, OrderedDict from typing import Dict, Any, Text, List -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from log_utils import AverageMeter, time_string, convert_secs2time from config_utils import load_config, dict2config -from datasets import get_datasets -from models import CellStructure, get_cell_based_tiny_net, get_search_spaces -from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount -from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders -from utils import get_md5_file -from nas_201_api import NASBench201API +from datasets import get_datasets +from models import CellStructure, get_cell_based_tiny_net, get_search_spaces +from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount +from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders +from utils import get_md5_file +from nas_201_api import NASBench201API -NATS_TSS_BASE_NAME = 'NATS-tss-v1_0' # 2020.08.28 +NATS_TSS_BASE_NAME = "NATS-tss-v1_0" # 2020.08.28 def simplify(save_dir, save_name, nets, total, sup_config): - hps, seeds = ['12', '200'], set() - for hp in hps: - sub_save_dir = save_dir / 'raw-data-{:}'.format(hp) - ckps = sorted(list(sub_save_dir.glob('arch-*-seed-*.pth'))) - seed2names = defaultdict(list) - for ckp in ckps: - parts = re.split('-|\.', ckp.name) - seed2names[parts[3]].append(ckp.name) - print('DIR : {:}'.format(sub_save_dir)) - nums = [] - for seed, xlist in seed2names.items(): - seeds.add(seed) - nums.append(len(xlist)) - print(' [seed={:}] there are {:} checkpoints.'.format(seed, len(xlist))) - assert len(nets) == total == max(nums), 'there are some missed files : {:} vs {:}'.format(max(nums), total) - print('{:} start simplify the checkpoint.'.format(time_string())) + hps, seeds = ["12", "200"], set() + for hp in hps: + sub_save_dir = save_dir / "raw-data-{:}".format(hp) + ckps = sorted(list(sub_save_dir.glob("arch-*-seed-*.pth"))) + seed2names = defaultdict(list) + for ckp in ckps: + parts = re.split("-|\.", ckp.name) + seed2names[parts[3]].append(ckp.name) + print("DIR : {:}".format(sub_save_dir)) + nums = [] + for seed, xlist in seed2names.items(): + seeds.add(seed) + nums.append(len(xlist)) + print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist))) + assert len(nets) == total == max(nums), "there are some missed files : {:} vs {:}".format(max(nums), total) + print("{:} start simplify the checkpoint.".format(time_string())) - datasets = ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120') + datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") - # Create the directory to save the processed data - # full_save_dir contains all benchmark files with trained weights. - # simplify_save_dir contains all benchmark files without trained weights. - full_save_dir = save_dir / (save_name + '-FULL') - simple_save_dir = save_dir / (save_name + '-SIMPLIFY') - full_save_dir.mkdir(parents=True, exist_ok=True) - simple_save_dir.mkdir(parents=True, exist_ok=True) - # all data in memory - arch2infos, evaluated_indexes = dict(), set() - end_time, arch_time = time.time(), AverageMeter() - # save the meta information - for index in tqdm(range(total)): - arch_str = nets[index] - hp2info = OrderedDict() + # Create the directory to save the processed data + # full_save_dir contains all benchmark files with trained weights. + # simplify_save_dir contains all benchmark files without trained weights. + full_save_dir = save_dir / (save_name + "-FULL") + simple_save_dir = save_dir / (save_name + "-SIMPLIFY") + full_save_dir.mkdir(parents=True, exist_ok=True) + simple_save_dir.mkdir(parents=True, exist_ok=True) + # all data in memory + arch2infos, evaluated_indexes = dict(), set() + end_time, arch_time = time.time(), AverageMeter() + # save the meta information + for index in tqdm(range(total)): + arch_str = nets[index] + hp2info = OrderedDict() - simple_save_path = simple_save_dir / '{:06d}.pickle'.format(index) - - arch2infos[index] = pickle_load(simple_save_path) - evaluated_indexes.add(index) + simple_save_path = simple_save_dir / "{:06d}.pickle".format(index) - # measure elapsed time - arch_time.update(time.time() - end_time) - end_time = time.time() - need_time = '{:}'.format(convert_secs2time(arch_time.avg * (total-index-1), True)) - # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) - print('{:} {:} done.'.format(time_string(), save_name)) - final_infos = {'meta_archs' : nets, - 'total_archs': total, - 'arch2infos' : arch2infos, - 'evaluated_indexes': evaluated_indexes} - save_file_name = save_dir / '{:}.pickle'.format(save_name) - pickle_save(final_infos, str(save_file_name)) - # move the benchmark file to a new path - hd5sum = get_md5_file(str(save_file_name) + '.pbz2') - hd5_file_name = save_dir / '{:}-{:}.pickle.pbz2'.format(NATS_TSS_BASE_NAME, hd5sum) - shutil.move(str(save_file_name) + '.pbz2', hd5_file_name) - print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name)) - # move the directory to a new path - hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum) - hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum) - shutil.move(full_save_dir, hd5_full_save_dir) - shutil.move(simple_save_dir, hd5_simple_save_dir) + arch2infos[index] = pickle_load(simple_save_path) + evaluated_indexes.add(index) + + # measure elapsed time + arch_time.update(time.time() - end_time) + end_time = time.time() + need_time = "{:}".format(convert_secs2time(arch_time.avg * (total - index - 1), True)) + # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) + print("{:} {:} done.".format(time_string(), save_name)) + final_infos = { + "meta_archs": nets, + "total_archs": total, + "arch2infos": arch2infos, + "evaluated_indexes": evaluated_indexes, + } + save_file_name = save_dir / "{:}.pickle".format(save_name) + pickle_save(final_infos, str(save_file_name)) + # move the benchmark file to a new path + hd5sum = get_md5_file(str(save_file_name) + ".pbz2") + hd5_file_name = save_dir / "{:}-{:}.pickle.pbz2".format(NATS_TSS_BASE_NAME, hd5sum) + shutil.move(str(save_file_name) + ".pbz2", hd5_file_name) + print( + "Save {:} / {:} architecture results into {:} -> {:}.".format( + len(evaluated_indexes), total, save_file_name, hd5_file_name + ) + ) + # move the directory to a new path + hd5_full_save_dir = save_dir / "{:}-{:}-full".format(NATS_TSS_BASE_NAME, hd5sum) + hd5_simple_save_dir = save_dir / "{:}-{:}-simple".format(NATS_TSS_BASE_NAME, hd5sum) + shutil.move(full_save_dir, hd5_full_save_dir) + shutil.move(simple_save_dir, hd5_simple_save_dir) def traverse_net(max_node): - aa_nas_bench_ss = get_search_spaces('cell', 'nats-bench') - archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) - print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2))) + aa_nas_bench_ss = get_search_spaces("cell", "nats-bench") + archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) + print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2))) - random.seed( 88 ) # please do not change this line for reproducibility - random.shuffle( archs ) - assert archs[0 ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0]) - assert archs[9 ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9]) - assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123]) - return [x.tostr() for x in archs] + random.seed(88) # please do not change this line for reproducibility + random.shuffle(archs) + assert ( + archs[0].tostr() + == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" + ), "please check the 0-th architecture : {:}".format(archs[0]) + assert ( + archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" + ), "please check the 9-th architecture : {:}".format(archs[9]) + assert ( + archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" + ), "please check the 123-th architecture : {:}".format(archs[123]) + return [x.tostr() for x in archs] -if __name__ == '__main__': +if __name__ == "__main__": - parser = argparse.ArgumentParser(description='NATS-Bench (topology search space)', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--base_save_dir', type=str, default='./output/NATS-Bench-topology', help='The base-name of folder to save checkpoints and log.') - parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.') - parser.add_argument('--channel' , type=int, default=16, help='The number of channels.') - parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.') - parser.add_argument('--check_N' , type=int, default=15625, help='For safety.') - parser.add_argument('--save_name' , type=str, default='process', help='The save directory.') - args = parser.parse_args() - - nets = traverse_net(args.max_node) - if len(nets) != args.check_N: - raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N)) - - save_dir = Path(args.base_save_dir) - simplify(save_dir, args.save_name, nets, args.check_N, {'name': 'infer.tiny', 'channel': args.channel, 'num_cells': args.num_cells}) + parser = argparse.ArgumentParser( + description="NATS-Bench (topology search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--base_save_dir", + type=str, + default="./output/NATS-Bench-topology", + help="The base-name of folder to save checkpoints and log.", + ) + parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.") + parser.add_argument("--channel", type=int, default=16, help="The number of channels.") + parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + parser.add_argument("--check_N", type=int, default=15625, help="For safety.") + parser.add_argument("--save_name", type=str, default="process", help="The save directory.") + args = parser.parse_args() + + nets = traverse_net(args.max_node) + if len(nets) != args.check_N: + raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) + + save_dir = Path(args.base_save_dir) + simplify( + save_dir, + args.save_name, + nets, + args.check_N, + {"name": "infer.tiny", "channel": args.channel, "num_cells": args.num_cells}, + ) diff --git a/exps/NATS-Bench/tss-collect.py b/exps/NATS-Bench/tss-collect.py index 62cc0dd..51017df 100644 --- a/exps/NATS-Bench/tss-collect.py +++ b/exps/NATS-Bench/tss-collect.py @@ -18,246 +18,335 @@ from tqdm import tqdm from pathlib import Path from collections import defaultdict, OrderedDict from typing import Dict, Any, Text, List -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from log_utils import AverageMeter, time_string, convert_secs2time from config_utils import load_config, dict2config -from datasets import get_datasets -from models import CellStructure, get_cell_based_tiny_net, get_search_spaces -from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount -from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders -from utils import get_md5_file -from nas_201_api import NASBench201API +from datasets import get_datasets +from models import CellStructure, get_cell_based_tiny_net, get_search_spaces +from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount +from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders +from utils import get_md5_file +from nas_201_api import NASBench201API -api = NASBench201API('{:}/.torch/NAS-Bench-201-v1_0-e61699.pth'.format(os.environ['HOME'])) +api = NASBench201API("{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"])) -NATS_TSS_BASE_NAME = 'NATS-tss-v1_0' # 2020.08.28 +NATS_TSS_BASE_NAME = "NATS-tss-v1_0" # 2020.08.28 -def create_result_count(used_seed: int, dataset: Text, arch_config: Dict[Text, Any], - results: Dict[Text, Any], dataloader_dict: Dict[Text, Any]) -> ResultsCount: - xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'], - results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None) - net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes': arch_config['class_num']}, None) - if 'train_times' in results: # new version - xresult.update_train_info(results['train_acc1es'], results['train_acc5es'], results['train_losses'], results['train_times']) - xresult.update_eval(results['valid_acc1es'], results['valid_losses'], results['valid_times']) - else: - network = get_cell_based_tiny_net(net_config) - network.load_state_dict(xresult.get_net_param()) - if dataset == 'cifar10-valid': - xresult.update_OLD_eval('x-valid' , results['valid_acc1es'], results['valid_losses']) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format('cifar10', 'test')], network.cuda()) - xresult.update_OLD_eval('ori-test', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) - xresult.update_latency(latencies) - elif dataset == 'cifar10': - xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses']) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda()) - xresult.update_latency(latencies) - elif dataset == 'cifar100' or dataset == 'ImageNet16-120': - xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses']) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'valid')], network.cuda()) - xresult.update_OLD_eval('x-valid', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) - loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda()) - xresult.update_OLD_eval('x-test' , {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) - xresult.update_latency(latencies) +def create_result_count( + used_seed: int, + dataset: Text, + arch_config: Dict[Text, Any], + results: Dict[Text, Any], + dataloader_dict: Dict[Text, Any], +) -> ResultsCount: + xresult = ResultsCount( + dataset, + results["net_state_dict"], + results["train_acc1es"], + results["train_losses"], + results["param"], + results["flop"], + arch_config, + used_seed, + results["total_epoch"], + None, + ) + net_config = dict2config( + { + "name": "infer.tiny", + "C": arch_config["channel"], + "N": arch_config["num_cells"], + "genotype": CellStructure.str2structure(arch_config["arch_str"]), + "num_classes": arch_config["class_num"], + }, + None, + ) + if "train_times" in results: # new version + xresult.update_train_info( + results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"] + ) + xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"]) else: - raise ValueError('invalid dataset name : {:}'.format(dataset)) - return xresult + network = get_cell_based_tiny_net(net_config) + network.load_state_dict(xresult.get_net_param()) + if dataset == "cifar10-valid": + xresult.update_OLD_eval("x-valid", results["valid_acc1es"], results["valid_losses"]) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda() + ) + xresult.update_OLD_eval("ori-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_latency(latencies) + elif dataset == "cifar10": + xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() + ) + xresult.update_latency(latencies) + elif dataset == "cifar100" or dataset == "ImageNet16-120": + xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda() + ) + xresult.update_OLD_eval("x-valid", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + loss, top1, top5, latencies = pure_evaluate( + dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() + ) + xresult.update_OLD_eval("x-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_latency(latencies) + else: + raise ValueError("invalid dataset name : {:}".format(dataset)) + return xresult def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dict): - information = ArchResults(arch_index, arch_str) + information = ArchResults(arch_index, arch_str) - for checkpoint_path in checkpoints: - checkpoint = torch.load(checkpoint_path, map_location='cpu') - used_seed = checkpoint_path.name.split('-')[-1].split('.')[0] - ok_dataset = 0 - for dataset in datasets: - if dataset not in checkpoint: - print('Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path)) - continue - else: - ok_dataset += 1 - results = checkpoint[dataset] - assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path) - arch_config = {'channel': results['channel'], 'num_cells': results['num_cells'], 'arch_str': arch_str, 'class_num': results['config']['class_num']} - - xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict) - information.update(dataset, int(used_seed), xresult) - if ok_dataset == 0: raise ValueError('{:} does not find any data'.format(checkpoint_path)) - return information + for checkpoint_path in checkpoints: + checkpoint = torch.load(checkpoint_path, map_location="cpu") + used_seed = checkpoint_path.name.split("-")[-1].split(".")[0] + ok_dataset = 0 + for dataset in datasets: + if dataset not in checkpoint: + print("Can not find {:} in arch-{:} from {:}".format(dataset, arch_index, checkpoint_path)) + continue + else: + ok_dataset += 1 + results = checkpoint[dataset] + assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( + arch_index, used_seed, dataset, checkpoint_path + ) + arch_config = { + "channel": results["channel"], + "num_cells": results["num_cells"], + "arch_str": arch_str, + "class_num": results["config"]["class_num"], + } + + xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict) + information.update(dataset, int(used_seed), xresult) + if ok_dataset == 0: + raise ValueError("{:} does not find any data".format(checkpoint_path)) + return information def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResults]): - # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth - cifar010_latency = (api.get_latency(arch_index, 'cifar10-valid', hp='200') + api.get_latency(arch_index, 'cifar10', hp='200')) / 2 - cifar100_latency = api.get_latency(arch_index, 'cifar100', hp='200') - image_latency = api.get_latency(arch_index, 'ImageNet16-120', hp='200') - for hp, arch_info in arch_infos.items(): - arch_info.reset_latency('cifar10-valid', None, cifar010_latency) - arch_info.reset_latency('cifar10', None, cifar010_latency) - arch_info.reset_latency('cifar100', None, cifar100_latency) - arch_info.reset_latency('ImageNet16-120', None, image_latency) + # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth + cifar010_latency = ( + api.get_latency(arch_index, "cifar10-valid", hp="200") + api.get_latency(arch_index, "cifar10", hp="200") + ) / 2 + cifar100_latency = api.get_latency(arch_index, "cifar100", hp="200") + image_latency = api.get_latency(arch_index, "ImageNet16-120", hp="200") + for hp, arch_info in arch_infos.items(): + arch_info.reset_latency("cifar10-valid", None, cifar010_latency) + arch_info.reset_latency("cifar10", None, cifar010_latency) + arch_info.reset_latency("cifar100", None, cifar100_latency) + arch_info.reset_latency("ImageNet16-120", None, image_latency) - train_per_epoch_time = list(arch_infos['12'].query('cifar10-valid', 777).train_times.values()) - train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) - eval_ori_test_time, eval_x_valid_time = [], [] - for key, value in arch_infos['12'].query('cifar10-valid', 777).eval_times.items(): - if key.startswith('ori-test@'): - eval_ori_test_time.append(value) - elif key.startswith('x-valid@'): - eval_x_valid_time.append(value) - else: raise ValueError('-- {:} --'.format(key)) - eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time)) - nums = {'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000, 'ImageNet16-120-test': 6000, - 'cifar10-valid-train': 25000, 'cifar10-valid-valid': 25000, - 'cifar10-train': 50000, 'cifar10-test': 10000, - 'cifar100-train': 50000, 'cifar100-test': 10000, 'cifar100-valid': 5000} - eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (nums['cifar10-valid-valid'] + nums['cifar10-test']) - for hp, arch_info in arch_infos.items(): - arch_info.reset_pseudo_train_times('cifar10-valid', None, - train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar10-valid-train']) - arch_info.reset_pseudo_train_times('cifar10', None, - train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar10-train']) - arch_info.reset_pseudo_train_times('cifar100', None, - train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar100-train']) - arch_info.reset_pseudo_train_times('ImageNet16-120', None, - train_per_epoch_time / nums['cifar10-valid-train'] * nums['ImageNet16-120-train']) - arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'x-valid', eval_per_sample*nums['cifar10-valid-valid']) - arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'ori-test', eval_per_sample * nums['cifar10-test']) - arch_info.reset_pseudo_eval_times('cifar10', None, 'ori-test', eval_per_sample * nums['cifar10-test']) - arch_info.reset_pseudo_eval_times('cifar100', None, 'x-valid', eval_per_sample * nums['cifar100-valid']) - arch_info.reset_pseudo_eval_times('cifar100', None, 'x-test', eval_per_sample * nums['cifar100-valid']) - arch_info.reset_pseudo_eval_times('cifar100', None, 'ori-test', eval_per_sample * nums['cifar100-test']) - arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-valid', eval_per_sample * nums['ImageNet16-120-valid']) - arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-test', eval_per_sample * nums['ImageNet16-120-valid']) - arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'ori-test', eval_per_sample * nums['ImageNet16-120-test']) - return arch_infos + train_per_epoch_time = list(arch_infos["12"].query("cifar10-valid", 777).train_times.values()) + train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) + eval_ori_test_time, eval_x_valid_time = [], [] + for key, value in arch_infos["12"].query("cifar10-valid", 777).eval_times.items(): + if key.startswith("ori-test@"): + eval_ori_test_time.append(value) + elif key.startswith("x-valid@"): + eval_x_valid_time.append(value) + else: + raise ValueError("-- {:} --".format(key)) + eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time)) + nums = { + "ImageNet16-120-train": 151700, + "ImageNet16-120-valid": 3000, + "ImageNet16-120-test": 6000, + "cifar10-valid-train": 25000, + "cifar10-valid-valid": 25000, + "cifar10-train": 50000, + "cifar10-test": 10000, + "cifar100-train": 50000, + "cifar100-test": 10000, + "cifar100-valid": 5000, + } + eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (nums["cifar10-valid-valid"] + nums["cifar10-test"]) + for hp, arch_info in arch_infos.items(): + arch_info.reset_pseudo_train_times( + "cifar10-valid", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-valid-train"] + ) + arch_info.reset_pseudo_train_times( + "cifar10", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"] + ) + arch_info.reset_pseudo_train_times( + "cifar100", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"] + ) + arch_info.reset_pseudo_train_times( + "ImageNet16-120", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["ImageNet16-120-train"] + ) + arch_info.reset_pseudo_eval_times( + "cifar10-valid", None, "x-valid", eval_per_sample * nums["cifar10-valid-valid"] + ) + arch_info.reset_pseudo_eval_times("cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"]) + arch_info.reset_pseudo_eval_times("cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"]) + arch_info.reset_pseudo_eval_times("cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"]) + arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"]) + arch_info.reset_pseudo_eval_times("cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"]) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", None, "x-valid", eval_per_sample * nums["ImageNet16-120-valid"] + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", None, "x-test", eval_per_sample * nums["ImageNet16-120-valid"] + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", None, "ori-test", eval_per_sample * nums["ImageNet16-120-test"] + ) + return arch_infos def simplify(save_dir, save_name, nets, total, sup_config): - dataloader_dict = get_nas_bench_loaders(6) - hps, seeds = ['12', '200'], set() - for hp in hps: - sub_save_dir = save_dir / 'raw-data-{:}'.format(hp) - ckps = sorted(list(sub_save_dir.glob('arch-*-seed-*.pth'))) - seed2names = defaultdict(list) - for ckp in ckps: - parts = re.split('-|\.', ckp.name) - seed2names[parts[3]].append(ckp.name) - print('DIR : {:}'.format(sub_save_dir)) - nums = [] - for seed, xlist in seed2names.items(): - seeds.add(seed) - nums.append(len(xlist)) - print(' [seed={:}] there are {:} checkpoints.'.format(seed, len(xlist))) - assert len(nets) == total == max(nums), 'there are some missed files : {:} vs {:}'.format(max(nums), total) - print('{:} start simplify the checkpoint.'.format(time_string())) - - datasets = ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120') - - # Create the directory to save the processed data - # full_save_dir contains all benchmark files with trained weights. - # simplify_save_dir contains all benchmark files without trained weights. - full_save_dir = save_dir / (save_name + '-FULL') - simple_save_dir = save_dir / (save_name + '-SIMPLIFY') - full_save_dir.mkdir(parents=True, exist_ok=True) - simple_save_dir.mkdir(parents=True, exist_ok=True) - # all data in memory - arch2infos, evaluated_indexes = dict(), set() - end_time, arch_time = time.time(), AverageMeter() - # save the meta information - temp_final_infos = {'meta_archs' : nets, - 'total_archs': total, - 'arch2infos' : None, - 'evaluated_indexes': set()} - pickle_save(temp_final_infos, str(full_save_dir / 'meta.pickle')) - pickle_save(temp_final_infos, str(simple_save_dir / 'meta.pickle')) - - for index in tqdm(range(total)): - arch_str = nets[index] - hp2info = OrderedDict() - - full_save_path = full_save_dir / '{:06d}.pickle'.format(index) - simple_save_path = simple_save_dir / '{:06d}.pickle'.format(index) + dataloader_dict = get_nas_bench_loaders(6) + hps, seeds = ["12", "200"], set() for hp in hps: - sub_save_dir = save_dir / 'raw-data-{:}'.format(hp) - ckps = [sub_save_dir / 'arch-{:06d}-seed-{:}.pth'.format(index, seed) for seed in seeds] - ckps = [x for x in ckps if x.exists()] - if len(ckps) == 0: - raise ValueError('Invalid data : index={:}, hp={:}'.format(index, hp)) + sub_save_dir = save_dir / "raw-data-{:}".format(hp) + ckps = sorted(list(sub_save_dir.glob("arch-*-seed-*.pth"))) + seed2names = defaultdict(list) + for ckp in ckps: + parts = re.split("-|\.", ckp.name) + seed2names[parts[3]].append(ckp.name) + print("DIR : {:}".format(sub_save_dir)) + nums = [] + for seed, xlist in seed2names.items(): + seeds.add(seed) + nums.append(len(xlist)) + print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist))) + assert len(nets) == total == max(nums), "there are some missed files : {:} vs {:}".format(max(nums), total) + print("{:} start simplify the checkpoint.".format(time_string())) - arch_info = account_one_arch(index, arch_str, ckps, datasets, dataloader_dict) - hp2info[hp] = arch_info - - hp2info = correct_time_related_info(index, hp2info) - evaluated_indexes.add(index) - - to_save_data = OrderedDict({'12': hp2info['12'].state_dict(), - '200': hp2info['200'].state_dict()}) - pickle_save(to_save_data, str(full_save_path)) - - for hp in hps: hp2info[hp].clear_params() - to_save_data = OrderedDict({'12': hp2info['12'].state_dict(), - '200': hp2info['200'].state_dict()}) - pickle_save(to_save_data, str(simple_save_path)) - arch2infos[index] = to_save_data - # measure elapsed time - arch_time.update(time.time() - end_time) - end_time = time.time() - need_time = '{:}'.format(convert_secs2time(arch_time.avg * (total-index-1), True)) - # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) - print('{:} {:} done.'.format(time_string(), save_name)) - final_infos = {'meta_archs' : nets, - 'total_archs': total, - 'arch2infos' : arch2infos, - 'evaluated_indexes': evaluated_indexes} - save_file_name = save_dir / '{:}.pickle'.format(save_name) - pickle_save(final_infos, str(save_file_name)) - # move the benchmark file to a new path - hd5sum = get_md5_file(str(save_file_name) + '.pbz2') - hd5_file_name = save_dir / '{:}-{:}.pickle.pbz2'.format(NATS_TSS_BASE_NAME, hd5sum) - shutil.move(str(save_file_name) + '.pbz2', hd5_file_name) - print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name)) - # move the directory to a new path - hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum) - hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum) - shutil.move(full_save_dir, hd5_full_save_dir) - shutil.move(simple_save_dir, hd5_simple_save_dir) - # save the meta information for simple and full - # final_infos['arch2infos'] = None - # final_infos['evaluated_indexes'] = set() + datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") + + # Create the directory to save the processed data + # full_save_dir contains all benchmark files with trained weights. + # simplify_save_dir contains all benchmark files without trained weights. + full_save_dir = save_dir / (save_name + "-FULL") + simple_save_dir = save_dir / (save_name + "-SIMPLIFY") + full_save_dir.mkdir(parents=True, exist_ok=True) + simple_save_dir.mkdir(parents=True, exist_ok=True) + # all data in memory + arch2infos, evaluated_indexes = dict(), set() + end_time, arch_time = time.time(), AverageMeter() + # save the meta information + temp_final_infos = {"meta_archs": nets, "total_archs": total, "arch2infos": None, "evaluated_indexes": set()} + pickle_save(temp_final_infos, str(full_save_dir / "meta.pickle")) + pickle_save(temp_final_infos, str(simple_save_dir / "meta.pickle")) + + for index in tqdm(range(total)): + arch_str = nets[index] + hp2info = OrderedDict() + + full_save_path = full_save_dir / "{:06d}.pickle".format(index) + simple_save_path = simple_save_dir / "{:06d}.pickle".format(index) + for hp in hps: + sub_save_dir = save_dir / "raw-data-{:}".format(hp) + ckps = [sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) for seed in seeds] + ckps = [x for x in ckps if x.exists()] + if len(ckps) == 0: + raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp)) + + arch_info = account_one_arch(index, arch_str, ckps, datasets, dataloader_dict) + hp2info[hp] = arch_info + + hp2info = correct_time_related_info(index, hp2info) + evaluated_indexes.add(index) + + to_save_data = OrderedDict({"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()}) + pickle_save(to_save_data, str(full_save_path)) + + for hp in hps: + hp2info[hp].clear_params() + to_save_data = OrderedDict({"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()}) + pickle_save(to_save_data, str(simple_save_path)) + arch2infos[index] = to_save_data + # measure elapsed time + arch_time.update(time.time() - end_time) + end_time = time.time() + need_time = "{:}".format(convert_secs2time(arch_time.avg * (total - index - 1), True)) + # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) + print("{:} {:} done.".format(time_string(), save_name)) + final_infos = { + "meta_archs": nets, + "total_archs": total, + "arch2infos": arch2infos, + "evaluated_indexes": evaluated_indexes, + } + save_file_name = save_dir / "{:}.pickle".format(save_name) + pickle_save(final_infos, str(save_file_name)) + # move the benchmark file to a new path + hd5sum = get_md5_file(str(save_file_name) + ".pbz2") + hd5_file_name = save_dir / "{:}-{:}.pickle.pbz2".format(NATS_TSS_BASE_NAME, hd5sum) + shutil.move(str(save_file_name) + ".pbz2", hd5_file_name) + print( + "Save {:} / {:} architecture results into {:} -> {:}.".format( + len(evaluated_indexes), total, save_file_name, hd5_file_name + ) + ) + # move the directory to a new path + hd5_full_save_dir = save_dir / "{:}-{:}-full".format(NATS_TSS_BASE_NAME, hd5sum) + hd5_simple_save_dir = save_dir / "{:}-{:}-simple".format(NATS_TSS_BASE_NAME, hd5sum) + shutil.move(full_save_dir, hd5_full_save_dir) + shutil.move(simple_save_dir, hd5_simple_save_dir) + # save the meta information for simple and full + # final_infos['arch2infos'] = None + # final_infos['evaluated_indexes'] = set() def traverse_net(max_node): - aa_nas_bench_ss = get_search_spaces('cell', 'nats-bench') - archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) - print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2))) + aa_nas_bench_ss = get_search_spaces("cell", "nats-bench") + archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) + print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2))) - random.seed( 88 ) # please do not change this line for reproducibility - random.shuffle( archs ) - assert archs[0 ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0]) - assert archs[9 ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9]) - assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123]) - return [x.tostr() for x in archs] + random.seed(88) # please do not change this line for reproducibility + random.shuffle(archs) + assert ( + archs[0].tostr() + == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" + ), "please check the 0-th architecture : {:}".format(archs[0]) + assert ( + archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" + ), "please check the 9-th architecture : {:}".format(archs[9]) + assert ( + archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" + ), "please check the 123-th architecture : {:}".format(archs[123]) + return [x.tostr() for x in archs] -if __name__ == '__main__': +if __name__ == "__main__": - parser = argparse.ArgumentParser(description='NATS-Bench (topology search space)', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--base_save_dir', type=str, default='./output/NATS-Bench-topology', help='The base-name of folder to save checkpoints and log.') - parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.') - parser.add_argument('--channel' , type=int, default=16, help='The number of channels.') - parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.') - parser.add_argument('--check_N' , type=int, default=15625, help='For safety.') - parser.add_argument('--save_name' , type=str, default='process', help='The save directory.') - args = parser.parse_args() - - nets = traverse_net(args.max_node) - if len(nets) != args.check_N: - raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N)) - - save_dir = Path(args.base_save_dir) - simplify(save_dir, args.save_name, nets, args.check_N, {'name': 'infer.tiny', 'channel': args.channel, 'num_cells': args.num_cells}) + parser = argparse.ArgumentParser( + description="NATS-Bench (topology search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--base_save_dir", + type=str, + default="./output/NATS-Bench-topology", + help="The base-name of folder to save checkpoints and log.", + ) + parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.") + parser.add_argument("--channel", type=int, default=16, help="The number of channels.") + parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + parser.add_argument("--check_N", type=int, default=15625, help="For safety.") + parser.add_argument("--save_name", type=str, default="process", help="The save directory.") + args = parser.parse_args() + + nets = traverse_net(args.max_node) + if len(nets) != args.check_N: + raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) + + save_dir = Path(args.base_save_dir) + simplify( + save_dir, + args.save_name, + nets, + args.check_N, + {"name": "infer.tiny", "channel": args.channel, "num_cells": args.num_cells}, + ) diff --git a/exps/NATS-Bench/tss-file-manager.py b/exps/NATS-Bench/tss-file-manager.py index 7c0ca6c..10c94eb 100644 --- a/exps/NATS-Bench/tss-file-manager.py +++ b/exps/NATS-Bench/tss-file-manager.py @@ -9,72 +9,82 @@ import os, sys, time, torch, argparse from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config -from procedures import bench_evaluate_for_seed -from procedures import get_machine_info -from datasets import get_datasets -from log_utils import Logger, AverageMeter, time_string, convert_secs2time +from procedures import bench_evaluate_for_seed +from procedures import get_machine_info +from datasets import get_datasets +from log_utils import Logger, AverageMeter, time_string, convert_secs2time def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]): - seed2ckps = defaultdict(list) - miss2ckps = defaultdict(list) - for i in range(total): - for seed in possible_seeds: - path = os.path.join(save_dir, 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed)) - if os.path.exists(path): - seed2ckps[seed].append(i) - else: - miss2ckps[seed].append(i) - for seed, xlist in seed2ckps.items(): - print('[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}'.format(save_dir, seed, len(xlist), total, total-len(xlist), total)) - return dict(seed2ckps), dict(miss2ckps) - + seed2ckps = defaultdict(list) + miss2ckps = defaultdict(list) + for i in range(total): + for seed in possible_seeds: + path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed)) + if os.path.exists(path): + seed2ckps[seed].append(i) + else: + miss2ckps[seed].append(i) + for seed, xlist in seed2ckps.items(): + print( + "[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format( + save_dir, seed, len(xlist), total, total - len(xlist), total + ) + ) + return dict(seed2ckps), dict(miss2ckps) + def copy_data(source_dir, target_dir, meta_path): - target_dir = Path(target_dir) - target_dir.mkdir(parents=True, exist_ok=True) - miss2ckps = torch.load(meta_path)['miss2ckps'] - s2t = {} - for seed, xlist in miss2ckps.items(): - for i in xlist: - file_name = 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed) - source_path = os.path.join(source_dir, file_name) - target_path = os.path.join(target_dir, file_name) - if os.path.exists(source_path): - s2t[source_path] = target_path - print('Map from {:} to {:}, find {:} missed ckps.'.format(source_dir, target_dir, len(s2t))) - for s, t in s2t.items(): - copyfile(s, t) + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + miss2ckps = torch.load(meta_path)["miss2ckps"] + s2t = {} + for seed, xlist in miss2ckps.items(): + for i in xlist: + file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed) + source_path = os.path.join(source_dir, file_name) + target_path = os.path.join(target_dir, file_name) + if os.path.exists(source_path): + s2t[source_path] = target_path + print("Map from {:} to {:}, find {:} missed ckps.".format(source_dir, target_dir, len(s2t))) + for s, t in s2t.items(): + copyfile(s, t) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NATS-Bench (topology search space) file manager.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--mode', type=str, required=True, choices=['check', 'copy'], help='The script mode.') - parser.add_argument('--save_dir', type=str, default='output/NATS-Bench-topology', help='Folder to save checkpoints and log.') - parser.add_argument('--check_N', type=int, default=15625, help='For safety.') - # use for train the model - args = parser.parse_args() - possible_configs = ['12', '200'] - possible_seedss = [[111, 777], [777, 888, 999]] - if args.mode == 'check': - for config, possible_seeds in zip(possible_configs, possible_seedss): - cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config) - seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds) - torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config)) - elif args.mode == 'copy': - for config in possible_configs: - cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config) - cur_copy_dir = '{:}/copy-{:}'.format(args.save_dir, config) - cur_meta_path = '{:}/meta-{:}.pth'.format(args.save_dir, config) - if os.path.exists(cur_meta_path): - copy_data(cur_save_dir, cur_copy_dir, cur_meta_path) - else: - print('Do not find : {:}'.format(cur_meta_path)) - else: - raise ValueError('invalid mode : {:}'.format(args.mode)) +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="NATS-Bench (topology search space) file manager.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--mode", type=str, required=True, choices=["check", "copy"], help="The script mode.") + parser.add_argument( + "--save_dir", type=str, default="output/NATS-Bench-topology", help="Folder to save checkpoints and log." + ) + parser.add_argument("--check_N", type=int, default=15625, help="For safety.") + # use for train the model + args = parser.parse_args() + possible_configs = ["12", "200"] + possible_seedss = [[111, 777], [777, 888, 999]] + if args.mode == "check": + for config, possible_seeds in zip(possible_configs, possible_seedss): + cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) + seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds) + torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), "{:}/meta-{:}.pth".format(args.save_dir, config)) + elif args.mode == "copy": + for config in possible_configs: + cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) + cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config) + cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config) + if os.path.exists(cur_meta_path): + copy_data(cur_save_dir, cur_copy_dir, cur_meta_path) + else: + print("Do not find : {:}".format(cur_meta_path)) + else: + raise ValueError("invalid mode : {:}".format(args.mode)) diff --git a/exps/NATS-algos/bohb.py b/exps/NATS-algos/bohb.py index 0d8af60..e7840b1 100644 --- a/exps/NATS-algos/bohb.py +++ b/exps/NATS-algos/bohb.py @@ -12,14 +12,17 @@ import os, sys, time, random, argparse, collections from copy import deepcopy from pathlib import Path import torch -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config -from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger -from log_utils import AverageMeter, time_string, convert_secs2time -from nats_bench import create -from models import CellStructure, get_search_spaces +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger +from log_utils import AverageMeter, time_string, convert_secs2time +from nats_bench import create +from models import CellStructure, get_search_spaces + # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 import ConfigSpace from hpbandster.optimizers.bohb import BOHB @@ -28,161 +31,193 @@ from hpbandster.core.worker import Worker def get_topology_config_space(search_space, max_nodes=4): - cs = ConfigSpace.ConfigurationSpace() - #edge2index = {} - for i in range(1, max_nodes): - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) - return cs + cs = ConfigSpace.ConfigurationSpace() + # edge2index = {} + for i in range(1, max_nodes): + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) + return cs def get_size_config_space(search_space): - cs = ConfigSpace.ConfigurationSpace() - for ilayer in range(search_space['numbers']): - node_str = 'layer-{:}'.format(ilayer) - cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space['candidates'])) - return cs + cs = ConfigSpace.ConfigurationSpace() + for ilayer in range(search_space["numbers"]): + node_str = "layer-{:}".format(ilayer) + cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space["candidates"])) + return cs def config2topology_func(max_nodes=4): - def config2structure(config): - genotypes = [] - for i in range(1, max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - op_name = config[node_str] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return CellStructure( genotypes ) - return config2structure + def config2structure(config): + genotypes = [] + for i in range(1, max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + op_name = config[node_str] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return CellStructure(genotypes) + + return config2structure def config2size_func(search_space): - def config2structure(config): - channels = [] - for ilayer in range(search_space['numbers']): - node_str = 'layer-{:}'.format(ilayer) - channels.append(str(config[node_str])) - return ':'.join(channels) - return config2structure + def config2structure(config): + channels = [] + for ilayer in range(search_space["numbers"]): + node_str = "layer-{:}".format(ilayer) + channels.append(str(config[node_str])) + return ":".join(channels) + + return config2structure class MyWorker(Worker): + def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs): + super().__init__(*args, **kwargs) + self.convert_func = convert_func + self._dataset = dataset + self._api = api + self.total_times = [] + self.trajectory = [] - def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs): - super().__init__(*args, **kwargs) - self.convert_func = convert_func - self._dataset = dataset - self._api = api - self.total_times = [] - self.trajectory = [] - - def compute(self, config, budget, **kwargs): - arch = self.convert_func( config ) - accuracy, latency, time_cost, total_time = self._api.simulate_train_eval(arch, self._dataset, iepoch=int(budget)-1, hp='12') - self.trajectory.append((accuracy, arch)) - self.total_times.append(total_time) - return ({'loss': 100 - accuracy, - 'info': self._api.query_index_by_arch(arch)}) + def compute(self, config, budget, **kwargs): + arch = self.convert_func(config) + accuracy, latency, time_cost, total_time = self._api.simulate_train_eval( + arch, self._dataset, iepoch=int(budget) - 1, hp="12" + ) + self.trajectory.append((accuracy, arch)) + self.total_times.append(total_time) + return {"loss": 100 - accuracy, "info": self._api.query_index_by_arch(arch)} def main(xargs, api): - torch.set_num_threads(4) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + torch.set_num_threads(4) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - logger.log('{:} use api : {:}'.format(time_string(), api)) - api.reset_time() - search_space = get_search_spaces(xargs.search_space, 'nats-bench') - if xargs.search_space == 'tss': - cs = get_topology_config_space(search_space) - config2structure = config2topology_func() - else: - cs = get_size_config_space(search_space) - config2structure = config2size_func(search_space) - - hb_run_id = '0' + logger.log("{:} use api : {:}".format(time_string(), api)) + api.reset_time() + search_space = get_search_spaces(xargs.search_space, "nats-bench") + if xargs.search_space == "tss": + cs = get_topology_config_space(search_space) + config2structure = config2topology_func() + else: + cs = get_size_config_space(search_space) + config2structure = config2size_func(search_space) - NS = hpns.NameServer(run_id=hb_run_id, host='localhost', port=0) - ns_host, ns_port = NS.start() - num_workers = 1 + hb_run_id = "0" - workers = [] - for i in range(num_workers): - w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataset=xargs.dataset, api=api, run_id=hb_run_id, id=i) - w.run(background=True) - workers.append(w) + NS = hpns.NameServer(run_id=hb_run_id, host="localhost", port=0) + ns_host, ns_port = NS.start() + num_workers = 1 - start_time = time.time() - bohb = BOHB(configspace=cs, run_id=hb_run_id, - eta=3, min_budget=1, max_budget=12, - nameserver=ns_host, - nameserver_port=ns_port, - num_samples=xargs.num_samples, - random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor, - ping_interval=10, min_bandwidth=xargs.min_bandwidth) - - results = bohb.run(xargs.n_iters, min_n_workers=num_workers) + workers = [] + for i in range(num_workers): + w = MyWorker( + nameserver=ns_host, + nameserver_port=ns_port, + convert_func=config2structure, + dataset=xargs.dataset, + api=api, + run_id=hb_run_id, + id=i, + ) + w.run(background=True) + workers.append(w) - bohb.shutdown(shutdown_workers=True) - NS.shutdown() + start_time = time.time() + bohb = BOHB( + configspace=cs, + run_id=hb_run_id, + eta=3, + min_budget=1, + max_budget=12, + nameserver=ns_host, + nameserver_port=ns_port, + num_samples=xargs.num_samples, + random_fraction=xargs.random_fraction, + bandwidth_factor=xargs.bandwidth_factor, + ping_interval=10, + min_bandwidth=xargs.min_bandwidth, + ) - # print('There are {:} runs.'.format(len(results.get_all_runs()))) - # workers[0].total_times - # workers[0].trajectory - current_best_index = [] - for idx in range(len(workers[0].trajectory)): - trajectory = workers[0].trajectory[:idx+1] - arch = max(trajectory, key=lambda x: x[0])[1] - current_best_index.append(api.query_index_by_arch(arch)) - - best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1] - logger.log('Best found configuration: {:} within {:.3f} s'.format(best_arch, workers[0].total_times[-1])) - info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90') - logger.log('{:}'.format(info)) - logger.log('-'*100) - logger.close() + results = bohb.run(xargs.n_iters, min_n_workers=num_workers) - return logger.log_dir, current_best_index, workers[0].total_times + bohb.shutdown(shutdown_workers=True) + NS.shutdown() + + # print('There are {:} runs.'.format(len(results.get_all_runs()))) + # workers[0].total_times + # workers[0].trajectory + current_best_index = [] + for idx in range(len(workers[0].trajectory)): + trajectory = workers[0].trajectory[: idx + 1] + arch = max(trajectory, key=lambda x: x[0])[1] + current_best_index.append(api.query_index_by_arch(arch)) + + best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1] + logger.log("Best found configuration: {:} within {:.3f} s".format(best_arch, workers[0].total_times[-1])) + info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90") + logger.log("{:}".format(info)) + logger.log("-" * 100) + logger.close() + + return logger.log_dir, current_best_index, workers[0].total_times -if __name__ == '__main__': - parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale") - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - # general arg - parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') - parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).') - parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') - # BOHB - parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') - parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') - parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function') - parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations') - parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth') - parser.add_argument('--n_iters', default=300, type=int, nargs='?', help='number of iterations for optimization method') - # log - parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') - parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') - args = parser.parse_args() - - api = create(None, args.search_space, fast_mode=False, verbose=False) +if __name__ == "__main__": + parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + # general arg + parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") + parser.add_argument( + "--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)." + ) + parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.") + # BOHB + parser.add_argument( + "--strategy", default="sampling", type=str, nargs="?", help="optimization strategy for the acquisition function" + ) + parser.add_argument("--min_bandwidth", default=0.3, type=float, nargs="?", help="minimum bandwidth for KDE") + parser.add_argument( + "--num_samples", default=64, type=int, nargs="?", help="number of samples for the acquisition function" + ) + parser.add_argument( + "--random_fraction", default=0.33, type=float, nargs="?", help="fraction of random configurations" + ) + parser.add_argument("--bandwidth_factor", default=3, type=int, nargs="?", help="factor multiplied to the bandwidth") + parser.add_argument( + "--n_iters", default=300, type=int, nargs="?", help="number of iterations for optimization method" + ) + # log + parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() - args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), - '{:}-T{:}'.format(args.dataset, args.time_budget), 'BOHB') - print('save-dir : {:}'.format(args.save_dir)) + api = create(None, args.search_space, fast_mode=False, verbose=False) - if args.rand_seed < 0: - save_dir, all_info = None, collections.OrderedDict() - for i in range(args.loops_if_rand): - print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand)) - args.rand_seed = random.randint(1, 100000) - save_dir, all_archs, all_total_times = main(args, api) - all_info[i] = {'all_archs': all_archs, - 'all_total_times': all_total_times} - save_path = save_dir / 'results.pth' - print('save into {:}'.format(save_path)) - torch.save(all_info, save_path) - else: - main(args, api) + args.save_dir = os.path.join( + "{:}-{:}".format(args.save_dir, args.search_space), "{:}-T{:}".format(args.dataset, args.time_budget), "BOHB" + ) + print("save-dir : {:}".format(args.save_dir)) + + if args.rand_seed < 0: + save_dir, all_info = None, collections.OrderedDict() + for i in range(args.loops_if_rand): + print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand)) + args.rand_seed = random.randint(1, 100000) + save_dir, all_archs, all_total_times = main(args, api) + all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times} + save_path = save_dir / "results.pth" + print("save into {:}".format(save_path)) + torch.save(all_info, save_path) + else: + main(args, api) diff --git a/exps/NATS-algos/random_wo_share.py b/exps/NATS-algos/random_wo_share.py index efe03f3..5c1d3ad 100644 --- a/exps/NATS-algos/random_wo_share.py +++ b/exps/NATS-algos/random_wo_share.py @@ -13,80 +13,93 @@ from copy import deepcopy import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_search_spaces -from nats_bench import create +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_search_spaces +from nats_bench import create from regularized_ea import random_topology_func, random_size_func def main(xargs, api): - torch.set_num_threads(4) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + torch.set_num_threads(4) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - logger.log('{:} use api : {:}'.format(time_string(), api)) - api.reset_time() + logger.log("{:} use api : {:}".format(time_string(), api)) + api.reset_time() - search_space = get_search_spaces(xargs.search_space, 'nats-bench') - if xargs.search_space == 'tss': - random_arch = random_topology_func(search_space) - else: - random_arch = random_size_func(search_space) + search_space = get_search_spaces(xargs.search_space, "nats-bench") + if xargs.search_space == "tss": + random_arch = random_topology_func(search_space) + else: + random_arch = random_size_func(search_space) - best_arch, best_acc, total_time_cost, history = None, -1, [], [] - current_best_index = [] - while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget: - arch = random_arch() - accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12') - total_time_cost.append(total_cost) - history.append(arch) - if best_arch is None or best_acc < accuracy: - best_acc, best_arch = accuracy, arch - logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy)) - current_best_index.append(api.query_index_by_arch(best_arch)) - logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost[-1])) - - info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90') - logger.log('{:}'.format(info)) - logger.log('-'*100) - logger.close() - return logger.log_dir, current_best_index, total_time_cost + best_arch, best_acc, total_time_cost, history = None, -1, [], [] + current_best_index = [] + while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget: + arch = random_arch() + accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, hp="12") + total_time_cost.append(total_cost) + history.append(arch) + if best_arch is None or best_acc < accuracy: + best_acc, best_arch = accuracy, arch + logger.log("[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy)) + current_best_index.append(api.query_index_by_arch(best_arch)) + logger.log( + "{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.".format( + time_string(), best_arch, best_acc, len(history), total_time_cost[-1] + ) + ) + + info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90") + logger.log("{:}".format(info)) + logger.log("-" * 100) + logger.close() + return logger.log_dir, current_best_index, total_time_cost -if __name__ == '__main__': - parser = argparse.ArgumentParser("Random NAS") - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') +if __name__ == "__main__": + parser = argparse.ArgumentParser("Random NAS") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") - parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).') - parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') - # log - parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') - parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') - args = parser.parse_args() - - api = create(None, args.search_space, fast_mode=True, verbose=False) + parser.add_argument( + "--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)." + ) + parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.") + # log + parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() - args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), - '{:}-T{:}'.format(args.dataset, args.time_budget), 'RANDOM') - print('save-dir : {:}'.format(args.save_dir)) + api = create(None, args.search_space, fast_mode=True, verbose=False) - if args.rand_seed < 0: - save_dir, all_info = None, collections.OrderedDict() - for i in range(args.loops_if_rand): - print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand)) - args.rand_seed = random.randint(1, 100000) - save_dir, all_archs, all_total_times = main(args, api) - all_info[i] = {'all_archs': all_archs, - 'all_total_times': all_total_times} - save_path = save_dir / 'results.pth' - print('save into {:}'.format(save_path)) - torch.save(all_info, save_path) - else: - main(args, api) + args.save_dir = os.path.join( + "{:}-{:}".format(args.save_dir, args.search_space), "{:}-T{:}".format(args.dataset, args.time_budget), "RANDOM" + ) + print("save-dir : {:}".format(args.save_dir)) + + if args.rand_seed < 0: + save_dir, all_info = None, collections.OrderedDict() + for i in range(args.loops_if_rand): + print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand)) + args.rand_seed = random.randint(1, 100000) + save_dir, all_archs, all_total_times = main(args, api) + all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times} + save_path = save_dir / "results.pth" + print("save into {:}".format(save_path)) + torch.save(all_info, save_path) + else: + main(args, api) diff --git a/exps/NATS-algos/regularized_ea.py b/exps/NATS-algos/regularized_ea.py index be63992..86c21e3 100644 --- a/exps/NATS-algos/regularized_ea.py +++ b/exps/NATS-algos/regularized_ea.py @@ -17,214 +17,242 @@ from copy import deepcopy import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import CellStructure, get_search_spaces -from nats_bench import create +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import CellStructure, get_search_spaces +from nats_bench import create class Model(object): + def __init__(self): + self.arch = None + self.accuracy = None + + def __str__(self): + """Prints a readable version of this bitstring.""" + return "{:}".format(self.arch) - def __init__(self): - self.arch = None - self.accuracy = None - - def __str__(self): - """Prints a readable version of this bitstring.""" - return '{:}'.format(self.arch) - def random_topology_func(op_names, max_nodes=4): - # Return a random architecture - def random_architecture(): - genotypes = [] - for i in range(1, max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - op_name = random.choice( op_names ) - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return CellStructure( genotypes ) - return random_architecture + # Return a random architecture + def random_architecture(): + genotypes = [] + for i in range(1, max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + op_name = random.choice(op_names) + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return CellStructure(genotypes) + + return random_architecture def random_size_func(info): - # Return a random architecture - def random_architecture(): - channels = [] - for i in range(info['numbers']): - channels.append( - str(random.choice(info['candidates']))) - return ':'.join(channels) - return random_architecture + # Return a random architecture + def random_architecture(): + channels = [] + for i in range(info["numbers"]): + channels.append(str(random.choice(info["candidates"]))) + return ":".join(channels) + + return random_architecture def mutate_topology_func(op_names): - """Computes the architecture for a child of the given parent architecture. - The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another. - """ - def mutate_topology_func(parent_arch): - child_arch = deepcopy( parent_arch ) - node_id = random.randint(0, len(child_arch.nodes)-1) - node_info = list( child_arch.nodes[node_id] ) - snode_id = random.randint(0, len(node_info)-1) - xop = random.choice( op_names ) - while xop == node_info[snode_id][0]: - xop = random.choice( op_names ) - node_info[snode_id] = (xop, node_info[snode_id][1]) - child_arch.nodes[node_id] = tuple( node_info ) - return child_arch - return mutate_topology_func + """Computes the architecture for a child of the given parent architecture. + The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another. + """ + + def mutate_topology_func(parent_arch): + child_arch = deepcopy(parent_arch) + node_id = random.randint(0, len(child_arch.nodes) - 1) + node_info = list(child_arch.nodes[node_id]) + snode_id = random.randint(0, len(node_info) - 1) + xop = random.choice(op_names) + while xop == node_info[snode_id][0]: + xop = random.choice(op_names) + node_info[snode_id] = (xop, node_info[snode_id][1]) + child_arch.nodes[node_id] = tuple(node_info) + return child_arch + + return mutate_topology_func def mutate_size_func(info): - """Computes the architecture for a child of the given parent architecture. - The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another. - """ - def mutate_size_func(parent_arch): - child_arch = deepcopy(parent_arch) - child_arch = child_arch.split(':') - index = random.randint(0, len(child_arch)-1) - child_arch[index] = str(random.choice(info['candidates'])) - return ':'.join(child_arch) - return mutate_size_func + """Computes the architecture for a child of the given parent architecture. + The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another. + """ + + def mutate_size_func(parent_arch): + child_arch = deepcopy(parent_arch) + child_arch = child_arch.split(":") + index = random.randint(0, len(child_arch) - 1) + child_arch[index] = str(random.choice(info["candidates"])) + return ":".join(child_arch) + + return mutate_size_func -def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset): - """Algorithm for regularized evolution (i.e. aging evolution). - - Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image - Classifier Architecture Search". - - Args: - cycles: the number of cycles the algorithm should run for. - population_size: the number of individuals to keep in the population. - sample_size: the number of individuals that should participate in each tournament. - time_budget: the upper bound of searching cost +def regularized_evolution( + cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset +): + """Algorithm for regularized evolution (i.e. aging evolution). - Returns: - history: a list of `Model` instances, representing all the models computed - during the evolution experiment. - """ - population = collections.deque() - api.reset_time() - history, total_time_cost = [], [] # Not used by the algorithm, only used to report results. - current_best_index = [] - # Initialize the population with random models. - while len(population) < population_size: - model = Model() - model.arch = random_arch() - model.accuracy, _, _, total_cost = api.simulate_train_eval( - model.arch, dataset, hp='12' if use_proxy else api.full_train_epochs) - # Append the info - population.append(model) - history.append((model.accuracy, model.arch)) - total_time_cost.append(total_cost) - current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1])) + Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image + Classifier Architecture Search". - # Carry out evolution in cycles. Each cycle produces a model and removes another. - while total_time_cost[-1] < time_budget: - # Sample randomly chosen models from the current population. - start_time, sample = time.time(), [] - while len(sample) < sample_size: - # Inefficient, but written this way for clarity. In the case of neural - # nets, the efficiency of this line is irrelevant because training neural - # nets is the rate-determining step. - candidate = random.choice(list(population)) - sample.append(candidate) + Args: + cycles: the number of cycles the algorithm should run for. + population_size: the number of individuals to keep in the population. + sample_size: the number of individuals that should participate in each tournament. + time_budget: the upper bound of searching cost - # The parent is the best model in the sample. - parent = max(sample, key=lambda i: i.accuracy) + Returns: + history: a list of `Model` instances, representing all the models computed + during the evolution experiment. + """ + population = collections.deque() + api.reset_time() + history, total_time_cost = [], [] # Not used by the algorithm, only used to report results. + current_best_index = [] + # Initialize the population with random models. + while len(population) < population_size: + model = Model() + model.arch = random_arch() + model.accuracy, _, _, total_cost = api.simulate_train_eval( + model.arch, dataset, hp="12" if use_proxy else api.full_train_epochs + ) + # Append the info + population.append(model) + history.append((model.accuracy, model.arch)) + total_time_cost.append(total_cost) + current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1])) - # Create the child model and store it. - child = Model() - child.arch = mutate_arch(parent.arch) - child.accuracy, _, _, total_cost = api.simulate_train_eval( - child.arch, dataset, hp='12' if use_proxy else api.full_train_epochs) - # Append the info - population.append(child) - history.append((child.accuracy, child.arch)) - current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1])) - total_time_cost.append(total_cost) + # Carry out evolution in cycles. Each cycle produces a model and removes another. + while total_time_cost[-1] < time_budget: + # Sample randomly chosen models from the current population. + start_time, sample = time.time(), [] + while len(sample) < sample_size: + # Inefficient, but written this way for clarity. In the case of neural + # nets, the efficiency of this line is irrelevant because training neural + # nets is the rate-determining step. + candidate = random.choice(list(population)) + sample.append(candidate) - # Remove the oldest model. - population.popleft() - return history, current_best_index, total_time_cost + # The parent is the best model in the sample. + parent = max(sample, key=lambda i: i.accuracy) + + # Create the child model and store it. + child = Model() + child.arch = mutate_arch(parent.arch) + child.accuracy, _, _, total_cost = api.simulate_train_eval( + child.arch, dataset, hp="12" if use_proxy else api.full_train_epochs + ) + # Append the info + population.append(child) + history.append((child.accuracy, child.arch)) + current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1])) + total_time_cost.append(total_cost) + + # Remove the oldest model. + population.popleft() + return history, current_best_index, total_time_cost def main(xargs, api): - torch.set_num_threads(4) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + torch.set_num_threads(4) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - search_space = get_search_spaces(xargs.search_space, 'nats-bench') - if xargs.search_space == 'tss': - random_arch = random_topology_func(search_space) - mutate_arch = mutate_topology_func(search_space) - else: - random_arch = random_size_func(search_space) - mutate_arch = mutate_size_func(search_space) + search_space = get_search_spaces(xargs.search_space, "nats-bench") + if xargs.search_space == "tss": + random_arch = random_topology_func(search_space) + mutate_arch = mutate_topology_func(search_space) + else: + random_arch = random_size_func(search_space) + mutate_arch = mutate_size_func(search_space) - x_start_time = time.time() - logger.log('{:} use api : {:}'.format(time_string(), api)) - logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) - history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, - xargs.ea_population, - xargs.ea_sample_size, - xargs.time_budget, - random_arch, mutate_arch, api, xargs.use_proxy > 0, xargs.dataset) - logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time)) - best_arch = max(history, key=lambda x: x[0])[1] - logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) - - info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90') - logger.log('{:}'.format(info)) - logger.log('-'*100) - logger.close() - return logger.log_dir, current_best_index, total_times + x_start_time = time.time() + logger.log("{:} use api : {:}".format(time_string(), api)) + logger.log("-" * 30 + " start searching with the time budget of {:} s".format(xargs.time_budget)) + history, current_best_index, total_times = regularized_evolution( + xargs.ea_cycles, + xargs.ea_population, + xargs.ea_sample_size, + xargs.time_budget, + random_arch, + mutate_arch, + api, + xargs.use_proxy > 0, + xargs.dataset, + ) + logger.log( + "{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).".format( + time_string(), len(history), total_times[-1], time.time() - x_start_time + ) + ) + best_arch = max(history, key=lambda x: x[0])[1] + logger.log("{:} best arch is {:}".format(time_string(), best_arch)) + + info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90") + logger.log("{:}".format(info)) + logger.log("-" * 100) + logger.close() + return logger.log_dir, current_best_index, total_times -if __name__ == '__main__': - parser = argparse.ArgumentParser("Regularized Evolution Algorithm") - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') - # hyperparameters for REA - parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.') - parser.add_argument('--ea_population', type=int, help='The population size in EA.') - parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.') - parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).') - parser.add_argument('--use_proxy', type=int, default=1, help='Whether to use the proxy (H0) task or not.') - # - parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') - # log - parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') - parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser("Regularized Evolution Algorithm") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") + # hyperparameters for REA + parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.") + parser.add_argument("--ea_population", type=int, help="The population size in EA.") + parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.") + parser.add_argument( + "--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)." + ) + parser.add_argument("--use_proxy", type=int, default=1, help="Whether to use the proxy (H0) task or not.") + # + parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.") + # log + parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() - api = create(None, args.search_space, fast_mode=True, verbose=False) + api = create(None, args.search_space, fast_mode=True, verbose=False) - args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), - '{:}-T{:}{:}'.format(args.dataset, args.time_budget, '' if args.use_proxy > 0 else '-FULL'), - 'R-EA-SS{:}'.format(args.ea_sample_size)) - print('save-dir : {:}'.format(args.save_dir)) - print('xargs : {:}'.format(args)) + args.save_dir = os.path.join( + "{:}-{:}".format(args.save_dir, args.search_space), + "{:}-T{:}{:}".format(args.dataset, args.time_budget, "" if args.use_proxy > 0 else "-FULL"), + "R-EA-SS{:}".format(args.ea_sample_size), + ) + print("save-dir : {:}".format(args.save_dir)) + print("xargs : {:}".format(args)) - if args.rand_seed < 0: - save_dir, all_info = None, collections.OrderedDict() - for i in range(args.loops_if_rand): - print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand)) - args.rand_seed = random.randint(1, 100000) - save_dir, all_archs, all_total_times = main(args, api) - all_info[i] = {'all_archs': all_archs, - 'all_total_times': all_total_times} - save_path = save_dir / 'results.pth' - print('save into {:}'.format(save_path)) - torch.save(all_info, save_path) - else: - main(args, api) + if args.rand_seed < 0: + save_dir, all_info = None, collections.OrderedDict() + for i in range(args.loops_if_rand): + print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand)) + args.rand_seed = random.randint(1, 100000) + save_dir, all_archs, all_total_times = main(args, api) + all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times} + save_path = save_dir / "results.pth" + print("save into {:}".format(save_path)) + torch.save(all_info, save_path) + else: + main(args, api) diff --git a/exps/NATS-algos/reinforce.py b/exps/NATS-algos/reinforce.py index 91778dc..1280ff8 100644 --- a/exps/NATS-algos/reinforce.py +++ b/exps/NATS-algos/reinforce.py @@ -3,12 +3,12 @@ ##################################################################################################### # modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py # ##################################################################################################### -# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.01 -# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01 -# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.01 -# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.01 -# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01 -# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.01 +# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.01 +# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01 +# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.01 +# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.01 +# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01 +# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.01 ##################################################################################################### import os, sys, time, glob, random, argparse import numpy as np, collections @@ -17,197 +17,216 @@ from pathlib import Path import torch import torch.nn as nn from torch.distributions import Categorical -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import CellStructure, get_search_spaces -from nats_bench import create +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import CellStructure, get_search_spaces +from nats_bench import create class PolicyTopology(nn.Module): + def __init__(self, search_space, max_nodes=4): + super(PolicyTopology, self).__init__() + self.max_nodes = max_nodes + self.search_space = deepcopy(search_space) + self.edge2index = {} + for i in range(1, max_nodes): + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + self.edge2index[node_str] = len(self.edge2index) + self.arch_parameters = nn.Parameter(1e-3 * torch.randn(len(self.edge2index), len(search_space))) - def __init__(self, search_space, max_nodes=4): - super(PolicyTopology, self).__init__() - self.max_nodes = max_nodes - self.search_space = deepcopy(search_space) - self.edge2index = {} - for i in range(1, max_nodes): - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - self.edge2index[ node_str ] = len(self.edge2index) - self.arch_parameters = nn.Parameter(1e-3*torch.randn(len(self.edge2index), len(search_space))) + def generate_arch(self, actions): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + op_name = self.search_space[actions[self.edge2index[node_str]]] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return CellStructure(genotypes) - def generate_arch(self, actions): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - op_name = self.search_space[ actions[ self.edge2index[ node_str ] ] ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return CellStructure( genotypes ) + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[self.edge2index[node_str]] + op_name = self.search_space[weights.argmax().item()] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return CellStructure(genotypes) - def genotype(self): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - with torch.no_grad(): - weights = self.arch_parameters[ self.edge2index[node_str] ] - op_name = self.search_space[ weights.argmax().item() ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return CellStructure( genotypes ) - - def forward(self): - alphas = nn.functional.softmax(self.arch_parameters, dim=-1) - return alphas + def forward(self): + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + return alphas class PolicySize(nn.Module): + def __init__(self, search_space): + super(PolicySize, self).__init__() + self.candidates = search_space["candidates"] + self.numbers = search_space["numbers"] + self.arch_parameters = nn.Parameter(1e-3 * torch.randn(self.numbers, len(self.candidates))) - def __init__(self, search_space): - super(PolicySize, self).__init__() - self.candidates = search_space['candidates'] - self.numbers = search_space['numbers'] - self.arch_parameters = nn.Parameter(1e-3*torch.randn(self.numbers, len(self.candidates))) + def generate_arch(self, actions): + channels = [str(self.candidates[i]) for i in actions] + return ":".join(channels) - def generate_arch(self, actions): - channels = [str(self.candidates[i]) for i in actions] - return ':'.join(channels) + def genotype(self): + channels = [] + for i in range(self.numbers): + index = self.arch_parameters[i].argmax().item() + channels.append(str(self.candidates[index])) + return ":".join(channels) - def genotype(self): - channels = [] - for i in range(self.numbers): - index = self.arch_parameters[i].argmax().item() - channels.append(str(self.candidates[index])) - return ':'.join(channels) - - def forward(self): - alphas = nn.functional.softmax(self.arch_parameters, dim=-1) - return alphas + def forward(self): + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + return alphas class ExponentialMovingAverage(object): - """Class that maintains an exponential moving average.""" + """Class that maintains an exponential moving average.""" - def __init__(self, momentum): - self._numerator = 0 - self._denominator = 0 - self._momentum = momentum + def __init__(self, momentum): + self._numerator = 0 + self._denominator = 0 + self._momentum = momentum - def update(self, value): - self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value - self._denominator = self._momentum * self._denominator + (1 - self._momentum) + def update(self, value): + self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value + self._denominator = self._momentum * self._denominator + (1 - self._momentum) - def value(self): - """Return the current value of the moving average""" - return self._numerator / self._denominator + def value(self): + """Return the current value of the moving average""" + return self._numerator / self._denominator def select_action(policy): - probs = policy() - m = Categorical(probs) - action = m.sample() - # policy.saved_log_probs.append(m.log_prob(action)) - return m.log_prob(action), action.cpu().tolist() + probs = policy() + m = Categorical(probs) + action = m.sample() + # policy.saved_log_probs.append(m.log_prob(action)) + return m.log_prob(action), action.cpu().tolist() def main(xargs, api): - torch.set_num_threads(4) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) - - search_space = get_search_spaces(xargs.search_space, 'nats-bench') - if xargs.search_space == 'tss': - policy = PolicyTopology(search_space) - else: - policy = PolicySize(search_space) - optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate) - #optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate) - eps = np.finfo(np.float32).eps.item() - baseline = ExponentialMovingAverage(xargs.EMA_momentum) - logger.log('policy : {:}'.format(policy)) - logger.log('optimizer : {:}'.format(optimizer)) - logger.log('eps : {:}'.format(eps)) + torch.set_num_threads(4) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - # nas dataset load - logger.log('{:} use api : {:}'.format(time_string(), api)) - api.reset_time() + search_space = get_search_spaces(xargs.search_space, "nats-bench") + if xargs.search_space == "tss": + policy = PolicyTopology(search_space) + else: + policy = PolicySize(search_space) + optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate) + # optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate) + eps = np.finfo(np.float32).eps.item() + baseline = ExponentialMovingAverage(xargs.EMA_momentum) + logger.log("policy : {:}".format(policy)) + logger.log("optimizer : {:}".format(optimizer)) + logger.log("eps : {:}".format(eps)) - # REINFORCE - x_start_time = time.time() - logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget)) - total_steps, total_costs, trace = 0, [], [] - current_best_index = [] - while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget: - start_time = time.time() - log_prob, action = select_action( policy ) - arch = policy.generate_arch( action ) - reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12') - trace.append((reward, arch)) - total_costs.append(current_total_cost) + # nas dataset load + logger.log("{:} use api : {:}".format(time_string(), api)) + api.reset_time() - baseline.update(reward) - # calculate loss - policy_loss = ( -log_prob * (reward - baseline.value()) ).sum() - optimizer.zero_grad() - policy_loss.backward() - optimizer.step() - # accumulate time - total_steps += 1 - logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype())) - # to analyze - current_best_index.append(api.query_index_by_arch(max(trace, key=lambda x: x[0])[1])) - # best_arch = policy.genotype() # first version - best_arch = max(trace, key=lambda x: x[0])[1] - logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs[-1], time.time()-x_start_time)) - info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90') - logger.log('{:}'.format(info)) - logger.log('-'*100) - logger.close() + # REINFORCE + x_start_time = time.time() + logger.log("Will start searching with time budget of {:} s.".format(xargs.time_budget)) + total_steps, total_costs, trace = 0, [], [] + current_best_index = [] + while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget: + start_time = time.time() + log_prob, action = select_action(policy) + arch = policy.generate_arch(action) + reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, hp="12") + trace.append((reward, arch)) + total_costs.append(current_total_cost) - return logger.log_dir, current_best_index, total_costs + baseline.update(reward) + # calculate loss + policy_loss = (-log_prob * (reward - baseline.value())).sum() + optimizer.zero_grad() + policy_loss.backward() + optimizer.step() + # accumulate time + total_steps += 1 + logger.log( + "step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}".format( + total_steps, baseline.value(), policy_loss.item(), policy.genotype() + ) + ) + # to analyze + current_best_index.append(api.query_index_by_arch(max(trace, key=lambda x: x[0])[1])) + # best_arch = policy.genotype() # first version + best_arch = max(trace, key=lambda x: x[0])[1] + logger.log( + "REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).".format( + total_steps, total_costs[-1], time.time() - x_start_time + ) + ) + info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90") + logger.log("{:}".format(info)) + logger.log("-" * 100) + logger.close() + + return logger.log_dir, current_best_index, total_costs -if __name__ == '__main__': - parser = argparse.ArgumentParser("The REINFORCE Algorithm") - parser.add_argument('--data_path', type=str, help='Path to dataset') - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') - parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.') - parser.add_argument('--EMA_momentum', type=float, default=0.9, help='The momentum value for EMA.') - parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).') - parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') - # log - parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') - parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') - parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser("The REINFORCE Algorithm") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") + parser.add_argument("--learning_rate", type=float, help="The learning rate for REINFORCE.") + parser.add_argument("--EMA_momentum", type=float, default=0.9, help="The momentum value for EMA.") + parser.add_argument( + "--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)." + ) + parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.") + # log + parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") + parser.add_argument( + "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + ) + parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() - api = create(None, args.search_space, fast_mode=True, verbose=False) + api = create(None, args.search_space, fast_mode=True, verbose=False) - args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), - '{:}-T{:}'.format(args.dataset, args.time_budget), 'REINFORCE-{:}'.format(args.learning_rate)) - print('save-dir : {:}'.format(args.save_dir)) + args.save_dir = os.path.join( + "{:}-{:}".format(args.save_dir, args.search_space), + "{:}-T{:}".format(args.dataset, args.time_budget), + "REINFORCE-{:}".format(args.learning_rate), + ) + print("save-dir : {:}".format(args.save_dir)) - if args.rand_seed < 0: - save_dir, all_info = None, collections.OrderedDict() - for i in range(args.loops_if_rand): - print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand)) - args.rand_seed = random.randint(1, 100000) - save_dir, all_archs, all_total_times = main(args, api) - all_info[i] = {'all_archs': all_archs, - 'all_total_times': all_total_times} - save_path = save_dir / 'results.pth' - print('save into {:}'.format(save_path)) - torch.save(all_info, save_path) - else: - main(args, api) + if args.rand_seed < 0: + save_dir, all_info = None, collections.OrderedDict() + for i in range(args.loops_if_rand): + print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand)) + args.rand_seed = random.randint(1, 100000) + save_dir, all_archs, all_total_times = main(args, api) + all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times} + save_path = save_dir / "results.pth" + print("save into {:}".format(save_path)) + torch.save(all_info, save_path) + else: + main(args, api) diff --git a/exps/NATS-algos/search-cell.py b/exps/NATS-algos/search-cell.py index cedd3cb..fd51703 100644 --- a/exps/NATS-algos/search-cell.py +++ b/exps/NATS-algos/search-cell.py @@ -31,494 +31,637 @@ from copy import deepcopy import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import count_parameters_in_MB, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces -from nats_bench import create +from datasets import get_datasets, get_nas_search_loaders +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import count_parameters_in_MB, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_cell_based_tiny_net, get_search_spaces +from nats_bench import create # The following three functions are used for DARTS-V2 def _concat(xs): - return torch.cat([x.view(-1) for x in xs]) + return torch.cat([x.view(-1) for x in xs]) def _hessian_vector_product(vector, network, criterion, base_inputs, base_targets, r=1e-2): - R = r / _concat(vector).norm() - for p, v in zip(network.weights, vector): - p.data.add_(R, v) - _, logits = network(base_inputs) - loss = criterion(logits, base_targets) - grads_p = torch.autograd.grad(loss, network.alphas) + R = r / _concat(vector).norm() + for p, v in zip(network.weights, vector): + p.data.add_(R, v) + _, logits = network(base_inputs) + loss = criterion(logits, base_targets) + grads_p = torch.autograd.grad(loss, network.alphas) - for p, v in zip(network.weights, vector): - p.data.sub_(2*R, v) - _, logits = network(base_inputs) - loss = criterion(logits, base_targets) - grads_n = torch.autograd.grad(loss, network.alphas) + for p, v in zip(network.weights, vector): + p.data.sub_(2 * R, v) + _, logits = network(base_inputs) + loss = criterion(logits, base_targets) + grads_n = torch.autograd.grad(loss, network.alphas) - for p, v in zip(network.weights, vector): - p.data.add_(R, v) - return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)] + for p, v in zip(network.weights, vector): + p.data.add_(R, v) + return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)] def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets): - # _compute_unrolled_model - _, logits = network(base_inputs) - loss = criterion(logits, base_targets) - LR, WD, momentum = w_optimizer.param_groups[0]['lr'], w_optimizer.param_groups[0]['weight_decay'], w_optimizer.param_groups[0]['momentum'] - with torch.no_grad(): - theta = _concat(network.weights) - try: - moment = _concat(w_optimizer.state[v]['momentum_buffer'] for v in network.weights) - moment = moment.mul_(momentum) - except: - moment = torch.zeros_like(theta) - dtheta = _concat(torch.autograd.grad(loss, network.weights)) + WD*theta - params = theta.sub(LR, moment+dtheta) - unrolled_model = deepcopy(network) - model_dict = unrolled_model.state_dict() - new_params, offset = {}, 0 - for k, v in network.named_parameters(): - if 'arch_parameters' in k: continue - v_length = np.prod(v.size()) - new_params[k] = params[offset: offset+v_length].view(v.size()) - offset += v_length - model_dict.update(new_params) - unrolled_model.load_state_dict(model_dict) + # _compute_unrolled_model + _, logits = network(base_inputs) + loss = criterion(logits, base_targets) + LR, WD, momentum = ( + w_optimizer.param_groups[0]["lr"], + w_optimizer.param_groups[0]["weight_decay"], + w_optimizer.param_groups[0]["momentum"], + ) + with torch.no_grad(): + theta = _concat(network.weights) + try: + moment = _concat(w_optimizer.state[v]["momentum_buffer"] for v in network.weights) + moment = moment.mul_(momentum) + except: + moment = torch.zeros_like(theta) + dtheta = _concat(torch.autograd.grad(loss, network.weights)) + WD * theta + params = theta.sub(LR, moment + dtheta) + unrolled_model = deepcopy(network) + model_dict = unrolled_model.state_dict() + new_params, offset = {}, 0 + for k, v in network.named_parameters(): + if "arch_parameters" in k: + continue + v_length = np.prod(v.size()) + new_params[k] = params[offset : offset + v_length].view(v.size()) + offset += v_length + model_dict.update(new_params) + unrolled_model.load_state_dict(model_dict) - unrolled_model.zero_grad() - _, unrolled_logits = unrolled_model(arch_inputs) - unrolled_loss = criterion(unrolled_logits, arch_targets) - unrolled_loss.backward() + unrolled_model.zero_grad() + _, unrolled_logits = unrolled_model(arch_inputs) + unrolled_loss = criterion(unrolled_logits, arch_targets) + unrolled_loss.backward() - dalpha = unrolled_model.arch_parameters.grad - vector = [v.grad.data for v in unrolled_model.weights] - [implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets) - - dalpha.data.sub_(LR, implicit_grads.data) + dalpha = unrolled_model.arch_parameters.grad + vector = [v.grad.data for v in unrolled_model.weights] + [implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets) - if network.arch_parameters.grad is None: - network.arch_parameters.grad = deepcopy( dalpha ) - else: - network.arch_parameters.grad.data.copy_( dalpha.data ) - return unrolled_loss.detach(), unrolled_logits.detach() + dalpha.data.sub_(LR, implicit_grads.data) + + if network.arch_parameters.grad is None: + network.arch_parameters.grad = deepcopy(dalpha) + else: + network.arch_parameters.grad.data.copy_(dalpha.data) + return unrolled_loss.detach(), unrolled_logits.detach() def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, algo, logger): - data_time, batch_time = AverageMeter(), AverageMeter() - base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - end = time.time() - network.train() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): - scheduler.update(None, 1.0 * step / len(xloader)) - base_inputs = base_inputs.cuda(non_blocking=True) - arch_inputs = arch_inputs.cuda(non_blocking=True) - base_targets = base_targets.cuda(non_blocking=True) - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - - # Update the weights - if algo == 'setn': - sampled_arch = network.dync_genotype(True) - network.set_cal_mode('dynamic', sampled_arch) - elif algo == 'gdas': - network.set_cal_mode('gdas', None) - elif algo.startswith('darts'): - network.set_cal_mode('joint', None) - elif algo == 'random': - network.set_cal_mode('urs', None) - elif algo == 'enas': - with torch.no_grad(): - network.controller.eval() - _, _, sampled_arch = network.controller() - network.set_cal_mode('dynamic', sampled_arch) - else: - raise ValueError('Invalid algo name : {:}'.format(algo)) - - network.zero_grad() - _, logits = network(base_inputs) - base_loss = criterion(logits, base_targets) - base_loss.backward() - w_optimizer.step() - # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) - base_losses.update(base_loss.item(), base_inputs.size(0)) - base_top1.update (base_prec1.item(), base_inputs.size(0)) - base_top5.update (base_prec5.item(), base_inputs.size(0)) - - # update the architecture-weight - if algo == 'setn': - network.set_cal_mode('joint') - elif algo == 'gdas': - network.set_cal_mode('gdas', None) - elif algo.startswith('darts'): - network.set_cal_mode('joint', None) - elif algo == 'random': - network.set_cal_mode('urs', None) - elif algo != 'enas': - raise ValueError('Invalid algo name : {:}'.format(algo)) - network.zero_grad() - if algo == 'darts-v2': - arch_loss, logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets) - a_optimizer.step() - elif algo == 'random' or algo == 'enas': - with torch.no_grad(): - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - else: - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - arch_loss.backward() - a_optimizer.step() - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) + data_time, batch_time = AverageMeter(), AverageMeter() + base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() end = time.time() + network.train() + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + scheduler.update(None, 1.0 * step / len(xloader)) + base_inputs = base_inputs.cuda(non_blocking=True) + arch_inputs = arch_inputs.cuda(non_blocking=True) + base_targets = base_targets.cuda(non_blocking=True) + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) - if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) - Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) - Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) - Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) - logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) - return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg + # Update the weights + if algo == "setn": + sampled_arch = network.dync_genotype(True) + network.set_cal_mode("dynamic", sampled_arch) + elif algo == "gdas": + network.set_cal_mode("gdas", None) + elif algo.startswith("darts"): + network.set_cal_mode("joint", None) + elif algo == "random": + network.set_cal_mode("urs", None) + elif algo == "enas": + with torch.no_grad(): + network.controller.eval() + _, _, sampled_arch = network.controller() + network.set_cal_mode("dynamic", sampled_arch) + else: + raise ValueError("Invalid algo name : {:}".format(algo)) + + network.zero_grad() + _, logits = network(base_inputs) + base_loss = criterion(logits, base_targets) + base_loss.backward() + w_optimizer.step() + # record + base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_losses.update(base_loss.item(), base_inputs.size(0)) + base_top1.update(base_prec1.item(), base_inputs.size(0)) + base_top5.update(base_prec5.item(), base_inputs.size(0)) + + # update the architecture-weight + if algo == "setn": + network.set_cal_mode("joint") + elif algo == "gdas": + network.set_cal_mode("gdas", None) + elif algo.startswith("darts"): + network.set_cal_mode("joint", None) + elif algo == "random": + network.set_cal_mode("urs", None) + elif algo != "enas": + raise ValueError("Invalid algo name : {:}".format(algo)) + network.zero_grad() + if algo == "darts-v2": + arch_loss, logits = backward_step_unrolled( + network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets + ) + a_optimizer.step() + elif algo == "random" or algo == "enas": + with torch.no_grad(): + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + else: + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + arch_loss.backward() + a_optimizer.step() + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if step % print_freq == 0 or step + 1 == len(xloader): + Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( + batch_time=batch_time, data_time=data_time + ) + Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=base_losses, top1=base_top1, top5=base_top5 + ) + Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=arch_losses, top1=arch_top1, top5=arch_top5 + ) + logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) + return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg def train_controller(xloader, network, criterion, optimizer, prev_baseline, epoch_str, print_freq, logger): - # config. (containing some necessary arg) - # baseline: The baseline score (i.e. average val_acc) from the previous epoch - data_time, batch_time = AverageMeter(), AverageMeter() - GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time() - - controller_num_aggregate = 20 - controller_train_steps = 50 - controller_bl_dec = 0.99 - controller_entropy_weight = 0.0001 + # config. (containing some necessary arg) + # baseline: The baseline score (i.e. average val_acc) from the previous epoch + data_time, batch_time = AverageMeter(), AverageMeter() + GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = ( + AverageMeter(), + AverageMeter(), + AverageMeter(), + AverageMeter(), + AverageMeter(), + AverageMeter(), + time.time(), + ) - network.eval() - network.controller.train() - network.controller.zero_grad() - loader_iter = iter(xloader) - for step in range(controller_train_steps * controller_num_aggregate): - try: - inputs, targets = next(loader_iter) - except: - loader_iter = iter(xloader) - inputs, targets = next(loader_iter) - inputs = inputs.cuda(non_blocking=True) - targets = targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - xend) - - log_prob, entropy, sampled_arch = network.controller() - with torch.no_grad(): - network.set_cal_mode('dynamic', sampled_arch) - _, logits = network(inputs) - val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) - val_top1 = val_top1.view(-1) / 100 - reward = val_top1 + controller_entropy_weight * entropy - if prev_baseline is None: - baseline = val_top1 - else: - baseline = prev_baseline - (1 - controller_bl_dec) * (prev_baseline - reward) - - loss = -1 * log_prob * (reward - baseline) - - # account - RewardMeter.update(reward.item()) - BaselineMeter.update(baseline.item()) - ValAccMeter.update(val_top1.item()*100) - LossMeter.update(loss.item()) - EntropyMeter.update(entropy.item()) - - # Average gradient over controller_num_aggregate samples - loss = loss / controller_num_aggregate - loss.backward(retain_graph=True) + controller_num_aggregate = 20 + controller_train_steps = 50 + controller_bl_dec = 0.99 + controller_entropy_weight = 0.0001 - # measure elapsed time - batch_time.update(time.time() - xend) - xend = time.time() - if (step+1) % controller_num_aggregate == 0: - grad_norm = torch.nn.utils.clip_grad_norm_(network.controller.parameters(), 5.0) - GradnormMeter.update(grad_norm) - optimizer.step() - network.controller.zero_grad() + network.eval() + network.controller.train() + network.controller.zero_grad() + loader_iter = iter(xloader) + for step in range(controller_train_steps * controller_num_aggregate): + try: + inputs, targets = next(loader_iter) + except: + loader_iter = iter(xloader) + inputs, targets = next(loader_iter) + inputs = inputs.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - xend) - if step % print_freq == 0: - Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, controller_train_steps * controller_num_aggregate) - Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) - Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter) - Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg) - logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr) + log_prob, entropy, sampled_arch = network.controller() + with torch.no_grad(): + network.set_cal_mode("dynamic", sampled_arch) + _, logits = network(inputs) + val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) + val_top1 = val_top1.view(-1) / 100 + reward = val_top1 + controller_entropy_weight * entropy + if prev_baseline is None: + baseline = val_top1 + else: + baseline = prev_baseline - (1 - controller_bl_dec) * (prev_baseline - reward) - return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg + loss = -1 * log_prob * (reward - baseline) + + # account + RewardMeter.update(reward.item()) + BaselineMeter.update(baseline.item()) + ValAccMeter.update(val_top1.item() * 100) + LossMeter.update(loss.item()) + EntropyMeter.update(entropy.item()) + + # Average gradient over controller_num_aggregate samples + loss = loss / controller_num_aggregate + loss.backward(retain_graph=True) + + # measure elapsed time + batch_time.update(time.time() - xend) + xend = time.time() + if (step + 1) % controller_num_aggregate == 0: + grad_norm = torch.nn.utils.clip_grad_norm_(network.controller.parameters(), 5.0) + GradnormMeter.update(grad_norm) + optimizer.step() + network.controller.zero_grad() + + if step % print_freq == 0: + Sstr = ( + "*Train-Controller* " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, controller_train_steps * controller_num_aggregate) + ) + Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( + batch_time=batch_time, data_time=data_time + ) + Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})".format( + loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter + ) + Estr = "Entropy={:.4f} ({:.4f})".format(EntropyMeter.val, EntropyMeter.avg) + logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Estr) + + return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg def get_best_arch(xloader, network, n_samples, algo): - with torch.no_grad(): - network.eval() - if algo == 'random': - archs, valid_accs = network.return_topK(n_samples, True), [] - elif algo == 'setn': - archs, valid_accs = network.return_topK(n_samples, False), [] - elif algo.startswith('darts') or algo == 'gdas': - arch = network.genotype - archs, valid_accs = [arch], [] - elif algo == 'enas': - archs, valid_accs = [], [] - for _ in range(n_samples): - _, _, sampled_arch = network.controller() - archs.append(sampled_arch) - else: - raise ValueError('Invalid algorithm name : {:}'.format(algo)) - loader_iter = iter(xloader) - for i, sampled_arch in enumerate(archs): - network.set_cal_mode('dynamic', sampled_arch) - try: - inputs, targets = next(loader_iter) - except: + with torch.no_grad(): + network.eval() + if algo == "random": + archs, valid_accs = network.return_topK(n_samples, True), [] + elif algo == "setn": + archs, valid_accs = network.return_topK(n_samples, False), [] + elif algo.startswith("darts") or algo == "gdas": + arch = network.genotype + archs, valid_accs = [arch], [] + elif algo == "enas": + archs, valid_accs = [], [] + for _ in range(n_samples): + _, _, sampled_arch = network.controller() + archs.append(sampled_arch) + else: + raise ValueError("Invalid algorithm name : {:}".format(algo)) loader_iter = iter(xloader) - inputs, targets = next(loader_iter) - _, logits = network(inputs.cuda(non_blocking=True)) - val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) - valid_accs.append(val_top1.item()) - best_idx = np.argmax(valid_accs) - best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] - return best_arch, best_valid_acc + for i, sampled_arch in enumerate(archs): + network.set_cal_mode("dynamic", sampled_arch) + try: + inputs, targets = next(loader_iter) + except: + loader_iter = iter(xloader) + inputs, targets = next(loader_iter) + _, logits = network(inputs.cuda(non_blocking=True)) + val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) + valid_accs.append(val_top1.item()) + best_idx = np.argmax(valid_accs) + best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] + return best_arch, best_valid_acc def valid_func(xloader, network, criterion, algo, logger): - data_time, batch_time = AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - end = time.time() - with torch.no_grad(): - network.eval() - for step, (arch_inputs, arch_targets) in enumerate(xloader): - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - # prediction - _, logits = network(arch_inputs.cuda(non_blocking=True)) - arch_loss = criterion(logits, arch_targets) - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - return arch_losses.avg, arch_top1.avg, arch_top5.avg + data_time, batch_time = AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() + end = time.time() + with torch.no_grad(): + network.eval() + for step, (arch_inputs, arch_targets) in enumerate(xloader): + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) + # prediction + _, logits = network(arch_inputs.cuda(non_blocking=True)) + arch_loss = criterion(logits, arch_targets) + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + return arch_losses.avg, arch_top1.avg, arch_top5.avg def main(xargs): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - if xargs.overwite_epochs is None: - extra_info = {'class_num': class_num, 'xshape': xshape} - else: - extra_info = {'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs} - config = load_config(xargs.config_path, extra_info, logger) - search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', (config.batch_size, config.test_batch_size), xargs.workers) - logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - - search_space = get_search_spaces(xargs.search_space, 'nats-bench') - - model_config = dict2config( - dict(name='generic', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num, - space=search_space, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None) - logger.log('search space : {:}'.format(search_space)) - logger.log('model config : {:}'.format(model_config)) - search_model = get_cell_based_tiny_net(model_config) - search_model.set_algo(xargs.algo) - logger.log('{:}'.format(search_model)) - - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) - a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps) - logger.log('w-optimizer : {:}'.format(w_optimizer)) - logger.log('a-optimizer : {:}'.format(a_optimizer)) - logger.log('w-scheduler : {:}'.format(w_scheduler)) - logger.log('criterion : {:}'.format(criterion)) - params = count_parameters_in_MB(search_model) - logger.log('The parameters of the search model = {:.2f} MB'.format(params)) - logger.log('search-space : {:}'.format(search_space)) - if bool(xargs.use_api): - api = create(None, 'topology', fast_mode=True, verbose=False) - else: - api = None - logger.log('{:} create API = {:} done'.format(time_string(), api)) - - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU - - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - - if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) - last_info = torch.load(last_info) - start_epoch = last_info['epoch'] - checkpoint = torch.load(last_info['last_checkpoint']) - genotypes = checkpoint['genotypes'] - baseline = checkpoint['baseline'] - valid_accuracies = checkpoint['valid_accuracies'] - search_model.load_state_dict( checkpoint['search_model'] ) - w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) - w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) - a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: network.return_topK(1, True)[0]} - baseline = None - - # start training - start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup - for epoch in range(start_epoch, total_epoch): - w_scheduler.update(epoch, 0.0) - need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True)) - epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) - logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) - - network.set_drop_path(float(epoch+1) / total_epoch, xargs.drop_path_rate) - if xargs.algo == 'gdas': - network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) ) - logger.log('[RESET tau as : {:} and drop_path as {:}]'.format(network.tau, network.drop_path)) - search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ - = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger) - search_time.update(time.time() - start_time) - logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) - logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) - if xargs.algo == 'enas': - ctl_loss, ctl_acc, baseline, ctl_reward \ - = train_controller(valid_loader, network, criterion, a_optimizer, baseline, epoch_str, xargs.print_freq, logger) - logger.log('[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}'.format(epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward)) - - genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo) - if xargs.algo == 'setn' or xargs.algo == 'enas': - network.set_cal_mode('dynamic', genotype) - elif xargs.algo == 'gdas': - network.set_cal_mode('gdas', None) - elif xargs.algo.startswith('darts'): - network.set_cal_mode('joint', None) - elif xargs.algo == 'random': - network.set_cal_mode('urs', None) + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + if xargs.overwite_epochs is None: + extra_info = {"class_num": class_num, "xshape": xshape} else: - raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo)) - logger.log('[{:}] - [get_best_arch] : {:} -> {:}'.format(epoch_str, genotype, temp_accuracy)) - valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger) - logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) - valid_accuracies[epoch] = valid_a_top1 + extra_info = {"class_num": class_num, "xshape": xshape, "epochs": xargs.overwite_epochs} + config = load_config(xargs.config_path, extra_info, logger) + search_loader, train_loader, valid_loader = get_nas_search_loaders( + train_data, + valid_data, + xargs.dataset, + "configs/nas-benchmark/", + (config.batch_size, config.test_batch_size), + xargs.workers, + ) + logger.log( + "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(search_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - genotypes[epoch] = genotype - logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) - # save checkpoint - save_path = save_checkpoint({'epoch' : epoch + 1, - 'args' : deepcopy(xargs), - 'baseline' : baseline, - 'search_model': search_model.state_dict(), - 'w_optimizer' : w_optimizer.state_dict(), - 'a_optimizer' : a_optimizer.state_dict(), - 'w_scheduler' : w_scheduler.state_dict(), - 'genotypes' : genotypes, - 'valid_accuracies' : valid_accuracies}, - model_base_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch + 1, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) - with torch.no_grad(): - logger.log('{:}'.format(search_model.show_alphas())) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) - # measure elapsed time - epoch_time.update(time.time() - start_time) + search_space = get_search_spaces(xargs.search_space, "nats-bench") + + model_config = dict2config( + dict( + name="generic", + C=xargs.channel, + N=xargs.num_cells, + max_nodes=xargs.max_nodes, + num_classes=class_num, + space=search_space, + affine=bool(xargs.affine), + track_running_stats=bool(xargs.track_running_stats), + ), + None, + ) + logger.log("search space : {:}".format(search_space)) + logger.log("model config : {:}".format(model_config)) + search_model = get_cell_based_tiny_net(model_config) + search_model.set_algo(xargs.algo) + logger.log("{:}".format(search_model)) + + w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) + a_optimizer = torch.optim.Adam( + search_model.alphas, + lr=xargs.arch_learning_rate, + betas=(0.5, 0.999), + weight_decay=xargs.arch_weight_decay, + eps=xargs.arch_eps, + ) + logger.log("w-optimizer : {:}".format(w_optimizer)) + logger.log("a-optimizer : {:}".format(a_optimizer)) + logger.log("w-scheduler : {:}".format(w_scheduler)) + logger.log("criterion : {:}".format(criterion)) + params = count_parameters_in_MB(search_model) + logger.log("The parameters of the search model = {:.2f} MB".format(params)) + logger.log("search-space : {:}".format(search_space)) + if bool(xargs.use_api): + api = create(None, "topology", fast_mode=True, verbose=False) + else: + api = None + logger.log("{:} create API = {:} done".format(time_string(), api)) + + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU + + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + + if last_info.exists(): # automatically resume from previous checkpoint + logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + last_info = torch.load(last_info) + start_epoch = last_info["epoch"] + checkpoint = torch.load(last_info["last_checkpoint"]) + genotypes = checkpoint["genotypes"] + baseline = checkpoint["baseline"] + valid_accuracies = checkpoint["valid_accuracies"] + search_model.load_state_dict(checkpoint["search_model"]) + w_scheduler.load_state_dict(checkpoint["w_scheduler"]) + w_optimizer.load_state_dict(checkpoint["w_optimizer"]) + a_optimizer.load_state_dict(checkpoint["a_optimizer"]) + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + ) + else: + logger.log("=> do not find the last-info file : {:}".format(last_info)) + start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: network.return_topK(1, True)[0]} + baseline = None + + # start training + start_time, search_time, epoch_time, total_epoch = ( + time.time(), + AverageMeter(), + AverageMeter(), + config.epochs + config.warmup, + ) + for epoch in range(start_epoch, total_epoch): + w_scheduler.update(epoch, 0.0) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) + logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min(w_scheduler.get_lr()))) + + network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate) + if xargs.algo == "gdas": + network.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) + logger.log("[RESET tau as : {:} and drop_path as {:}]".format(network.tau, network.drop_path)) + search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 = search_func( + search_loader, + network, + criterion, + w_scheduler, + w_optimizer, + a_optimizer, + epoch_str, + xargs.print_freq, + xargs.algo, + logger, + ) + search_time.update(time.time() - start_time) + logger.log( + "[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( + epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum + ) + ) + logger.log( + "[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( + epoch_str, search_a_loss, search_a_top1, search_a_top5 + ) + ) + if xargs.algo == "enas": + ctl_loss, ctl_acc, baseline, ctl_reward = train_controller( + valid_loader, network, criterion, a_optimizer, baseline, epoch_str, xargs.print_freq, logger + ) + logger.log( + "[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}".format( + epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward + ) + ) + + genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo) + if xargs.algo == "setn" or xargs.algo == "enas": + network.set_cal_mode("dynamic", genotype) + elif xargs.algo == "gdas": + network.set_cal_mode("gdas", None) + elif xargs.algo.startswith("darts"): + network.set_cal_mode("joint", None) + elif xargs.algo == "random": + network.set_cal_mode("urs", None) + else: + raise ValueError("Invalid algorithm name : {:}".format(xargs.algo)) + logger.log("[{:}] - [get_best_arch] : {:} -> {:}".format(epoch_str, genotype, temp_accuracy)) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger) + logger.log( + "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format( + epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype + ) + ) + valid_accuracies[epoch] = valid_a_top1 + + genotypes[epoch] = genotype + logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(xargs), + "baseline": baseline, + "search_model": search_model.state_dict(), + "w_optimizer": w_optimizer.state_dict(), + "a_optimizer": a_optimizer.state_dict(), + "w_scheduler": w_scheduler.state_dict(), + "genotypes": genotypes, + "valid_accuracies": valid_accuracies, + }, + model_base_path, + logger, + ) + last_info = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + with torch.no_grad(): + logger.log("{:}".format(search_model.show_alphas())) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + # the final post procedure : count the time start_time = time.time() + genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo) + if xargs.algo == "setn" or xargs.algo == "enas": + network.set_cal_mode("dynamic", genotype) + elif xargs.algo == "gdas": + network.set_cal_mode("gdas", None) + elif xargs.algo.startswith("darts"): + network.set_cal_mode("joint", None) + elif xargs.algo == "random": + network.set_cal_mode("urs", None) + else: + raise ValueError("Invalid algorithm name : {:}".format(xargs.algo)) + search_time.update(time.time() - start_time) - # the final post procedure : count the time - start_time = time.time() - genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo) - if xargs.algo == 'setn' or xargs.algo == 'enas': - network.set_cal_mode('dynamic', genotype) - elif xargs.algo == 'gdas': - network.set_cal_mode('gdas', None) - elif xargs.algo.startswith('darts'): - network.set_cal_mode('joint', None) - elif xargs.algo == 'random': - network.set_cal_mode('urs', None) - else: - raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo)) - search_time.update(time.time() - start_time) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger) + logger.log("Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(genotype, valid_a_top1)) - valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger) - logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1)) + logger.log("\n" + "-" * 100) + # check the performance from the architecture dataset + logger.log( + "[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( + xargs.algo, total_epoch, search_time.sum, genotype + ) + ) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotype, "200"))) + logger.close() - logger.log('\n' + '-'*100) - # check the performance from the architecture dataset - logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(xargs.algo, total_epoch, search_time.sum, genotype)) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200') )) - logger.close() - -if __name__ == '__main__': - parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.") - parser.add_argument('--data_path' , type=str, help='Path to dataset') - parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - parser.add_argument('--search_space', type=str, default='tss', choices=['tss'], help='The search space name.') - parser.add_argument('--algo' , type=str, choices=['darts-v1', 'darts-v2', 'gdas', 'setn', 'random', 'enas'], help='The search space name.') - parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).') - # FOR GDAS - parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.') - parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.') - # channels and number-of-cells - parser.add_argument('--max_nodes' , type=int, default=4, help='The maximum number of nodes.') - parser.add_argument('--channel' , type=int, default=16, help='The number of channels.') - parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.') - # - parser.add_argument('--eval_candidate_num', type=int, default=100, help='The number of selected architectures to evaluate.') - # - parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') - parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.') - parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.') - parser.add_argument('--overwite_epochs', type=int, help='The number of epochs to overwrite that value in config files.') - # architecture leraning rate - parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') - parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding') - parser.add_argument('--arch_eps' , type=float, default=1e-8, help='weight decay for arch encoding') - parser.add_argument('--drop_path_rate' , type=float, help='The drop path rate.') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') - parser.add_argument('--print_freq', type=int, default=200, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - if args.overwite_epochs is None: - args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), - args.dataset, - '{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate)) - else: - args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), - args.dataset, - '{:}-affine{:}_BN{:}-E{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.overwite_epochs, args.drop_path_rate)) +if __name__ == "__main__": + parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + parser.add_argument("--search_space", type=str, default="tss", choices=["tss"], help="The search space name.") + parser.add_argument( + "--algo", + type=str, + choices=["darts-v1", "darts-v2", "gdas", "setn", "random", "enas"], + help="The search space name.", + ) + parser.add_argument( + "--use_api", type=int, default=1, choices=[0, 1], help="Whether use API or not (which will cost much memory)." + ) + # FOR GDAS + parser.add_argument("--tau_min", type=float, default=0.1, help="The minimum tau for Gumbel Softmax.") + parser.add_argument("--tau_max", type=float, default=10, help="The maximum tau for Gumbel Softmax.") + # channels and number-of-cells + parser.add_argument("--max_nodes", type=int, default=4, help="The maximum number of nodes.") + parser.add_argument("--channel", type=int, default=16, help="The number of channels.") + parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + # + parser.add_argument( + "--eval_candidate_num", type=int, default=100, help="The number of selected architectures to evaluate." + ) + # + parser.add_argument( + "--track_running_stats", + type=int, + default=0, + choices=[0, 1], + help="Whether use track_running_stats or not in the BN layer.", + ) + parser.add_argument( + "--affine", type=int, default=0, choices=[0, 1], help="Whether use affine=True or False in the BN layer." + ) + parser.add_argument( + "--config_path", + type=str, + default="./configs/nas-benchmark/algos/weight-sharing.config", + help="The path of configuration.", + ) + parser.add_argument( + "--overwite_epochs", type=int, help="The number of epochs to overwrite that value in config files." + ) + # architecture leraning rate + parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") + parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") + parser.add_argument("--arch_eps", type=float, default=1e-8, help="weight decay for arch encoding") + parser.add_argument("--drop_path_rate", type=float, help="The drop path rate.") + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") + parser.add_argument("--print_freq", type=int, default=200, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + if args.overwite_epochs is None: + args.save_dir = os.path.join( + "{:}-{:}".format(args.save_dir, args.search_space), + args.dataset, + "{:}-affine{:}_BN{:}-{:}".format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate), + ) + else: + args.save_dir = os.path.join( + "{:}-{:}".format(args.save_dir, args.search_space), + args.dataset, + "{:}-affine{:}_BN{:}-E{:}-{:}".format( + args.algo, args.affine, args.track_running_stats, args.overwite_epochs, args.drop_path_rate + ), + ) - main(args) + main(args) diff --git a/exps/NATS-algos/search-size.py b/exps/NATS-algos/search-size.py index 0cfba71..1c978cb 100644 --- a/exps/NATS-algos/search-size.py +++ b/exps/NATS-algos/search-size.py @@ -32,294 +32,420 @@ from copy import deepcopy import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import count_parameters_in_MB, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces -from nats_bench import create +from datasets import get_datasets, get_nas_search_loaders +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import count_parameters_in_MB, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_cell_based_tiny_net, get_search_spaces +from nats_bench import create # Ad-hoc for RL algorithms. class ExponentialMovingAverage(object): - """Class that maintains an exponential moving average.""" + """Class that maintains an exponential moving average.""" - def __init__(self, momentum): - self._numerator = 0 - self._denominator = 0 - self._momentum = momentum + def __init__(self, momentum): + self._numerator = 0 + self._denominator = 0 + self._momentum = momentum - def update(self, value): - self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value - self._denominator = self._momentum * self._denominator + (1 - self._momentum) + def update(self, value): + self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value + self._denominator = self._momentum * self._denominator + (1 - self._momentum) + + @property + def value(self): + """Return the current value of the moving average""" + return self._numerator / self._denominator - @property - def value(self): - """Return the current value of the moving average""" - return self._numerator / self._denominator RL_BASELINE_EMA = ExponentialMovingAverage(0.95) -def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, enable_controller, algo, epoch_str, print_freq, logger): - data_time, batch_time = AverageMeter(), AverageMeter() - base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - end = time.time() - network.train() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): - scheduler.update(None, 1.0 * step / len(xloader)) - base_inputs = base_inputs.cuda(non_blocking=True) - arch_inputs = arch_inputs.cuda(non_blocking=True) - base_targets = base_targets.cuda(non_blocking=True) - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - - # Update the weights - network.zero_grad() - _, logits, _ = network(base_inputs) - base_loss = criterion(logits, base_targets) - base_loss.backward() - w_optimizer.step() - # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) - base_losses.update(base_loss.item(), base_inputs.size(0)) - base_top1.update (base_prec1.item(), base_inputs.size(0)) - base_top5.update (base_prec5.item(), base_inputs.size(0)) - - # update the architecture-weight - network.zero_grad() - a_optimizer.zero_grad() - _, logits, log_probs = network(arch_inputs) - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - if algo == 'mask_rl': - with torch.no_grad(): - RL_BASELINE_EMA.update(arch_prec1.item()) - rl_advantage = arch_prec1 - RL_BASELINE_EMA.value - rl_log_prob = sum(log_probs) - arch_loss = - rl_advantage * rl_log_prob - elif algo == 'tas' or algo == 'mask_gumbel': - arch_loss = criterion(logits, arch_targets) - else: - raise ValueError('invalid algorightm name: {:}'.format(algo)) - if enable_controller: - arch_loss.backward() - a_optimizer.step() - # record - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) +def search_func( + xloader, + network, + criterion, + scheduler, + w_optimizer, + a_optimizer, + enable_controller, + algo, + epoch_str, + print_freq, + logger, +): + data_time, batch_time = AverageMeter(), AverageMeter() + base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() end = time.time() + network.train() + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + scheduler.update(None, 1.0 * step / len(xloader)) + base_inputs = base_inputs.cuda(non_blocking=True) + arch_inputs = arch_inputs.cuda(non_blocking=True) + base_targets = base_targets.cuda(non_blocking=True) + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) - if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) - Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) - Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) - Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) - logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) - return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg + # Update the weights + network.zero_grad() + _, logits, _ = network(base_inputs) + base_loss = criterion(logits, base_targets) + base_loss.backward() + w_optimizer.step() + # record + base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_losses.update(base_loss.item(), base_inputs.size(0)) + base_top1.update(base_prec1.item(), base_inputs.size(0)) + base_top5.update(base_prec5.item(), base_inputs.size(0)) + + # update the architecture-weight + network.zero_grad() + a_optimizer.zero_grad() + _, logits, log_probs = network(arch_inputs) + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + if algo == "mask_rl": + with torch.no_grad(): + RL_BASELINE_EMA.update(arch_prec1.item()) + rl_advantage = arch_prec1 - RL_BASELINE_EMA.value + rl_log_prob = sum(log_probs) + arch_loss = -rl_advantage * rl_log_prob + elif algo == "tas" or algo == "mask_gumbel": + arch_loss = criterion(logits, arch_targets) + else: + raise ValueError("invalid algorightm name: {:}".format(algo)) + if enable_controller: + arch_loss.backward() + a_optimizer.step() + # record + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if step % print_freq == 0 or step + 1 == len(xloader): + Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( + batch_time=batch_time, data_time=data_time + ) + Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=base_losses, top1=base_top1, top5=base_top5 + ) + Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=arch_losses, top1=arch_top1, top5=arch_top5 + ) + logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) + return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg def valid_func(xloader, network, criterion, logger): - data_time, batch_time = AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - end = time.time() - with torch.no_grad(): - network.eval() - for step, (arch_inputs, arch_targets) in enumerate(xloader): - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - # prediction - _, logits, _ = network(arch_inputs.cuda(non_blocking=True)) - arch_loss = criterion(logits, arch_targets) - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - return arch_losses.avg, arch_top1.avg, arch_top5.avg + data_time, batch_time = AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() + end = time.time() + with torch.no_grad(): + network.eval() + for step, (arch_inputs, arch_targets) in enumerate(xloader): + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) + # prediction + _, logits, _ = network(arch_inputs.cuda(non_blocking=True)) + arch_loss = criterion(logits, arch_targets) + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + return arch_losses.avg, arch_top1.avg, arch_top5.avg def main(xargs): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - if xargs.overwite_epochs is None: - extra_info = {'class_num': class_num, 'xshape': xshape} - else: - extra_info = {'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs} - config = load_config(xargs.config_path, extra_info, logger) - search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', (config.batch_size, config.test_batch_size), xargs.workers) - logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - - search_space = get_search_spaces(xargs.search_space, 'nats-bench') - - model_config = dict2config( - dict(name='generic', super_type='search-shape', candidate_Cs=search_space['candidates'], max_num_Cs=search_space['numbers'], num_classes=class_num, - genotype=args.genotype, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None) - logger.log('search space : {:}'.format(search_space)) - logger.log('model config : {:}'.format(model_config)) - search_model = get_cell_based_tiny_net(model_config) - search_model.set_algo(xargs.algo) - logger.log('{:}'.format(search_model)) - - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) - a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps) - logger.log('w-optimizer : {:}'.format(w_optimizer)) - logger.log('a-optimizer : {:}'.format(a_optimizer)) - logger.log('w-scheduler : {:}'.format(w_scheduler)) - logger.log('criterion : {:}'.format(criterion)) - params = count_parameters_in_MB(search_model) - logger.log('The parameters of the search model = {:.2f} MB'.format(params)) - logger.log('search-space : {:}'.format(search_space)) - if bool(xargs.use_api): - api = create(None, 'size', fast_mode=True, verbose=False) - else: - api = None - logger.log('{:} create API = {:} done'.format(time_string(), api)) - - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU - - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - - if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) - last_info = torch.load(last_info) - start_epoch = last_info['epoch'] - checkpoint = torch.load(last_info['last_checkpoint']) - genotypes = checkpoint['genotypes'] - valid_accuracies = checkpoint['valid_accuracies'] - search_model.load_state_dict( checkpoint['search_model'] ) - w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) - w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) - a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: network.random} - - # start training - start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup - for epoch in range(start_epoch, total_epoch): - w_scheduler.update(epoch, 0.0) - need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True)) - epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) - - if xargs.warmup_ratio is None or xargs.warmup_ratio <= float(epoch) / total_epoch: - enable_controller = True - network.set_warmup_ratio(None) + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + if xargs.overwite_epochs is None: + extra_info = {"class_num": class_num, "xshape": xshape} else: - enable_controller = False - network.set_warmup_ratio(1.0 - float(epoch) / total_epoch / xargs.warmup_ratio) + extra_info = {"class_num": class_num, "xshape": xshape, "epochs": xargs.overwite_epochs} + config = load_config(xargs.config_path, extra_info, logger) + search_loader, train_loader, valid_loader = get_nas_search_loaders( + train_data, + valid_data, + xargs.dataset, + "configs/nas-benchmark/", + (config.batch_size, config.test_batch_size), + xargs.workers, + ) + logger.log( + "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(search_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller)) + search_space = get_search_spaces(xargs.search_space, "nats-bench") - if xargs.algo == 'mask_gumbel' or xargs.algo == 'tas': - network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1)) - logger.log('[RESET tau as : {:}]'.format(network.tau)) - search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ - = search_func(search_loader, network, criterion, w_scheduler, - w_optimizer, a_optimizer, enable_controller, xargs.algo, epoch_str, xargs.print_freq, logger) - search_time.update(time.time() - start_time) - logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) - logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) + model_config = dict2config( + dict( + name="generic", + super_type="search-shape", + candidate_Cs=search_space["candidates"], + max_num_Cs=search_space["numbers"], + num_classes=class_num, + genotype=args.genotype, + affine=bool(xargs.affine), + track_running_stats=bool(xargs.track_running_stats), + ), + None, + ) + logger.log("search space : {:}".format(search_space)) + logger.log("model config : {:}".format(model_config)) + search_model = get_cell_based_tiny_net(model_config) + search_model.set_algo(xargs.algo) + logger.log("{:}".format(search_model)) - genotype = network.genotype - logger.log('[{:}] - [get_best_arch] : {:}'.format(epoch_str, genotype)) - valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, logger) - logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) - valid_accuracies[epoch] = valid_a_top1 + w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) + a_optimizer = torch.optim.Adam( + search_model.alphas, + lr=xargs.arch_learning_rate, + betas=(0.5, 0.999), + weight_decay=xargs.arch_weight_decay, + eps=xargs.arch_eps, + ) + logger.log("w-optimizer : {:}".format(w_optimizer)) + logger.log("a-optimizer : {:}".format(a_optimizer)) + logger.log("w-scheduler : {:}".format(w_scheduler)) + logger.log("criterion : {:}".format(criterion)) + params = count_parameters_in_MB(search_model) + logger.log("The parameters of the search model = {:.2f} MB".format(params)) + logger.log("search-space : {:}".format(search_space)) + if bool(xargs.use_api): + api = create(None, "size", fast_mode=True, verbose=False) + else: + api = None + logger.log("{:} create API = {:} done".format(time_string(), api)) - genotypes[epoch] = genotype - logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) - # save checkpoint - save_path = save_checkpoint({'epoch' : epoch + 1, - 'args' : deepcopy(xargs), - 'search_model': search_model.state_dict(), - 'w_optimizer' : w_optimizer.state_dict(), - 'a_optimizer' : a_optimizer.state_dict(), - 'w_scheduler' : w_scheduler.state_dict(), - 'genotypes' : genotypes, - 'valid_accuracies' : valid_accuracies}, - model_base_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch + 1, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) - with torch.no_grad(): - logger.log('{:}'.format(search_model.show_alphas())) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '90'))) - # measure elapsed time - epoch_time.update(time.time() - start_time) + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU + + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + + if last_info.exists(): # automatically resume from previous checkpoint + logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + last_info = torch.load(last_info) + start_epoch = last_info["epoch"] + checkpoint = torch.load(last_info["last_checkpoint"]) + genotypes = checkpoint["genotypes"] + valid_accuracies = checkpoint["valid_accuracies"] + search_model.load_state_dict(checkpoint["search_model"]) + w_scheduler.load_state_dict(checkpoint["w_scheduler"]) + w_optimizer.load_state_dict(checkpoint["w_optimizer"]) + a_optimizer.load_state_dict(checkpoint["a_optimizer"]) + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + ) + else: + logger.log("=> do not find the last-info file : {:}".format(last_info)) + start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: network.random} + + # start training + start_time, search_time, epoch_time, total_epoch = ( + time.time(), + AverageMeter(), + AverageMeter(), + config.epochs + config.warmup, + ) + for epoch in range(start_epoch, total_epoch): + w_scheduler.update(epoch, 0.0) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) + + if xargs.warmup_ratio is None or xargs.warmup_ratio <= float(epoch) / total_epoch: + enable_controller = True + network.set_warmup_ratio(None) + else: + enable_controller = False + network.set_warmup_ratio(1.0 - float(epoch) / total_epoch / xargs.warmup_ratio) + + logger.log( + "\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}".format( + epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller + ) + ) + + if xargs.algo == "mask_gumbel" or xargs.algo == "tas": + network.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) + logger.log("[RESET tau as : {:}]".format(network.tau)) + search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 = search_func( + search_loader, + network, + criterion, + w_scheduler, + w_optimizer, + a_optimizer, + enable_controller, + xargs.algo, + epoch_str, + xargs.print_freq, + logger, + ) + search_time.update(time.time() - start_time) + logger.log( + "[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( + epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum + ) + ) + logger.log( + "[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( + epoch_str, search_a_loss, search_a_top1, search_a_top5 + ) + ) + + genotype = network.genotype + logger.log("[{:}] - [get_best_arch] : {:}".format(epoch_str, genotype)) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, logger) + logger.log( + "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format( + epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype + ) + ) + valid_accuracies[epoch] = valid_a_top1 + + genotypes[epoch] = genotype + logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(xargs), + "search_model": search_model.state_dict(), + "w_optimizer": w_optimizer.state_dict(), + "a_optimizer": a_optimizer.state_dict(), + "w_scheduler": w_scheduler.state_dict(), + "genotypes": genotypes, + "valid_accuracies": valid_accuracies, + }, + model_base_path, + logger, + ) + last_info = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + with torch.no_grad(): + logger.log("{:}".format(search_model.show_alphas())) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "90"))) + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + # the final post procedure : count the time start_time = time.time() + genotype = network.genotype + search_time.update(time.time() - start_time) - # the final post procedure : count the time - start_time = time.time() - genotype = network.genotype - search_time.update(time.time() - start_time) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, logger) + logger.log("Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(genotype, valid_a_top1)) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, logger) - logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1)) + logger.log("\n" + "-" * 100) + # check the performance from the architecture dataset + logger.log( + "[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( + xargs.algo, total_epoch, search_time.sum, genotype + ) + ) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotype, "90"))) + logger.close() - logger.log('\n' + '-'*100) - # check the performance from the architecture dataset - logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(xargs.algo, total_epoch, search_time.sum, genotype)) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '90') )) - logger.close() - -if __name__ == '__main__': - parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.") - parser.add_argument('--data_path' , type=str, help='Path to dataset') - parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.') - parser.add_argument('--algo' , type=str, choices=['tas', 'mask_gumbel', 'mask_rl'], help='The search space name.') - parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.') - parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).') - # FOR GDAS - parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.') - parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.') - # FOR ALL - parser.add_argument('--warmup_ratio', type=float, help='The warmup ratio, if None, not use warmup.') - # - parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') - parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.') - parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.') - parser.add_argument('--overwite_epochs', type=int, help='The number of epochs to overwrite that value in config files.') - # architecture leraning rate - parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') - parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding') - parser.add_argument('--arch_eps' , type=float, default=1e-8, help='weight decay for arch encoding') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') - parser.add_argument('--print_freq', type=int, default=200, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - dirname = '{:}-affine{:}_BN{:}-AWD{:}-WARM{:}'.format(args.algo, args.affine, args.track_running_stats, args.arch_weight_decay, args.warmup_ratio) - if args.overwite_epochs is not None: - dirname = dirname + '-E{:}'.format(args.overwite_epochs) - args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, dirname) +if __name__ == "__main__": + parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + parser.add_argument("--search_space", type=str, default="sss", choices=["sss"], help="The search space name.") + parser.add_argument("--algo", type=str, choices=["tas", "mask_gumbel", "mask_rl"], help="The search space name.") + parser.add_argument( + "--genotype", + type=str, + default="|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|", + help="The genotype.", + ) + parser.add_argument( + "--use_api", type=int, default=1, choices=[0, 1], help="Whether use API or not (which will cost much memory)." + ) + # FOR GDAS + parser.add_argument("--tau_min", type=float, default=0.1, help="The minimum tau for Gumbel Softmax.") + parser.add_argument("--tau_max", type=float, default=10, help="The maximum tau for Gumbel Softmax.") + # FOR ALL + parser.add_argument("--warmup_ratio", type=float, help="The warmup ratio, if None, not use warmup.") + # + parser.add_argument( + "--track_running_stats", + type=int, + default=0, + choices=[0, 1], + help="Whether use track_running_stats or not in the BN layer.", + ) + parser.add_argument( + "--affine", type=int, default=0, choices=[0, 1], help="Whether use affine=True or False in the BN layer." + ) + parser.add_argument( + "--config_path", + type=str, + default="./configs/nas-benchmark/algos/weight-sharing.config", + help="The path of configuration.", + ) + parser.add_argument( + "--overwite_epochs", type=int, help="The number of epochs to overwrite that value in config files." + ) + # architecture leraning rate + parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") + parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") + parser.add_argument("--arch_eps", type=float, default=1e-8, help="weight decay for arch encoding") + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") + parser.add_argument("--print_freq", type=int, default=200, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + dirname = "{:}-affine{:}_BN{:}-AWD{:}-WARM{:}".format( + args.algo, args.affine, args.track_running_stats, args.arch_weight_decay, args.warmup_ratio + ) + if args.overwite_epochs is not None: + dirname = dirname + "-E{:}".format(args.overwite_epochs) + args.save_dir = os.path.join("{:}-{:}".format(args.save_dir, args.search_space), args.dataset, dirname) - main(args) + main(args) diff --git a/exps/algos/BOHB.py b/exps/algos/BOHB.py index c6c45f2..9d3f0a7 100644 --- a/exps/algos/BOHB.py +++ b/exps/algos/BOHB.py @@ -11,14 +11,17 @@ import os, sys, time, random, argparse from copy import deepcopy from pathlib import Path import torch -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config -from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger -from log_utils import AverageMeter, time_string, convert_secs2time -from nas_201_api import NASBench201API as API -from models import CellStructure, get_search_spaces +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger +from log_utils import AverageMeter, time_string, convert_secs2time +from nas_201_api import NASBench201API as API +from models import CellStructure, get_search_spaces + # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 import ConfigSpace from hpbandster.optimizers.bohb import BOHB @@ -27,209 +30,258 @@ from hpbandster.core.worker import Worker def get_configuration_space(max_nodes, search_space): - cs = ConfigSpace.ConfigurationSpace() - #edge2index = {} - for i in range(1, max_nodes): - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) - return cs + cs = ConfigSpace.ConfigurationSpace() + # edge2index = {} + for i in range(1, max_nodes): + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) + return cs def config2structure_func(max_nodes): - def config2structure(config): - genotypes = [] - for i in range(1, max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - op_name = config[node_str] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return CellStructure( genotypes ) - return config2structure + def config2structure(config): + genotypes = [] + for i in range(1, max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + op_name = config[node_str] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return CellStructure(genotypes) + + return config2structure class MyWorker(Worker): + def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs): + super().__init__(*args, **kwargs) + self.convert_func = convert_func + self._dataname = dataname + self._nas_bench = nas_bench + self.time_budget = time_budget + self.seen_archs = [] + self.sim_cost_time = 0 + self.real_cost_time = 0 + self.is_end = False - def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs): - super().__init__(*args, **kwargs) - self.convert_func = convert_func - self._dataname = dataname - self._nas_bench = nas_bench - self.time_budget = time_budget - self.seen_archs = [] - self.sim_cost_time = 0 - self.real_cost_time = 0 - self.is_end = False + def get_the_best(self): + assert len(self.seen_archs) > 0 + best_index, best_acc = -1, None + for arch_index in self.seen_archs: + info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp="200", is_random=True) + vacc = info["valid-accuracy"] + if best_acc is None or best_acc < vacc: + best_acc = vacc + best_index = arch_index + assert best_index != -1 + return best_index - def get_the_best(self): - assert len(self.seen_archs) > 0 - best_index, best_acc = -1, None - for arch_index in self.seen_archs: - info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True) - vacc = info['valid-accuracy'] - if best_acc is None or best_acc < vacc: - best_acc = vacc - best_index = arch_index - assert best_index != -1 - return best_index - - def compute(self, config, budget, **kwargs): - start_time = time.time() - structure = self.convert_func( config ) - arch_index = self._nas_bench.query_index_by_arch( structure ) - info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True) - cur_time = info['train-all-time'] + info['valid-per-time'] - cur_vacc = info['valid-accuracy'] - self.real_cost_time += (time.time() - start_time) - if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end: - self.sim_cost_time += cur_time - self.seen_archs.append( arch_index ) - return ({'loss': 100 - float(cur_vacc), - 'info': {'seen-arch' : len(self.seen_archs), - 'sim-test-time' : self.sim_cost_time, - 'current-arch' : arch_index} - }) - else: - self.is_end = True - return ({'loss': 100, - 'info': {'seen-arch' : len(self.seen_archs), - 'sim-test-time' : self.sim_cost_time, - 'current-arch' : None} - }) + def compute(self, config, budget, **kwargs): + start_time = time.time() + structure = self.convert_func(config) + arch_index = self._nas_bench.query_index_by_arch(structure) + info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp="200", is_random=True) + cur_time = info["train-all-time"] + info["valid-per-time"] + cur_vacc = info["valid-accuracy"] + self.real_cost_time += time.time() - start_time + if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end: + self.sim_cost_time += cur_time + self.seen_archs.append(arch_index) + return { + "loss": 100 - float(cur_vacc), + "info": { + "seen-arch": len(self.seen_archs), + "sim-test-time": self.sim_cost_time, + "current-arch": arch_index, + }, + } + else: + self.is_end = True + return { + "loss": 100, + "info": {"seen-arch": len(self.seen_archs), "sim-test-time": self.sim_cost_time, "current-arch": None}, + } def main(xargs, nas_bench): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - if xargs.dataset == 'cifar10': - dataname = 'cifar10-valid' - else: - dataname = xargs.dataset - if xargs.data_path is not None: - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - split_Fpath = 'configs/nas-benchmark/cifar-split.txt' - cifar_split = load_config(split_Fpath, None, None) - train_split, valid_split = cifar_split.train, cifar_split.valid - logger.log('Load split file from {:}'.format(split_Fpath)) - config_path = 'configs/nas-benchmark/algos/R-EA.config' - config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) - # To split data - train_data_v2 = deepcopy(train_data) - train_data_v2.transform = valid_data.transform - valid_data = train_data_v2 - search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) - # data loader - train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) - logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} - else: - config_path = 'configs/nas-benchmark/algos/R-EA.config' - config = load_config(config_path, None, logger) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} + if xargs.dataset == "cifar10": + dataname = "cifar10-valid" + else: + dataname = xargs.dataset + if xargs.data_path is not None: + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + split_Fpath = "configs/nas-benchmark/cifar-split.txt" + cifar_split = load_config(split_Fpath, None, None) + train_split, valid_split = cifar_split.train, cifar_split.valid + logger.log("Load split file from {:}".format(split_Fpath)) + config_path = "configs/nas-benchmark/algos/R-EA.config" + config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger) + # To split data + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) + # data loader + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), + num_workers=xargs.workers, + pin_memory=True, + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), + num_workers=xargs.workers, + pin_memory=True, + ) + logger.log( + "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(train_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) + extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader} + else: + config_path = "configs/nas-benchmark/algos/R-EA.config" + config = load_config(config_path, None, logger) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) + extra_info = {"config": config, "train_loader": None, "valid_loader": None} - # nas dataset load - assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset) - search_space = get_search_spaces('cell', xargs.search_space_name) - cs = get_configuration_space(xargs.max_nodes, search_space) + # nas dataset load + assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset) + search_space = get_search_spaces("cell", xargs.search_space_name) + cs = get_configuration_space(xargs.max_nodes, search_space) - config2structure = config2structure_func(xargs.max_nodes) - hb_run_id = '0' + config2structure = config2structure_func(xargs.max_nodes) + hb_run_id = "0" - NS = hpns.NameServer(run_id=hb_run_id, host='localhost', port=0) - ns_host, ns_port = NS.start() - num_workers = 1 + NS = hpns.NameServer(run_id=hb_run_id, host="localhost", port=0) + ns_host, ns_port = NS.start() + num_workers = 1 - #nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) - #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) - workers = [] - for i in range(num_workers): - w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataname=dataname, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i) - w.run(background=True) - workers.append(w) - - start_time = time.time() - bohb = BOHB(configspace=cs, - run_id=hb_run_id, - eta=3, min_budget=12, max_budget=200, + # nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) + # logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) + workers = [] + for i in range(num_workers): + w = MyWorker( nameserver=ns_host, nameserver_port=ns_port, - num_samples=xargs.num_samples, - random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor, - ping_interval=10, min_bandwidth=xargs.min_bandwidth) - - results = bohb.run(xargs.n_iters, min_n_workers=num_workers) + convert_func=config2structure, + dataname=dataname, + nas_bench=nas_bench, + time_budget=xargs.time_budget, + run_id=hb_run_id, + id=i, + ) + w.run(background=True) + workers.append(w) - bohb.shutdown(shutdown_workers=True) - NS.shutdown() + start_time = time.time() + bohb = BOHB( + configspace=cs, + run_id=hb_run_id, + eta=3, + min_budget=12, + max_budget=200, + nameserver=ns_host, + nameserver_port=ns_port, + num_samples=xargs.num_samples, + random_fraction=xargs.random_fraction, + bandwidth_factor=xargs.bandwidth_factor, + ping_interval=10, + min_bandwidth=xargs.min_bandwidth, + ) - real_cost_time = time.time() - start_time + results = bohb.run(xargs.n_iters, min_n_workers=num_workers) - id2config = results.get_id2config_mapping() - incumbent = results.get_incumbent_id() - logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time)) - best_arch = config2structure( id2config[incumbent]['config'] ) + bohb.shutdown(shutdown_workers=True) + NS.shutdown() - info = nas_bench.query_by_arch(best_arch, '200') - if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) - else : logger.log('{:}'.format(info)) - logger.log('-'*100) + real_cost_time = time.time() - start_time - logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs))) - logger.close() - return logger.log_dir, nas_bench.query_index_by_arch( best_arch ), real_cost_time - + id2config = results.get_id2config_mapping() + incumbent = results.get_incumbent_id() + logger.log("Best found configuration: {:} within {:.3f} s".format(id2config[incumbent]["config"], real_cost_time)) + best_arch = config2structure(id2config[incumbent]["config"]) + + info = nas_bench.query_by_arch(best_arch, "200") + if info is None: + logger.log("Did not find this architecture : {:}.".format(best_arch)) + else: + logger.log("{:}".format(info)) + logger.log("-" * 100) + + logger.log("workers : {:.1f}s with {:} archs".format(workers[0].time_budget, len(workers[0].seen_archs))) + logger.close() + return logger.log_dir, nas_bench.query_index_by_arch(best_arch), real_cost_time -if __name__ == '__main__': - parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale") - parser.add_argument('--data_path', type=str, help='Path to dataset') - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - # channels and number-of-cells - parser.add_argument('--search_space_name', type=str, help='The search space name.') - parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') - parser.add_argument('--channel', type=int, help='The number of channels.') - parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') - parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') - # BOHB - parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') - parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') - parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function') - parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations') - parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth') - parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') - parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') - args = parser.parse_args() - #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): - nas_bench = None - else: - print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) - nas_bench = API(args.arch_nas_dataset) - if args.rand_seed < 0: - save_dir, all_indexes, num, all_times = None, [], 500, [] - for i in range(num): - print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) - args.rand_seed = random.randint(1, 100000) - save_dir, index, ctime = main(args, nas_bench) - all_indexes.append( index ) - all_times.append( ctime ) - print ('\n average time : {:.3f} s'.format(sum(all_times)/len(all_times))) - torch.save(all_indexes, save_dir / 'results.pth') - else: - main(args, nas_bench) +if __name__ == "__main__": + parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + # channels and number-of-cells + parser.add_argument("--search_space_name", type=str, help="The search space name.") + parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") + parser.add_argument("--channel", type=int, help="The number of channels.") + parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).") + # BOHB + parser.add_argument( + "--strategy", default="sampling", type=str, nargs="?", help="optimization strategy for the acquisition function" + ) + parser.add_argument("--min_bandwidth", default=0.3, type=float, nargs="?", help="minimum bandwidth for KDE") + parser.add_argument( + "--num_samples", default=64, type=int, nargs="?", help="number of samples for the acquisition function" + ) + parser.add_argument( + "--random_fraction", default=0.33, type=float, nargs="?", help="fraction of random configurations" + ) + parser.add_argument("--bandwidth_factor", default=3, type=int, nargs="?", help="factor multiplied to the bandwidth") + parser.add_argument( + "--n_iters", default=100, type=int, nargs="?", help="number of iterations for optimization method" + ) + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument( + "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + ) + parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, help="manual seed") + args = parser.parse_args() + # if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): + nas_bench = None + else: + print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset)) + nas_bench = API(args.arch_nas_dataset) + if args.rand_seed < 0: + save_dir, all_indexes, num, all_times = None, [], 500, [] + for i in range(num): + print("{:} : {:03d}/{:03d}".format(time_string(), i, num)) + args.rand_seed = random.randint(1, 100000) + save_dir, index, ctime = main(args, nas_bench) + all_indexes.append(index) + all_times.append(ctime) + print("\n average time : {:.3f} s".format(sum(all_times) / len(all_times))) + torch.save(all_indexes, save_dir / "results.pth") + else: + main(args, nas_bench) diff --git a/exps/algos/DARTS-V1.py b/exps/algos/DARTS-V1.py index 705bc0d..44df9ce 100644 --- a/exps/algos/DARTS-V1.py +++ b/exps/algos/DARTS-V1.py @@ -7,232 +7,330 @@ import sys, time, random, argparse from copy import deepcopy import torch from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces -from nas_201_api import NASBench201API as API +from datasets import get_datasets, get_nas_search_loaders +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_cell_based_tiny_net, get_search_spaces +from nas_201_api import NASBench201API as API -def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger, gradient_clip): - data_time, batch_time = AverageMeter(), AverageMeter() - base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - network.train() - end = time.time() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): - scheduler.update(None, 1.0 * step / len(xloader)) - base_targets = base_targets.cuda(non_blocking=True) - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - - # update the weights - w_optimizer.zero_grad() - _, logits = network(base_inputs) - base_loss = criterion(logits, base_targets) - base_loss.backward() - if gradient_clip > 0: torch.nn.utils.clip_grad_norm_(network.parameters(), gradient_clip) - w_optimizer.step() - # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) - base_losses.update(base_loss.item(), base_inputs.size(0)) - base_top1.update (base_prec1.item(), base_inputs.size(0)) - base_top5.update (base_prec5.item(), base_inputs.size(0)) - - # update the architecture-weight - a_optimizer.zero_grad() - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - arch_loss.backward() - a_optimizer.step() - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) +def search_func( + xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger, gradient_clip +): + data_time, batch_time = AverageMeter(), AverageMeter() + base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() + network.train() end = time.time() + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + scheduler.update(None, 1.0 * step / len(xloader)) + base_targets = base_targets.cuda(non_blocking=True) + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) - if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) - Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) - Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) - Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) - logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) - return base_losses.avg, base_top1.avg, base_top5.avg + # update the weights + w_optimizer.zero_grad() + _, logits = network(base_inputs) + base_loss = criterion(logits, base_targets) + base_loss.backward() + if gradient_clip > 0: + torch.nn.utils.clip_grad_norm_(network.parameters(), gradient_clip) + w_optimizer.step() + # record + base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_losses.update(base_loss.item(), base_inputs.size(0)) + base_top1.update(base_prec1.item(), base_inputs.size(0)) + base_top5.update(base_prec5.item(), base_inputs.size(0)) + + # update the architecture-weight + a_optimizer.zero_grad() + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + arch_loss.backward() + a_optimizer.step() + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if step % print_freq == 0 or step + 1 == len(xloader): + Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( + batch_time=batch_time, data_time=data_time + ) + Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=base_losses, top1=base_top1, top5=base_top5 + ) + Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=arch_losses, top1=arch_top1, top5=arch_top5 + ) + logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) + return base_losses.avg, base_top1.avg, base_top5.avg def valid_func(xloader, network, criterion): - data_time, batch_time = AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - network.eval() - end = time.time() - with torch.no_grad(): - for step, (arch_inputs, arch_targets) in enumerate(xloader): - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - # prediction - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - return arch_losses.avg, arch_top1.avg, arch_top5.avg + data_time, batch_time = AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() + network.eval() + end = time.time() + with torch.no_grad(): + for step, (arch_inputs, arch_targets) in enumerate(xloader): + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) + # prediction + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + return arch_losses.avg, arch_top1.avg, arch_top5.avg def main(xargs): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - #config_path = 'configs/nas-benchmark/algos/DARTS.config' - config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) - search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) - logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + # config_path = 'configs/nas-benchmark/algos/DARTS.config' + config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + search_loader, _, valid_loader = get_nas_search_loaders( + train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers + ) + logger.log( + "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(search_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - search_space = get_search_spaces('cell', xargs.search_space_name) - if xargs.model_config is None: - model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells, - 'max_nodes': xargs.max_nodes, 'num_classes': class_num, - 'space' : search_space, - 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) - else: - model_config = load_config(xargs.model_config, {'num_classes': class_num, 'space' : search_space, - 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) - search_model = get_cell_based_tiny_net(model_config) - logger.log('search-model :\n{:}'.format(search_model)) - - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) - a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) - logger.log('w-optimizer : {:}'.format(w_optimizer)) - logger.log('a-optimizer : {:}'.format(a_optimizer)) - logger.log('w-scheduler : {:}'.format(w_scheduler)) - logger.log('criterion : {:}'.format(criterion)) - flop, param = get_model_infos(search_model, xshape) - #logger.log('{:}'.format(search_model)) - logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) - if xargs.arch_nas_dataset is None: - api = None - else: - api = API(xargs.arch_nas_dataset) - logger.log('{:} create API = {:} done'.format(time_string(), api)) + search_space = get_search_spaces("cell", xargs.search_space_name) + if xargs.model_config is None: + model_config = dict2config( + { + "name": "DARTS-V1", + "C": xargs.channel, + "N": xargs.num_cells, + "max_nodes": xargs.max_nodes, + "num_classes": class_num, + "space": search_space, + "affine": False, + "track_running_stats": bool(xargs.track_running_stats), + }, + None, + ) + else: + model_config = load_config( + xargs.model_config, + { + "num_classes": class_num, + "space": search_space, + "affine": False, + "track_running_stats": bool(xargs.track_running_stats), + }, + None, + ) + search_model = get_cell_based_tiny_net(model_config) + logger.log("search-model :\n{:}".format(search_model)) - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() + w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) + a_optimizer = torch.optim.Adam( + search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay + ) + logger.log("w-optimizer : {:}".format(w_optimizer)) + logger.log("a-optimizer : {:}".format(a_optimizer)) + logger.log("w-scheduler : {:}".format(w_scheduler)) + logger.log("criterion : {:}".format(criterion)) + flop, param = get_model_infos(search_model, xshape) + # logger.log('{:}'.format(search_model)) + logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param)) + if xargs.arch_nas_dataset is None: + api = None + else: + api = API(xargs.arch_nas_dataset) + logger.log("{:} create API = {:} done".format(time_string(), api)) - if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) - last_info = torch.load(last_info) - start_epoch = last_info['epoch'] - checkpoint = torch.load(last_info['last_checkpoint']) - genotypes = checkpoint['genotypes'] - valid_accuracies = checkpoint['valid_accuracies'] - search_model.load_state_dict( checkpoint['search_model'] ) - w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) - w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) - a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()} + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() - # start training - start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup - for epoch in range(start_epoch, total_epoch): - w_scheduler.update(epoch, 0.0) - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) - epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) - logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) + if last_info.exists(): # automatically resume from previous checkpoint + logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + last_info = torch.load(last_info) + start_epoch = last_info["epoch"] + checkpoint = torch.load(last_info["last_checkpoint"]) + genotypes = checkpoint["genotypes"] + valid_accuracies = checkpoint["valid_accuracies"] + search_model.load_state_dict(checkpoint["search_model"]) + w_scheduler.load_state_dict(checkpoint["w_scheduler"]) + w_optimizer.load_state_dict(checkpoint["w_optimizer"]) + a_optimizer.load_state_dict(checkpoint["a_optimizer"]) + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + ) + else: + logger.log("=> do not find the last-info file : {:}".format(last_info)) + start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: search_model.genotype()} - search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, xargs.gradient_clip) - search_time.update(time.time() - start_time) - logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) - valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) - logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) - # check the best accuracy - valid_accuracies[epoch] = valid_a_top1 - if valid_a_top1 > valid_accuracies['best']: - valid_accuracies['best'] = valid_a_top1 - genotypes['best'] = search_model.genotype() - find_best = True - else: find_best = False + # start training + start_time, search_time, epoch_time, total_epoch = ( + time.time(), + AverageMeter(), + AverageMeter(), + config.epochs + config.warmup, + ) + for epoch in range(start_epoch, total_epoch): + w_scheduler.update(epoch, 0.0) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) + logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min(w_scheduler.get_lr()))) - genotypes[epoch] = search_model.genotype() - logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) - # save checkpoint - save_path = save_checkpoint({'epoch' : epoch + 1, - 'args' : deepcopy(xargs), - 'search_model': search_model.state_dict(), - 'w_optimizer' : w_optimizer.state_dict(), - 'a_optimizer' : a_optimizer.state_dict(), - 'w_scheduler' : w_scheduler.state_dict(), - 'genotypes' : genotypes, - 'valid_accuracies' : valid_accuracies}, - model_base_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch + 1, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) - if find_best: - logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) - copy_checkpoint(model_base_path, model_best_path, logger) - with torch.no_grad(): - #logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) - logger.log('{:}'.format(search_model.show_alphas())) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) - # measure elapsed time - epoch_time.update(time.time() - start_time) - start_time = time.time() + search_w_loss, search_w_top1, search_w_top5 = search_func( + search_loader, + network, + criterion, + w_scheduler, + w_optimizer, + a_optimizer, + epoch_str, + xargs.print_freq, + logger, + xargs.gradient_clip, + ) + search_time.update(time.time() - start_time) + logger.log( + "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( + epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum + ) + ) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) + logger.log( + "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( + epoch_str, valid_a_loss, valid_a_top1, valid_a_top5 + ) + ) + # check the best accuracy + valid_accuracies[epoch] = valid_a_top1 + if valid_a_top1 > valid_accuracies["best"]: + valid_accuracies["best"] = valid_a_top1 + genotypes["best"] = search_model.genotype() + find_best = True + else: + find_best = False - logger.log('\n' + '-'*100) - logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200'))) - logger.close() - + genotypes[epoch] = search_model.genotype() + logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(xargs), + "search_model": search_model.state_dict(), + "w_optimizer": w_optimizer.state_dict(), + "a_optimizer": a_optimizer.state_dict(), + "w_scheduler": w_scheduler.state_dict(), + "genotypes": genotypes, + "valid_accuracies": valid_accuracies, + }, + model_base_path, + logger, + ) + last_info = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + if find_best: + logger.log( + "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format( + epoch_str, valid_a_top1 + ) + ) + copy_checkpoint(model_base_path, model_best_path, logger) + with torch.no_grad(): + # logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) + logger.log("{:}".format(search_model.show_alphas())) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("\n" + "-" * 100) + logger.log( + "DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( + total_epoch, search_time.sum, genotypes[total_epoch - 1] + ) + ) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1], "200"))) + logger.close() -if __name__ == '__main__': - parser = argparse.ArgumentParser("DARTS first order") - parser.add_argument('--data_path', type=str, help='Path to dataset') - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - # channels and number-of-cells - parser.add_argument('--search_space_name', type=str, help='The search space name.') - parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') - parser.add_argument('--channel', type=int, help='The number of channels.') - parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') - parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') - parser.add_argument('--config_path', type=str, help='The config path.') - parser.add_argument('--model_config', type=str, help='The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.') - parser.add_argument('--gradient_clip', type=float, default=5, help='') - # architecture leraning rate - parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') - parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (nas-benchmark).') - parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - main(args) +if __name__ == "__main__": + parser = argparse.ArgumentParser("DARTS first order") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + # channels and number-of-cells + parser.add_argument("--search_space_name", type=str, help="The search space name.") + parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") + parser.add_argument("--channel", type=int, help="The number of channels.") + parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument( + "--track_running_stats", + type=int, + choices=[0, 1], + help="Whether use track_running_stats or not in the BN layer.", + ) + parser.add_argument("--config_path", type=str, help="The config path.") + parser.add_argument( + "--model_config", + type=str, + help="The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.", + ) + parser.add_argument("--gradient_clip", type=float, default=5, help="") + # architecture leraning rate + parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") + parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument( + "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (nas-benchmark)." + ) + parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + main(args) diff --git a/exps/algos/DARTS-V2.py b/exps/algos/DARTS-V2.py index 798ad85..802bbdd 100644 --- a/exps/algos/DARTS-V2.py +++ b/exps/algos/DARTS-V2.py @@ -9,290 +9,381 @@ from copy import deepcopy import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces -from nas_201_api import NASBench201API as API +from datasets import get_datasets, get_nas_search_loaders +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_cell_based_tiny_net, get_search_spaces +from nas_201_api import NASBench201API as API def _concat(xs): - return torch.cat([x.view(-1) for x in xs]) + return torch.cat([x.view(-1) for x in xs]) def _hessian_vector_product(vector, network, criterion, base_inputs, base_targets, r=1e-2): - R = r / _concat(vector).norm() - for p, v in zip(network.module.get_weights(), vector): - p.data.add_(R, v) - _, logits = network(base_inputs) - loss = criterion(logits, base_targets) - grads_p = torch.autograd.grad(loss, network.module.get_alphas()) + R = r / _concat(vector).norm() + for p, v in zip(network.module.get_weights(), vector): + p.data.add_(R, v) + _, logits = network(base_inputs) + loss = criterion(logits, base_targets) + grads_p = torch.autograd.grad(loss, network.module.get_alphas()) - for p, v in zip(network.module.get_weights(), vector): - p.data.sub_(2*R, v) - _, logits = network(base_inputs) - loss = criterion(logits, base_targets) - grads_n = torch.autograd.grad(loss, network.module.get_alphas()) + for p, v in zip(network.module.get_weights(), vector): + p.data.sub_(2 * R, v) + _, logits = network(base_inputs) + loss = criterion(logits, base_targets) + grads_n = torch.autograd.grad(loss, network.module.get_alphas()) - for p, v in zip(network.module.get_weights(), vector): - p.data.add_(R, v) - return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)] + for p, v in zip(network.module.get_weights(), vector): + p.data.add_(R, v) + return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)] def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets): - # _compute_unrolled_model - _, logits = network(base_inputs) - loss = criterion(logits, base_targets) - LR, WD, momentum = w_optimizer.param_groups[0]['lr'], w_optimizer.param_groups[0]['weight_decay'], w_optimizer.param_groups[0]['momentum'] - with torch.no_grad(): - theta = _concat(network.module.get_weights()) - try: - moment = _concat(w_optimizer.state[v]['momentum_buffer'] for v in network.module.get_weights()) - moment = moment.mul_(momentum) - except: - moment = torch.zeros_like(theta) - dtheta = _concat(torch.autograd.grad(loss, network.module.get_weights())) + WD*theta - params = theta.sub(LR, moment+dtheta) - unrolled_model = deepcopy(network) - model_dict = unrolled_model.state_dict() - new_params, offset = {}, 0 - for k, v in network.named_parameters(): - if 'arch_parameters' in k: continue - v_length = np.prod(v.size()) - new_params[k] = params[offset: offset+v_length].view(v.size()) - offset += v_length - model_dict.update(new_params) - unrolled_model.load_state_dict(model_dict) + # _compute_unrolled_model + _, logits = network(base_inputs) + loss = criterion(logits, base_targets) + LR, WD, momentum = ( + w_optimizer.param_groups[0]["lr"], + w_optimizer.param_groups[0]["weight_decay"], + w_optimizer.param_groups[0]["momentum"], + ) + with torch.no_grad(): + theta = _concat(network.module.get_weights()) + try: + moment = _concat(w_optimizer.state[v]["momentum_buffer"] for v in network.module.get_weights()) + moment = moment.mul_(momentum) + except: + moment = torch.zeros_like(theta) + dtheta = _concat(torch.autograd.grad(loss, network.module.get_weights())) + WD * theta + params = theta.sub(LR, moment + dtheta) + unrolled_model = deepcopy(network) + model_dict = unrolled_model.state_dict() + new_params, offset = {}, 0 + for k, v in network.named_parameters(): + if "arch_parameters" in k: + continue + v_length = np.prod(v.size()) + new_params[k] = params[offset : offset + v_length].view(v.size()) + offset += v_length + model_dict.update(new_params) + unrolled_model.load_state_dict(model_dict) - unrolled_model.zero_grad() - _, unrolled_logits = unrolled_model(arch_inputs) - unrolled_loss = criterion(unrolled_logits, arch_targets) - unrolled_loss.backward() + unrolled_model.zero_grad() + _, unrolled_logits = unrolled_model(arch_inputs) + unrolled_loss = criterion(unrolled_logits, arch_targets) + unrolled_loss.backward() - dalpha = unrolled_model.module.arch_parameters.grad - vector = [v.grad.data for v in unrolled_model.module.get_weights()] - [implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets) - - dalpha.data.sub_(LR, implicit_grads.data) + dalpha = unrolled_model.module.arch_parameters.grad + vector = [v.grad.data for v in unrolled_model.module.get_weights()] + [implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets) + + dalpha.data.sub_(LR, implicit_grads.data) + + if network.module.arch_parameters.grad is None: + network.module.arch_parameters.grad = deepcopy(dalpha) + else: + network.module.arch_parameters.grad.data.copy_(dalpha.data) + return unrolled_loss.detach(), unrolled_logits.detach() - if network.module.arch_parameters.grad is None: - network.module.arch_parameters.grad = deepcopy( dalpha ) - else: - network.module.arch_parameters.grad.data.copy_( dalpha.data ) - return unrolled_loss.detach(), unrolled_logits.detach() - def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): - data_time, batch_time = AverageMeter(), AverageMeter() - base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - network.train() - end = time.time() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): - scheduler.update(None, 1.0 * step / len(xloader)) - base_targets = base_targets.cuda(non_blocking=True) - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - - # update the architecture-weight - a_optimizer.zero_grad() - arch_loss, arch_logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets) - a_optimizer.step() - # record - arch_prec1, arch_prec5 = obtain_accuracy(arch_logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - - # update the weights - w_optimizer.zero_grad() - _, logits = network(base_inputs) - base_loss = criterion(logits, base_targets) - base_loss.backward() - torch.nn.utils.clip_grad_norm_(network.parameters(), 5) - w_optimizer.step() - # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) - base_losses.update(base_loss.item(), base_inputs.size(0)) - base_top1.update (base_prec1.item(), base_inputs.size(0)) - base_top5.update (base_prec5.item(), base_inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) + data_time, batch_time = AverageMeter(), AverageMeter() + base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() + network.train() end = time.time() + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + scheduler.update(None, 1.0 * step / len(xloader)) + base_targets = base_targets.cuda(non_blocking=True) + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) - if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) - Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) - Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) - Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) - logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) - return base_losses.avg, base_top1.avg, base_top5.avg + # update the architecture-weight + a_optimizer.zero_grad() + arch_loss, arch_logits = backward_step_unrolled( + network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets + ) + a_optimizer.step() + # record + arch_prec1, arch_prec5 = obtain_accuracy(arch_logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + + # update the weights + w_optimizer.zero_grad() + _, logits = network(base_inputs) + base_loss = criterion(logits, base_targets) + base_loss.backward() + torch.nn.utils.clip_grad_norm_(network.parameters(), 5) + w_optimizer.step() + # record + base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_losses.update(base_loss.item(), base_inputs.size(0)) + base_top1.update(base_prec1.item(), base_inputs.size(0)) + base_top5.update(base_prec5.item(), base_inputs.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if step % print_freq == 0 or step + 1 == len(xloader): + Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( + batch_time=batch_time, data_time=data_time + ) + Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=base_losses, top1=base_top1, top5=base_top5 + ) + Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=arch_losses, top1=arch_top1, top5=arch_top5 + ) + logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) + return base_losses.avg, base_top1.avg, base_top5.avg def valid_func(xloader, network, criterion): - data_time, batch_time = AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - network.eval() - end = time.time() - with torch.no_grad(): - for step, (arch_inputs, arch_targets) in enumerate(xloader): - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - # prediction - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - return arch_losses.avg, arch_top1.avg, arch_top5.avg + data_time, batch_time = AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() + network.eval() + end = time.time() + with torch.no_grad(): + for step, (arch_inputs, arch_targets) in enumerate(xloader): + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) + # prediction + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + return arch_losses.avg, arch_top1.avg, arch_top5.avg def main(xargs): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) - search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) - logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + search_loader, _, valid_loader = get_nas_search_loaders( + train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers + ) + logger.log( + "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(search_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - search_space = get_search_spaces('cell', xargs.search_space_name) - model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, - 'max_nodes': xargs.max_nodes, 'num_classes': class_num, - 'space' : search_space, - 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) - search_model = get_cell_based_tiny_net(model_config) - logger.log('search-model :\n{:}'.format(search_model)) - - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) - a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) - logger.log('w-optimizer : {:}'.format(w_optimizer)) - logger.log('a-optimizer : {:}'.format(a_optimizer)) - logger.log('w-scheduler : {:}'.format(w_scheduler)) - logger.log('criterion : {:}'.format(criterion)) - flop, param = get_model_infos(search_model, xshape) - #logger.log('{:}'.format(search_model)) - logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) - if xargs.arch_nas_dataset is None: - api = None - else: - api = API(xargs.arch_nas_dataset) - logger.log('{:} create API = {:} done'.format(time_string(), api)) + search_space = get_search_spaces("cell", xargs.search_space_name) + model_config = dict2config( + { + "name": "DARTS-V2", + "C": xargs.channel, + "N": xargs.num_cells, + "max_nodes": xargs.max_nodes, + "num_classes": class_num, + "space": search_space, + "affine": False, + "track_running_stats": bool(xargs.track_running_stats), + }, + None, + ) + search_model = get_cell_based_tiny_net(model_config) + logger.log("search-model :\n{:}".format(search_model)) - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() + w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) + a_optimizer = torch.optim.Adam( + search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay + ) + logger.log("w-optimizer : {:}".format(w_optimizer)) + logger.log("a-optimizer : {:}".format(a_optimizer)) + logger.log("w-scheduler : {:}".format(w_scheduler)) + logger.log("criterion : {:}".format(criterion)) + flop, param = get_model_infos(search_model, xshape) + # logger.log('{:}'.format(search_model)) + logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param)) + if xargs.arch_nas_dataset is None: + api = None + else: + api = API(xargs.arch_nas_dataset) + logger.log("{:} create API = {:} done".format(time_string(), api)) - if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) - last_info = torch.load(last_info) - start_epoch = last_info['epoch'] - checkpoint = torch.load(last_info['last_checkpoint']) - genotypes = checkpoint['genotypes'] - valid_accuracies = checkpoint['valid_accuracies'] - search_model.load_state_dict( checkpoint['search_model'] ) - w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) - w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) - a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()} + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() - # start training - start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup - for epoch in range(start_epoch, total_epoch): - w_scheduler.update(epoch, 0.0) - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) - epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) - min_LR = min(w_scheduler.get_lr()) - logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR)) + if last_info.exists(): # automatically resume from previous checkpoint + logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + last_info = torch.load(last_info) + start_epoch = last_info["epoch"] + checkpoint = torch.load(last_info["last_checkpoint"]) + genotypes = checkpoint["genotypes"] + valid_accuracies = checkpoint["valid_accuracies"] + search_model.load_state_dict(checkpoint["search_model"]) + w_scheduler.load_state_dict(checkpoint["w_scheduler"]) + w_optimizer.load_state_dict(checkpoint["w_optimizer"]) + a_optimizer.load_state_dict(checkpoint["a_optimizer"]) + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + ) + else: + logger.log("=> do not find the last-info file : {:}".format(last_info)) + start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: search_model.genotype()} - search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) - search_time.update(time.time() - start_time) - logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) - valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) - logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) - # check the best accuracy - valid_accuracies[epoch] = valid_a_top1 - if valid_a_top1 > valid_accuracies['best']: - valid_accuracies['best'] = valid_a_top1 - genotypes['best'] = search_model.genotype() - find_best = True - else: find_best = False + # start training + start_time, search_time, epoch_time, total_epoch = ( + time.time(), + AverageMeter(), + AverageMeter(), + config.epochs + config.warmup, + ) + for epoch in range(start_epoch, total_epoch): + w_scheduler.update(epoch, 0.0) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) + min_LR = min(w_scheduler.get_lr()) + logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min_LR)) - genotypes[epoch] = search_model.genotype() - logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) - # save checkpoint - save_path = save_checkpoint({'epoch' : epoch + 1, - 'args' : deepcopy(xargs), - 'search_model': search_model.state_dict(), - 'w_optimizer' : w_optimizer.state_dict(), - 'a_optimizer' : a_optimizer.state_dict(), - 'w_scheduler' : w_scheduler.state_dict(), - 'genotypes' : genotypes, - 'valid_accuracies' : valid_accuracies}, - model_base_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch + 1, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) - if find_best: - logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) - copy_checkpoint(model_base_path, model_best_path, logger) - with torch.no_grad(): - logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) - # measure elapsed time - epoch_time.update(time.time() - start_time) - start_time = time.time() + search_w_loss, search_w_top1, search_w_top5 = search_func( + search_loader, + network, + criterion, + w_scheduler, + w_optimizer, + a_optimizer, + epoch_str, + xargs.print_freq, + logger, + ) + search_time.update(time.time() - start_time) + logger.log( + "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( + epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum + ) + ) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) + logger.log( + "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( + epoch_str, valid_a_loss, valid_a_top1, valid_a_top5 + ) + ) + # check the best accuracy + valid_accuracies[epoch] = valid_a_top1 + if valid_a_top1 > valid_accuracies["best"]: + valid_accuracies["best"] = valid_a_top1 + genotypes["best"] = search_model.genotype() + find_best = True + else: + find_best = False - logger.log('\n' + '-'*100) - # check the performance from the architecture dataset - logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1]), '200')) - logger.close() - + genotypes[epoch] = search_model.genotype() + logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(xargs), + "search_model": search_model.state_dict(), + "w_optimizer": w_optimizer.state_dict(), + "a_optimizer": a_optimizer.state_dict(), + "w_scheduler": w_scheduler.state_dict(), + "genotypes": genotypes, + "valid_accuracies": valid_accuracies, + }, + model_base_path, + logger, + ) + last_info = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + if find_best: + logger.log( + "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format( + epoch_str, valid_a_top1 + ) + ) + copy_checkpoint(model_base_path, model_best_path, logger) + with torch.no_grad(): + logger.log( + "arch-parameters :\n{:}".format(nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu()) + ) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("\n" + "-" * 100) + # check the performance from the architecture dataset + logger.log( + "DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( + total_epoch, search_time.sum, genotypes[total_epoch - 1] + ) + ) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1]), "200")) + logger.close() -if __name__ == '__main__': - parser = argparse.ArgumentParser("DARTS Second Order") - parser.add_argument('--data_path', type=str, help='Path to dataset') - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - # channels and number-of-cells - parser.add_argument('--config_path', type=str, help='The config path.') - parser.add_argument('--search_space_name', type=str, help='The search space name.') - parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') - parser.add_argument('--channel', type=int, help='The number of channels.') - parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') - parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') - # architecture leraning rate - parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') - parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') - parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - main(args) +if __name__ == "__main__": + parser = argparse.ArgumentParser("DARTS Second Order") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + # channels and number-of-cells + parser.add_argument("--config_path", type=str, help="The config path.") + parser.add_argument("--search_space_name", type=str, help="The search space name.") + parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") + parser.add_argument("--channel", type=int, help="The number of channels.") + parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument( + "--track_running_stats", + type=int, + choices=[0, 1], + help="Whether use track_running_stats or not in the BN layer.", + ) + # architecture leraning rate + parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") + parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument( + "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + ) + parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + main(args) diff --git a/exps/algos/ENAS.py b/exps/algos/ENAS.py index 0af87ce..2a850bc 100644 --- a/exps/algos/ENAS.py +++ b/exps/algos/ENAS.py @@ -9,336 +9,447 @@ from copy import deepcopy import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces -from nas_201_api import NASBench201API as API +from datasets import get_datasets, get_nas_search_loaders +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_cell_based_tiny_net, get_search_spaces +from nas_201_api import NASBench201API as API def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger): - data_time, batch_time = AverageMeter(), AverageMeter() - losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter(), time.time() - - shared_cnn.train() - controller.eval() + data_time, batch_time = AverageMeter(), AverageMeter() + losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter(), time.time() - for step, (inputs, targets) in enumerate(xloader): - scheduler.update(None, 1.0 * step / len(xloader)) - targets = targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - xend) - - with torch.no_grad(): - _, _, sampled_arch = controller() + shared_cnn.train() + controller.eval() - optimizer.zero_grad() - shared_cnn.module.update_arch(sampled_arch) - _, logits = shared_cnn(inputs) - loss = criterion(logits, targets) - loss.backward() - torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5) - optimizer.step() - # record - prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) - losses.update(loss.item(), inputs.size(0)) - top1s.update (prec1.item(), inputs.size(0)) - top5s.update (prec5.item(), inputs.size(0)) + for step, (inputs, targets) in enumerate(xloader): + scheduler.update(None, 1.0 * step / len(xloader)) + targets = targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - xend) - # measure elapsed time - batch_time.update(time.time() - xend) - xend = time.time() + with torch.no_grad(): + _, _, sampled_arch = controller() - if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = '*Train-Shared-CNN* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) - Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) - Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=losses, top1=top1s, top5=top5s) - logger.log(Sstr + ' ' + Tstr + ' ' + Wstr) - return losses.avg, top1s.avg, top5s.avg + optimizer.zero_grad() + shared_cnn.module.update_arch(sampled_arch) + _, logits = shared_cnn(inputs) + loss = criterion(logits, targets) + loss.backward() + torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5) + optimizer.step() + # record + prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) + losses.update(loss.item(), inputs.size(0)) + top1s.update(prec1.item(), inputs.size(0)) + top5s.update(prec5.item(), inputs.size(0)) + + # measure elapsed time + batch_time.update(time.time() - xend) + xend = time.time() + + if step % print_freq == 0 or step + 1 == len(xloader): + Sstr = "*Train-Shared-CNN* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( + batch_time=batch_time, data_time=data_time + ) + Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=losses, top1=top1s, top5=top5s + ) + logger.log(Sstr + " " + Tstr + " " + Wstr) + return losses.avg, top1s.avg, top5s.avg def train_controller(xloader, shared_cnn, controller, criterion, optimizer, config, epoch_str, print_freq, logger): - # config. (containing some necessary arg) - # baseline: The baseline score (i.e. average val_acc) from the previous epoch - data_time, batch_time = AverageMeter(), AverageMeter() - GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time() - - shared_cnn.eval() - controller.train() - controller.zero_grad() - #for step, (inputs, targets) in enumerate(xloader): - loader_iter = iter(xloader) - for step in range(config.ctl_train_steps * config.ctl_num_aggre): - try: - inputs, targets = next(loader_iter) - except: - loader_iter = iter(xloader) - inputs, targets = next(loader_iter) - targets = targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - xend) - - log_prob, entropy, sampled_arch = controller() - with torch.no_grad(): - shared_cnn.module.update_arch(sampled_arch) - _, logits = shared_cnn(inputs) - val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) - val_top1 = val_top1.view(-1) / 100 - reward = val_top1 + config.ctl_entropy_w * entropy - if config.baseline is None: - baseline = val_top1 - else: - baseline = config.baseline - (1 - config.ctl_bl_dec) * (config.baseline - reward) - - loss = -1 * log_prob * (reward - baseline) - - # account - RewardMeter.update(reward.item()) - BaselineMeter.update(baseline.item()) - ValAccMeter.update(val_top1.item()*100) - LossMeter.update(loss.item()) - EntropyMeter.update(entropy.item()) - - # Average gradient over controller_num_aggregate samples - loss = loss / config.ctl_num_aggre - loss.backward(retain_graph=True) + # config. (containing some necessary arg) + # baseline: The baseline score (i.e. average val_acc) from the previous epoch + data_time, batch_time = AverageMeter(), AverageMeter() + GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = ( + AverageMeter(), + AverageMeter(), + AverageMeter(), + AverageMeter(), + AverageMeter(), + AverageMeter(), + time.time(), + ) - # measure elapsed time - batch_time.update(time.time() - xend) - xend = time.time() - if (step+1) % config.ctl_num_aggre == 0: - grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), 5.0) - GradnormMeter.update(grad_norm) - optimizer.step() - controller.zero_grad() + shared_cnn.eval() + controller.train() + controller.zero_grad() + # for step, (inputs, targets) in enumerate(xloader): + loader_iter = iter(xloader) + for step in range(config.ctl_train_steps * config.ctl_num_aggre): + try: + inputs, targets = next(loader_iter) + except: + loader_iter = iter(xloader) + inputs, targets = next(loader_iter) + targets = targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - xend) - if step % print_freq == 0: - Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre) - Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) - Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter) - Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg) - logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr) + log_prob, entropy, sampled_arch = controller() + with torch.no_grad(): + shared_cnn.module.update_arch(sampled_arch) + _, logits = shared_cnn(inputs) + val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) + val_top1 = val_top1.view(-1) / 100 + reward = val_top1 + config.ctl_entropy_w * entropy + if config.baseline is None: + baseline = val_top1 + else: + baseline = config.baseline - (1 - config.ctl_bl_dec) * (config.baseline - reward) - return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg, baseline.item() + loss = -1 * log_prob * (reward - baseline) + + # account + RewardMeter.update(reward.item()) + BaselineMeter.update(baseline.item()) + ValAccMeter.update(val_top1.item() * 100) + LossMeter.update(loss.item()) + EntropyMeter.update(entropy.item()) + + # Average gradient over controller_num_aggregate samples + loss = loss / config.ctl_num_aggre + loss.backward(retain_graph=True) + + # measure elapsed time + batch_time.update(time.time() - xend) + xend = time.time() + if (step + 1) % config.ctl_num_aggre == 0: + grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), 5.0) + GradnormMeter.update(grad_norm) + optimizer.step() + controller.zero_grad() + + if step % print_freq == 0: + Sstr = ( + "*Train-Controller* " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre) + ) + Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( + batch_time=batch_time, data_time=data_time + ) + Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})".format( + loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter + ) + Estr = "Entropy={:.4f} ({:.4f})".format(EntropyMeter.val, EntropyMeter.avg) + logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Estr) + + return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg, baseline.item() def get_best_arch(controller, shared_cnn, xloader, n_samples=10): - with torch.no_grad(): - controller.eval() - shared_cnn.eval() - archs, valid_accs = [], [] - loader_iter = iter(xloader) - for i in range(n_samples): - try: - inputs, targets = next(loader_iter) - except: + with torch.no_grad(): + controller.eval() + shared_cnn.eval() + archs, valid_accs = [], [] loader_iter = iter(xloader) - inputs, targets = next(loader_iter) + for i in range(n_samples): + try: + inputs, targets = next(loader_iter) + except: + loader_iter = iter(xloader) + inputs, targets = next(loader_iter) - _, _, sampled_arch = controller() - arch = shared_cnn.module.update_arch(sampled_arch) - _, logits = shared_cnn(inputs) - val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) + _, _, sampled_arch = controller() + arch = shared_cnn.module.update_arch(sampled_arch) + _, logits = shared_cnn(inputs) + val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) - archs.append( arch ) - valid_accs.append( val_top1.item() ) + archs.append(arch) + valid_accs.append(val_top1.item()) - best_idx = np.argmax(valid_accs) - best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] - return best_arch, best_valid_acc + best_idx = np.argmax(valid_accs) + best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] + return best_arch, best_valid_acc def valid_func(xloader, network, criterion): - data_time, batch_time = AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - network.eval() - end = time.time() - with torch.no_grad(): - for step, (arch_inputs, arch_targets) in enumerate(xloader): - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - # prediction - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - return arch_losses.avg, arch_top1.avg, arch_top5.avg + data_time, batch_time = AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() + network.eval() + end = time.time() + with torch.no_grad(): + for step, (arch_inputs, arch_targets) in enumerate(xloader): + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) + # prediction + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + return arch_losses.avg, arch_top1.avg, arch_top5.avg def main(xargs): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - logger.log('use config from : {:}'.format(xargs.config_path)) - config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) - _, train_loader, valid_loader = get_nas_search_loaders(train_data, test_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) - # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader - valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform) - if hasattr(valid_loader.dataset, 'transforms'): - valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms) - # data loader - logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) + train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + logger.log("use config from : {:}".format(xargs.config_path)) + config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + _, train_loader, valid_loader = get_nas_search_loaders( + train_data, test_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers + ) + # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader + valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform) + if hasattr(valid_loader.dataset, "transforms"): + valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms) + # data loader + logger.log( + "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(train_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - search_space = get_search_spaces('cell', xargs.search_space_name) - model_config = dict2config({'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells, - 'max_nodes': xargs.max_nodes, 'num_classes': class_num, - 'space' : search_space, - 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) - shared_cnn = get_cell_based_tiny_net(model_config) - controller = shared_cnn.create_controller() - - w_optimizer, w_scheduler, criterion = get_optim_scheduler(shared_cnn.parameters(), config) - a_optimizer = torch.optim.Adam(controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps) - logger.log('w-optimizer : {:}'.format(w_optimizer)) - logger.log('a-optimizer : {:}'.format(a_optimizer)) - logger.log('w-scheduler : {:}'.format(w_scheduler)) - logger.log('criterion : {:}'.format(criterion)) - #flop, param = get_model_infos(shared_cnn, xshape) - #logger.log('{:}'.format(shared_cnn)) - #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) - logger.log('search-space : {:}'.format(search_space)) - if xargs.arch_nas_dataset is None: - api = None - else: - api = API(xargs.arch_nas_dataset) - logger.log('{:} create API = {:} done'.format(time_string(), api)) - shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda() + search_space = get_search_spaces("cell", xargs.search_space_name) + model_config = dict2config( + { + "name": "ENAS", + "C": xargs.channel, + "N": xargs.num_cells, + "max_nodes": xargs.max_nodes, + "num_classes": class_num, + "space": search_space, + "affine": False, + "track_running_stats": bool(xargs.track_running_stats), + }, + None, + ) + shared_cnn = get_cell_based_tiny_net(model_config) + controller = shared_cnn.create_controller() - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') + w_optimizer, w_scheduler, criterion = get_optim_scheduler(shared_cnn.parameters(), config) + a_optimizer = torch.optim.Adam( + controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps + ) + logger.log("w-optimizer : {:}".format(w_optimizer)) + logger.log("a-optimizer : {:}".format(a_optimizer)) + logger.log("w-scheduler : {:}".format(w_scheduler)) + logger.log("criterion : {:}".format(criterion)) + # flop, param = get_model_infos(shared_cnn, xshape) + # logger.log('{:}'.format(shared_cnn)) + # logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) + logger.log("search-space : {:}".format(search_space)) + if xargs.arch_nas_dataset is None: + api = None + else: + api = API(xargs.arch_nas_dataset) + logger.log("{:} create API = {:} done".format(time_string(), api)) + shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda() - if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) - last_info = torch.load(last_info) - start_epoch = last_info['epoch'] - checkpoint = torch.load(last_info['last_checkpoint']) - genotypes = checkpoint['genotypes'] - baseline = checkpoint['baseline'] - valid_accuracies = checkpoint['valid_accuracies'] - shared_cnn.load_state_dict( checkpoint['shared_cnn'] ) - controller.load_state_dict( checkpoint['controller'] ) - w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) - w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) - a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") - # start training - start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup - for epoch in range(start_epoch, total_epoch): - w_scheduler.update(epoch, 0.0) - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) - epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) - logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), baseline)) + if last_info.exists(): # automatically resume from previous checkpoint + logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + last_info = torch.load(last_info) + start_epoch = last_info["epoch"] + checkpoint = torch.load(last_info["last_checkpoint"]) + genotypes = checkpoint["genotypes"] + baseline = checkpoint["baseline"] + valid_accuracies = checkpoint["valid_accuracies"] + shared_cnn.load_state_dict(checkpoint["shared_cnn"]) + controller.load_state_dict(checkpoint["controller"]) + w_scheduler.load_state_dict(checkpoint["w_scheduler"]) + w_optimizer.load_state_dict(checkpoint["w_optimizer"]) + a_optimizer.load_state_dict(checkpoint["a_optimizer"]) + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + ) + else: + logger.log("=> do not find the last-info file : {:}".format(last_info)) + start_epoch, valid_accuracies, genotypes, baseline = 0, {"best": -1}, {}, None - cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) - logger.log('[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, cnn_loss, cnn_top1, cnn_top5)) - ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline \ - = train_controller(valid_loader, shared_cnn, controller, criterion, a_optimizer, \ - dict2config({'baseline': baseline, - 'ctl_train_steps': xargs.controller_train_steps, 'ctl_num_aggre': xargs.controller_num_aggregate, - 'ctl_entropy_w': xargs.controller_entropy_weight, - 'ctl_bl_dec' : xargs.controller_bl_dec}, None), \ - epoch_str, xargs.print_freq, logger) - search_time.update(time.time() - start_time) - logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum)) - best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) - shared_cnn.module.update_arch(best_arch) - _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) + # start training + start_time, search_time, epoch_time, total_epoch = ( + time.time(), + AverageMeter(), + AverageMeter(), + config.epochs + config.warmup, + ) + for epoch in range(start_epoch, total_epoch): + w_scheduler.update(epoch, 0.0) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) + logger.log( + "\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}".format( + epoch_str, need_time, min(w_scheduler.get_lr()), baseline + ) + ) - genotypes[epoch] = best_arch - # check the best accuracy - valid_accuracies[epoch] = best_valid_acc - if best_valid_acc > valid_accuracies['best']: - valid_accuracies['best'] = best_valid_acc - genotypes['best'] = best_arch - find_best = True - else: find_best = False + cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn( + train_loader, + shared_cnn, + controller, + criterion, + w_scheduler, + w_optimizer, + epoch_str, + xargs.print_freq, + logger, + ) + logger.log( + "[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( + epoch_str, cnn_loss, cnn_top1, cnn_top5 + ) + ) + ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline = train_controller( + valid_loader, + shared_cnn, + controller, + criterion, + a_optimizer, + dict2config( + { + "baseline": baseline, + "ctl_train_steps": xargs.controller_train_steps, + "ctl_num_aggre": xargs.controller_num_aggregate, + "ctl_entropy_w": xargs.controller_entropy_weight, + "ctl_bl_dec": xargs.controller_bl_dec, + }, + None, + ), + epoch_str, + xargs.print_freq, + logger, + ) + search_time.update(time.time() - start_time) + logger.log( + "[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s".format( + epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum + ) + ) + best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) + shared_cnn.module.update_arch(best_arch) + _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) - logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) - # save checkpoint - save_path = save_checkpoint({'epoch' : epoch + 1, - 'args' : deepcopy(xargs), - 'baseline' : baseline, - 'shared_cnn' : shared_cnn.state_dict(), - 'controller' : controller.state_dict(), - 'w_optimizer' : w_optimizer.state_dict(), - 'a_optimizer' : a_optimizer.state_dict(), - 'w_scheduler' : w_scheduler.state_dict(), - 'genotypes' : genotypes, - 'valid_accuracies' : valid_accuracies}, - model_base_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch + 1, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) - if find_best: - logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc)) - copy_checkpoint(model_base_path, model_best_path, logger) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) - # measure elapsed time - epoch_time.update(time.time() - start_time) + genotypes[epoch] = best_arch + # check the best accuracy + valid_accuracies[epoch] = best_valid_acc + if best_valid_acc > valid_accuracies["best"]: + valid_accuracies["best"] = best_valid_acc + genotypes["best"] = best_arch + find_best = True + else: + find_best = False + + logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(xargs), + "baseline": baseline, + "shared_cnn": shared_cnn.state_dict(), + "controller": controller.state_dict(), + "w_optimizer": w_optimizer.state_dict(), + "a_optimizer": a_optimizer.state_dict(), + "w_scheduler": w_scheduler.state_dict(), + "genotypes": genotypes, + "valid_accuracies": valid_accuracies, + }, + model_base_path, + logger, + ) + last_info = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + if find_best: + logger.log( + "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format( + epoch_str, best_valid_acc + ) + ) + copy_checkpoint(model_base_path, model_best_path, logger) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("\n" + "-" * 100) + logger.log("During searching, the best architecture is {:}".format(genotypes["best"])) + logger.log("Its accuracy is {:.2f}%".format(valid_accuracies["best"])) + logger.log("Randomly select {:} architectures and select the best.".format(xargs.controller_num_samples)) start_time = time.time() - - logger.log('\n' + '-'*100) - logger.log('During searching, the best architecture is {:}'.format(genotypes['best'])) - logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best'])) - logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples)) - start_time = time.time() - final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) - search_time.update(time.time() - start_time) - shared_cnn.module.update_arch(final_arch) - final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) - logger.log('The Selected Final Architecture : {:}'.format(final_arch)) - logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5)) - logger.log('ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, final_arch)) - if api is not None: logger.log('{:}'.format( api.query_by_arch(final_arch) )) - logger.close() - + final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) + search_time.update(time.time() - start_time) + shared_cnn.module.update_arch(final_arch) + final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) + logger.log("The Selected Final Architecture : {:}".format(final_arch)) + logger.log("Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%".format(final_loss, final_top1, final_top5)) + logger.log( + "ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(total_epoch, search_time.sum, final_arch) + ) + if api is not None: + logger.log("{:}".format(api.query_by_arch(final_arch))) + logger.close() -if __name__ == '__main__': - parser = argparse.ArgumentParser("ENAS") - parser.add_argument('--data_path', type=str, help='Path to dataset') - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - # channels and number-of-cells - parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') - parser.add_argument('--search_space_name', type=str, help='The search space name.') - parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') - parser.add_argument('--channel', type=int, help='The number of channels.') - parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') - parser.add_argument('--config_path', type=str, help='The config file to train ENAS.') - parser.add_argument('--controller_train_steps', type=int, help='.') - parser.add_argument('--controller_num_aggregate', type=int, help='.') - parser.add_argument('--controller_entropy_weight', type=float, help='The weight for the entropy of the controller.') - parser.add_argument('--controller_bl_dec' , type=float, help='.') - parser.add_argument('--controller_num_samples' , type=int, help='.') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (nas-benchmark).') - parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - main(args) +if __name__ == "__main__": + parser = argparse.ArgumentParser("ENAS") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + # channels and number-of-cells + parser.add_argument( + "--track_running_stats", + type=int, + choices=[0, 1], + help="Whether use track_running_stats or not in the BN layer.", + ) + parser.add_argument("--search_space_name", type=str, help="The search space name.") + parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") + parser.add_argument("--channel", type=int, help="The number of channels.") + parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument("--config_path", type=str, help="The config file to train ENAS.") + parser.add_argument("--controller_train_steps", type=int, help=".") + parser.add_argument("--controller_num_aggregate", type=int, help=".") + parser.add_argument("--controller_entropy_weight", type=float, help="The weight for the entropy of the controller.") + parser.add_argument("--controller_bl_dec", type=float, help=".") + parser.add_argument("--controller_num_samples", type=int, help=".") + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument( + "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (nas-benchmark)." + ) + parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + main(args) diff --git a/exps/algos/GDAS.py b/exps/algos/GDAS.py index 329073c..758d052 100644 --- a/exps/algos/GDAS.py +++ b/exps/algos/GDAS.py @@ -7,211 +7,308 @@ import sys, time, random, argparse from copy import deepcopy import torch from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config -from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces -from nas_201_api import NASBench201API as API +from datasets import get_datasets, get_nas_search_loaders +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_cell_based_tiny_net, get_search_spaces +from nas_201_api import NASBench201API as API def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): - data_time, batch_time = AverageMeter(), AverageMeter() - base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - network.train() - end = time.time() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): - scheduler.update(None, 1.0 * step / len(xloader)) - base_targets = base_targets.cuda(non_blocking=True) - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - - # update the weights - w_optimizer.zero_grad() - _, logits = network(base_inputs) - base_loss = criterion(logits, base_targets) - base_loss.backward() - torch.nn.utils.clip_grad_norm_(network.parameters(), 5) - w_optimizer.step() - # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) - base_losses.update(base_loss.item(), base_inputs.size(0)) - base_top1.update (base_prec1.item(), base_inputs.size(0)) - base_top5.update (base_prec5.item(), base_inputs.size(0)) - - # update the architecture-weight - a_optimizer.zero_grad() - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - arch_loss.backward() - a_optimizer.step() - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) + data_time, batch_time = AverageMeter(), AverageMeter() + base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() + network.train() end = time.time() + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + scheduler.update(None, 1.0 * step / len(xloader)) + base_targets = base_targets.cuda(non_blocking=True) + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) - if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) - Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) - Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) - Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) - logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) - return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg + # update the weights + w_optimizer.zero_grad() + _, logits = network(base_inputs) + base_loss = criterion(logits, base_targets) + base_loss.backward() + torch.nn.utils.clip_grad_norm_(network.parameters(), 5) + w_optimizer.step() + # record + base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_losses.update(base_loss.item(), base_inputs.size(0)) + base_top1.update(base_prec1.item(), base_inputs.size(0)) + base_top5.update(base_prec5.item(), base_inputs.size(0)) + + # update the architecture-weight + a_optimizer.zero_grad() + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + arch_loss.backward() + a_optimizer.step() + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if step % print_freq == 0 or step + 1 == len(xloader): + Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( + batch_time=batch_time, data_time=data_time + ) + Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=base_losses, top1=base_top1, top5=base_top5 + ) + Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=arch_losses, top1=arch_top1, top5=arch_top5 + ) + logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) + return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg def main(xargs): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - #config_path = 'configs/nas-benchmark/algos/GDAS.config' - config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) - search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) - logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + # config_path = 'configs/nas-benchmark/algos/GDAS.config' + config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + search_loader, _, valid_loader = get_nas_search_loaders( + train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers + ) + logger.log( + "||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(search_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - search_space = get_search_spaces('cell', xargs.search_space_name) - if xargs.model_config is None: - model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, - 'max_nodes': xargs.max_nodes, 'num_classes': class_num, - 'space' : search_space, - 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) - else: - model_config = load_config(xargs.model_config, {'num_classes': class_num, 'space' : search_space, - 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) - search_model = get_cell_based_tiny_net(model_config) - logger.log('search-model :\n{:}'.format(search_model)) - logger.log('model-config : {:}'.format(model_config)) - - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) - a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) - logger.log('w-optimizer : {:}'.format(w_optimizer)) - logger.log('a-optimizer : {:}'.format(a_optimizer)) - logger.log('w-scheduler : {:}'.format(w_scheduler)) - logger.log('criterion : {:}'.format(criterion)) - flop, param = get_model_infos(search_model, xshape) - logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) - logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) - if xargs.arch_nas_dataset is None: - api = None - else: - api = API(xargs.arch_nas_dataset) - logger.log('{:} create API = {:} done'.format(time_string(), api)) + search_space = get_search_spaces("cell", xargs.search_space_name) + if xargs.model_config is None: + model_config = dict2config( + { + "name": "GDAS", + "C": xargs.channel, + "N": xargs.num_cells, + "max_nodes": xargs.max_nodes, + "num_classes": class_num, + "space": search_space, + "affine": False, + "track_running_stats": bool(xargs.track_running_stats), + }, + None, + ) + else: + model_config = load_config( + xargs.model_config, + { + "num_classes": class_num, + "space": search_space, + "affine": False, + "track_running_stats": bool(xargs.track_running_stats), + }, + None, + ) + search_model = get_cell_based_tiny_net(model_config) + logger.log("search-model :\n{:}".format(search_model)) + logger.log("model-config : {:}".format(model_config)) - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() + w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) + a_optimizer = torch.optim.Adam( + search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay + ) + logger.log("w-optimizer : {:}".format(w_optimizer)) + logger.log("a-optimizer : {:}".format(a_optimizer)) + logger.log("w-scheduler : {:}".format(w_scheduler)) + logger.log("criterion : {:}".format(criterion)) + flop, param = get_model_infos(search_model, xshape) + logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param)) + logger.log("search-space [{:} ops] : {:}".format(len(search_space), search_space)) + if xargs.arch_nas_dataset is None: + api = None + else: + api = API(xargs.arch_nas_dataset) + logger.log("{:} create API = {:} done".format(time_string(), api)) - if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) - last_info = torch.load(last_info) - start_epoch = last_info['epoch'] - checkpoint = torch.load(last_info['last_checkpoint']) - genotypes = checkpoint['genotypes'] - valid_accuracies = checkpoint['valid_accuracies'] - search_model.load_state_dict( checkpoint['search_model'] ) - w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) - w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) - a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()} + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() - # start training - start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup - for epoch in range(start_epoch, total_epoch): - w_scheduler.update(epoch, 0.0) - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) - epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) - search_model.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) ) - logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format(epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr()))) + if last_info.exists(): # automatically resume from previous checkpoint + logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + last_info = torch.load(last_info) + start_epoch = last_info["epoch"] + checkpoint = torch.load(last_info["last_checkpoint"]) + genotypes = checkpoint["genotypes"] + valid_accuracies = checkpoint["valid_accuracies"] + search_model.load_state_dict(checkpoint["search_model"]) + w_scheduler.load_state_dict(checkpoint["w_scheduler"]) + w_optimizer.load_state_dict(checkpoint["w_optimizer"]) + a_optimizer.load_state_dict(checkpoint["a_optimizer"]) + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + ) + else: + logger.log("=> do not find the last-info file : {:}".format(last_info)) + start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: search_model.genotype()} - search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ - = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) - search_time.update(time.time() - start_time) - logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) - logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 )) - # check the best accuracy - valid_accuracies[epoch] = valid_a_top1 - if valid_a_top1 > valid_accuracies['best']: - valid_accuracies['best'] = valid_a_top1 - genotypes['best'] = search_model.genotype() - find_best = True - else: find_best = False + # start training + start_time, search_time, epoch_time, total_epoch = ( + time.time(), + AverageMeter(), + AverageMeter(), + config.epochs + config.warmup, + ) + for epoch in range(start_epoch, total_epoch): + w_scheduler.update(epoch, 0.0) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) + search_model.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) + logger.log( + "\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}".format( + epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr()) + ) + ) - genotypes[epoch] = search_model.genotype() - logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) - # save checkpoint - save_path = save_checkpoint({'epoch' : epoch + 1, - 'args' : deepcopy(xargs), - 'search_model': search_model.state_dict(), - 'w_optimizer' : w_optimizer.state_dict(), - 'a_optimizer' : a_optimizer.state_dict(), - 'w_scheduler' : w_scheduler.state_dict(), - 'genotypes' : genotypes, - 'valid_accuracies' : valid_accuracies}, - model_base_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch + 1, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) - if find_best: - logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) - copy_checkpoint(model_base_path, model_best_path, logger) - with torch.no_grad(): - logger.log('{:}'.format(search_model.show_alphas())) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) - # measure elapsed time - epoch_time.update(time.time() - start_time) - start_time = time.time() + search_w_loss, search_w_top1, search_w_top5, valid_a_loss, valid_a_top1, valid_a_top5 = search_func( + search_loader, + network, + criterion, + w_scheduler, + w_optimizer, + a_optimizer, + epoch_str, + xargs.print_freq, + logger, + ) + search_time.update(time.time() - start_time) + logger.log( + "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( + epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum + ) + ) + logger.log( + "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( + epoch_str, valid_a_loss, valid_a_top1, valid_a_top5 + ) + ) + # check the best accuracy + valid_accuracies[epoch] = valid_a_top1 + if valid_a_top1 > valid_accuracies["best"]: + valid_accuracies["best"] = valid_a_top1 + genotypes["best"] = search_model.genotype() + find_best = True + else: + find_best = False - logger.log('\n' + '-'*100) - # check the performance from the architecture dataset - logger.log('GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200'))) - logger.close() - + genotypes[epoch] = search_model.genotype() + logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(xargs), + "search_model": search_model.state_dict(), + "w_optimizer": w_optimizer.state_dict(), + "a_optimizer": a_optimizer.state_dict(), + "w_scheduler": w_scheduler.state_dict(), + "genotypes": genotypes, + "valid_accuracies": valid_accuracies, + }, + model_base_path, + logger, + ) + last_info = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + if find_best: + logger.log( + "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format( + epoch_str, valid_a_top1 + ) + ) + copy_checkpoint(model_base_path, model_best_path, logger) + with torch.no_grad(): + logger.log("{:}".format(search_model.show_alphas())) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("\n" + "-" * 100) + # check the performance from the architecture dataset + logger.log( + "GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( + total_epoch, search_time.sum, genotypes[total_epoch - 1] + ) + ) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1], "200"))) + logger.close() -if __name__ == '__main__': - parser = argparse.ArgumentParser("GDAS") - parser.add_argument('--data_path', type=str, help='Path to dataset') - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - # channels and number-of-cells - parser.add_argument('--search_space_name', type=str, help='The search space name.') - parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') - parser.add_argument('--channel', type=int, help='The number of channels.') - parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') - parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') - parser.add_argument('--config_path', type=str, help='The path of the configuration.') - parser.add_argument('--model_config', type=str, help='The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.') - # architecture leraning rate - parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') - parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') - parser.add_argument('--tau_min', type=float, help='The minimum tau for Gumbel') - parser.add_argument('--tau_max', type=float, help='The maximum tau for Gumbel') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') - parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - main(args) +if __name__ == "__main__": + parser = argparse.ArgumentParser("GDAS") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + # channels and number-of-cells + parser.add_argument("--search_space_name", type=str, help="The search space name.") + parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") + parser.add_argument("--channel", type=int, help="The number of channels.") + parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument( + "--track_running_stats", + type=int, + choices=[0, 1], + help="Whether use track_running_stats or not in the BN layer.", + ) + parser.add_argument("--config_path", type=str, help="The path of the configuration.") + parser.add_argument( + "--model_config", + type=str, + help="The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.", + ) + # architecture leraning rate + parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") + parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") + parser.add_argument("--tau_min", type=float, help="The minimum tau for Gumbel") + parser.add_argument("--tau_max", type=float, help="The maximum tau for Gumbel") + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument( + "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + ) + parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + main(args) diff --git a/exps/algos/RANDOM-NAS.py b/exps/algos/RANDOM-NAS.py index 78eddda..d6b8a3a 100644 --- a/exps/algos/RANDOM-NAS.py +++ b/exps/algos/RANDOM-NAS.py @@ -9,229 +9,306 @@ from copy import deepcopy import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces -from nas_201_api import NASBench201API as API +from datasets import get_datasets, get_nas_search_loaders +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_cell_based_tiny_net, get_search_spaces +from nas_201_api import NASBench201API as API def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger): - data_time, batch_time = AverageMeter(), AverageMeter() - base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() - network.train() - end = time.time() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): - scheduler.update(None, 1.0 * step / len(xloader)) - base_targets = base_targets.cuda(non_blocking=True) - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - - # update the weights - network.module.random_genotype( True ) - w_optimizer.zero_grad() - _, logits = network(base_inputs) - base_loss = criterion(logits, base_targets) - base_loss.backward() - nn.utils.clip_grad_norm_(network.parameters(), 5) - w_optimizer.step() - # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) - base_losses.update(base_loss.item(), base_inputs.size(0)) - base_top1.update (base_prec1.item(), base_inputs.size(0)) - base_top5.update (base_prec5.item(), base_inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) + data_time, batch_time = AverageMeter(), AverageMeter() + base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() + network.train() end = time.time() + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + scheduler.update(None, 1.0 * step / len(xloader)) + base_targets = base_targets.cuda(non_blocking=True) + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) - if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) - Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) - Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) - logger.log(Sstr + ' ' + Tstr + ' ' + Wstr) - return base_losses.avg, base_top1.avg, base_top5.avg + # update the weights + network.module.random_genotype(True) + w_optimizer.zero_grad() + _, logits = network(base_inputs) + base_loss = criterion(logits, base_targets) + base_loss.backward() + nn.utils.clip_grad_norm_(network.parameters(), 5) + w_optimizer.step() + # record + base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_losses.update(base_loss.item(), base_inputs.size(0)) + base_top1.update(base_prec1.item(), base_inputs.size(0)) + base_top5.update(base_prec5.item(), base_inputs.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if step % print_freq == 0 or step + 1 == len(xloader): + Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( + batch_time=batch_time, data_time=data_time + ) + Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=base_losses, top1=base_top1, top5=base_top5 + ) + logger.log(Sstr + " " + Tstr + " " + Wstr) + return base_losses.avg, base_top1.avg, base_top5.avg def valid_func(xloader, network, criterion): - data_time, batch_time = AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - network.eval() - end = time.time() - with torch.no_grad(): - for step, (arch_inputs, arch_targets) in enumerate(xloader): - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - # prediction + data_time, batch_time = AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() + network.eval() + end = time.time() + with torch.no_grad(): + for step, (arch_inputs, arch_targets) in enumerate(xloader): + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) + # prediction - network.module.random_genotype( True ) - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - return arch_losses.avg, arch_top1.avg, arch_top5.avg + network.module.random_genotype(True) + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + return arch_losses.avg, arch_top1.avg, arch_top5.avg def search_find_best(xloader, network, n_samples): - with torch.no_grad(): - network.eval() - archs, valid_accs = [], [] - #print ('obtain the top-{:} architectures'.format(n_samples)) - loader_iter = iter(xloader) - for i in range(n_samples): - arch = network.module.random_genotype( True ) - try: - inputs, targets = next(loader_iter) - except: + with torch.no_grad(): + network.eval() + archs, valid_accs = [], [] + # print ('obtain the top-{:} architectures'.format(n_samples)) loader_iter = iter(xloader) - inputs, targets = next(loader_iter) + for i in range(n_samples): + arch = network.module.random_genotype(True) + try: + inputs, targets = next(loader_iter) + except: + loader_iter = iter(xloader) + inputs, targets = next(loader_iter) - _, logits = network(inputs) - val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) + _, logits = network(inputs) + val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) - archs.append( arch ) - valid_accs.append( val_top1.item() ) + archs.append(arch) + valid_accs.append(val_top1.item()) - best_idx = np.argmax(valid_accs) - best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] - return best_arch, best_valid_acc + best_idx = np.argmax(valid_accs) + best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] + return best_arch, best_valid_acc def main(xargs): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) - search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ - (config.batch_size, config.test_batch_size), xargs.workers) - logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + search_loader, _, valid_loader = get_nas_search_loaders( + train_data, + valid_data, + xargs.dataset, + "configs/nas-benchmark/", + (config.batch_size, config.test_batch_size), + xargs.workers, + ) + logger.log( + "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(search_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - search_space = get_search_spaces('cell', xargs.search_space_name) - model_config = dict2config({'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells, - 'max_nodes': xargs.max_nodes, 'num_classes': class_num, - 'space' : search_space, - 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) - search_model = get_cell_based_tiny_net(model_config) - - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config) - logger.log('w-optimizer : {:}'.format(w_optimizer)) - logger.log('w-scheduler : {:}'.format(w_scheduler)) - logger.log('criterion : {:}'.format(criterion)) - if xargs.arch_nas_dataset is None: api = None - else : api = API(xargs.arch_nas_dataset) - logger.log('{:} create API = {:} done'.format(time_string(), api)) + search_space = get_search_spaces("cell", xargs.search_space_name) + model_config = dict2config( + { + "name": "RANDOM", + "C": xargs.channel, + "N": xargs.num_cells, + "max_nodes": xargs.max_nodes, + "num_classes": class_num, + "space": search_space, + "affine": False, + "track_running_stats": bool(xargs.track_running_stats), + }, + None, + ) + search_model = get_cell_based_tiny_net(model_config) - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() + w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config) + logger.log("w-optimizer : {:}".format(w_optimizer)) + logger.log("w-scheduler : {:}".format(w_scheduler)) + logger.log("criterion : {:}".format(criterion)) + if xargs.arch_nas_dataset is None: + api = None + else: + api = API(xargs.arch_nas_dataset) + logger.log("{:} create API = {:} done".format(time_string(), api)) - if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) - last_info = torch.load(last_info) - start_epoch = last_info['epoch'] - checkpoint = torch.load(last_info['last_checkpoint']) - genotypes = checkpoint['genotypes'] - valid_accuracies = checkpoint['valid_accuracies'] - search_model.load_state_dict( checkpoint['search_model'] ) - w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) - w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() - # start training - start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup - for epoch in range(start_epoch, total_epoch): - w_scheduler.update(epoch, 0.0) - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) - epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) - logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) + if last_info.exists(): # automatically resume from previous checkpoint + logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + last_info = torch.load(last_info) + start_epoch = last_info["epoch"] + checkpoint = torch.load(last_info["last_checkpoint"]) + genotypes = checkpoint["genotypes"] + valid_accuracies = checkpoint["valid_accuracies"] + search_model.load_state_dict(checkpoint["search_model"]) + w_scheduler.load_state_dict(checkpoint["w_scheduler"]) + w_optimizer.load_state_dict(checkpoint["w_optimizer"]) + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + ) + else: + logger.log("=> do not find the last-info file : {:}".format(last_info)) + start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {} - # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) - search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) - search_time.update(time.time() - start_time) - logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) - valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) - logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) - cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num) - logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc)) - genotypes[epoch] = cur_arch - # check the best accuracy - valid_accuracies[epoch] = valid_a_top1 - if valid_a_top1 > valid_accuracies['best']: - valid_accuracies['best'] = valid_a_top1 - find_best = True - else: find_best = False + # start training + start_time, search_time, epoch_time, total_epoch = ( + time.time(), + AverageMeter(), + AverageMeter(), + config.epochs + config.warmup, + ) + for epoch in range(start_epoch, total_epoch): + w_scheduler.update(epoch, 0.0) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) + logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min(w_scheduler.get_lr()))) - # save checkpoint - save_path = save_checkpoint({'epoch' : epoch + 1, - 'args' : deepcopy(xargs), - 'search_model': search_model.state_dict(), - 'w_optimizer' : w_optimizer.state_dict(), - 'w_scheduler' : w_scheduler.state_dict(), - 'genotypes' : genotypes, - 'valid_accuracies' : valid_accuracies}, - model_base_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch + 1, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) - if find_best: - logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) - copy_checkpoint(model_base_path, model_best_path, logger) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) - # measure elapsed time - epoch_time.update(time.time() - start_time) + # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) + search_w_loss, search_w_top1, search_w_top5 = search_func( + search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger + ) + search_time.update(time.time() - start_time) + logger.log( + "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( + epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum + ) + ) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) + logger.log( + "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( + epoch_str, valid_a_loss, valid_a_top1, valid_a_top5 + ) + ) + cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num) + logger.log("[{:}] find-the-best : {:}, accuracy@1={:.2f}%".format(epoch_str, cur_arch, cur_valid_acc)) + genotypes[epoch] = cur_arch + # check the best accuracy + valid_accuracies[epoch] = valid_a_top1 + if valid_a_top1 > valid_accuracies["best"]: + valid_accuracies["best"] = valid_a_top1 + find_best = True + else: + find_best = False + + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(xargs), + "search_model": search_model.state_dict(), + "w_optimizer": w_optimizer.state_dict(), + "w_scheduler": w_scheduler.state_dict(), + "genotypes": genotypes, + "valid_accuracies": valid_accuracies, + }, + model_base_path, + logger, + ) + last_info = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + if find_best: + logger.log( + "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format( + epoch_str, valid_a_top1 + ) + ) + copy_checkpoint(model_base_path, model_best_path, logger) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("\n" + "-" * 200) + logger.log("Pre-searching costs {:.1f} s".format(search_time.sum)) start_time = time.time() - - logger.log('\n' + '-'*200) - logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum)) - start_time = time.time() - best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) - search_time.update(time.time() - start_time) - logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum)) - if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch, '200'))) - logger.close() + best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) + search_time.update(time.time() - start_time) + logger.log( + "RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.".format( + best_arch, best_acc, search_time.sum + ) + ) + if api is not None: + logger.log("{:}".format(api.query_by_arch(best_arch, "200"))) + logger.close() -if __name__ == '__main__': - parser = argparse.ArgumentParser("Random search for NAS.") - parser.add_argument('--data_path', type=str, help='Path to dataset') - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - # channels and number-of-cells - parser.add_argument('--search_space_name', type=str, help='The search space name.') - parser.add_argument('--config_path', type=str, help='The path to the configuration.') - parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') - parser.add_argument('--channel', type=int, help='The number of channels.') - parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') - parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.') - parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') - parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - main(args) +if __name__ == "__main__": + parser = argparse.ArgumentParser("Random search for NAS.") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + # channels and number-of-cells + parser.add_argument("--search_space_name", type=str, help="The search space name.") + parser.add_argument("--config_path", type=str, help="The path to the configuration.") + parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") + parser.add_argument("--channel", type=int, help="The number of channels.") + parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument("--select_num", type=int, help="The number of selected architectures to evaluate.") + parser.add_argument( + "--track_running_stats", + type=int, + choices=[0, 1], + help="Whether use track_running_stats or not in the BN layer.", + ) + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument( + "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + ) + parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + main(args) diff --git a/exps/algos/RANDOM.py b/exps/algos/RANDOM.py index e38bf60..150f6ec 100644 --- a/exps/algos/RANDOM.py +++ b/exps/algos/RANDOM.py @@ -7,113 +7,145 @@ from copy import deepcopy import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_search_spaces -from nas_201_api import NASBench201API as API -from R_EA import train_and_eval, random_architecture_func +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_search_spaces +from nas_201_api import NASBench201API as API +from R_EA import train_and_eval, random_architecture_func def main(xargs, nas_bench): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - if xargs.dataset == 'cifar10': - dataname = 'cifar10-valid' - else: - dataname = xargs.dataset - if xargs.data_path is not None: - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - split_Fpath = 'configs/nas-benchmark/cifar-split.txt' - cifar_split = load_config(split_Fpath, None, None) - train_split, valid_split = cifar_split.train, cifar_split.valid - logger.log('Load split file from {:}'.format(split_Fpath)) - config_path = 'configs/nas-benchmark/algos/R-EA.config' - config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) - # To split data - train_data_v2 = deepcopy(train_data) - train_data_v2.transform = valid_data.transform - valid_data = train_data_v2 - search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) - # data loader - train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) - logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} - else: - config_path = 'configs/nas-benchmark/algos/R-EA.config' - config = load_config(config_path, None, logger) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} - search_space = get_search_spaces('cell', xargs.search_space_name) - random_arch = random_architecture_func(xargs.max_nodes, search_space) - #x =random_arch() ; y = mutate_arch(x) - x_start_time = time.time() - logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) - best_arch, best_acc, total_time_cost, history = None, -1, 0, [] - #for idx in range(xargs.random_num): - while total_time_cost < xargs.time_budget: - arch = random_arch() - accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname) - if total_time_cost + cost_time > xargs.time_budget: break - else: total_time_cost += cost_time - history.append(arch) - if best_arch is None or best_acc < accuracy: - best_acc, best_arch = accuracy, arch - logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy)) - logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time)) - - info = nas_bench.query_by_arch(best_arch, '200') - if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) - else : logger.log('{:}'.format(info)) - logger.log('-'*100) - logger.close() - return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) + if xargs.dataset == "cifar10": + dataname = "cifar10-valid" + else: + dataname = xargs.dataset + if xargs.data_path is not None: + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + split_Fpath = "configs/nas-benchmark/cifar-split.txt" + cifar_split = load_config(split_Fpath, None, None) + train_split, valid_split = cifar_split.train, cifar_split.valid + logger.log("Load split file from {:}".format(split_Fpath)) + config_path = "configs/nas-benchmark/algos/R-EA.config" + config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger) + # To split data + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) + # data loader + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), + num_workers=xargs.workers, + pin_memory=True, + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), + num_workers=xargs.workers, + pin_memory=True, + ) + logger.log( + "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(train_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) + extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader} + else: + config_path = "configs/nas-benchmark/algos/R-EA.config" + config = load_config(config_path, None, logger) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) + extra_info = {"config": config, "train_loader": None, "valid_loader": None} + search_space = get_search_spaces("cell", xargs.search_space_name) + random_arch = random_architecture_func(xargs.max_nodes, search_space) + # x =random_arch() ; y = mutate_arch(x) + x_start_time = time.time() + logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench)) + best_arch, best_acc, total_time_cost, history = None, -1, 0, [] + # for idx in range(xargs.random_num): + while total_time_cost < xargs.time_budget: + arch = random_arch() + accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname) + if total_time_cost + cost_time > xargs.time_budget: + break + else: + total_time_cost += cost_time + history.append(arch) + if best_arch is None or best_acc < accuracy: + best_acc, best_arch = accuracy, arch + logger.log("[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy)) + logger.log( + "{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).".format( + time_string(), best_arch, best_acc, len(history), total_time_cost, time.time() - x_start_time + ) + ) + + info = nas_bench.query_by_arch(best_arch, "200") + if info is None: + logger.log("Did not find this architecture : {:}.".format(best_arch)) + else: + logger.log("{:}".format(info)) + logger.log("-" * 100) + logger.close() + return logger.log_dir, nas_bench.query_index_by_arch(best_arch) - -if __name__ == '__main__': - parser = argparse.ArgumentParser("Random NAS") - parser.add_argument('--data_path', type=str, help='Path to dataset') - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - # channels and number-of-cells - parser.add_argument('--search_space_name', type=str, help='The search space name.') - parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') - parser.add_argument('--channel', type=int, help='The number of channels.') - parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') - #parser.add_argument('--random_num', type=int, help='The number of random selected architectures.') - parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') - parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') - args = parser.parse_args() - #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): - nas_bench = None - else: - print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) - nas_bench = API(args.arch_nas_dataset) - if args.rand_seed < 0: - save_dir, all_indexes, num = None, [], 500 - for i in range(num): - print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) - args.rand_seed = random.randint(1, 100000) - save_dir, index = main(args, nas_bench) - all_indexes.append( index ) - torch.save(all_indexes, save_dir / 'results.pth') - else: - main(args, nas_bench) +if __name__ == "__main__": + parser = argparse.ArgumentParser("Random NAS") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + # channels and number-of-cells + parser.add_argument("--search_space_name", type=str, help="The search space name.") + parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") + parser.add_argument("--channel", type=int, help="The number of channels.") + parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + # parser.add_argument('--random_num', type=int, help='The number of random selected architectures.') + parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).") + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument( + "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + ) + parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, help="manual seed") + args = parser.parse_args() + # if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): + nas_bench = None + else: + print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset)) + nas_bench = API(args.arch_nas_dataset) + if args.rand_seed < 0: + save_dir, all_indexes, num = None, [], 500 + for i in range(num): + print("{:} : {:03d}/{:03d}".format(time_string(), i, num)) + args.rand_seed = random.randint(1, 100000) + save_dir, index = main(args, nas_bench) + all_indexes.append(index) + torch.save(all_indexes, save_dir / "results.pth") + else: + main(args, nas_bench) diff --git a/exps/algos/R_EA.py b/exps/algos/R_EA.py index ddfcde8..d6920cf 100644 --- a/exps/algos/R_EA.py +++ b/exps/algos/R_EA.py @@ -9,260 +9,324 @@ from copy import deepcopy import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from nas_201_api import NASBench201API as API -from models import CellStructure, get_search_spaces +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from nas_201_api import NASBench201API as API +from models import CellStructure, get_search_spaces class Model(object): + def __init__(self): + self.arch = None + self.accuracy = None + + def __str__(self): + """Prints a readable version of this bitstring.""" + return "{:}".format(self.arch) - def __init__(self): - self.arch = None - self.accuracy = None - - def __str__(self): - """Prints a readable version of this bitstring.""" - return '{:}'.format(self.arch) - # This function is to mimic the training and evaluatinig procedure for a single architecture `arch`. # The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch. # For use_012_epoch_training = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0. # In this case, the LR schedular is converged. # For use_012_epoch_training = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure. -# -def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_012_epoch_training=True): +# +def train_and_eval(arch, nas_bench, extra_info, dataname="cifar10-valid", use_012_epoch_training=True): - if use_012_epoch_training and nas_bench is not None: - arch_index = nas_bench.query_index_by_arch( arch ) - assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) - info = nas_bench.get_more_info(arch_index, dataname, iepoch=None, hp='12', is_random=True) - valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] - #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs - elif not use_012_epoch_training and nas_bench is not None: - # Please contact me if you want to use the following logic, because it has some potential issues. - # Please use `use_012_epoch_training=False` for cifar10 only. - # It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details) - arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25 - assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) - xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch=None, hp='12') - xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', hp='200') - info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', is_random=True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready). - cost = nas_bench.get_cost_info(arch_index, dataname, hp='200') - # The following codes are used to estimate the time cost. - # When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record. - # When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared. - nums = {'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000, - 'cifar10-valid-train' : 25000, 'cifar10-valid-valid' : 25000, - 'cifar100-train' : 50000, 'cifar100-valid' : 5000} - estimated_train_cost = xoinfo['train-per-time'] / nums['cifar10-valid-train'] * nums['{:}-train'.format(dataname)] / xocost['latency'] * cost['latency'] * nepoch - estimated_valid_cost = xoinfo['valid-per-time'] / nums['cifar10-valid-valid'] * nums['{:}-valid'.format(dataname)] / xocost['latency'] * cost['latency'] - try: - valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost - except: - valid_acc, time_cost = info['valtest-accuracy'], estimated_train_cost + estimated_valid_cost - else: - # train a model from scratch. - raise ValueError('NOT IMPLEMENT YET') - return valid_acc, time_cost + if use_012_epoch_training and nas_bench is not None: + arch_index = nas_bench.query_index_by_arch(arch) + assert arch_index >= 0, "can not find this arch : {:}".format(arch) + info = nas_bench.get_more_info(arch_index, dataname, iepoch=None, hp="12", is_random=True) + valid_acc, time_cost = info["valid-accuracy"], info["train-all-time"] + info["valid-per-time"] + # _, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs + elif not use_012_epoch_training and nas_bench is not None: + # Please contact me if you want to use the following logic, because it has some potential issues. + # Please use `use_012_epoch_training=False` for cifar10 only. + # It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details) + arch_index, nepoch = nas_bench.query_index_by_arch(arch), 25 + assert arch_index >= 0, "can not find this arch : {:}".format(arch) + xoinfo = nas_bench.get_more_info(arch_index, "cifar10-valid", iepoch=None, hp="12") + xocost = nas_bench.get_cost_info(arch_index, "cifar10-valid", hp="200") + info = nas_bench.get_more_info( + arch_index, dataname, nepoch, hp="200", is_random=True + ) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready). + cost = nas_bench.get_cost_info(arch_index, dataname, hp="200") + # The following codes are used to estimate the time cost. + # When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record. + # When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared. + nums = { + "ImageNet16-120-train": 151700, + "ImageNet16-120-valid": 3000, + "cifar10-valid-train": 25000, + "cifar10-valid-valid": 25000, + "cifar100-train": 50000, + "cifar100-valid": 5000, + } + estimated_train_cost = ( + xoinfo["train-per-time"] + / nums["cifar10-valid-train"] + * nums["{:}-train".format(dataname)] + / xocost["latency"] + * cost["latency"] + * nepoch + ) + estimated_valid_cost = ( + xoinfo["valid-per-time"] + / nums["cifar10-valid-valid"] + * nums["{:}-valid".format(dataname)] + / xocost["latency"] + * cost["latency"] + ) + try: + valid_acc, time_cost = info["valid-accuracy"], estimated_train_cost + estimated_valid_cost + except: + valid_acc, time_cost = info["valtest-accuracy"], estimated_train_cost + estimated_valid_cost + else: + # train a model from scratch. + raise ValueError("NOT IMPLEMENT YET") + return valid_acc, time_cost def random_architecture_func(max_nodes, op_names): - # return a random architecture - def random_architecture(): - genotypes = [] - for i in range(1, max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - op_name = random.choice( op_names ) - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return CellStructure( genotypes ) - return random_architecture + # return a random architecture + def random_architecture(): + genotypes = [] + for i in range(1, max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + op_name = random.choice(op_names) + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return CellStructure(genotypes) + + return random_architecture def mutate_arch_func(op_names): - """Computes the architecture for a child of the given parent architecture. - The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another. - """ - def mutate_arch_func(parent_arch): - child_arch = deepcopy( parent_arch ) - node_id = random.randint(0, len(child_arch.nodes)-1) - node_info = list( child_arch.nodes[node_id] ) - snode_id = random.randint(0, len(node_info)-1) - xop = random.choice( op_names ) - while xop == node_info[snode_id][0]: - xop = random.choice( op_names ) - node_info[snode_id] = (xop, node_info[snode_id][1]) - child_arch.nodes[node_id] = tuple( node_info ) - return child_arch - return mutate_arch_func + """Computes the architecture for a child of the given parent architecture. + The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another. + """ + + def mutate_arch_func(parent_arch): + child_arch = deepcopy(parent_arch) + node_id = random.randint(0, len(child_arch.nodes) - 1) + node_info = list(child_arch.nodes[node_id]) + snode_id = random.randint(0, len(node_info) - 1) + xop = random.choice(op_names) + while xop == node_info[snode_id][0]: + xop = random.choice(op_names) + node_info[snode_id] = (xop, node_info[snode_id][1]) + child_arch.nodes[node_id] = tuple(node_info) + return child_arch + + return mutate_arch_func -def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info, dataname): - """Algorithm for regularized evolution (i.e. aging evolution). - - Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image - Classifier Architecture Search". - - Args: - cycles: the number of cycles the algorithm should run for. - population_size: the number of individuals to keep in the population. - sample_size: the number of individuals that should participate in each tournament. - time_budget: the upper bound of searching cost +def regularized_evolution( + cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info, dataname +): + """Algorithm for regularized evolution (i.e. aging evolution). - Returns: - history: a list of `Model` instances, representing all the models computed - during the evolution experiment. - """ - population = collections.deque() - history, total_time_cost = [], 0 # Not used by the algorithm, only used to report results. + Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image + Classifier Architecture Search". - # Initialize the population with random models. - while len(population) < population_size: - model = Model() - model.arch = random_arch() - model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info, dataname) - population.append(model) - history.append(model) - total_time_cost += time_cost + Args: + cycles: the number of cycles the algorithm should run for. + population_size: the number of individuals to keep in the population. + sample_size: the number of individuals that should participate in each tournament. + time_budget: the upper bound of searching cost - # Carry out evolution in cycles. Each cycle produces a model and removes - # another. - #while len(history) < cycles: - while total_time_cost < time_budget: - # Sample randomly chosen models from the current population. - start_time, sample = time.time(), [] - while len(sample) < sample_size: - # Inefficient, but written this way for clarity. In the case of neural - # nets, the efficiency of this line is irrelevant because training neural - # nets is the rate-determining step. - candidate = random.choice(list(population)) - sample.append(candidate) + Returns: + history: a list of `Model` instances, representing all the models computed + during the evolution experiment. + """ + population = collections.deque() + history, total_time_cost = [], 0 # Not used by the algorithm, only used to report results. - # The parent is the best model in the sample. - parent = max(sample, key=lambda i: i.accuracy) + # Initialize the population with random models. + while len(population) < population_size: + model = Model() + model.arch = random_arch() + model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info, dataname) + population.append(model) + history.append(model) + total_time_cost += time_cost - # Create the child model and store it. - child = Model() - child.arch = mutate_arch(parent.arch) - total_time_cost += time.time() - start_time - child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info, dataname) - if total_time_cost + time_cost > time_budget: # return - return history, total_time_cost - else: - total_time_cost += time_cost - population.append(child) - history.append(child) + # Carry out evolution in cycles. Each cycle produces a model and removes + # another. + # while len(history) < cycles: + while total_time_cost < time_budget: + # Sample randomly chosen models from the current population. + start_time, sample = time.time(), [] + while len(sample) < sample_size: + # Inefficient, but written this way for clarity. In the case of neural + # nets, the efficiency of this line is irrelevant because training neural + # nets is the rate-determining step. + candidate = random.choice(list(population)) + sample.append(candidate) - # Remove the oldest model. - population.popleft() - return history, total_time_cost + # The parent is the best model in the sample. + parent = max(sample, key=lambda i: i.accuracy) + + # Create the child model and store it. + child = Model() + child.arch = mutate_arch(parent.arch) + total_time_cost += time.time() - start_time + child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info, dataname) + if total_time_cost + time_cost > time_budget: # return + return history, total_time_cost + else: + total_time_cost += time_cost + population.append(child) + history.append(child) + + # Remove the oldest model. + population.popleft() + return history, total_time_cost def main(xargs, nas_bench): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - if xargs.dataset == 'cifar10': - dataname = 'cifar10-valid' - else: - dataname = xargs.dataset - if xargs.data_path is not None: - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - split_Fpath = 'configs/nas-benchmark/cifar-split.txt' - cifar_split = load_config(split_Fpath, None, None) - train_split, valid_split = cifar_split.train, cifar_split.valid - logger.log('Load split file from {:}'.format(split_Fpath)) - config_path = 'configs/nas-benchmark/algos/R-EA.config' - config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) - # To split data - train_data_v2 = deepcopy(train_data) - train_data_v2.transform = valid_data.transform - valid_data = train_data_v2 - search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) - # data loader - train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) - logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} - else: - config_path = 'configs/nas-benchmark/algos/R-EA.config' - config = load_config(config_path, None, logger) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} + if xargs.dataset == "cifar10": + dataname = "cifar10-valid" + else: + dataname = xargs.dataset + if xargs.data_path is not None: + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + split_Fpath = "configs/nas-benchmark/cifar-split.txt" + cifar_split = load_config(split_Fpath, None, None) + train_split, valid_split = cifar_split.train, cifar_split.valid + logger.log("Load split file from {:}".format(split_Fpath)) + config_path = "configs/nas-benchmark/algos/R-EA.config" + config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger) + # To split data + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) + # data loader + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), + num_workers=xargs.workers, + pin_memory=True, + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), + num_workers=xargs.workers, + pin_memory=True, + ) + logger.log( + "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(train_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) + extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader} + else: + config_path = "configs/nas-benchmark/algos/R-EA.config" + config = load_config(config_path, None, logger) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) + extra_info = {"config": config, "train_loader": None, "valid_loader": None} - search_space = get_search_spaces('cell', xargs.search_space_name) - random_arch = random_architecture_func(xargs.max_nodes, search_space) - mutate_arch = mutate_arch_func(search_space) - #x =random_arch() ; y = mutate_arch(x) - x_start_time = time.time() - logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) - logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) - history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info, dataname) - logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_cost, time.time()-x_start_time)) - best_arch = max(history, key=lambda i: i.accuracy) - best_arch = best_arch.arch - logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) - - info = nas_bench.query_by_arch(best_arch, '200') - if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) - else : logger.log('{:}'.format(info)) - logger.log('-'*100) - logger.close() - return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) - + search_space = get_search_spaces("cell", xargs.search_space_name) + random_arch = random_architecture_func(xargs.max_nodes, search_space) + mutate_arch = mutate_arch_func(search_space) + # x =random_arch() ; y = mutate_arch(x) + x_start_time = time.time() + logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench)) + logger.log("-" * 30 + " start searching with the time budget of {:} s".format(xargs.time_budget)) + history, total_cost = regularized_evolution( + xargs.ea_cycles, + xargs.ea_population, + xargs.ea_sample_size, + xargs.time_budget, + random_arch, + mutate_arch, + nas_bench if args.ea_fast_by_api else None, + extra_info, + dataname, + ) + logger.log( + "{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).".format( + time_string(), len(history), total_cost, time.time() - x_start_time + ) + ) + best_arch = max(history, key=lambda i: i.accuracy) + best_arch = best_arch.arch + logger.log("{:} best arch is {:}".format(time_string(), best_arch)) + + info = nas_bench.query_by_arch(best_arch, "200") + if info is None: + logger.log("Did not find this architecture : {:}.".format(best_arch)) + else: + logger.log("{:}".format(info)) + logger.log("-" * 100) + logger.close() + return logger.log_dir, nas_bench.query_index_by_arch(best_arch) -if __name__ == '__main__': - parser = argparse.ArgumentParser("Regularized Evolution Algorithm") - parser.add_argument('--data_path', type=str, help='Path to dataset') - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - # channels and number-of-cells - parser.add_argument('--search_space_name', type=str, help='The search space name.') - parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') - parser.add_argument('--channel', type=int, help='The number of channels.') - parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') - parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.') - parser.add_argument('--ea_population', type=int, help='The population size in EA.') - parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.') - parser.add_argument('--ea_fast_by_api', type=int, help='Use our API to speed up the experiments or not.') - parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') - parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') - args = parser.parse_args() - #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - args.ea_fast_by_api = args.ea_fast_by_api > 0 +if __name__ == "__main__": + parser = argparse.ArgumentParser("Regularized Evolution Algorithm") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + # channels and number-of-cells + parser.add_argument("--search_space_name", type=str, help="The search space name.") + parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") + parser.add_argument("--channel", type=int, help="The number of channels.") + parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.") + parser.add_argument("--ea_population", type=int, help="The population size in EA.") + parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.") + parser.add_argument("--ea_fast_by_api", type=int, help="Use our API to speed up the experiments or not.") + parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).") + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument( + "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + ) + parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + # if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + args.ea_fast_by_api = args.ea_fast_by_api > 0 - if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): - nas_bench = None - else: - print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) - nas_bench = API(args.arch_nas_dataset) - if args.rand_seed < 0: - save_dir, all_indexes, num = None, [], 500 - for i in range(num): - print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) - args.rand_seed = random.randint(1, 100000) - save_dir, index = main(args, nas_bench) - all_indexes.append( index ) - torch.save(all_indexes, save_dir / 'results.pth') - else: - main(args, nas_bench) + if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): + nas_bench = None + else: + print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset)) + nas_bench = API(args.arch_nas_dataset) + if args.rand_seed < 0: + save_dir, all_indexes, num = None, [], 500 + for i in range(num): + print("{:} : {:03d}/{:03d}".format(time_string(), i, num)) + args.rand_seed = random.randint(1, 100000) + save_dir, index = main(args, nas_bench) + all_indexes.append(index) + torch.save(all_indexes, save_dir / "results.pth") + else: + main(args, nas_bench) diff --git a/exps/algos/SETN.py b/exps/algos/SETN.py index 038a65d..c6e23bc 100644 --- a/exps/algos/SETN.py +++ b/exps/algos/SETN.py @@ -9,274 +9,363 @@ from copy import deepcopy import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces -from nas_201_api import NASBench201API as API +from datasets import get_datasets, get_nas_search_loaders +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_cell_based_tiny_net, get_search_spaces +from nas_201_api import NASBench201API as API def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): - data_time, batch_time = AverageMeter(), AverageMeter() - base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - end = time.time() - network.train() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): - scheduler.update(None, 1.0 * step / len(xloader)) - base_targets = base_targets.cuda(non_blocking=True) - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - - # update the weights - sampled_arch = network.module.dync_genotype(True) - network.module.set_cal_mode('dynamic', sampled_arch) - #network.module.set_cal_mode( 'urs' ) - network.zero_grad() - _, logits = network(base_inputs) - base_loss = criterion(logits, base_targets) - base_loss.backward() - w_optimizer.step() - # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) - base_losses.update(base_loss.item(), base_inputs.size(0)) - base_top1.update (base_prec1.item(), base_inputs.size(0)) - base_top5.update (base_prec5.item(), base_inputs.size(0)) - - # update the architecture-weight - network.module.set_cal_mode( 'joint' ) - network.zero_grad() - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - arch_loss.backward() - a_optimizer.step() - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) + data_time, batch_time = AverageMeter(), AverageMeter() + base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() end = time.time() + network.train() + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + scheduler.update(None, 1.0 * step / len(xloader)) + base_targets = base_targets.cuda(non_blocking=True) + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) - if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) - Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) - Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) - Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) - logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) - #print (nn.functional.softmax(network.module.arch_parameters, dim=-1)) - #print (network.module.arch_parameters) - return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg + # update the weights + sampled_arch = network.module.dync_genotype(True) + network.module.set_cal_mode("dynamic", sampled_arch) + # network.module.set_cal_mode( 'urs' ) + network.zero_grad() + _, logits = network(base_inputs) + base_loss = criterion(logits, base_targets) + base_loss.backward() + w_optimizer.step() + # record + base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_losses.update(base_loss.item(), base_inputs.size(0)) + base_top1.update(base_prec1.item(), base_inputs.size(0)) + base_top5.update(base_prec5.item(), base_inputs.size(0)) + + # update the architecture-weight + network.module.set_cal_mode("joint") + network.zero_grad() + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + arch_loss.backward() + a_optimizer.step() + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if step % print_freq == 0 or step + 1 == len(xloader): + Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( + batch_time=batch_time, data_time=data_time + ) + Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=base_losses, top1=base_top1, top5=base_top5 + ) + Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( + loss=arch_losses, top1=arch_top1, top5=arch_top5 + ) + logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) + # print (nn.functional.softmax(network.module.arch_parameters, dim=-1)) + # print (network.module.arch_parameters) + return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg def get_best_arch(xloader, network, n_samples): - with torch.no_grad(): - network.eval() - archs, valid_accs = network.module.return_topK(n_samples), [] - #print ('obtain the top-{:} architectures'.format(n_samples)) - loader_iter = iter(xloader) - for i, sampled_arch in enumerate(archs): - network.module.set_cal_mode('dynamic', sampled_arch) - try: - inputs, targets = next(loader_iter) - except: + with torch.no_grad(): + network.eval() + archs, valid_accs = network.module.return_topK(n_samples), [] + # print ('obtain the top-{:} architectures'.format(n_samples)) loader_iter = iter(xloader) - inputs, targets = next(loader_iter) + for i, sampled_arch in enumerate(archs): + network.module.set_cal_mode("dynamic", sampled_arch) + try: + inputs, targets = next(loader_iter) + except: + loader_iter = iter(xloader) + inputs, targets = next(loader_iter) - _, logits = network(inputs) - val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) + _, logits = network(inputs) + val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) - valid_accs.append(val_top1.item()) + valid_accs.append(val_top1.item()) - best_idx = np.argmax(valid_accs) - best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] - return best_arch, best_valid_acc + best_idx = np.argmax(valid_accs) + best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] + return best_arch, best_valid_acc def valid_func(xloader, network, criterion): - data_time, batch_time = AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - end = time.time() - with torch.no_grad(): - network.eval() - for step, (arch_inputs, arch_targets) in enumerate(xloader): - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - # prediction - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - return arch_losses.avg, arch_top1.avg, arch_top5.avg + data_time, batch_time = AverageMeter(), AverageMeter() + arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() + end = time.time() + with torch.no_grad(): + network.eval() + for step, (arch_inputs, arch_targets) in enumerate(xloader): + arch_targets = arch_targets.cuda(non_blocking=True) + # measure data loading time + data_time.update(time.time() - end) + # prediction + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + # record + arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_losses.update(arch_loss.item(), arch_inputs.size(0)) + arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) + arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + return arch_losses.avg, arch_top1.avg, arch_top5.avg def main(xargs): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) - search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ - (config.batch_size, config.test_batch_size), xargs.workers) - logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + search_loader, _, valid_loader = get_nas_search_loaders( + train_data, + valid_data, + xargs.dataset, + "configs/nas-benchmark/", + (config.batch_size, config.test_batch_size), + xargs.workers, + ) + logger.log( + "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(search_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - search_space = get_search_spaces('cell', xargs.search_space_name) - if xargs.model_config is None: - model_config = dict2config( - dict(name='SETN', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num, - space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats)), None) - else: - model_config = load_config(xargs.model_config, dict(num_classes=class_num, space=search_space, affine=False, - track_running_stats=bool(xargs.track_running_stats)), None) - logger.log('search space : {:}'.format(search_space)) - search_model = get_cell_based_tiny_net(model_config) - - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) - a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) - logger.log('w-optimizer : {:}'.format(w_optimizer)) - logger.log('a-optimizer : {:}'.format(a_optimizer)) - logger.log('w-scheduler : {:}'.format(w_scheduler)) - logger.log('criterion : {:}'.format(criterion)) - flop, param = get_model_infos(search_model, xshape) - logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) - logger.log('search-space : {:}'.format(search_space)) - if xargs.arch_nas_dataset is None: - api = None - else: - api = API(xargs.arch_nas_dataset) - logger.log('{:} create API = {:} done'.format(time_string(), api)) + search_space = get_search_spaces("cell", xargs.search_space_name) + if xargs.model_config is None: + model_config = dict2config( + dict( + name="SETN", + C=xargs.channel, + N=xargs.num_cells, + max_nodes=xargs.max_nodes, + num_classes=class_num, + space=search_space, + affine=False, + track_running_stats=bool(xargs.track_running_stats), + ), + None, + ) + else: + model_config = load_config( + xargs.model_config, + dict( + num_classes=class_num, + space=search_space, + affine=False, + track_running_stats=bool(xargs.track_running_stats), + ), + None, + ) + logger.log("search space : {:}".format(search_space)) + search_model = get_cell_based_tiny_net(model_config) - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() + w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) + a_optimizer = torch.optim.Adam( + search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay + ) + logger.log("w-optimizer : {:}".format(w_optimizer)) + logger.log("a-optimizer : {:}".format(a_optimizer)) + logger.log("w-scheduler : {:}".format(w_scheduler)) + logger.log("criterion : {:}".format(criterion)) + flop, param = get_model_infos(search_model, xshape) + logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param)) + logger.log("search-space : {:}".format(search_space)) + if xargs.arch_nas_dataset is None: + api = None + else: + api = API(xargs.arch_nas_dataset) + logger.log("{:} create API = {:} done".format(time_string(), api)) - if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) - last_info = torch.load(last_info) - start_epoch = last_info['epoch'] - checkpoint = torch.load(last_info['last_checkpoint']) - genotypes = checkpoint['genotypes'] - valid_accuracies = checkpoint['valid_accuracies'] - search_model.load_state_dict( checkpoint['search_model'] ) - w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) - w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) - a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num) - start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: init_genotype} + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() - # start training - start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup - for epoch in range(start_epoch, total_epoch): - w_scheduler.update(epoch, 0.0) - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) - epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) - logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) + if last_info.exists(): # automatically resume from previous checkpoint + logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + last_info = torch.load(last_info) + start_epoch = last_info["epoch"] + checkpoint = torch.load(last_info["last_checkpoint"]) + genotypes = checkpoint["genotypes"] + valid_accuracies = checkpoint["valid_accuracies"] + search_model.load_state_dict(checkpoint["search_model"]) + w_scheduler.load_state_dict(checkpoint["w_scheduler"]) + w_optimizer.load_state_dict(checkpoint["w_optimizer"]) + a_optimizer.load_state_dict(checkpoint["a_optimizer"]) + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + ) + else: + logger.log("=> do not find the last-info file : {:}".format(last_info)) + init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num) + start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: init_genotype} - search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ - = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) - search_time.update(time.time() - start_time) - logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) - logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) + # start training + start_time, search_time, epoch_time, total_epoch = ( + time.time(), + AverageMeter(), + AverageMeter(), + config.epochs + config.warmup, + ) + for epoch in range(start_epoch, total_epoch): + w_scheduler.update(epoch, 0.0) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) + logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min(w_scheduler.get_lr()))) - genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) - network.module.set_cal_mode('dynamic', genotype) - valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) - logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) - #search_model.set_cal_mode('urs') - #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) - #logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) - #search_model.set_cal_mode('joint') - #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) - #logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) - #search_model.set_cal_mode('select') - #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) - #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) - # check the best accuracy - valid_accuracies[epoch] = valid_a_top1 + search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 = search_func( + search_loader, + network, + criterion, + w_scheduler, + w_optimizer, + a_optimizer, + epoch_str, + xargs.print_freq, + logger, + ) + search_time.update(time.time() - start_time) + logger.log( + "[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( + epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum + ) + ) + logger.log( + "[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( + epoch_str, search_a_loss, search_a_top1, search_a_top5 + ) + ) - genotypes[epoch] = genotype - logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) - # save checkpoint - save_path = save_checkpoint({'epoch' : epoch + 1, - 'args' : deepcopy(xargs), - 'search_model': search_model.state_dict(), - 'w_optimizer' : w_optimizer.state_dict(), - 'a_optimizer' : a_optimizer.state_dict(), - 'w_scheduler' : w_scheduler.state_dict(), - 'genotypes' : genotypes, - 'valid_accuracies' : valid_accuracies}, - model_base_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch + 1, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) - with torch.no_grad(): - logger.log('{:}'.format(search_model.show_alphas())) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) - # measure elapsed time - epoch_time.update(time.time() - start_time) + genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) + network.module.set_cal_mode("dynamic", genotype) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) + logger.log( + "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format( + epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype + ) + ) + # search_model.set_cal_mode('urs') + # valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) + # logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) + # search_model.set_cal_mode('joint') + # valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) + # logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) + # search_model.set_cal_mode('select') + # valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) + # logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) + # check the best accuracy + valid_accuracies[epoch] = valid_a_top1 + + genotypes[epoch] = genotype + logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(xargs), + "search_model": search_model.state_dict(), + "w_optimizer": w_optimizer.state_dict(), + "a_optimizer": a_optimizer.state_dict(), + "w_scheduler": w_scheduler.state_dict(), + "genotypes": genotypes, + "valid_accuracies": valid_accuracies, + }, + model_base_path, + logger, + ) + last_info = save_checkpoint( + { + "epoch": epoch + 1, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + with torch.no_grad(): + logger.log("{:}".format(search_model.show_alphas())) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + # the final post procedure : count the time start_time = time.time() + genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) + search_time.update(time.time() - start_time) + network.module.set_cal_mode("dynamic", genotype) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) + logger.log("Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(genotype, valid_a_top1)) - # the final post procedure : count the time - start_time = time.time() - genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) - search_time.update(time.time() - start_time) - network.module.set_cal_mode('dynamic', genotype) - valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) - logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1)) - - logger.log('\n' + '-'*100) - # check the performance from the architecture dataset - logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype)) - if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200') )) - logger.close() - + logger.log("\n" + "-" * 100) + # check the performance from the architecture dataset + logger.log("SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(total_epoch, search_time.sum, genotype)) + if api is not None: + logger.log("{:}".format(api.query_by_arch(genotype, "200"))) + logger.close() -if __name__ == '__main__': - parser = argparse.ArgumentParser("SETN") - parser.add_argument('--data_path', type=str, help='Path to dataset') - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - # channels and number-of-cells - parser.add_argument('--search_space_name', type=str, help='The search space name.') - parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') - parser.add_argument('--channel', type=int, help='The number of channels.') - parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') - parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.') - parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') - parser.add_argument('--config_path', type=str, help='The path of the configuration.') - # architecture leraning rate - parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') - parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') - parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - main(args) +if __name__ == "__main__": + parser = argparse.ArgumentParser("SETN") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + # channels and number-of-cells + parser.add_argument("--search_space_name", type=str, help="The search space name.") + parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") + parser.add_argument("--channel", type=int, help="The number of channels.") + parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument("--select_num", type=int, help="The number of selected architectures to evaluate.") + parser.add_argument( + "--track_running_stats", + type=int, + choices=[0, 1], + help="Whether use track_running_stats or not in the BN layer.", + ) + parser.add_argument("--config_path", type=str, help="The path of the configuration.") + # architecture leraning rate + parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") + parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument( + "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + ) + parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + main(args) diff --git a/exps/algos/reinforce.py b/exps/algos/reinforce.py index 8dbed40..2e8f629 100644 --- a/exps/algos/reinforce.py +++ b/exps/algos/reinforce.py @@ -10,212 +10,245 @@ from pathlib import Path import torch import torch.nn as nn from torch.distributions import Categorical -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from nas_201_api import NASBench201API as API -from models import CellStructure, get_search_spaces +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from nas_201_api import NASBench201API as API +from models import CellStructure, get_search_spaces from R_EA import train_and_eval class Policy(nn.Module): + def __init__(self, max_nodes, search_space): + super(Policy, self).__init__() + self.max_nodes = max_nodes + self.search_space = deepcopy(search_space) + self.edge2index = {} + for i in range(1, max_nodes): + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + self.edge2index[node_str] = len(self.edge2index) + self.arch_parameters = nn.Parameter(1e-3 * torch.randn(len(self.edge2index), len(search_space))) - def __init__(self, max_nodes, search_space): - super(Policy, self).__init__() - self.max_nodes = max_nodes - self.search_space = deepcopy(search_space) - self.edge2index = {} - for i in range(1, max_nodes): - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - self.edge2index[ node_str ] = len(self.edge2index) - self.arch_parameters = nn.Parameter( 1e-3*torch.randn(len(self.edge2index), len(search_space)) ) + def generate_arch(self, actions): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + op_name = self.search_space[actions[self.edge2index[node_str]]] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return CellStructure(genotypes) - def generate_arch(self, actions): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - op_name = self.search_space[ actions[ self.edge2index[ node_str ] ] ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return CellStructure( genotypes ) + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[self.edge2index[node_str]] + op_name = self.search_space[weights.argmax().item()] + xlist.append((op_name, j)) + genotypes.append(tuple(xlist)) + return CellStructure(genotypes) - def genotype(self): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - with torch.no_grad(): - weights = self.arch_parameters[ self.edge2index[node_str] ] - op_name = self.search_space[ weights.argmax().item() ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return CellStructure( genotypes ) - - def forward(self): - alphas = nn.functional.softmax(self.arch_parameters, dim=-1) - return alphas + def forward(self): + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + return alphas class ExponentialMovingAverage(object): - """Class that maintains an exponential moving average.""" + """Class that maintains an exponential moving average.""" - def __init__(self, momentum): - self._numerator = 0 - self._denominator = 0 - self._momentum = momentum + def __init__(self, momentum): + self._numerator = 0 + self._denominator = 0 + self._momentum = momentum - def update(self, value): - self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value - self._denominator = self._momentum * self._denominator + (1 - self._momentum) + def update(self, value): + self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value + self._denominator = self._momentum * self._denominator + (1 - self._momentum) - def value(self): - """Return the current value of the moving average""" - return self._numerator / self._denominator + def value(self): + """Return the current value of the moving average""" + return self._numerator / self._denominator def select_action(policy): - probs = policy() - m = Categorical(probs) - action = m.sample() - #policy.saved_log_probs.append(m.log_prob(action)) - return m.log_prob(action), action.cpu().tolist() + probs = policy() + m = Categorical(probs) + action = m.sample() + # policy.saved_log_probs.append(m.log_prob(action)) + return m.log_prob(action), action.cpu().tolist() def main(xargs, nas_bench): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) - if xargs.dataset == 'cifar10': - dataname = 'cifar10-valid' - else: - dataname = xargs.dataset - if xargs.data_path is not None: - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - split_Fpath = 'configs/nas-benchmark/cifar-split.txt' - cifar_split = load_config(split_Fpath, None, None) - train_split, valid_split = cifar_split.train, cifar_split.valid - logger.log('Load split file from {:}'.format(split_Fpath)) - config_path = 'configs/nas-benchmark/algos/R-EA.config' - config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) - # To split data - train_data_v2 = deepcopy(train_data) - train_data_v2.transform = valid_data.transform - valid_data = train_data_v2 - search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) - # data loader - train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) - logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} - else: - config_path = 'configs/nas-benchmark/algos/R-EA.config' - config = load_config(config_path, None, logger) - extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - - - search_space = get_search_spaces('cell', xargs.search_space_name) - policy = Policy(xargs.max_nodes, search_space) - optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate) - #optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate) - eps = np.finfo(np.float32).eps.item() - baseline = ExponentialMovingAverage(xargs.EMA_momentum) - logger.log('policy : {:}'.format(policy)) - logger.log('optimizer : {:}'.format(optimizer)) - logger.log('eps : {:}'.format(eps)) + if xargs.dataset == "cifar10": + dataname = "cifar10-valid" + else: + dataname = xargs.dataset + if xargs.data_path is not None: + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + split_Fpath = "configs/nas-benchmark/cifar-split.txt" + cifar_split = load_config(split_Fpath, None, None) + train_split, valid_split = cifar_split.train, cifar_split.valid + logger.log("Load split file from {:}".format(split_Fpath)) + config_path = "configs/nas-benchmark/algos/R-EA.config" + config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger) + # To split data + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) + # data loader + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), + num_workers=xargs.workers, + pin_memory=True, + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), + num_workers=xargs.workers, + pin_memory=True, + ) + logger.log( + "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( + xargs.dataset, len(train_loader), len(valid_loader), config.batch_size + ) + ) + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) + extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader} + else: + config_path = "configs/nas-benchmark/algos/R-EA.config" + config = load_config(config_path, None, logger) + extra_info = {"config": config, "train_loader": None, "valid_loader": None} + logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - # nas dataset load - logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) + search_space = get_search_spaces("cell", xargs.search_space_name) + policy = Policy(xargs.max_nodes, search_space) + optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate) + # optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate) + eps = np.finfo(np.float32).eps.item() + baseline = ExponentialMovingAverage(xargs.EMA_momentum) + logger.log("policy : {:}".format(policy)) + logger.log("optimizer : {:}".format(optimizer)) + logger.log("eps : {:}".format(eps)) - # REINFORCE - # attempts = 0 - x_start_time = time.time() - logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget)) - total_steps, total_costs, trace = 0, 0, [] - #for istep in range(xargs.RL_steps): - while total_costs < xargs.time_budget: - start_time = time.time() - log_prob, action = select_action( policy ) - arch = policy.generate_arch( action ) - reward, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname) - trace.append( (reward, arch) ) - # accumulate time - if total_costs + cost_time < xargs.time_budget: - total_costs += cost_time - else: break + # nas dataset load + logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench)) - baseline.update(reward) - # calculate loss - policy_loss = ( -log_prob * (reward - baseline.value()) ).sum() - optimizer.zero_grad() - policy_loss.backward() - optimizer.step() - # accumulate time - total_costs += time.time() - start_time - total_steps += 1 - logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype())) - #logger.log('----> {:}'.format(policy.arch_parameters)) - #logger.log('') + # REINFORCE + # attempts = 0 + x_start_time = time.time() + logger.log("Will start searching with time budget of {:} s.".format(xargs.time_budget)) + total_steps, total_costs, trace = 0, 0, [] + # for istep in range(xargs.RL_steps): + while total_costs < xargs.time_budget: + start_time = time.time() + log_prob, action = select_action(policy) + arch = policy.generate_arch(action) + reward, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname) + trace.append((reward, arch)) + # accumulate time + if total_costs + cost_time < xargs.time_budget: + total_costs += cost_time + else: + break - # best_arch = policy.genotype() # first version - best_arch = max(trace, key=lambda x: x[0])[1] - logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time)) - info = nas_bench.query_by_arch(best_arch, '200') - if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) - else : logger.log('{:}'.format(info)) - logger.log('-'*100) - logger.close() - return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) - + baseline.update(reward) + # calculate loss + policy_loss = (-log_prob * (reward - baseline.value())).sum() + optimizer.zero_grad() + policy_loss.backward() + optimizer.step() + # accumulate time + total_costs += time.time() - start_time + total_steps += 1 + logger.log( + "step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}".format( + total_steps, baseline.value(), policy_loss.item(), policy.genotype() + ) + ) + # logger.log('----> {:}'.format(policy.arch_parameters)) + # logger.log('') + + # best_arch = policy.genotype() # first version + best_arch = max(trace, key=lambda x: x[0])[1] + logger.log( + "REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).".format( + total_steps, total_costs, time.time() - x_start_time + ) + ) + info = nas_bench.query_by_arch(best_arch, "200") + if info is None: + logger.log("Did not find this architecture : {:}.".format(best_arch)) + else: + logger.log("{:}".format(info)) + logger.log("-" * 100) + logger.close() + return logger.log_dir, nas_bench.query_index_by_arch(best_arch) -if __name__ == '__main__': - parser = argparse.ArgumentParser("The REINFORCE Algorithm") - parser.add_argument('--data_path', type=str, help='Path to dataset') - parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') - # channels and number-of-cells - parser.add_argument('--search_space_name', type=str, help='The search space name.') - parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') - parser.add_argument('--channel', type=int, help='The number of channels.') - parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') - parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.') - #parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.') - parser.add_argument('--EMA_momentum', type=float, help='The momentum value for EMA.') - parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') - # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') - parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') - args = parser.parse_args() - #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): - nas_bench = None - else: - print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) - nas_bench = API(args.arch_nas_dataset) - if args.rand_seed < 0: - save_dir, all_indexes, num = None, [], 500 - for i in range(num): - print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) - args.rand_seed = random.randint(1, 100000) - save_dir, index = main(args, nas_bench) - all_indexes.append( index ) - torch.save(all_indexes, save_dir / 'results.pth') - else: - main(args, nas_bench) +if __name__ == "__main__": + parser = argparse.ArgumentParser("The REINFORCE Algorithm") + parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument( + "--dataset", + type=str, + choices=["cifar10", "cifar100", "ImageNet16-120"], + help="Choose between Cifar10/100 and ImageNet-16.", + ) + # channels and number-of-cells + parser.add_argument("--search_space_name", type=str, help="The search space name.") + parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") + parser.add_argument("--channel", type=int, help="The number of channels.") + parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument("--learning_rate", type=float, help="The learning rate for REINFORCE.") + # parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.') + parser.add_argument("--EMA_momentum", type=float, help="The momentum value for EMA.") + parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).") + # log + parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") + parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument( + "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + ) + parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + # if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): + nas_bench = None + else: + print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset)) + nas_bench = API(args.arch_nas_dataset) + if args.rand_seed < 0: + save_dir, all_indexes, num = None, [], 500 + for i in range(num): + print("{:} : {:03d}/{:03d}".format(time_string(), i, num)) + args.rand_seed = random.randint(1, 100000) + save_dir, index = main(args, nas_bench) + all_indexes.append(index) + torch.save(all_indexes, save_dir / "results.pth") + else: + main(args, nas_bench) diff --git a/exps/basic-eval.py b/exps/basic-eval.py index ecf2253..f0e2509 100644 --- a/exps/basic-eval.py +++ b/exps/basic-eval.py @@ -2,67 +2,83 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # ##################################################### import os, sys, time, torch, random, argparse -from PIL import ImageFile +from PIL import ImageFile + ImageFile.LOAD_TRUNCATED_IMAGES = True -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config -from procedures import get_procedures, get_optim_scheduler -from datasets import get_datasets -from models import obtain_model -from utils import get_model_infos -from log_utils import PrintLogger, time_string +from procedures import get_procedures, get_optim_scheduler +from datasets import get_datasets +from models import obtain_model +from utils import get_model_infos +from log_utils import PrintLogger, time_string -assert torch.cuda.is_available(), 'torch.cuda is not available' +assert torch.cuda.is_available(), "torch.cuda is not available" def main(args): - assert os.path.isdir ( args.data_path ) , 'invalid data-path : {:}'.format(args.data_path) - assert os.path.isfile( args.checkpoint ), 'invalid checkpoint : {:}'.format(args.checkpoint) + assert os.path.isdir(args.data_path), "invalid data-path : {:}".format(args.data_path) + assert os.path.isfile(args.checkpoint), "invalid checkpoint : {:}".format(args.checkpoint) - checkpoint = torch.load( args.checkpoint ) - xargs = checkpoint['args'] - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, args.data_path, xargs.cutout_length) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=xargs.batch_size, shuffle=False, num_workers=xargs.workers, pin_memory=True) + checkpoint = torch.load(args.checkpoint) + xargs = checkpoint["args"] + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, args.data_path, xargs.cutout_length) + valid_loader = torch.utils.data.DataLoader( + valid_data, batch_size=xargs.batch_size, shuffle=False, num_workers=xargs.workers, pin_memory=True + ) - logger = PrintLogger() - model_config = dict2config(checkpoint['model-config'], logger) - base_model = obtain_model(model_config) - flop, param = get_model_infos(base_model, xshape) - logger.log('model ====>>>>:\n{:}'.format(base_model)) - logger.log('model information : {:}'.format(base_model.get_message())) - logger.log('-'*50) - logger.log('Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G'.format(param, flop, flop/1e3)) - logger.log('-'*50) - logger.log('valid_data : {:}'.format(valid_data)) - optim_config = dict2config(checkpoint['optim-config'], logger) - _, _, criterion = get_optim_scheduler(base_model.parameters(), optim_config) - logger.log('criterion : {:}'.format(criterion)) - base_model.load_state_dict( checkpoint['base-model'] ) - _, valid_func = get_procedures(xargs.procedure) - logger.log('initialize the CNN done, evaluate it using {:}'.format(valid_func)) - network = torch.nn.DataParallel(base_model).cuda() - - try: - valid_loss, valid_acc1, valid_acc5 = valid_func(valid_loader, network, criterion, optim_config, 'pure-evaluation', xargs.print_freq_eval, logger) - except: - _, valid_func = get_procedures('basic') - valid_loss, valid_acc1, valid_acc5 = valid_func(valid_loader, network, criterion, optim_config, 'pure-evaluation', xargs.print_freq_eval, logger) - - num_bytes = torch.cuda.max_memory_cached( next(network.parameters()).device ) * 1.0 - logger.log('***{:s}*** EVALUATION loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f}, error@1 = {:.2f}, error@5 = {:.2f}'.format(time_string(), valid_loss, valid_acc1, valid_acc5, 100-valid_acc1, 100-valid_acc5)) - logger.log('[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]'.format(next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9)) - logger.close() + logger = PrintLogger() + model_config = dict2config(checkpoint["model-config"], logger) + base_model = obtain_model(model_config) + flop, param = get_model_infos(base_model, xshape) + logger.log("model ====>>>>:\n{:}".format(base_model)) + logger.log("model information : {:}".format(base_model.get_message())) + logger.log("-" * 50) + logger.log("Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format(param, flop, flop / 1e3)) + logger.log("-" * 50) + logger.log("valid_data : {:}".format(valid_data)) + optim_config = dict2config(checkpoint["optim-config"], logger) + _, _, criterion = get_optim_scheduler(base_model.parameters(), optim_config) + logger.log("criterion : {:}".format(criterion)) + base_model.load_state_dict(checkpoint["base-model"]) + _, valid_func = get_procedures(xargs.procedure) + logger.log("initialize the CNN done, evaluate it using {:}".format(valid_func)) + network = torch.nn.DataParallel(base_model).cuda() + + try: + valid_loss, valid_acc1, valid_acc5 = valid_func( + valid_loader, network, criterion, optim_config, "pure-evaluation", xargs.print_freq_eval, logger + ) + except: + _, valid_func = get_procedures("basic") + valid_loss, valid_acc1, valid_acc5 = valid_func( + valid_loader, network, criterion, optim_config, "pure-evaluation", xargs.print_freq_eval, logger + ) + + num_bytes = torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 + logger.log( + "***{:s}*** EVALUATION loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f}, error@1 = {:.2f}, error@5 = {:.2f}".format( + time_string(), valid_loss, valid_acc1, valid_acc5, 100 - valid_acc1, 100 - valid_acc5 + ) + ) + logger.log( + "[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]".format( + next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9 + ) + ) + logger.close() -if __name__ == '__main__': - parser = argparse.ArgumentParser("Evaluate-CNN") - parser.add_argument('--data_path', type=str, help='Path to dataset.') - parser.add_argument('--checkpoint', type=str, help='Choose between Cifar10/100 and ImageNet.') - args = parser.parse_args() - main(args) +if __name__ == "__main__": + parser = argparse.ArgumentParser("Evaluate-CNN") + parser.add_argument("--data_path", type=str, help="Path to dataset.") + parser.add_argument("--checkpoint", type=str, help="Choose between Cifar10/100 and ImageNet.") + args = parser.parse_args() + main(args) diff --git a/exps/basic-main.py b/exps/basic-main.py index ecdf2d9..82f756d 100644 --- a/exps/basic-main.py +++ b/exps/basic-main.py @@ -2,166 +2,219 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # ##################################################### import sys, time, torch, random, argparse -from PIL import ImageFile +from PIL import ImageFile + ImageFile.LOAD_TRUNCATED_IMAGES = True -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, obtain_basic_args as obtain_args -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from procedures import get_optim_scheduler, get_procedures -from datasets import get_datasets -from models import obtain_model +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint +from procedures import get_optim_scheduler, get_procedures +from datasets import get_datasets +from models import obtain_model from nas_infer_model import obtain_nas_infer_model -from utils import get_model_infos -from log_utils import AverageMeter, time_string, convert_secs2time +from utils import get_model_infos +from log_utils import AverageMeter, time_string, convert_secs2time def main(args): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = True - #torch.backends.cudnn.deterministic = True - torch.set_num_threads( args.workers ) - - prepare_seed(args.rand_seed) - logger = prepare_logger(args) - - train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) - train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) - # get configures - model_config = load_config(args.model_config, {'class_num': class_num}, logger) - optim_config = load_config(args.optim_config, {'class_num': class_num}, logger) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + torch.set_num_threads(args.workers) - if args.model_source == 'normal': - base_model = obtain_model(model_config) - elif args.model_source == 'nas': - base_model = obtain_nas_infer_model(model_config, args.extra_model_path) - elif args.model_source == 'autodl-searched': - base_model = obtain_model(model_config, args.extra_model_path) - else: - raise ValueError('invalid model-source : {:}'.format(args.model_source)) - flop, param = get_model_infos(base_model, xshape) - logger.log('model ====>>>>:\n{:}'.format(base_model)) - logger.log('model information : {:}'.format(base_model.get_message())) - logger.log('-'*50) - logger.log('Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G'.format(param, flop, flop/1e3)) - logger.log('-'*50) - logger.log('train_data : {:}'.format(train_data)) - logger.log('valid_data : {:}'.format(valid_data)) - optimizer, scheduler, criterion = get_optim_scheduler(base_model.parameters(), optim_config) - logger.log('optimizer : {:}'.format(optimizer)) - logger.log('scheduler : {:}'.format(scheduler)) - logger.log('criterion : {:}'.format(criterion)) - - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() + prepare_seed(args.rand_seed) + logger = prepare_logger(args) - if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) - last_infox = torch.load(last_info) - start_epoch = last_infox['epoch'] + 1 - last_checkpoint_path = last_infox['last_checkpoint'] - if not last_checkpoint_path.exists(): - logger.log('Does not find {:}, try another path'.format(last_checkpoint_path)) - last_checkpoint_path = last_info.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name - checkpoint = torch.load( last_checkpoint_path ) - base_model.load_state_dict( checkpoint['base-model'] ) - scheduler.load_state_dict ( checkpoint['scheduler'] ) - optimizer.load_state_dict ( checkpoint['optimizer'] ) - valid_accuracies = checkpoint['valid_accuracies'] - max_bytes = checkpoint['max_bytes'] - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) - elif args.resume is not None: - assert Path(args.resume).exists(), 'Can not find the resume file : {:}'.format(args.resume) - checkpoint = torch.load( args.resume ) - start_epoch = checkpoint['epoch'] + 1 - base_model.load_state_dict( checkpoint['base-model'] ) - scheduler.load_state_dict ( checkpoint['scheduler'] ) - optimizer.load_state_dict ( checkpoint['optimizer'] ) - valid_accuracies = checkpoint['valid_accuracies'] - max_bytes = checkpoint['max_bytes'] - logger.log("=> loading checkpoint from '{:}' start with {:}-th epoch.".format(args.resume, start_epoch)) - elif args.init_model is not None: - assert Path(args.init_model).exists(), 'Can not find the initialization file : {:}'.format(args.init_model) - checkpoint = torch.load( args.init_model ) - base_model.load_state_dict( checkpoint['base-model'] ) - start_epoch, valid_accuracies, max_bytes = 0, {'best': -1}, {} - logger.log('=> initialize the model from {:}'.format( args.init_model )) - else: - logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, max_bytes = 0, {'best': -1}, {} + train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) + train_loader = torch.utils.data.DataLoader( + train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True + ) + valid_loader = torch.utils.data.DataLoader( + valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True + ) + # get configures + model_config = load_config(args.model_config, {"class_num": class_num}, logger) + optim_config = load_config(args.optim_config, {"class_num": class_num}, logger) - train_func, valid_func = get_procedures(args.procedure) - - total_epoch = optim_config.epochs + optim_config.warmup - # Main Training and Evaluation Loop - start_time = time.time() - epoch_time = AverageMeter() - for epoch in range(start_epoch, total_epoch): - scheduler.update(epoch, 0.0) - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch), True) ) - epoch_str = 'epoch={:03d}/{:03d}'.format(epoch, total_epoch) - LRs = scheduler.get_lr() - find_best = False - # set-up drop-out ratio - if hasattr(base_model, 'update_drop_path'): base_model.update_drop_path(model_config.drop_path_prob * epoch / total_epoch) - logger.log('\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler)) - - # train for one epoch - train_loss, train_acc1, train_acc5 = train_func(train_loader, network, criterion, scheduler, optimizer, optim_config, epoch_str, args.print_freq, logger) - # log the results - logger.log('***{:s}*** TRAIN [{:}] loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}'.format(time_string(), epoch_str, train_loss, train_acc1, train_acc5)) + if args.model_source == "normal": + base_model = obtain_model(model_config) + elif args.model_source == "nas": + base_model = obtain_nas_infer_model(model_config, args.extra_model_path) + elif args.model_source == "autodl-searched": + base_model = obtain_model(model_config, args.extra_model_path) + else: + raise ValueError("invalid model-source : {:}".format(args.model_source)) + flop, param = get_model_infos(base_model, xshape) + logger.log("model ====>>>>:\n{:}".format(base_model)) + logger.log("model information : {:}".format(base_model.get_message())) + logger.log("-" * 50) + logger.log("Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format(param, flop, flop / 1e3)) + logger.log("-" * 50) + logger.log("train_data : {:}".format(train_data)) + logger.log("valid_data : {:}".format(valid_data)) + optimizer, scheduler, criterion = get_optim_scheduler(base_model.parameters(), optim_config) + logger.log("optimizer : {:}".format(optimizer)) + logger.log("scheduler : {:}".format(scheduler)) + logger.log("criterion : {:}".format(criterion)) - # evaluate the performance - if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): - logger.log('-'*150) - valid_loss, valid_acc1, valid_acc5 = valid_func(valid_loader, network, criterion, optim_config, epoch_str, args.print_freq_eval, logger) - valid_accuracies[epoch] = valid_acc1 - logger.log('***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}'.format(time_string(), epoch_str, valid_loss, valid_acc1, valid_acc5, valid_accuracies['best'], 100-valid_accuracies['best'])) - if valid_acc1 > valid_accuracies['best']: - valid_accuracies['best'] = valid_acc1 - find_best = True - logger.log('Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.'.format(epoch, valid_acc1, valid_acc5, 100-valid_acc1, 100-valid_acc5, model_best_path)) - num_bytes = torch.cuda.max_memory_cached( next(network.parameters()).device ) * 1.0 - logger.log('[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]'.format(next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9)) - max_bytes[epoch] = num_bytes - if epoch % 10 == 0: torch.cuda.empty_cache() + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() - # save checkpoint - save_path = save_checkpoint({ - 'epoch' : epoch, - 'args' : deepcopy(args), - 'max_bytes' : deepcopy(max_bytes), - 'FLOP' : flop, - 'PARAM' : param, - 'valid_accuracies': deepcopy(valid_accuracies), - 'model-config' : model_config._asdict(), - 'optim-config' : optim_config._asdict(), - 'base-model' : base_model.state_dict(), - 'scheduler' : scheduler.state_dict(), - 'optimizer' : optimizer.state_dict(), - }, model_base_path, logger) - if find_best: copy_checkpoint(model_base_path, model_best_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) + if last_info.exists(): # automatically resume from previous checkpoint + logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + last_infox = torch.load(last_info) + start_epoch = last_infox["epoch"] + 1 + last_checkpoint_path = last_infox["last_checkpoint"] + if not last_checkpoint_path.exists(): + logger.log("Does not find {:}, try another path".format(last_checkpoint_path)) + last_checkpoint_path = last_info.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name + checkpoint = torch.load(last_checkpoint_path) + base_model.load_state_dict(checkpoint["base-model"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + valid_accuracies = checkpoint["valid_accuracies"] + max_bytes = checkpoint["max_bytes"] + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + ) + elif args.resume is not None: + assert Path(args.resume).exists(), "Can not find the resume file : {:}".format(args.resume) + checkpoint = torch.load(args.resume) + start_epoch = checkpoint["epoch"] + 1 + base_model.load_state_dict(checkpoint["base-model"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + valid_accuracies = checkpoint["valid_accuracies"] + max_bytes = checkpoint["max_bytes"] + logger.log("=> loading checkpoint from '{:}' start with {:}-th epoch.".format(args.resume, start_epoch)) + elif args.init_model is not None: + assert Path(args.init_model).exists(), "Can not find the initialization file : {:}".format(args.init_model) + checkpoint = torch.load(args.init_model) + base_model.load_state_dict(checkpoint["base-model"]) + start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} + logger.log("=> initialize the model from {:}".format(args.init_model)) + else: + logger.log("=> do not find the last-info file : {:}".format(last_info)) + start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} - # measure elapsed time - epoch_time.update(time.time() - start_time) + train_func, valid_func = get_procedures(args.procedure) + + total_epoch = optim_config.epochs + optim_config.warmup + # Main Training and Evaluation Loop start_time = time.time() + epoch_time = AverageMeter() + for epoch in range(start_epoch, total_epoch): + scheduler.update(epoch, 0.0) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch), True)) + epoch_str = "epoch={:03d}/{:03d}".format(epoch, total_epoch) + LRs = scheduler.get_lr() + find_best = False + # set-up drop-out ratio + if hasattr(base_model, "update_drop_path"): + base_model.update_drop_path(model_config.drop_path_prob * epoch / total_epoch) + logger.log( + "\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}".format( + time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler + ) + ) - logger.log('\n' + '-'*200) - logger.log('Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}'.format(convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e6, logger.path('info'))) - logger.log('-'*200 + '\n') - logger.close() + # train for one epoch + train_loss, train_acc1, train_acc5 = train_func( + train_loader, network, criterion, scheduler, optimizer, optim_config, epoch_str, args.print_freq, logger + ) + # log the results + logger.log( + "***{:s}*** TRAIN [{:}] loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}".format( + time_string(), epoch_str, train_loss, train_acc1, train_acc5 + ) + ) + + # evaluate the performance + if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): + logger.log("-" * 150) + valid_loss, valid_acc1, valid_acc5 = valid_func( + valid_loader, network, criterion, optim_config, epoch_str, args.print_freq_eval, logger + ) + valid_accuracies[epoch] = valid_acc1 + logger.log( + "***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}".format( + time_string(), + epoch_str, + valid_loss, + valid_acc1, + valid_acc5, + valid_accuracies["best"], + 100 - valid_accuracies["best"], + ) + ) + if valid_acc1 > valid_accuracies["best"]: + valid_accuracies["best"] = valid_acc1 + find_best = True + logger.log( + "Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.".format( + epoch, valid_acc1, valid_acc5, 100 - valid_acc1, 100 - valid_acc5, model_best_path + ) + ) + num_bytes = torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 + logger.log( + "[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]".format( + next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9 + ) + ) + max_bytes[epoch] = num_bytes + if epoch % 10 == 0: + torch.cuda.empty_cache() + + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch, + "args": deepcopy(args), + "max_bytes": deepcopy(max_bytes), + "FLOP": flop, + "PARAM": param, + "valid_accuracies": deepcopy(valid_accuracies), + "model-config": model_config._asdict(), + "optim-config": optim_config._asdict(), + "base-model": base_model.state_dict(), + "scheduler": scheduler.state_dict(), + "optimizer": optimizer.state_dict(), + }, + model_base_path, + logger, + ) + if find_best: + copy_checkpoint(model_base_path, model_best_path, logger) + last_info = save_checkpoint( + { + "epoch": epoch, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("\n" + "-" * 200) + logger.log( + "Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}".format( + convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e6, logger.path("info") + ) + ) + logger.log("-" * 200 + "\n") + logger.close() -if __name__ == '__main__': - args = obtain_args() - main(args) +if __name__ == "__main__": + args = obtain_args() + main(args) diff --git a/exps/experimental/example-nas-bench.py b/exps/experimental/example-nas-bench.py index 35e0ff6..79fbde2 100644 --- a/exps/experimental/example-nas-bench.py +++ b/exps/experimental/example-nas-bench.py @@ -12,39 +12,43 @@ from pathlib import Path from collections import OrderedDict import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from nas_201_api import NASBench201API from log_utils import time_string from models import get_cell_based_tiny_net from utils import weight_watcher -if __name__ == '__main__': - parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") - parser.add_argument('--api_path' , type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.') - parser.add_argument('--archive_path', type=str, default=None, help='The path to the NAS-Bench-201 weight dir.') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") + parser.add_argument( + "--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file and weight dir." + ) + parser.add_argument("--archive_path", type=str, default=None, help="The path to the NAS-Bench-201 weight dir.") + args = parser.parse_args() - meta_file = Path(args.api_path) - weight_dir = Path(args.archive_path) - assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) - assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir) + meta_file = Path(args.api_path) + weight_dir = Path(args.archive_path) + assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) + assert weight_dir.exists() and weight_dir.is_dir(), "invalid path for weight dir : {:}".format(weight_dir) + api = NASBench201API(meta_file, verbose=True) - api = NASBench201API(meta_file, verbose=True) + arch_index = 3 # query the 3-th architecture + api.reload(weight_dir, arch_index) # reload the data of 3-th from archive dir - arch_index = 3 # query the 3-th architecture - api.reload(weight_dir, arch_index) # reload the data of 3-th from archive dir + data = "cifar10" # query the info from CIFAR-10 + config = api.get_net_config(arch_index, data) + net = get_cell_based_tiny_net(config) + meta_info = api.query_meta_info_by_index(arch_index, hp="200") # all info about this architecture + params = meta_info.get_net_param(data, 888) - data = 'cifar10' # query the info from CIFAR-10 - config = api.get_net_config(arch_index, data) - net = get_cell_based_tiny_net(config) - meta_info = api.query_meta_info_by_index(arch_index, hp='200') # all info about this architecture - params = meta_info.get_net_param(data, 888) - - net.load_state_dict(params) - _, summary = weight_watcher.analyze(net, alphas=False) - print('The summary of {:}-th architecture:\n{:}'.format(arch_index, summary)) + net.load_state_dict(params) + _, summary = weight_watcher.analyze(net, alphas=False) + print("The summary of {:}-th architecture:\n{:}".format(arch_index, summary)) diff --git a/exps/experimental/test-flops.py b/exps/experimental/test-flops.py index 73df113..a804cb3 100644 --- a/exps/experimental/test-flops.py +++ b/exps/experimental/test-flops.py @@ -2,23 +2,27 @@ import sys, time, random, argparse from copy import deepcopy import torchvision.models as models from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from utils import get_model_infos -#from models.ImageNet_MobileNetV2 import MobileNetV2 + +# from models.ImageNet_MobileNetV2 import MobileNetV2 from torchvision.models.mobilenet import MobileNetV2 + def main(width_mult): - # model = MobileNetV2(1001, width_mult, 32, 1280, 'InvertedResidual', 0.2) - model = MobileNetV2(width_mult=width_mult) - print(model) - flops, params = get_model_infos(model, (2, 3, 224, 224)) - print('FLOPs : {:}'.format(flops)) - print('Params : {:}'.format(params)) - print('-'*50) + # model = MobileNetV2(1001, width_mult, 32, 1280, 'InvertedResidual', 0.2) + model = MobileNetV2(width_mult=width_mult) + print(model) + flops, params = get_model_infos(model, (2, 3, 224, 224)) + print("FLOPs : {:}".format(flops)) + print("Params : {:}".format(params)) + print("-" * 50) -if __name__ == '__main__': - main(1.0) - main(1.4) +if __name__ == "__main__": + main(1.0) + main(1.4) diff --git a/exps/experimental/test-nas-plot.py b/exps/experimental/test-nas-plot.py index 6fdecd4..7a129d1 100644 --- a/exps/experimental/test-nas-plot.py +++ b/exps/experimental/test-nas-plot.py @@ -5,110 +5,148 @@ from copy import deepcopy import torch import numpy as np from collections import OrderedDict -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from nas_201_api import NASBench201API as API + def test_nas_api(): - from nas_201_api import ArchResults - xdata = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth') - for key in ['full', 'less']: - print ('\n------------------------- {:} -------------------------'.format(key)) - archRes = ArchResults.create_from_state_dict(xdata[key]) - print(archRes) - print(archRes.arch_idx_str()) - print(archRes.get_dataset_names()) - print(archRes.get_comput_costs('cifar10-valid')) - # get the metrics - print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False)) - print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) - print(archRes.query('cifar10-valid', 777)) + from nas_201_api import ArchResults + + xdata = torch.load( + "/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth" + ) + for key in ["full", "less"]: + print("\n------------------------- {:} -------------------------".format(key)) + archRes = ArchResults.create_from_state_dict(xdata[key]) + print(archRes) + print(archRes.arch_idx_str()) + print(archRes.get_dataset_names()) + print(archRes.get_comput_costs("cifar10-valid")) + # get the metrics + print(archRes.get_metrics("cifar10-valid", "x-valid", None, False)) + print(archRes.get_metrics("cifar10-valid", "x-valid", None, True)) + print(archRes.query("cifar10-valid", 777)) -OPS = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3'] -COLORS = ['chartreuse' , 'cyan' , 'navyblue', 'chocolate1'] +OPS = ["skip-connect", "conv-1x1", "conv-3x3", "pool-3x3"] +COLORS = ["chartreuse", "cyan", "navyblue", "chocolate1"] + def plot(filename): - from graphviz import Digraph - g = Digraph( - format='png', - edge_attr=dict(fontsize='20', fontname="times"), - node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), - engine='dot') - g.body.extend(['rankdir=LR']) + from graphviz import Digraph - steps = 5 - for i in range(0, steps): - if i == 0: - g.node(str(i), fillcolor='darkseagreen2') - elif i+1 == steps: - g.node(str(i), fillcolor='palegoldenrod') - else: g.node(str(i), fillcolor='lightblue') + g = Digraph( + format="png", + edge_attr=dict(fontsize="20", fontname="times"), + node_attr=dict( + style="filled", + shape="rect", + align="center", + fontsize="20", + height="0.5", + width="0.5", + penwidth="2", + fontname="times", + ), + engine="dot", + ) + g.body.extend(["rankdir=LR"]) - for i in range(1, steps): - for xin in range(i): - op_i = random.randint(0, len(OPS)-1) - #g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i]) - g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i]) - #import pdb; pdb.set_trace() - g.render(filename, cleanup=True, view=False) + steps = 5 + for i in range(0, steps): + if i == 0: + g.node(str(i), fillcolor="darkseagreen2") + elif i + 1 == steps: + g.node(str(i), fillcolor="palegoldenrod") + else: + g.node(str(i), fillcolor="lightblue") + + for i in range(1, steps): + for xin in range(i): + op_i = random.randint(0, len(OPS) - 1) + # g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i]) + g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i]) + # import pdb; pdb.set_trace() + g.render(filename, cleanup=True, view=False) def test_auto_grad(): - class Net(torch.nn.Module): - def __init__(self, iS): - super(Net, self).__init__() - self.layer = torch.nn.Linear(iS, 1) - def forward(self, inputs): - outputs = self.layer(inputs) - outputs = torch.exp(outputs) - return outputs.mean() - net = Net(10) - inputs = torch.rand(256, 10) - loss = net(inputs) - first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True) - first_order_grads = torch.cat([x.view(-1) for x in first_order_grads]) - second_order_grads = [] - for grads in first_order_grads: - s_grads = torch.autograd.grad(grads, net.parameters()) - second_order_grads.append( s_grads ) + class Net(torch.nn.Module): + def __init__(self, iS): + super(Net, self).__init__() + self.layer = torch.nn.Linear(iS, 1) + + def forward(self, inputs): + outputs = self.layer(inputs) + outputs = torch.exp(outputs) + return outputs.mean() + + net = Net(10) + inputs = torch.rand(256, 10) + loss = net(inputs) + first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True) + first_order_grads = torch.cat([x.view(-1) for x in first_order_grads]) + second_order_grads = [] + for grads in first_order_grads: + s_grads = torch.autograd.grad(grads, net.parameters()) + second_order_grads.append(s_grads) def test_one_shot_model(ckpath, use_train): - from models import get_cell_based_tiny_net, get_search_spaces - from datasets import get_datasets, SearchDataset - from config_utils import load_config, dict2config - from utils.nas_utils import evaluate_one_shot - use_train = int(use_train) > 0 - #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' - #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' - print ('ckpath : {:}'.format(ckpath)) - ckp = torch.load(ckpath) - xargs = ckp['args'] - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - #config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) - config = load_config('./configs/nas-benchmark/algos/DARTS.config', {'class_num': class_num, 'xshape': xshape}, None) - if xargs.dataset == 'cifar10': - cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None) - xvalid_data = deepcopy(train_data) - xvalid_data.transform = valid_data.transform - valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True) - else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet)) - search_space = get_search_spaces('cell', xargs.search_space_name) - model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, - 'max_nodes': xargs.max_nodes, 'num_classes': class_num, - 'space' : search_space, - 'affine' : False, 'track_running_stats': True}, None) - search_model = get_cell_based_tiny_net(model_config) - search_model.load_state_dict( ckp['search_model'] ) - search_model = search_model.cuda() - api = API('/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth') - archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train) + from models import get_cell_based_tiny_net, get_search_spaces + from datasets import get_datasets, SearchDataset + from config_utils import load_config, dict2config + from utils.nas_utils import evaluate_one_shot + + use_train = int(use_train) > 0 + # ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' + # ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' + print("ckpath : {:}".format(ckpath)) + ckp = torch.load(ckpath) + xargs = ckp["args"] + train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + # config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) + config = load_config("./configs/nas-benchmark/algos/DARTS.config", {"class_num": class_num, "xshape": xshape}, None) + if xargs.dataset == "cifar10": + cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None) + xvalid_data = deepcopy(train_data) + xvalid_data.transform = valid_data.transform + valid_loader = torch.utils.data.DataLoader( + xvalid_data, + batch_size=2048, + sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), + num_workers=12, + pin_memory=True, + ) + else: + raise ValueError("invalid dataset : {:}".format(xargs.dataseet)) + search_space = get_search_spaces("cell", xargs.search_space_name) + model_config = dict2config( + { + "name": "SETN", + "C": xargs.channel, + "N": xargs.num_cells, + "max_nodes": xargs.max_nodes, + "num_classes": class_num, + "space": search_space, + "affine": False, + "track_running_stats": True, + }, + None, + ) + search_model = get_cell_based_tiny_net(model_config) + search_model.load_state_dict(ckp["search_model"]) + search_model = search_model.cuda() + api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth") + archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train) -if __name__ == '__main__': - #test_nas_api() - #for i in range(200): plot('{:04d}'.format(i)) - #test_auto_grad() - test_one_shot_model(sys.argv[1], sys.argv[2]) +if __name__ == "__main__": + # test_nas_api() + # for i in range(200): plot('{:04d}'.format(i)) + # test_auto_grad() + test_one_shot_model(sys.argv[1], sys.argv[2]) diff --git a/exps/experimental/test-resnest.py b/exps/experimental/test-resnest.py index df1df8a..eeeef81 100644 --- a/exps/experimental/test-resnest.py +++ b/exps/experimental/test-resnest.py @@ -4,24 +4,28 @@ # python exps/experimental/test-resnest.py ##################################################### import sys, time, torch, random, argparse -from PIL import ImageFile +from PIL import ImageFile + ImageFile.LOAD_TRUNCATED_IMAGES = True -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from utils import get_model_infos +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from utils import get_model_infos -torch.hub.list('zhanghang1989/ResNeSt', force_reload=True) +torch.hub.list("zhanghang1989/ResNeSt", force_reload=True) -for model_name, xshape in [('resnest50', (1,3,224,224)), - ('resnest101', (1,3,256,256)), - ('resnest200', (1,3,320,320)), - ('resnest269', (1,3,416,416))]: - # net = torch.hub.load('zhanghang1989/ResNeSt', model_name, pretrained=True) - net = torch.hub.load('zhanghang1989/ResNeSt', model_name, pretrained=False) - print('Model : {:}, input shape : {:}'.format(model_name, xshape)) - flops, param = get_model_infos(net, xshape) - print('flops : {:.3f}M'.format(flops)) - print('params : {:.3f}M'.format(param)) +for model_name, xshape in [ + ("resnest50", (1, 3, 224, 224)), + ("resnest101", (1, 3, 256, 256)), + ("resnest200", (1, 3, 320, 320)), + ("resnest269", (1, 3, 416, 416)), +]: + # net = torch.hub.load('zhanghang1989/ResNeSt', model_name, pretrained=True) + net = torch.hub.load("zhanghang1989/ResNeSt", model_name, pretrained=False) + print("Model : {:}, input shape : {:}".format(model_name, xshape)) + flops, param = get_model_infos(net, xshape) + print("flops : {:.3f}M".format(flops)) + print("params : {:.3f}M".format(param)) diff --git a/exps/experimental/test-ww-bench.py b/exps/experimental/test-ww-bench.py index 6cd3847..58d9ffc 100644 --- a/exps/experimental/test-ww-bench.py +++ b/exps/experimental/test-ww-bench.py @@ -15,10 +15,13 @@ from pathlib import Path from collections import OrderedDict import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from log_utils import time_string from nats_bench import create from models import get_cell_based_tiny_net @@ -38,111 +41,125 @@ def tostr(accdict, norms): return ' '.join(xstr) """ + def evaluate(api, weight_dir, data: str): - print('\nEvaluate dataset={:}'.format(data)) - process = psutil.Process(os.getpid()) - norms, accuracies = [], [] - ok, total = 0, 5000 - for idx in range(total): - arch_index = api.random() - api.reload(weight_dir, arch_index) - # compute the weight watcher results - config = api.get_net_config(arch_index, data) - net = get_cell_based_tiny_net(config) - meta_info = api.query_meta_info_by_index(arch_index, hp='200' if api.search_space_name == 'topology' else '90') - params = meta_info.get_net_param(data, 888 if api.search_space_name == 'topology' else 777) - with torch.no_grad(): - net.load_state_dict(params) - _, summary = weight_watcher.analyze(net, alphas=False) - if 'lognorm' not in summary: + print("\nEvaluate dataset={:}".format(data)) + process = psutil.Process(os.getpid()) + norms, accuracies = [], [] + ok, total = 0, 5000 + for idx in range(total): + arch_index = api.random() + api.reload(weight_dir, arch_index) + # compute the weight watcher results + config = api.get_net_config(arch_index, data) + net = get_cell_based_tiny_net(config) + meta_info = api.query_meta_info_by_index(arch_index, hp="200" if api.search_space_name == "topology" else "90") + params = meta_info.get_net_param(data, 888 if api.search_space_name == "topology" else 777) + with torch.no_grad(): + net.load_state_dict(params) + _, summary = weight_watcher.analyze(net, alphas=False) + if "lognorm" not in summary: + api.clear_params(arch_index, None) + del net + continue + continue + cur_norm = -summary["lognorm"] api.clear_params(arch_index, None) - del net ; continue - continue - cur_norm = -summary['lognorm'] - api.clear_params(arch_index, None) - if math.isnan(cur_norm): - del net, meta_info - continue - else: - ok += 1 - norms.append(cur_norm) - # query the accuracy - info = meta_info.get_metrics(data, 'ori-test', iepoch=None, is_random=888 if api.search_space_name == 'topology' else 777) - accuracies.append(info['accuracy']) - del net, meta_info - # print the information - if idx % 20 == 0: - gc.collect() - print('{:} {:04d}_{:04d}/{:04d} ({:.2f} MB memory)'.format(time_string(), ok, idx, total, process.memory_info().rss / 1e6)) - return norms, accuracies + if math.isnan(cur_norm): + del net, meta_info + continue + else: + ok += 1 + norms.append(cur_norm) + # query the accuracy + info = meta_info.get_metrics( + data, "ori-test", iepoch=None, is_random=888 if api.search_space_name == "topology" else 777 + ) + accuracies.append(info["accuracy"]) + del net, meta_info + # print the information + if idx % 20 == 0: + gc.collect() + print( + "{:} {:04d}_{:04d}/{:04d} ({:.2f} MB memory)".format( + time_string(), ok, idx, total, process.memory_info().rss / 1e6 + ) + ) + return norms, accuracies def main(search_space, meta_file: str, weight_dir, save_dir, xdata): - save_dir.mkdir(parents=True, exist_ok=True) - api = create(meta_file, search_space, verbose=False) - datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'] - print(time_string() + ' ' + '='*50) - for data in datasets: - hps = api.avaliable_hps - for hp in hps: - nums = api.statistics(data, hp=hp) - total = sum([k*v for k, v in nums.items()]) - print('Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).'.format(hp, data, total, nums)) - print(time_string() + ' ' + '='*50) + save_dir.mkdir(parents=True, exist_ok=True) + api = create(meta_file, search_space, verbose=False) + datasets = ["cifar10-valid", "cifar10", "cifar100", "ImageNet16-120"] + print(time_string() + " " + "=" * 50) + for data in datasets: + hps = api.avaliable_hps + for hp in hps: + nums = api.statistics(data, hp=hp) + total = sum([k * v for k, v in nums.items()]) + print("Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format(hp, data, total, nums)) + print(time_string() + " " + "=" * 50) - norms, accuracies = evaluate(api, weight_dir, xdata) + norms, accuracies = evaluate(api, weight_dir, xdata) - indexes = list(range(len(norms))) - norm_indexes = sorted(indexes, key=lambda i: norms[i]) - accy_indexes = sorted(indexes, key=lambda i: accuracies[i]) - labels = [] - for index in norm_indexes: - labels.append(accy_indexes.index(index)) + indexes = list(range(len(norms))) + norm_indexes = sorted(indexes, key=lambda i: norms[i]) + accy_indexes = sorted(indexes, key=lambda i: accuracies[i]) + labels = [] + for index in norm_indexes: + labels.append(accy_indexes.index(index)) - dpi, width, height = 200, 1400, 800 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 18, 12 - resnet_scale, resnet_alpha = 120, 0.5 + dpi, width, height = 200, 1400, 800 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 18, 12 + resnet_scale, resnet_alpha = 120, 0.5 - fig = plt.figure(figsize=figsize) - ax = fig.add_subplot(111) - plt.xlim(min(indexes), max(indexes)) - plt.ylim(min(indexes), max(indexes)) - # plt.ylabel('y').set_rotation(30) - plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical') - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize) - ax.scatter(indexes, labels , marker='*', s=0.5, c='tab:red' , alpha=0.8) - ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue' , alpha=0.8) - ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Test accuracy') - ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='Weight watcher') - plt.grid(zorder=0) - ax.set_axisbelow(True) - plt.legend(loc=0, fontsize=LegendFontsize) - ax.set_xlabel('architecture ranking sorted by the test accuracy ', fontsize=LabelSize) - ax.set_ylabel('architecture ranking computed by weight watcher', fontsize=LabelSize) - save_path = (save_dir / '{:}-{:}-test-ww.pdf'.format(search_space, xdata)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (save_dir / '{:}-{:}-test-ww.png'.format(search_space, xdata)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - - print('{:} finish this test.'.format(time_string())) + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xlim(min(indexes), max(indexes)) + plt.ylim(min(indexes), max(indexes)) + # plt.ylabel('y').set_rotation(30) + plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical") + plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize) + ax.scatter(indexes, labels, marker="*", s=0.5, c="tab:red", alpha=0.8) + ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) + ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Test accuracy") + ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="Weight watcher") + plt.grid(zorder=0) + ax.set_axisbelow(True) + plt.legend(loc=0, fontsize=LegendFontsize) + ax.set_xlabel("architecture ranking sorted by the test accuracy ", fontsize=LabelSize) + ax.set_ylabel("architecture ranking computed by weight watcher", fontsize=LabelSize) + save_path = (save_dir / "{:}-{:}-test-ww.pdf".format(search_space, xdata)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (save_dir / "{:}-{:}-test-ww.png".format(search_space, xdata)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + + print("{:} finish this test.".format(time_string())) -if __name__ == '__main__': - parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") - parser.add_argument('--save_dir', type=str, default='./output/vis-nas-bench/', help='The base-name of folder to save checkpoints and log.') - parser.add_argument('--search_space', type=str, default=None, choices=['tss', 'sss'], help='The search space.') - parser.add_argument('--base_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.') - parser.add_argument('--dataset' , type=str, default=None, help='.') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") + parser.add_argument( + "--save_dir", + type=str, + default="./output/vis-nas-bench/", + help="The base-name of folder to save checkpoints and log.", + ) + parser.add_argument("--search_space", type=str, default=None, choices=["tss", "sss"], help="The search space.") + parser.add_argument( + "--base_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file and weight dir." + ) + parser.add_argument("--dataset", type=str, default=None, help=".") + args = parser.parse_args() - save_dir = Path(args.save_dir) - save_dir.mkdir(parents=True, exist_ok=True) - meta_file = Path(args.base_path + '.pth') - weight_dir = Path(args.base_path + '-full') - assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) - assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir) - - main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset) + save_dir = Path(args.save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + meta_file = Path(args.base_path + ".pth") + weight_dir = Path(args.base_path + "-full") + assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) + assert weight_dir.exists() and weight_dir.is_dir(), "invalid path for weight dir : {:}".format(weight_dir) + main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset) diff --git a/exps/experimental/test-ww.py b/exps/experimental/test-ww.py index dec4e34..626a273 100644 --- a/exps/experimental/test-ww.py +++ b/exps/experimental/test-ww.py @@ -2,31 +2,33 @@ import sys, time, random, argparse from copy import deepcopy import torchvision.models as models from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from utils import weight_watcher def main(): - # model = models.vgg19_bn(pretrained=True) - # _, summary = weight_watcher.analyze(model, alphas=False) - # for key, value in summary.items(): - # print('{:10s} : {:}'.format(key, value)) + # model = models.vgg19_bn(pretrained=True) + # _, summary = weight_watcher.analyze(model, alphas=False) + # for key, value in summary.items(): + # print('{:10s} : {:}'.format(key, value)) - _, summary = weight_watcher.analyze(models.vgg13(pretrained=True), alphas=False) - print('vgg-13 : {:}'.format(summary['lognorm'])) - _, summary = weight_watcher.analyze(models.vgg13_bn(pretrained=True), alphas=False) - print('vgg-13-BN : {:}'.format(summary['lognorm'])) - _, summary = weight_watcher.analyze(models.vgg16(pretrained=True), alphas=False) - print('vgg-16 : {:}'.format(summary['lognorm'])) - _, summary = weight_watcher.analyze(models.vgg16_bn(pretrained=True), alphas=False) - print('vgg-16-BN : {:}'.format(summary['lognorm'])) - _, summary = weight_watcher.analyze(models.vgg19(pretrained=True), alphas=False) - print('vgg-19 : {:}'.format(summary['lognorm'])) - _, summary = weight_watcher.analyze(models.vgg19_bn(pretrained=True), alphas=False) - print('vgg-19-BN : {:}'.format(summary['lognorm'])) + _, summary = weight_watcher.analyze(models.vgg13(pretrained=True), alphas=False) + print("vgg-13 : {:}".format(summary["lognorm"])) + _, summary = weight_watcher.analyze(models.vgg13_bn(pretrained=True), alphas=False) + print("vgg-13-BN : {:}".format(summary["lognorm"])) + _, summary = weight_watcher.analyze(models.vgg16(pretrained=True), alphas=False) + print("vgg-16 : {:}".format(summary["lognorm"])) + _, summary = weight_watcher.analyze(models.vgg16_bn(pretrained=True), alphas=False) + print("vgg-16-BN : {:}".format(summary["lognorm"])) + _, summary = weight_watcher.analyze(models.vgg19(pretrained=True), alphas=False) + print("vgg-19 : {:}".format(summary["lognorm"])) + _, summary = weight_watcher.analyze(models.vgg19_bn(pretrained=True), alphas=False) + print("vgg-19-BN : {:}".format(summary["lognorm"])) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/exps/experimental/vis-nats-bench-algos.py b/exps/experimental/vis-nats-bench-algos.py index 9fb2e37..8e53f98 100644 --- a/exps/experimental/vis-nats-bench-algos.py +++ b/exps/experimental/vis-nats-bench-algos.py @@ -11,122 +11,133 @@ import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict, OrderedDict -from copy import deepcopy +from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from nats_bench import create from log_utils import time_string -def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): - ss_dir = '{:}-{:}'.format(root_dir, search_space) - alg2name, alg2path = OrderedDict(), OrderedDict() - alg2name['REA'] = 'R-EA-SS3' - alg2name['REINFORCE'] = 'REINFORCE-0.01' - alg2name['RANDOM'] = 'RANDOM' - alg2name['BOHB'] = 'BOHB' - for alg, name in alg2name.items(): - alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth') - assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg]) - alg2data = OrderedDict() - for alg, path in alg2path.items(): - data = torch.load(path) - for index, info in data.items(): - info['time_w_arch'] = [(x, y) for x, y in zip(info['all_total_times'], info['all_archs'])] - for j, arch in enumerate(info['all_archs']): - assert arch != -1, 'invalid arch from {:} {:} {:} ({:}, {:})'.format(alg, search_space, dataset, index, j) - alg2data[alg] = data - return alg2data +def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): + ss_dir = "{:}-{:}".format(root_dir, search_space) + alg2name, alg2path = OrderedDict(), OrderedDict() + alg2name["REA"] = "R-EA-SS3" + alg2name["REINFORCE"] = "REINFORCE-0.01" + alg2name["RANDOM"] = "RANDOM" + alg2name["BOHB"] = "BOHB" + for alg, name in alg2name.items(): + alg2path[alg] = os.path.join(ss_dir, dataset, name, "results.pth") + assert os.path.isfile(alg2path[alg]), "invalid path : {:}".format(alg2path[alg]) + alg2data = OrderedDict() + for alg, path in alg2path.items(): + data = torch.load(path) + for index, info in data.items(): + info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])] + for j, arch in enumerate(info["all_archs"]): + assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( + alg, search_space, dataset, index, j + ) + alg2data[alg] = data + return alg2data def query_performance(api, data, dataset, ticket): - results, is_size_space = [], api.search_space_name == 'size' - for i, info in data.items(): - time_w_arch = sorted(info['time_w_arch'], key=lambda x: abs(x[0]-ticket)) - time_a, arch_a = time_w_arch[0] - time_b, arch_b = time_w_arch[1] - info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - accuracy_a, accuracy_b = info_a['test-accuracy'], info_b['test-accuracy'] - interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b - results.append(interplate) - return sum(results) / len(results) + results, is_size_space = [], api.search_space_name == "size" + for i, info in data.items(): + time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket)) + time_a, arch_a = time_w_arch[0] + time_b, arch_b = time_w_arch[1] + info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"] + interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (ticket - time_a) / ( + time_b - time_a + ) * accuracy_b + results.append(interplate) + return sum(results) / len(results) -y_min_s = {('cifar10', 'tss'): 90, - ('cifar10', 'sss'): 92, - ('cifar100', 'tss'): 65, - ('cifar100', 'sss'): 65, - ('ImageNet16-120', 'tss'): 36, - ('ImageNet16-120', 'sss'): 40} +y_min_s = { + ("cifar10", "tss"): 90, + ("cifar10", "sss"): 92, + ("cifar100", "tss"): 65, + ("cifar100", "sss"): 65, + ("ImageNet16-120", "tss"): 36, + ("ImageNet16-120", "sss"): 40, +} -y_max_s = {('cifar10', 'tss'): 94.5, - ('cifar10', 'sss'): 93.3, - ('cifar100', 'tss'): 72, - ('cifar100', 'sss'): 70, - ('ImageNet16-120', 'tss'): 44, - ('ImageNet16-120', 'sss'): 46} +y_max_s = { + ("cifar10", "tss"): 94.5, + ("cifar10", "sss"): 93.3, + ("cifar100", "tss"): 72, + ("cifar100", "sss"): 70, + ("ImageNet16-120", "tss"): 44, + ("ImageNet16-120", "sss"): 46, +} + +name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} -name2label = {'cifar10': 'CIFAR-10', - 'cifar100': 'CIFAR-100', - 'ImageNet16-120': 'ImageNet-16-120'} def visualize_curve(api, vis_save_dir, search_space, max_time): - vis_save_dir = vis_save_dir.resolve() - vis_save_dir.mkdir(parents=True, exist_ok=True) + vis_save_dir = vis_save_dir.resolve() + vis_save_dir.mkdir(parents=True, exist_ok=True) - dpi, width, height = 250, 5200, 1400 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 16, 16 + dpi, width, height = 250, 5200, 1400 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 16, 16 - def sub_plot_fn(ax, dataset): - alg2data = fetch_data(search_space=search_space, dataset=dataset) - alg2accuracies = OrderedDict() - total_tickets = 150 - time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)] - colors = ['b', 'g', 'c', 'm', 'y'] - ax.set_xlim(0, 200) - ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) - for idx, (alg, data) in enumerate(alg2data.items()): - print('plot alg : {:}'.format(alg)) - accuracies = [] - for ticket in time_tickets: - accuracy = query_performance(api, data, dataset, ticket) - accuracies.append(accuracy) - alg2accuracies[alg] = accuracies - ax.plot([x/100 for x in time_tickets], accuracies, c=colors[idx], label='{:}'.format(alg)) - ax.set_xlabel('Estimated wall-clock time (1e2 seconds)', fontsize=LabelSize) - ax.set_ylabel('Test accuracy on {:}'.format(name2label[dataset]), fontsize=LabelSize) - ax.set_title('Searching results on {:}'.format(name2label[dataset]), fontsize=LabelSize+4) - ax.legend(loc=4, fontsize=LegendFontsize) + def sub_plot_fn(ax, dataset): + alg2data = fetch_data(search_space=search_space, dataset=dataset) + alg2accuracies = OrderedDict() + total_tickets = 150 + time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)] + colors = ["b", "g", "c", "m", "y"] + ax.set_xlim(0, 200) + ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) + for idx, (alg, data) in enumerate(alg2data.items()): + print("plot alg : {:}".format(alg)) + accuracies = [] + for ticket in time_tickets: + accuracy = query_performance(api, data, dataset, ticket) + accuracies.append(accuracy) + alg2accuracies[alg] = accuracies + ax.plot([x / 100 for x in time_tickets], accuracies, c=colors[idx], label="{:}".format(alg)) + ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize) + ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize) + ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4) + ax.legend(loc=4, fontsize=LegendFontsize) - fig, axs = plt.subplots(1, 3, figsize=figsize) - datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] - for dataset, ax in zip(datasets, axs): - sub_plot_fn(ax, dataset) - print('sub-plot {:} on {:} done.'.format(dataset, search_space)) - save_path = (vis_save_dir / '{:}-curve.png'.format(search_space)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + fig, axs = plt.subplots(1, 3, figsize=figsize) + datasets = ["cifar10", "cifar100", "ImageNet16-120"] + for dataset, ax in zip(datasets, axs): + sub_plot_fn(ax, dataset) + print("sub-plot {:} on {:} done.".format(dataset, search_space)) + save_path = (vis_save_dir / "{:}-curve.png".format(search_space)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') - parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') - parser.add_argument('--max_time', type=float, default=20000, help='The maximum time budget.') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log." + ) + parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") + parser.add_argument("--max_time", type=float, default=20000, help="The maximum time budget.") + args = parser.parse_args() - save_dir = Path(args.save_dir) + save_dir = Path(args.save_dir) - api = create(None, args.search_space, verbose=False) - visualize_curve(api, save_dir, args.search_space, args.max_time) + api = create(None, args.search_space, verbose=False) + visualize_curve(api, save_dir, args.search_space, args.max_time) diff --git a/exps/experimental/vis-nats-bench-ws.py b/exps/experimental/vis-nats-bench-ws.py index de4a22a..c3715b3 100644 --- a/exps/experimental/vis-nats-bench-ws.py +++ b/exps/experimental/vis-nats-bench-ws.py @@ -11,132 +11,143 @@ import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict, OrderedDict -from copy import deepcopy +from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from nats_bench import create from log_utils import time_string # def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'): -def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARM0.3'): - ss_dir = '{:}-{:}'.format(root_dir, search_space) - alg2name, alg2path = OrderedDict(), OrderedDict() - seeds = [777, 888, 999] - print('\n[fetch data] from {:} on {:}'.format(search_space, dataset)) - if search_space == 'tss': - alg2name['GDAS'] = 'gdas-affine0_BN0-None' - alg2name['RSPS'] = 'random-affine0_BN0-None' - alg2name['DARTS (1st)'] = 'darts-v1-affine0_BN0-None' - alg2name['DARTS (2nd)'] = 'darts-v2-affine0_BN0-None' - alg2name['ENAS'] = 'enas-affine0_BN0-None' - alg2name['SETN'] = 'setn-affine0_BN0-None' - else: - # alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) - # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) - # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) - alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix) - alg2name['masking + Gumbel-Softmax'] = 'mask_gumbel-affine0_BN0-AWD0.001{:}'.format(suffix) - alg2name['masking + sampling'] = 'mask_rl-affine0_BN0-AWD0.0{:}'.format(suffix) - for alg, name in alg2name.items(): - alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') - alg2data = OrderedDict() - for alg, path in alg2path.items(): - alg2data[alg], ok_num = [], 0 - for seed in seeds: - xpath = path.format(seed) - if os.path.isfile(xpath): - ok_num += 1 - else: - print('This is an invalid path : {:}'.format(xpath)) - continue - data = torch.load(xpath, map_location=torch.device('cpu')) - data = torch.load(data['last_checkpoint'], map_location=torch.device('cpu')) - alg2data[alg].append(data['genotypes']) - print('This algorithm : {:} has {:} valid ckps.'.format(alg, ok_num)) - assert ok_num > 0, 'Must have at least 1 valid ckps.' - return alg2data +def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"): + ss_dir = "{:}-{:}".format(root_dir, search_space) + alg2name, alg2path = OrderedDict(), OrderedDict() + seeds = [777, 888, 999] + print("\n[fetch data] from {:} on {:}".format(search_space, dataset)) + if search_space == "tss": + alg2name["GDAS"] = "gdas-affine0_BN0-None" + alg2name["RSPS"] = "random-affine0_BN0-None" + alg2name["DARTS (1st)"] = "darts-v1-affine0_BN0-None" + alg2name["DARTS (2nd)"] = "darts-v2-affine0_BN0-None" + alg2name["ENAS"] = "enas-affine0_BN0-None" + alg2name["SETN"] = "setn-affine0_BN0-None" + else: + # alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) + # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) + # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) + alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(suffix) + alg2name["masking + Gumbel-Softmax"] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix) + alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix) + for alg, name in alg2name.items(): + alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth") + alg2data = OrderedDict() + for alg, path in alg2path.items(): + alg2data[alg], ok_num = [], 0 + for seed in seeds: + xpath = path.format(seed) + if os.path.isfile(xpath): + ok_num += 1 + else: + print("This is an invalid path : {:}".format(xpath)) + continue + data = torch.load(xpath, map_location=torch.device("cpu")) + data = torch.load(data["last_checkpoint"], map_location=torch.device("cpu")) + alg2data[alg].append(data["genotypes"]) + print("This algorithm : {:} has {:} valid ckps.".format(alg, ok_num)) + assert ok_num > 0, "Must have at least 1 valid ckps." + return alg2data -y_min_s = {('cifar10', 'tss'): 90, - ('cifar10', 'sss'): 92, - ('cifar100', 'tss'): 65, - ('cifar100', 'sss'): 65, - ('ImageNet16-120', 'tss'): 36, - ('ImageNet16-120', 'sss'): 40} +y_min_s = { + ("cifar10", "tss"): 90, + ("cifar10", "sss"): 92, + ("cifar100", "tss"): 65, + ("cifar100", "sss"): 65, + ("ImageNet16-120", "tss"): 36, + ("ImageNet16-120", "sss"): 40, +} -y_max_s = {('cifar10', 'tss'): 94.5, - ('cifar10', 'sss'): 93.3, - ('cifar100', 'tss'): 72, - ('cifar100', 'sss'): 70, - ('ImageNet16-120', 'tss'): 44, - ('ImageNet16-120', 'sss'): 46} +y_max_s = { + ("cifar10", "tss"): 94.5, + ("cifar10", "sss"): 93.3, + ("cifar100", "tss"): 72, + ("cifar100", "sss"): 70, + ("ImageNet16-120", "tss"): 44, + ("ImageNet16-120", "sss"): 46, +} + +name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} -name2label = {'cifar10': 'CIFAR-10', - 'cifar100': 'CIFAR-100', - 'ImageNet16-120': 'ImageNet-16-120'} def visualize_curve(api, vis_save_dir, search_space): - vis_save_dir = vis_save_dir.resolve() - vis_save_dir.mkdir(parents=True, exist_ok=True) + vis_save_dir = vis_save_dir.resolve() + vis_save_dir.mkdir(parents=True, exist_ok=True) - dpi, width, height = 250, 5200, 1400 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 16, 16 + dpi, width, height = 250, 5200, 1400 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 16, 16 - def sub_plot_fn(ax, dataset): - alg2data = fetch_data(search_space=search_space, dataset=dataset) - alg2accuracies = OrderedDict() - epochs = 100 - colors = ['b', 'g', 'c', 'm', 'y', 'r'] - ax.set_xlim(0, epochs) - # ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) - for idx, (alg, data) in enumerate(alg2data.items()): - print('plot alg : {:}'.format(alg)) - xs, accuracies = [], [] - for iepoch in range(epochs + 1): - try: - structures, accs = [_[iepoch-1] for _ in data], [] - except: - raise ValueError('This alg {:} on {:} has invalid checkpoints.'.format(alg, dataset)) - for structure in structures: - info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False) - accs.append(info['test-accuracy']) - accuracies.append(sum(accs)/len(accs)) - xs.append(iepoch) - alg2accuracies[alg] = accuracies - ax.plot(xs, accuracies, c=colors[idx], label='{:}'.format(alg)) - ax.set_xlabel('The searching epoch', fontsize=LabelSize) - ax.set_ylabel('Test accuracy on {:}'.format(name2label[dataset]), fontsize=LabelSize) - ax.set_title('Searching results on {:}'.format(name2label[dataset]), fontsize=LabelSize+4) - ax.legend(loc=4, fontsize=LegendFontsize) + def sub_plot_fn(ax, dataset): + alg2data = fetch_data(search_space=search_space, dataset=dataset) + alg2accuracies = OrderedDict() + epochs = 100 + colors = ["b", "g", "c", "m", "y", "r"] + ax.set_xlim(0, epochs) + # ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) + for idx, (alg, data) in enumerate(alg2data.items()): + print("plot alg : {:}".format(alg)) + xs, accuracies = [], [] + for iepoch in range(epochs + 1): + try: + structures, accs = [_[iepoch - 1] for _ in data], [] + except: + raise ValueError("This alg {:} on {:} has invalid checkpoints.".format(alg, dataset)) + for structure in structures: + info = api.get_more_info( + structure, dataset=dataset, hp=90 if api.search_space_name == "size" else 200, is_random=False + ) + accs.append(info["test-accuracy"]) + accuracies.append(sum(accs) / len(accs)) + xs.append(iepoch) + alg2accuracies[alg] = accuracies + ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg)) + ax.set_xlabel("The searching epoch", fontsize=LabelSize) + ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize) + ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4) + ax.legend(loc=4, fontsize=LegendFontsize) - fig, axs = plt.subplots(1, 3, figsize=figsize) - datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] - for dataset, ax in zip(datasets, axs): - sub_plot_fn(ax, dataset) - print('sub-plot {:} on {:} done.'.format(dataset, search_space)) - save_path = (vis_save_dir / '{:}-ws-curve.png'.format(search_space)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + fig, axs = plt.subplots(1, 3, figsize=figsize) + datasets = ["cifar10", "cifar100", "ImageNet16-120"] + for dataset, ax in zip(datasets, axs): + sub_plot_fn(ax, dataset) + print("sub-plot {:} on {:} done.".format(dataset, search_space)) + save_path = (vis_save_dir / "{:}-ws-curve.png".format(search_space)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') - parser.add_argument('--search_space', type=str, default='tss', choices=['tss', 'sss'], help='Choose the search space.') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log." + ) + parser.add_argument( + "--search_space", type=str, default="tss", choices=["tss", "sss"], help="Choose the search space." + ) + args = parser.parse_args() - save_dir = Path(args.save_dir) + save_dir = Path(args.save_dir) - api = create(None, args.search_space, fast_mode=True, verbose=False) - visualize_curve(api, save_dir, args.search_space) + api = create(None, args.search_space, fast_mode=True, verbose=False) + visualize_curve(api, save_dir, args.search_space) diff --git a/exps/experimental/visualize-nas-bench-x.py b/exps/experimental/visualize-nas-bench-x.py index 70412e6..a8e4786 100644 --- a/exps/experimental/visualize-nas-bench-x.py +++ b/exps/experimental/visualize-nas-bench-x.py @@ -10,16 +10,18 @@ import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict -from copy import deepcopy +from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from log_utils import time_string from models import get_cell_based_tiny_net @@ -27,382 +29,577 @@ from nats_bench import create def visualize_info(api, vis_save_dir, indicator): - vis_save_dir = vis_save_dir.resolve() - # print ('{:} start to visualize {:} information'.format(time_string(), api)) - vis_save_dir.mkdir(parents=True, exist_ok=True) + vis_save_dir = vis_save_dir.resolve() + # print ('{:} start to visualize {:} information'.format(time_string(), api)) + vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) - cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) - imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) - cifar010_info = torch.load(cifar010_cache_path) - cifar100_info = torch.load(cifar100_cache_path) - imagenet_info = torch.load(imagenet_cache_path) - indexes = list(range(len(cifar010_info['params']))) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_info = torch.load(cifar010_cache_path) + cifar100_info = torch.load(cifar100_cache_path) + imagenet_info = torch.load(imagenet_cache_path) + indexes = list(range(len(cifar010_info["params"]))) - print ('{:} start to visualize relative ranking'.format(time_string())) + print("{:} start to visualize relative ranking".format(time_string())) - cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i]) - cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i]) - imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i]) + cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i]) + cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i]) + imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i]) - cifar100_labels, imagenet_labels = [], [] - for idx in cifar010_ord_indexes: - cifar100_labels.append( cifar100_ord_indexes.index(idx) ) - imagenet_labels.append( imagenet_ord_indexes.index(idx) ) - print ('{:} prepare data done.'.format(time_string())) + cifar100_labels, imagenet_labels = [], [] + for idx in cifar010_ord_indexes: + cifar100_labels.append(cifar100_ord_indexes.index(idx)) + imagenet_labels.append(imagenet_ord_indexes.index(idx)) + print("{:} prepare data done.".format(time_string())) - dpi, width, height = 200, 1400, 800 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 18, 12 - resnet_scale, resnet_alpha = 120, 0.5 + dpi, width, height = 200, 1400, 800 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 18, 12 + resnet_scale, resnet_alpha = 120, 0.5 - fig = plt.figure(figsize=figsize) - ax = fig.add_subplot(111) - plt.xlim(min(indexes), max(indexes)) - plt.ylim(min(indexes), max(indexes)) - # plt.ylabel('y').set_rotation(30) - plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical') - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize) - ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) - ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8) - ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8) - ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10') - ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100') - ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120') - plt.grid(zorder=0) - ax.set_axisbelow(True) - plt.legend(loc=0, fontsize=LegendFontsize) - ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize) - ax.set_ylabel('architecture ranking', fontsize=LabelSize) - save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + plt.xlim(min(indexes), max(indexes)) + plt.ylim(min(indexes), max(indexes)) + # plt.ylabel('y').set_rotation(30) + plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical") + plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize) + ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) + ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8) + ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) + ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10") + ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100") + ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120") + plt.grid(zorder=0) + ax.set_axisbelow(True) + plt.legend(loc=0, fontsize=LegendFontsize) + ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize) + ax.set_ylabel("architecture ranking", fontsize=LabelSize) + save_path = (vis_save_dir / "{:}-relative-rank.pdf".format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (vis_save_dir / "{:}-relative-rank.png".format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) def visualize_sss_info(api, dataset, vis_save_dir): - vis_save_dir = vis_save_dir.resolve() - print ('{:} start to visualize {:} information'.format(time_string(), dataset)) - vis_save_dir.mkdir(parents=True, exist_ok=True) - cache_file_path = vis_save_dir / '{:}-cache-sss-info.pth'.format(dataset) - if not cache_file_path.exists(): - print ('Do not find cache file : {:}'.format(cache_file_path)) - params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] - for index in range(len(api)): - cost_info = api.get_cost_info(index, dataset, hp='90') - params.append(cost_info['params']) - flops.append(cost_info['flops']) - # accuracy - info = api.get_more_info(index, dataset, hp='90', is_random=False) - train_accs.append(info['train-accuracy']) - test_accs.append(info['test-accuracy']) - if dataset == 'cifar10': - info = api.get_more_info(index, 'cifar10-valid', hp='90', is_random=False) - valid_accs.append(info['valid-accuracy']) - else: - valid_accs.append(info['valid-accuracy']) - info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs} - torch.save(info, cache_file_path) - else: - print ('Find cache file : {:}'.format(cache_file_path)) - info = torch.load(cache_file_path) - params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'] - print ('{:} collect data done.'.format(time_string())) + vis_save_dir = vis_save_dir.resolve() + print("{:} start to visualize {:} information".format(time_string(), dataset)) + vis_save_dir.mkdir(parents=True, exist_ok=True) + cache_file_path = vis_save_dir / "{:}-cache-sss-info.pth".format(dataset) + if not cache_file_path.exists(): + print("Do not find cache file : {:}".format(cache_file_path)) + params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] + for index in range(len(api)): + cost_info = api.get_cost_info(index, dataset, hp="90") + params.append(cost_info["params"]) + flops.append(cost_info["flops"]) + # accuracy + info = api.get_more_info(index, dataset, hp="90", is_random=False) + train_accs.append(info["train-accuracy"]) + test_accs.append(info["test-accuracy"]) + if dataset == "cifar10": + info = api.get_more_info(index, "cifar10-valid", hp="90", is_random=False) + valid_accs.append(info["valid-accuracy"]) + else: + valid_accs.append(info["valid-accuracy"]) + info = { + "params": params, + "flops": flops, + "train_accs": train_accs, + "valid_accs": valid_accs, + "test_accs": test_accs, + } + torch.save(info, cache_file_path) + else: + print("Find cache file : {:}".format(cache_file_path)) + info = torch.load(cache_file_path) + params, flops, train_accs, valid_accs, test_accs = ( + info["params"], + info["flops"], + info["train_accs"], + info["valid_accs"], + info["test_accs"], + ) + print("{:} collect data done.".format(time_string())) - pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64'] - pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid] - largest_indexes = [api.query_index_by_arch('64:64:64:64:64')] + pyramid = [ + "8:16:32:48:64", + "8:8:16:32:48", + "8:8:16:16:32", + "8:8:16:16:48", + "8:8:16:16:64", + "16:16:32:32:64", + "32:32:64:64:64", + ] + pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid] + largest_indexes = [api.query_index_by_arch("64:64:64:64:64")] - indexes = list(range(len(params))) - dpi, width, height = 250, 8500, 1300 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 24, 24 - # resnet_scale, resnet_alpha = 120, 0.5 - xscale, xalpha = 120, 0.8 + indexes = list(range(len(params))) + dpi, width, height = 250, 8500, 1300 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 24, 24 + # resnet_scale, resnet_alpha = 120, 0.5 + xscale, xalpha = 120, 0.8 - fig, axs = plt.subplots(1, 4, figsize=figsize) - # ax1, ax2, ax3, ax4, ax5 = axs - for ax in axs: - for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f')) - for tick in ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax2, ax3, ax4, ax5 = axs - # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) - # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') - # ax1.set_xlabel('architecture ID', fontsize=LabelSize) - # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) + fig, axs = plt.subplots(1, 4, figsize=figsize) + # ax1, ax2, ax3, ax4, ax5 = axs + for ax in axs: + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax2, ax3, ax4, ax5 = axs + # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) + # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') + # ax1.set_xlabel('architecture ID', fontsize=LabelSize) + # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax2.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') - ax2.scatter([params[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax2.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax2.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax2.legend(loc=4, fontsize=LegendFontsize) + ax2.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") + ax2.scatter( + [params[x] for x in pyramid_indexes], + [train_accs[x] for x in pyramid_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="Pyramid Structure", + alpha=xalpha, + ) + ax2.scatter( + [params[x] for x in largest_indexes], + [train_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax2.set_xlabel("#parameters (MB)", fontsize=LabelSize) + ax2.set_ylabel("train accuracy (%)", fontsize=LabelSize) + ax2.legend(loc=4, fontsize=LegendFontsize) - ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue') - ax3.scatter([params[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax3.legend(loc=4, fontsize=LegendFontsize) + ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") + ax3.scatter( + [params[x] for x in pyramid_indexes], + [test_accs[x] for x in pyramid_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="Pyramid Structure", + alpha=xalpha, + ) + ax3.scatter( + [params[x] for x in largest_indexes], + [test_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) + ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) + ax3.legend(loc=4, fontsize=LegendFontsize) - ax4.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue') - ax4.scatter([flops[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax4.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - ax4.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax4.legend(loc=4, fontsize=LegendFontsize) + ax4.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") + ax4.scatter( + [flops[x] for x in pyramid_indexes], + [train_accs[x] for x in pyramid_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="Pyramid Structure", + alpha=xalpha, + ) + ax4.scatter( + [flops[x] for x in largest_indexes], + [train_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) + ax4.set_ylabel("train accuracy (%)", fontsize=LabelSize) + ax4.legend(loc=4, fontsize=LegendFontsize) - ax5.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue') - ax5.scatter([flops[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha) - ax5.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax5.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - ax5.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax5.legend(loc=4, fontsize=LegendFontsize) + ax5.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") + ax5.scatter( + [flops[x] for x in pyramid_indexes], + [test_accs[x] for x in pyramid_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="Pyramid Structure", + alpha=xalpha, + ) + ax5.scatter( + [flops[x] for x in largest_indexes], + [test_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax5.set_xlabel("#FLOPs (M)", fontsize=LabelSize) + ax5.set_ylabel("test accuracy (%)", fontsize=LabelSize) + ax5.legend(loc=4, fontsize=LegendFontsize) - save_path = vis_save_dir / 'sss-{:}.png'.format(dataset) - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + save_path = vis_save_dir / "sss-{:}.png".format(dataset) + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") def visualize_tss_info(api, dataset, vis_save_dir): - vis_save_dir = vis_save_dir.resolve() - print ('{:} start to visualize {:} information'.format(time_string(), dataset)) - vis_save_dir.mkdir(parents=True, exist_ok=True) - cache_file_path = vis_save_dir / '{:}-cache-tss-info.pth'.format(dataset) - if not cache_file_path.exists(): - print ('Do not find cache file : {:}'.format(cache_file_path)) - params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] - for index in range(len(api)): - cost_info = api.get_cost_info(index, dataset, hp='12') - params.append(cost_info['params']) - flops.append(cost_info['flops']) - # accuracy - info = api.get_more_info(index, dataset, hp='200', is_random=False) - train_accs.append(info['train-accuracy']) - test_accs.append(info['test-accuracy']) - if dataset == 'cifar10': - info = api.get_more_info(index, 'cifar10-valid', hp='200', is_random=False) - valid_accs.append(info['valid-accuracy']) - else: - valid_accs.append(info['valid-accuracy']) - print('') - info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs} - torch.save(info, cache_file_path) - else: - print ('Find cache file : {:}'.format(cache_file_path)) - info = torch.load(cache_file_path) - params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'] - print ('{:} collect data done.'.format(time_string())) + vis_save_dir = vis_save_dir.resolve() + print("{:} start to visualize {:} information".format(time_string(), dataset)) + vis_save_dir.mkdir(parents=True, exist_ok=True) + cache_file_path = vis_save_dir / "{:}-cache-tss-info.pth".format(dataset) + if not cache_file_path.exists(): + print("Do not find cache file : {:}".format(cache_file_path)) + params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] + for index in range(len(api)): + cost_info = api.get_cost_info(index, dataset, hp="12") + params.append(cost_info["params"]) + flops.append(cost_info["flops"]) + # accuracy + info = api.get_more_info(index, dataset, hp="200", is_random=False) + train_accs.append(info["train-accuracy"]) + test_accs.append(info["test-accuracy"]) + if dataset == "cifar10": + info = api.get_more_info(index, "cifar10-valid", hp="200", is_random=False) + valid_accs.append(info["valid-accuracy"]) + else: + valid_accs.append(info["valid-accuracy"]) + print("") + info = { + "params": params, + "flops": flops, + "train_accs": train_accs, + "valid_accs": valid_accs, + "test_accs": test_accs, + } + torch.save(info, cache_file_path) + else: + print("Find cache file : {:}".format(cache_file_path)) + info = torch.load(cache_file_path) + params, flops, train_accs, valid_accs, test_accs = ( + info["params"], + info["flops"], + info["train_accs"], + info["valid_accs"], + info["test_accs"], + ) + print("{:} collect data done.".format(time_string())) - resnet = ['|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|'] - resnet_indexes = [api.query_index_by_arch(x) for x in resnet] - largest_indexes = [api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|')] + resnet = ["|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"] + resnet_indexes = [api.query_index_by_arch(x) for x in resnet] + largest_indexes = [ + api.query_index_by_arch( + "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|" + ) + ] - indexes = list(range(len(params))) - dpi, width, height = 250, 8500, 1300 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 24, 24 - # resnet_scale, resnet_alpha = 120, 0.5 - xscale, xalpha = 120, 0.8 + indexes = list(range(len(params))) + dpi, width, height = 250, 8500, 1300 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 24, 24 + # resnet_scale, resnet_alpha = 120, 0.5 + xscale, xalpha = 120, 0.8 - fig, axs = plt.subplots(1, 4, figsize=figsize) - # ax1, ax2, ax3, ax4, ax5 = axs - for ax in axs: - for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f')) - for tick in ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - ax2, ax3, ax4, ax5 = axs - # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) - # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') - # ax1.set_xlabel('architecture ID', fontsize=LabelSize) - # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) + fig, axs = plt.subplots(1, 4, figsize=figsize) + # ax1, ax2, ax3, ax4, ax5 = axs + for ax in axs: + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + ax2, ax3, ax4, ax5 = axs + # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) + # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') + # ax1.set_xlabel('architecture ID', fontsize=LabelSize) + # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax2.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') - ax2.scatter([params[x] for x in resnet_indexes] , [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax2.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax2.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax2.legend(loc=4, fontsize=LegendFontsize) + ax2.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") + ax2.scatter( + [params[x] for x in resnet_indexes], + [train_accs[x] for x in resnet_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="ResNet", + alpha=xalpha, + ) + ax2.scatter( + [params[x] for x in largest_indexes], + [train_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax2.set_xlabel("#parameters (MB)", fontsize=LabelSize) + ax2.set_ylabel("train accuracy (%)", fontsize=LabelSize) + ax2.legend(loc=4, fontsize=LegendFontsize) - ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue') - ax3.scatter([params[x] for x in resnet_indexes] , [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize) - ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax3.legend(loc=4, fontsize=LegendFontsize) + ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") + ax3.scatter( + [params[x] for x in resnet_indexes], + [test_accs[x] for x in resnet_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="ResNet", + alpha=xalpha, + ) + ax3.scatter( + [params[x] for x in largest_indexes], + [test_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) + ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) + ax3.legend(loc=4, fontsize=LegendFontsize) - ax4.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue') - ax4.scatter([flops[x] for x in resnet_indexes], [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax4.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - ax4.set_ylabel('train accuracy (%)', fontsize=LabelSize) - ax4.legend(loc=4, fontsize=LegendFontsize) + ax4.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") + ax4.scatter( + [flops[x] for x in resnet_indexes], + [train_accs[x] for x in resnet_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="ResNet", + alpha=xalpha, + ) + ax4.scatter( + [flops[x] for x in largest_indexes], + [train_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) + ax4.set_ylabel("train accuracy (%)", fontsize=LabelSize) + ax4.legend(loc=4, fontsize=LegendFontsize) - ax5.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue') - ax5.scatter([flops[x] for x in resnet_indexes], [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha) - ax5.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha) - ax5.set_xlabel('#FLOPs (M)', fontsize=LabelSize) - ax5.set_ylabel('test accuracy (%)', fontsize=LabelSize) - ax5.legend(loc=4, fontsize=LegendFontsize) + ax5.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") + ax5.scatter( + [flops[x] for x in resnet_indexes], + [test_accs[x] for x in resnet_indexes], + marker="*", + s=xscale, + c="tab:orange", + label="ResNet", + alpha=xalpha, + ) + ax5.scatter( + [flops[x] for x in largest_indexes], + [test_accs[x] for x in largest_indexes], + marker="x", + s=xscale, + c="tab:green", + label="Largest Candidate", + alpha=xalpha, + ) + ax5.set_xlabel("#FLOPs (M)", fontsize=LabelSize) + ax5.set_ylabel("test accuracy (%)", fontsize=LabelSize) + ax5.legend(loc=4, fontsize=LegendFontsize) - save_path = vis_save_dir / 'tss-{:}.png'.format(dataset) - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + save_path = vis_save_dir / "tss-{:}.png".format(dataset) + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") def visualize_rank_info(api, vis_save_dir, indicator): - vis_save_dir = vis_save_dir.resolve() - # print ('{:} start to visualize {:} information'.format(time_string(), api)) - vis_save_dir.mkdir(parents=True, exist_ok=True) + vis_save_dir = vis_save_dir.resolve() + # print ('{:} start to visualize {:} information'.format(time_string(), api)) + vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) - cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) - imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) - cifar010_info = torch.load(cifar010_cache_path) - cifar100_info = torch.load(cifar100_cache_path) - imagenet_info = torch.load(imagenet_cache_path) - indexes = list(range(len(cifar010_info['params']))) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_info = torch.load(cifar010_cache_path) + cifar100_info = torch.load(cifar100_cache_path) + imagenet_info = torch.load(imagenet_cache_path) + indexes = list(range(len(cifar010_info["params"]))) - print ('{:} start to visualize relative ranking'.format(time_string())) + print("{:} start to visualize relative ranking".format(time_string())) - dpi, width, height = 250, 3800, 1200 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 14, 14 + dpi, width, height = 250, 3800, 1200 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 14, 14 - fig, axs = plt.subplots(1, 3, figsize=figsize) - ax1, ax2, ax3 = axs + fig, axs = plt.subplots(1, 3, figsize=figsize) + ax1, ax2, ax3 = axs - def get_labels(info): - ord_test_indexes = sorted(indexes, key=lambda i: info['test_accs'][i]) - ord_valid_indexes = sorted(indexes, key=lambda i: info['valid_accs'][i]) - labels = [] - for idx in ord_test_indexes: - labels.append(ord_valid_indexes.index(idx)) - return labels + def get_labels(info): + ord_test_indexes = sorted(indexes, key=lambda i: info["test_accs"][i]) + ord_valid_indexes = sorted(indexes, key=lambda i: info["valid_accs"][i]) + labels = [] + for idx in ord_test_indexes: + labels.append(ord_valid_indexes.index(idx)) + return labels - def plot_ax(labels, ax, name): - for tick in ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - for tick in ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize) - tick.label.set_rotation(90) - ax.set_xlim(min(indexes), max(indexes)) - ax.set_ylim(min(indexes), max(indexes)) - ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes)//3)) - ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes)//5)) - ax.scatter(indexes, labels , marker='^', s=0.5, c='tab:green', alpha=0.8) - ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue' , alpha=0.8) - ax.scatter([-1], [-1], marker='^', s=100, c='tab:green' , label='{:} test'.format(name)) - ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='{:} validation'.format(name)) - ax.legend(loc=4, fontsize=LegendFontsize) - ax.set_xlabel('ranking on the {:} validation'.format(name), fontsize=LabelSize) - ax.set_ylabel('architecture ranking', fontsize=LabelSize) - labels = get_labels(cifar010_info) - plot_ax(labels, ax1, 'CIFAR-10') - labels = get_labels(cifar100_info) - plot_ax(labels, ax2, 'CIFAR-100') - labels = get_labels(imagenet_info) - plot_ax(labels, ax3, 'ImageNet-16-120') + def plot_ax(labels, ax, name): + for tick in ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + for tick in ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize) + tick.label.set_rotation(90) + ax.set_xlim(min(indexes), max(indexes)) + ax.set_ylim(min(indexes), max(indexes)) + ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 3)) + ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) + ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8) + ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) + ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)) + ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name)) + ax.legend(loc=4, fontsize=LegendFontsize) + ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize) + ax.set_ylabel("architecture ranking", fontsize=LabelSize) - save_path = (vis_save_dir / '{:}-same-relative-rank.pdf'.format(indicator)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / '{:}-same-relative-rank.png'.format(indicator)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + labels = get_labels(cifar010_info) + plot_ax(labels, ax1, "CIFAR-10") + labels = get_labels(cifar100_info) + plot_ax(labels, ax2, "CIFAR-100") + labels = get_labels(imagenet_info) + plot_ax(labels, ax3, "ImageNet-16-120") + + save_path = (vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") + save_path = (vis_save_dir / "{:}-same-relative-rank.png".format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") def calculate_correlation(*vectors): - matrix = [] - for i, vectori in enumerate(vectors): - x = [] - for j, vectorj in enumerate(vectors): - x.append( np.corrcoef(vectori, vectorj)[0,1] ) - matrix.append( x ) - return np.array(matrix) + matrix = [] + for i, vectori in enumerate(vectors): + x = [] + for j, vectorj in enumerate(vectors): + x.append(np.corrcoef(vectori, vectorj)[0, 1]) + matrix.append(x) + return np.array(matrix) def visualize_all_rank_info(api, vis_save_dir, indicator): - vis_save_dir = vis_save_dir.resolve() - # print ('{:} start to visualize {:} information'.format(time_string(), api)) - vis_save_dir.mkdir(parents=True, exist_ok=True) + vis_save_dir = vis_save_dir.resolve() + # print ('{:} start to visualize {:} information'.format(time_string(), api)) + vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) - cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) - imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) - cifar010_info = torch.load(cifar010_cache_path) - cifar100_info = torch.load(cifar100_cache_path) - imagenet_info = torch.load(imagenet_cache_path) - indexes = list(range(len(cifar010_info['params']))) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_info = torch.load(cifar010_cache_path) + cifar100_info = torch.load(cifar100_cache_path) + imagenet_info = torch.load(imagenet_cache_path) + indexes = list(range(len(cifar010_info["params"]))) - print ('{:} start to visualize relative ranking'.format(time_string())) - + print("{:} start to visualize relative ranking".format(time_string())) - dpi, width, height = 250, 3200, 1400 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 14, 14 + dpi, width, height = 250, 3200, 1400 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 14, 14 - fig, axs = plt.subplots(1, 2, figsize=figsize) - ax1, ax2 = axs + fig, axs = plt.subplots(1, 2, figsize=figsize) + ax1, ax2 = axs - sns_size = 15 - CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs']) - - sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5, ax=ax1, - xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'], - yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T']) - - selected_indexes, acc_bar = [], 92 - for i, acc in enumerate(cifar010_info['test_accs']): - if acc > acc_bar: selected_indexes.append( i ) - cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ] - cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ] - cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ] - cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ] - imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ] - imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ] - CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs) - - sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5, ax=ax2, - xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'], - yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T']) - ax1.set_title('Correlation coefficient over ALL candidates') - ax2.set_title('Correlation coefficient over candidates with accuracy > {:}%'.format(acc_bar)) - save_path = (vis_save_dir / '{:}-all-relative-rank.png'.format(indicator)).resolve() - fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) - plt.close('all') + sns_size = 15 + CoRelMatrix = calculate_correlation( + cifar010_info["valid_accs"], + cifar010_info["test_accs"], + cifar100_info["valid_accs"], + cifar100_info["test_accs"], + imagenet_info["valid_accs"], + imagenet_info["test_accs"], + ) + + sns.heatmap( + CoRelMatrix, + annot=True, + annot_kws={"size": sns_size}, + fmt=".3f", + linewidths=0.5, + ax=ax1, + xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], + yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], + ) + + selected_indexes, acc_bar = [], 92 + for i, acc in enumerate(cifar010_info["test_accs"]): + if acc > acc_bar: + selected_indexes.append(i) + cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes] + cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes] + cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes] + cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes] + imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes] + imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes] + CoRelMatrix = calculate_correlation( + cifar010_valid_accs, + cifar010_test_accs, + cifar100_valid_accs, + cifar100_test_accs, + imagenet_valid_accs, + imagenet_test_accs, + ) + + sns.heatmap( + CoRelMatrix, + annot=True, + annot_kws={"size": sns_size}, + fmt=".3f", + linewidths=0.5, + ax=ax2, + xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], + yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], + ) + ax1.set_title("Correlation coefficient over ALL candidates") + ax2.set_title("Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)) + save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") + print("{:} save into {:}".format(time_string(), save_path)) + plt.close("all") -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench', help='Folder to save checkpoints and log.') - # use for train the model - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log." + ) + # use for train the model + args = parser.parse_args() - to_save_dir = Path(args.save_dir) + to_save_dir = Path(args.save_dir) - datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] - api201 = create(None, 'tss', verbose=True) - for xdata in datasets: - visualize_tss_info(api201, xdata, to_save_dir) + datasets = ["cifar10", "cifar100", "ImageNet16-120"] + api201 = create(None, "tss", verbose=True) + for xdata in datasets: + visualize_tss_info(api201, xdata, to_save_dir) - api_sss = create(None, 'size', verbose=True) - for xdata in datasets: - visualize_sss_info(api_sss, xdata, to_save_dir) + api_sss = create(None, "size", verbose=True) + for xdata in datasets: + visualize_sss_info(api_sss, xdata, to_save_dir) - visualize_info(None, to_save_dir, 'tss') - visualize_info(None, to_save_dir, 'sss') - visualize_rank_info(None, to_save_dir, 'tss') - visualize_rank_info(None, to_save_dir, 'sss') + visualize_info(None, to_save_dir, "tss") + visualize_info(None, to_save_dir, "sss") + visualize_rank_info(None, to_save_dir, "tss") + visualize_rank_info(None, to_save_dir, "sss") - visualize_all_rank_info(None, to_save_dir, 'tss') - visualize_all_rank_info(None, to_save_dir, 'sss') + visualize_all_rank_info(None, to_save_dir, "tss") + visualize_all_rank_info(None, to_save_dir, "sss") diff --git a/exps/prepare.py b/exps/prepare.py index 9179d89..8e66f85 100644 --- a/exps/prepare.py +++ b/exps/prepare.py @@ -4,74 +4,78 @@ import sys, time, torch, random, argparse from collections import defaultdict import os.path as osp -from PIL import ImageFile +from PIL import ImageFile + ImageFile.LOAD_TRUNCATED_IMAGES = True -from copy import deepcopy +from copy import deepcopy from pathlib import Path import torchvision import torchvision.datasets as dset -lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -parser = argparse.ArgumentParser(description='Prepare splits for searching', formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument('--name' , type=str, help='The dataset name.') -parser.add_argument('--root' , type=str, help='The directory to the dataset.') -parser.add_argument('--save' , type=str, help='The save path.') -parser.add_argument('--ratio', type=float, help='The save path.') +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +parser = argparse.ArgumentParser( + description="Prepare splits for searching", formatter_class=argparse.ArgumentDefaultsHelpFormatter +) +parser.add_argument("--name", type=str, help="The dataset name.") +parser.add_argument("--root", type=str, help="The directory to the dataset.") +parser.add_argument("--save", type=str, help="The save path.") +parser.add_argument("--ratio", type=float, help="The save path.") args = parser.parse_args() + def main(): - save_path = Path(args.save) - save_dir = save_path.parent - name = args.name - save_dir.mkdir(parents=True, exist_ok=True) - assert not save_path.exists(), '{:} already exists'.format(save_path) - print ('torchvision version : {:}'.format(torchvision.__version__)) + save_path = Path(args.save) + save_dir = save_path.parent + name = args.name + save_dir.mkdir(parents=True, exist_ok=True) + assert not save_path.exists(), "{:} already exists".format(save_path) + print("torchvision version : {:}".format(torchvision.__version__)) - if name == 'cifar10': - dataset = dset.CIFAR10 (args.root, train=True) - elif name == 'cifar100': - dataset = dset.CIFAR100(args.root, train=True) - elif name == 'imagenet-1k': - dataset = dset.ImageFolder(osp.join(args.root, 'train')) - else: raise TypeError("Unknow dataset : {:}".format(name)) + if name == "cifar10": + dataset = dset.CIFAR10(args.root, train=True) + elif name == "cifar100": + dataset = dset.CIFAR100(args.root, train=True) + elif name == "imagenet-1k": + dataset = dset.ImageFolder(osp.join(args.root, "train")) + else: + raise TypeError("Unknow dataset : {:}".format(name)) - if hasattr(dataset, 'targets'): - targets = dataset.targets - elif hasattr(dataset, 'train_labels'): - targets = dataset.train_labels - elif hasattr(dataset, 'imgs'): - targets = [x[1] for x in dataset.imgs] - else: - raise ValueError('invalid pattern') - print ('There are {:} samples in this dataset.'.format( len(targets) )) + if hasattr(dataset, "targets"): + targets = dataset.targets + elif hasattr(dataset, "train_labels"): + targets = dataset.train_labels + elif hasattr(dataset, "imgs"): + targets = [x[1] for x in dataset.imgs] + else: + raise ValueError("invalid pattern") + print("There are {:} samples in this dataset.".format(len(targets))) - class2index = defaultdict(list) - train, valid = [], [] - random.seed(111) - for index, cls in enumerate(targets): - class2index[cls].append( index ) - classes = sorted( list(class2index.keys()) ) - for cls in classes: - xlist = class2index[cls] - xtrain = random.sample(xlist, int(len(xlist)*args.ratio)) - xvalid = list(set(xlist) - set(xtrain)) - train += xtrain - valid += xvalid - train.sort() - valid.sort() - ## for statistics - class2numT, class2numV = defaultdict(int), defaultdict(int) - for index in train: - class2numT[ targets[index] ] += 1 - for index in valid: - class2numV[ targets[index] ] += 1 - class2numT, class2numV = dict(class2numT), dict(class2numV) - torch.save({'train': train, - 'valid': valid, - 'class2numTrain': class2numT, - 'class2numValid': class2numV}, save_path) - print ('-'*80) + class2index = defaultdict(list) + train, valid = [], [] + random.seed(111) + for index, cls in enumerate(targets): + class2index[cls].append(index) + classes = sorted(list(class2index.keys())) + for cls in classes: + xlist = class2index[cls] + xtrain = random.sample(xlist, int(len(xlist) * args.ratio)) + xvalid = list(set(xlist) - set(xtrain)) + train += xtrain + valid += xvalid + train.sort() + valid.sort() + ## for statistics + class2numT, class2numV = defaultdict(int), defaultdict(int) + for index in train: + class2numT[targets[index]] += 1 + for index in valid: + class2numV[targets[index]] += 1 + class2numT, class2numV = dict(class2numT), dict(class2numV) + torch.save({"train": train, "valid": valid, "class2numTrain": class2numT, "class2numValid": class2numV}, save_path) + print("-" * 80) -if __name__ == '__main__': - main() + +if __name__ == "__main__": + main() diff --git a/exps/search-shape.py b/exps/search-shape.py index 8eb5e4c..ed52ba7 100644 --- a/exps/search-shape.py +++ b/exps/search-shape.py @@ -2,201 +2,293 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # ##################################################### import sys, time, torch, random, argparse -from PIL import ImageFile -from os import path as osp +from PIL import ImageFile +from os import path as osp + ImageFile.LOAD_TRUNCATED_IMAGES = True import numpy as np -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() -print ('lib_dir : {:}'.format(lib_dir)) -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +print("lib_dir : {:}".format(lib_dir)) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, configure2str, obtain_search_single_args as obtain_args -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from procedures import get_optim_scheduler, get_procedures -from datasets import get_datasets, SearchDataset -from models import obtain_search_model, obtain_model, change_key -from utils import get_model_infos -from log_utils import AverageMeter, time_string, convert_secs2time +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint +from procedures import get_optim_scheduler, get_procedures +from datasets import get_datasets, SearchDataset +from models import obtain_search_model, obtain_model, change_key +from utils import get_model_infos +from log_utils import AverageMeter, time_string, convert_secs2time def main(args): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = True - #torch.backends.cudnn.deterministic = True - torch.set_num_threads( args.workers ) - - prepare_seed(args.rand_seed) - logger = prepare_logger(args) - - # prepare dataset - train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) - #train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + torch.set_num_threads(args.workers) - split_file_path = Path(args.split_path) - assert split_file_path.exists(), '{:} does not exist'.format(split_file_path) - split_info = torch.load(split_file_path) + prepare_seed(args.rand_seed) + logger = prepare_logger(args) - train_split, valid_split = split_info['train'], split_info['valid'] - assert len( set(train_split).intersection( set(valid_split) ) ) == 0, 'There should be 0 element that belongs to both train and valid' - assert len(train_split) + len(valid_split) == len(train_data), '{:} + {:} vs {:}'.format(len(train_split), len(valid_split), len(train_data)) - search_dataset = SearchDataset(args.dataset, train_data, train_split, valid_split) - - search_train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), pin_memory=True, num_workers=args.workers) - search_valid_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), pin_memory=True, num_workers=args.workers) - search_loader = torch.utils.data.DataLoader(search_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) - # get configures - model_config = load_config(args.model_config, {'class_num': class_num, 'search_mode': args.search_shape}, logger) + # prepare dataset + train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) + # train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) + valid_loader = torch.utils.data.DataLoader( + valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True + ) - # obtain the model - search_model = obtain_search_model(model_config) - MAX_FLOP, param = get_model_infos(search_model, xshape) - optim_config = load_config(args.optim_config, {'class_num': class_num, 'FLOP': MAX_FLOP}, logger) - logger.log('Model Information : {:}'.format(search_model.get_message())) - logger.log('MAX_FLOP = {:} M'.format(MAX_FLOP)) - logger.log('Params = {:} M'.format(param)) - logger.log('train_data : {:}'.format(train_data)) - logger.log('search-data: {:}'.format(search_dataset)) - logger.log('search_train_loader : {:} samples'.format( len(train_split) )) - logger.log('search_valid_loader : {:} samples'.format( len(valid_split) )) - base_optimizer, scheduler, criterion = get_optim_scheduler(search_model.base_parameters(), optim_config) - arch_optimizer = torch.optim.Adam(search_model.arch_parameters(), lr=optim_config.arch_LR, betas=(0.5, 0.999), weight_decay=optim_config.arch_decay) - logger.log('base-optimizer : {:}'.format(base_optimizer)) - logger.log('arch-optimizer : {:}'.format(arch_optimizer)) - logger.log('scheduler : {:}'.format(scheduler)) - logger.log('criterion : {:}'.format(criterion)) - - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() + split_file_path = Path(args.split_path) + assert split_file_path.exists(), "{:} does not exist".format(split_file_path) + split_info = torch.load(split_file_path) - # load checkpoint - if last_info.exists() or (args.resume is not None and osp.isfile(args.resume)): # automatically resume from previous checkpoint - if args.resume is not None and osp.isfile(args.resume): - resume_path = Path(args.resume) - elif last_info.exists(): - resume_path = last_info - else: raise ValueError('Something is wrong.') - logger.log("=> loading checkpoint of the last-info '{:}' start".format(resume_path)) - checkpoint = torch.load(resume_path) - if 'last_checkpoint' in checkpoint: - last_checkpoint_path = checkpoint['last_checkpoint'] - if not last_checkpoint_path.exists(): - logger.log('Does not find {:}, try another path'.format(last_checkpoint_path)) - last_checkpoint_path = resume_path.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name - assert last_checkpoint_path.exists(), 'can not find the checkpoint from {:}'.format(last_checkpoint_path) - checkpoint = torch.load( last_checkpoint_path ) - start_epoch = checkpoint['epoch'] + 1 - search_model.load_state_dict( checkpoint['search_model'] ) - scheduler.load_state_dict ( checkpoint['scheduler'] ) - base_optimizer.load_state_dict ( checkpoint['base_optimizer'] ) - arch_optimizer.load_state_dict ( checkpoint['arch_optimizer'] ) - valid_accuracies = checkpoint['valid_accuracies'] - arch_genotypes = checkpoint['arch_genotypes'] - discrepancies = checkpoint['discrepancies'] - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(resume_path, start_epoch)) - else: - logger.log("=> do not find the last-info file : {:} or resume : {:}".format(last_info, args.resume)) - start_epoch, valid_accuracies, arch_genotypes, discrepancies = 0, {'best': -1}, {}, {} + train_split, valid_split = split_info["train"], split_info["valid"] + assert ( + len(set(train_split).intersection(set(valid_split))) == 0 + ), "There should be 0 element that belongs to both train and valid" + assert len(train_split) + len(valid_split) == len(train_data), "{:} + {:} vs {:}".format( + len(train_split), len(valid_split), len(train_data) + ) + search_dataset = SearchDataset(args.dataset, train_data, train_split, valid_split) - # main procedure - train_func, valid_func = get_procedures(args.procedure) - total_epoch = optim_config.epochs + optim_config.warmup - start_time, epoch_time = time.time(), AverageMeter() - for epoch in range(start_epoch, total_epoch): - scheduler.update(epoch, 0.0) - search_model.set_tau(args.gumbel_tau_max, args.gumbel_tau_min, epoch*1.0/total_epoch) - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch), True) ) - epoch_str = 'epoch={:03d}/{:03d}'.format(epoch, total_epoch) - LRs = scheduler.get_lr() - find_best = False - - logger.log('\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}, tau={:}, FLOP={:.2f}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler, search_model.tau, MAX_FLOP)) + search_train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=args.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), + pin_memory=True, + num_workers=args.workers, + ) + search_valid_loader = torch.utils.data.DataLoader( + train_data, + batch_size=args.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), + pin_memory=True, + num_workers=args.workers, + ) + search_loader = torch.utils.data.DataLoader( + search_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True, + sampler=None, + ) + # get configures + model_config = load_config(args.model_config, {"class_num": class_num, "search_mode": args.search_shape}, logger) - # train for one epoch - train_base_loss, train_arch_loss, train_acc1, train_acc5 = train_func(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, \ - {'epoch-str' : epoch_str, 'FLOP-exp': MAX_FLOP * args.FLOP_ratio, - 'FLOP-weight': args.FLOP_weight, 'FLOP-tolerant': MAX_FLOP * args.FLOP_tolerant}, args.print_freq, logger) - # log the results - logger.log('***{:s}*** TRAIN [{:}] base-loss = {:.6f}, arch-loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}'.format(time_string(), epoch_str, train_base_loss, train_arch_loss, train_acc1, train_acc5)) - cur_FLOP, genotype = search_model.get_flop('genotype', model_config._asdict(), None) - arch_genotypes[epoch] = genotype - arch_genotypes['last'] = genotype - logger.log('[{:}] genotype : {:}'.format(epoch_str, genotype)) - arch_info, discrepancy = search_model.get_arch_info() - logger.log(arch_info) - discrepancies[epoch] = discrepancy - logger.log('[{:}] FLOP : {:.2f} MB, ratio : {:.4f}, Expected-ratio : {:.4f}, Discrepancy : {:.3f}'.format(epoch_str, cur_FLOP, cur_FLOP/MAX_FLOP, args.FLOP_ratio, np.mean(discrepancy))) + # obtain the model + search_model = obtain_search_model(model_config) + MAX_FLOP, param = get_model_infos(search_model, xshape) + optim_config = load_config(args.optim_config, {"class_num": class_num, "FLOP": MAX_FLOP}, logger) + logger.log("Model Information : {:}".format(search_model.get_message())) + logger.log("MAX_FLOP = {:} M".format(MAX_FLOP)) + logger.log("Params = {:} M".format(param)) + logger.log("train_data : {:}".format(train_data)) + logger.log("search-data: {:}".format(search_dataset)) + logger.log("search_train_loader : {:} samples".format(len(train_split))) + logger.log("search_valid_loader : {:} samples".format(len(valid_split))) + base_optimizer, scheduler, criterion = get_optim_scheduler(search_model.base_parameters(), optim_config) + arch_optimizer = torch.optim.Adam( + search_model.arch_parameters(), + lr=optim_config.arch_LR, + betas=(0.5, 0.999), + weight_decay=optim_config.arch_decay, + ) + logger.log("base-optimizer : {:}".format(base_optimizer)) + logger.log("arch-optimizer : {:}".format(arch_optimizer)) + logger.log("scheduler : {:}".format(scheduler)) + logger.log("criterion : {:}".format(criterion)) - #if cur_FLOP/MAX_FLOP > args.FLOP_ratio: - # init_flop_weight = init_flop_weight * args.FLOP_decay - #else: - # init_flop_weight = init_flop_weight / args.FLOP_decay - - # evaluate the performance - if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): - logger.log('-'*150) - valid_loss, valid_acc1, valid_acc5 = valid_func(search_valid_loader, network, criterion, epoch_str, args.print_freq_eval, logger) - valid_accuracies[epoch] = valid_acc1 - logger.log('***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}'.format(time_string(), epoch_str, valid_loss, valid_acc1, valid_acc5, valid_accuracies['best'], 100-valid_accuracies['best'])) - if valid_acc1 > valid_accuracies['best']: - valid_accuracies['best'] = valid_acc1 - arch_genotypes['best'] = genotype - find_best = True - logger.log('Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.'.format(epoch, valid_acc1, valid_acc5, 100-valid_acc1, 100-valid_acc5, model_best_path)) + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() - # save checkpoint - save_path = save_checkpoint({ - 'epoch' : epoch, - 'args' : deepcopy(args), - 'valid_accuracies': deepcopy(valid_accuracies), - 'model-config' : model_config._asdict(), - 'optim-config' : optim_config._asdict(), - 'search_model' : search_model.state_dict(), - 'scheduler' : scheduler.state_dict(), - 'base_optimizer': base_optimizer.state_dict(), - 'arch_optimizer': arch_optimizer.state_dict(), - 'arch_genotypes': arch_genotypes, - 'discrepancies' : discrepancies, - }, model_base_path, logger) - if find_best: copy_checkpoint(model_base_path, model_best_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) + # load checkpoint + if last_info.exists() or ( + args.resume is not None and osp.isfile(args.resume) + ): # automatically resume from previous checkpoint + if args.resume is not None and osp.isfile(args.resume): + resume_path = Path(args.resume) + elif last_info.exists(): + resume_path = last_info + else: + raise ValueError("Something is wrong.") + logger.log("=> loading checkpoint of the last-info '{:}' start".format(resume_path)) + checkpoint = torch.load(resume_path) + if "last_checkpoint" in checkpoint: + last_checkpoint_path = checkpoint["last_checkpoint"] + if not last_checkpoint_path.exists(): + logger.log("Does not find {:}, try another path".format(last_checkpoint_path)) + last_checkpoint_path = resume_path.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name + assert last_checkpoint_path.exists(), "can not find the checkpoint from {:}".format(last_checkpoint_path) + checkpoint = torch.load(last_checkpoint_path) + start_epoch = checkpoint["epoch"] + 1 + search_model.load_state_dict(checkpoint["search_model"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + base_optimizer.load_state_dict(checkpoint["base_optimizer"]) + arch_optimizer.load_state_dict(checkpoint["arch_optimizer"]) + valid_accuracies = checkpoint["valid_accuracies"] + arch_genotypes = checkpoint["arch_genotypes"] + discrepancies = checkpoint["discrepancies"] + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(resume_path, start_epoch) + ) + else: + logger.log("=> do not find the last-info file : {:} or resume : {:}".format(last_info, args.resume)) + start_epoch, valid_accuracies, arch_genotypes, discrepancies = 0, {"best": -1}, {}, {} - # measure elapsed time - epoch_time.update(time.time() - start_time) - start_time = time.time() - + # main procedure + train_func, valid_func = get_procedures(args.procedure) + total_epoch = optim_config.epochs + optim_config.warmup + start_time, epoch_time = time.time(), AverageMeter() + for epoch in range(start_epoch, total_epoch): + scheduler.update(epoch, 0.0) + search_model.set_tau(args.gumbel_tau_max, args.gumbel_tau_min, epoch * 1.0 / total_epoch) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch), True)) + epoch_str = "epoch={:03d}/{:03d}".format(epoch, total_epoch) + LRs = scheduler.get_lr() + find_best = False - logger.log('') - logger.log('-'*100) - last_config_path = logger.path('log') / 'seed-{:}-last.config'.format(args.rand_seed) - configure2str(arch_genotypes['last'], str(last_config_path)) - logger.log('save the last config int {:} :\n{:}'.format(last_config_path, arch_genotypes['last'])) + logger.log( + "\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}, tau={:}, FLOP={:.2f}".format( + time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler, search_model.tau, MAX_FLOP + ) + ) - best_arch, valid_acc = arch_genotypes['best'], valid_accuracies['best'] - for key, config in arch_genotypes.items(): - if key == 'last': continue - FLOP_ratio = config['estimated_FLOP'] / MAX_FLOP - if abs(FLOP_ratio - args.FLOP_ratio) <= args.FLOP_tolerant: - if valid_acc < valid_accuracies[key]: - best_arch, valid_acc = config, valid_accuracies[key] - print('Best-Arch : {:}\nRatio={:}, Valid-ACC={:}'.format(best_arch, best_arch['estimated_FLOP'] / MAX_FLOP, valid_acc)) - best_config_path = logger.path('log') / 'seed-{:}-best.config'.format(args.rand_seed) - configure2str(best_arch, str(best_config_path)) - logger.log('save the last config int {:} :\n{:}'.format(best_config_path, best_arch)) - logger.log('\n' + '-'*200) - logger.log('Finish training/validation in {:}, and save final checkpoint into {:}'.format(convert_secs2time(epoch_time.sum, True), logger.path('info'))) - logger.close() + # train for one epoch + train_base_loss, train_arch_loss, train_acc1, train_acc5 = train_func( + search_loader, + network, + criterion, + scheduler, + base_optimizer, + arch_optimizer, + optim_config, + { + "epoch-str": epoch_str, + "FLOP-exp": MAX_FLOP * args.FLOP_ratio, + "FLOP-weight": args.FLOP_weight, + "FLOP-tolerant": MAX_FLOP * args.FLOP_tolerant, + }, + args.print_freq, + logger, + ) + # log the results + logger.log( + "***{:s}*** TRAIN [{:}] base-loss = {:.6f}, arch-loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}".format( + time_string(), epoch_str, train_base_loss, train_arch_loss, train_acc1, train_acc5 + ) + ) + cur_FLOP, genotype = search_model.get_flop("genotype", model_config._asdict(), None) + arch_genotypes[epoch] = genotype + arch_genotypes["last"] = genotype + logger.log("[{:}] genotype : {:}".format(epoch_str, genotype)) + arch_info, discrepancy = search_model.get_arch_info() + logger.log(arch_info) + discrepancies[epoch] = discrepancy + logger.log( + "[{:}] FLOP : {:.2f} MB, ratio : {:.4f}, Expected-ratio : {:.4f}, Discrepancy : {:.3f}".format( + epoch_str, cur_FLOP, cur_FLOP / MAX_FLOP, args.FLOP_ratio, np.mean(discrepancy) + ) + ) + + # if cur_FLOP/MAX_FLOP > args.FLOP_ratio: + # init_flop_weight = init_flop_weight * args.FLOP_decay + # else: + # init_flop_weight = init_flop_weight / args.FLOP_decay + + # evaluate the performance + if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): + logger.log("-" * 150) + valid_loss, valid_acc1, valid_acc5 = valid_func( + search_valid_loader, network, criterion, epoch_str, args.print_freq_eval, logger + ) + valid_accuracies[epoch] = valid_acc1 + logger.log( + "***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}".format( + time_string(), + epoch_str, + valid_loss, + valid_acc1, + valid_acc5, + valid_accuracies["best"], + 100 - valid_accuracies["best"], + ) + ) + if valid_acc1 > valid_accuracies["best"]: + valid_accuracies["best"] = valid_acc1 + arch_genotypes["best"] = genotype + find_best = True + logger.log( + "Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.".format( + epoch, valid_acc1, valid_acc5, 100 - valid_acc1, 100 - valid_acc5, model_best_path + ) + ) + + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch, + "args": deepcopy(args), + "valid_accuracies": deepcopy(valid_accuracies), + "model-config": model_config._asdict(), + "optim-config": optim_config._asdict(), + "search_model": search_model.state_dict(), + "scheduler": scheduler.state_dict(), + "base_optimizer": base_optimizer.state_dict(), + "arch_optimizer": arch_optimizer.state_dict(), + "arch_genotypes": arch_genotypes, + "discrepancies": discrepancies, + }, + model_base_path, + logger, + ) + if find_best: + copy_checkpoint(model_base_path, model_best_path, logger) + last_info = save_checkpoint( + { + "epoch": epoch, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("") + logger.log("-" * 100) + last_config_path = logger.path("log") / "seed-{:}-last.config".format(args.rand_seed) + configure2str(arch_genotypes["last"], str(last_config_path)) + logger.log("save the last config int {:} :\n{:}".format(last_config_path, arch_genotypes["last"])) + + best_arch, valid_acc = arch_genotypes["best"], valid_accuracies["best"] + for key, config in arch_genotypes.items(): + if key == "last": + continue + FLOP_ratio = config["estimated_FLOP"] / MAX_FLOP + if abs(FLOP_ratio - args.FLOP_ratio) <= args.FLOP_tolerant: + if valid_acc < valid_accuracies[key]: + best_arch, valid_acc = config, valid_accuracies[key] + print( + "Best-Arch : {:}\nRatio={:}, Valid-ACC={:}".format(best_arch, best_arch["estimated_FLOP"] / MAX_FLOP, valid_acc) + ) + best_config_path = logger.path("log") / "seed-{:}-best.config".format(args.rand_seed) + configure2str(best_arch, str(best_config_path)) + logger.log("save the last config int {:} :\n{:}".format(best_config_path, best_arch)) + logger.log("\n" + "-" * 200) + logger.log( + "Finish training/validation in {:}, and save final checkpoint into {:}".format( + convert_secs2time(epoch_time.sum, True), logger.path("info") + ) + ) + logger.close() -if __name__ == '__main__': - args = obtain_args() - main(args) +if __name__ == "__main__": + args = obtain_args() + main(args) diff --git a/exps/search-transformable.py b/exps/search-transformable.py index 3b2a0eb..9b81a1f 100644 --- a/exps/search-transformable.py +++ b/exps/search-transformable.py @@ -4,214 +4,314 @@ # Network Pruning via Transformable Architecture Search, NeurIPS 2019 # ####################################################################### import sys, time, torch, random, argparse -from PIL import ImageFile -from os import path as osp +from PIL import ImageFile +from os import path as osp + ImageFile.LOAD_TRUNCATED_IMAGES = True import numpy as np -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import load_config, configure2str, obtain_search_args as obtain_args -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from procedures import get_optim_scheduler, get_procedures -from datasets import get_datasets, SearchDataset -from models import obtain_search_model, obtain_model, change_key -from utils import get_model_infos -from log_utils import AverageMeter, time_string, convert_secs2time +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint +from procedures import get_optim_scheduler, get_procedures +from datasets import get_datasets, SearchDataset +from models import obtain_search_model, obtain_model, change_key +from utils import get_model_infos +from log_utils import AverageMeter, time_string, convert_secs2time def main(args): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = True - #torch.backends.cudnn.deterministic = True - torch.set_num_threads( args.workers ) - - prepare_seed(args.rand_seed) - logger = prepare_logger(args) - - # prepare dataset - train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) - #train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) + assert torch.cuda.is_available(), "CUDA is not available." + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + torch.set_num_threads(args.workers) - split_file_path = Path(args.split_path) - assert split_file_path.exists(), '{:} does not exist'.format(split_file_path) - split_info = torch.load(split_file_path) + prepare_seed(args.rand_seed) + logger = prepare_logger(args) - train_split, valid_split = split_info['train'], split_info['valid'] - assert len( set(train_split).intersection( set(valid_split) ) ) == 0, 'There should be 0 element that belongs to both train and valid' - assert len(train_split) + len(valid_split) == len(train_data), '{:} + {:} vs {:}'.format(len(train_split), len(valid_split), len(train_data)) - search_dataset = SearchDataset(args.dataset, train_data, train_split, valid_split) - - search_train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), pin_memory=True, num_workers=args.workers) - search_valid_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), pin_memory=True, num_workers=args.workers) - search_loader = torch.utils.data.DataLoader(search_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) - # get configures - if args.ablation_num_select is None or args.ablation_num_select <= 0: - model_config = load_config(args.model_config, {'class_num': class_num, 'search_mode': 'shape'}, logger) - else: - model_config = load_config(args.model_config, {'class_num': class_num, 'search_mode': 'ablation', 'num_random_select': args.ablation_num_select}, logger) + # prepare dataset + train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) + # train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) + valid_loader = torch.utils.data.DataLoader( + valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True + ) - # obtain the model - search_model = obtain_search_model(model_config) - MAX_FLOP, param = get_model_infos(search_model, xshape) - optim_config = load_config(args.optim_config, {'class_num': class_num, 'FLOP': MAX_FLOP}, logger) - logger.log('Model Information : {:}'.format(search_model.get_message())) - logger.log('MAX_FLOP = {:} M'.format(MAX_FLOP)) - logger.log('Params = {:} M'.format(param)) - logger.log('train_data : {:}'.format(train_data)) - logger.log('search-data: {:}'.format(search_dataset)) - logger.log('search_train_loader : {:} samples'.format( len(train_split) )) - logger.log('search_valid_loader : {:} samples'.format( len(valid_split) )) - base_optimizer, scheduler, criterion = get_optim_scheduler(search_model.base_parameters(), optim_config) - arch_optimizer = torch.optim.Adam(search_model.arch_parameters(optim_config.arch_LR), lr=optim_config.arch_LR, betas=(0.5, 0.999), weight_decay=optim_config.arch_decay) - logger.log('base-optimizer : {:}'.format(base_optimizer)) - logger.log('arch-optimizer : {:}'.format(arch_optimizer)) - logger.log('scheduler : {:}'.format(scheduler)) - logger.log('criterion : {:}'.format(criterion)) - - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() + split_file_path = Path(args.split_path) + assert split_file_path.exists(), "{:} does not exist".format(split_file_path) + split_info = torch.load(split_file_path) - # load checkpoint - if last_info.exists() or (args.resume is not None and osp.isfile(args.resume)): # automatically resume from previous checkpoint - if args.resume is not None and osp.isfile(args.resume): - resume_path = Path(args.resume) - elif last_info.exists(): - resume_path = last_info - else: raise ValueError('Something is wrong.') - logger.log("=> loading checkpoint of the last-info '{:}' start".format(resume_path)) - checkpoint = torch.load(resume_path) - if 'last_checkpoint' in checkpoint: - last_checkpoint_path = checkpoint['last_checkpoint'] - if not last_checkpoint_path.exists(): - logger.log('Does not find {:}, try another path'.format(last_checkpoint_path)) - last_checkpoint_path = resume_path.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name - assert last_checkpoint_path.exists(), 'can not find the checkpoint from {:}'.format(last_checkpoint_path) - checkpoint = torch.load( last_checkpoint_path ) - start_epoch = checkpoint['epoch'] + 1 - #for key, value in checkpoint['search_model'].items(): - # print('K {:} = Shape={:}'.format(key, value.shape)) - search_model.load_state_dict( checkpoint['search_model'] ) - scheduler.load_state_dict ( checkpoint['scheduler'] ) - base_optimizer.load_state_dict ( checkpoint['base_optimizer'] ) - arch_optimizer.load_state_dict ( checkpoint['arch_optimizer'] ) - valid_accuracies = checkpoint['valid_accuracies'] - arch_genotypes = checkpoint['arch_genotypes'] - discrepancies = checkpoint['discrepancies'] - max_bytes = checkpoint['max_bytes'] - logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(resume_path, start_epoch)) - else: - logger.log("=> do not find the last-info file : {:} or resume : {:}".format(last_info, args.resume)) - start_epoch, valid_accuracies, arch_genotypes, discrepancies, max_bytes = 0, {'best': -1}, {}, {}, {} + train_split, valid_split = split_info["train"], split_info["valid"] + assert ( + len(set(train_split).intersection(set(valid_split))) == 0 + ), "There should be 0 element that belongs to both train and valid" + assert len(train_split) + len(valid_split) == len(train_data), "{:} + {:} vs {:}".format( + len(train_split), len(valid_split), len(train_data) + ) + search_dataset = SearchDataset(args.dataset, train_data, train_split, valid_split) - # main procedure - train_func, valid_func = get_procedures(args.procedure) - total_epoch = optim_config.epochs + optim_config.warmup - start_time, epoch_time = time.time(), AverageMeter() - for epoch in range(start_epoch, total_epoch): - scheduler.update(epoch, 0.0) - search_model.set_tau(args.gumbel_tau_max, args.gumbel_tau_min, epoch*1.0/total_epoch) - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch), True) ) - epoch_str = 'epoch={:03d}/{:03d}'.format(epoch, total_epoch) - LRs = scheduler.get_lr() - find_best = False - - logger.log('\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}, tau={:}, FLOP={:.2f}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler, search_model.tau, MAX_FLOP)) + search_train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=args.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), + pin_memory=True, + num_workers=args.workers, + ) + search_valid_loader = torch.utils.data.DataLoader( + train_data, + batch_size=args.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), + pin_memory=True, + num_workers=args.workers, + ) + search_loader = torch.utils.data.DataLoader( + search_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True, + sampler=None, + ) + # get configures + if args.ablation_num_select is None or args.ablation_num_select <= 0: + model_config = load_config(args.model_config, {"class_num": class_num, "search_mode": "shape"}, logger) + else: + model_config = load_config( + args.model_config, + {"class_num": class_num, "search_mode": "ablation", "num_random_select": args.ablation_num_select}, + logger, + ) - # train for one epoch - train_base_loss, train_arch_loss, train_acc1, train_acc5 = train_func(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, \ - {'epoch-str' : epoch_str, 'FLOP-exp': MAX_FLOP * args.FLOP_ratio, - 'FLOP-weight': args.FLOP_weight, 'FLOP-tolerant': MAX_FLOP * args.FLOP_tolerant}, args.print_freq, logger) - # log the results - logger.log('***{:s}*** TRAIN [{:}] base-loss = {:.6f}, arch-loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}'.format(time_string(), epoch_str, train_base_loss, train_arch_loss, train_acc1, train_acc5)) - cur_FLOP, genotype = search_model.get_flop('genotype', model_config._asdict(), None) - arch_genotypes[epoch] = genotype - arch_genotypes['last'] = genotype - logger.log('[{:}] genotype : {:}'.format(epoch_str, genotype)) - # save the configuration - configure2str(genotype, str( logger.path('log') / 'seed-{:}-temp.config'.format(args.rand_seed) )) - arch_info, discrepancy = search_model.get_arch_info() - logger.log(arch_info) - discrepancies[epoch] = discrepancy - logger.log('[{:}] FLOP : {:.2f} MB, ratio : {:.4f}, Expected-ratio : {:.4f}, Discrepancy : {:.3f}'.format(epoch_str, cur_FLOP, cur_FLOP/MAX_FLOP, args.FLOP_ratio, np.mean(discrepancy))) + # obtain the model + search_model = obtain_search_model(model_config) + MAX_FLOP, param = get_model_infos(search_model, xshape) + optim_config = load_config(args.optim_config, {"class_num": class_num, "FLOP": MAX_FLOP}, logger) + logger.log("Model Information : {:}".format(search_model.get_message())) + logger.log("MAX_FLOP = {:} M".format(MAX_FLOP)) + logger.log("Params = {:} M".format(param)) + logger.log("train_data : {:}".format(train_data)) + logger.log("search-data: {:}".format(search_dataset)) + logger.log("search_train_loader : {:} samples".format(len(train_split))) + logger.log("search_valid_loader : {:} samples".format(len(valid_split))) + base_optimizer, scheduler, criterion = get_optim_scheduler(search_model.base_parameters(), optim_config) + arch_optimizer = torch.optim.Adam( + search_model.arch_parameters(optim_config.arch_LR), + lr=optim_config.arch_LR, + betas=(0.5, 0.999), + weight_decay=optim_config.arch_decay, + ) + logger.log("base-optimizer : {:}".format(base_optimizer)) + logger.log("arch-optimizer : {:}".format(arch_optimizer)) + logger.log("scheduler : {:}".format(scheduler)) + logger.log("criterion : {:}".format(criterion)) - #if cur_FLOP/MAX_FLOP > args.FLOP_ratio: - # init_flop_weight = init_flop_weight * args.FLOP_decay - #else: - # init_flop_weight = init_flop_weight / args.FLOP_decay - - # evaluate the performance - if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): - logger.log('-'*150) - valid_loss, valid_acc1, valid_acc5 = valid_func(search_valid_loader, network, criterion, epoch_str, args.print_freq_eval, logger) - valid_accuracies[epoch] = valid_acc1 - logger.log('***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}'.format(time_string(), epoch_str, valid_loss, valid_acc1, valid_acc5, valid_accuracies['best'], 100-valid_accuracies['best'])) - if valid_acc1 > valid_accuracies['best']: - valid_accuracies['best'] = valid_acc1 - arch_genotypes['best'] = genotype - find_best = True - logger.log('Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.'.format(epoch, valid_acc1, valid_acc5, 100-valid_acc1, 100-valid_acc5, model_best_path)) - # log the GPU memory usage - #num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0 - num_bytes = torch.cuda.max_memory_cached( next(network.parameters()).device ) * 1.0 - logger.log('[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]'.format(next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9)) - max_bytes[epoch] = num_bytes + last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() - # save checkpoint - save_path = save_checkpoint({ - 'epoch' : epoch, - 'args' : deepcopy(args), - 'max_bytes' : deepcopy(max_bytes), - 'valid_accuracies': deepcopy(valid_accuracies), - 'model-config' : model_config._asdict(), - 'optim-config' : optim_config._asdict(), - 'search_model' : search_model.state_dict(), - 'scheduler' : scheduler.state_dict(), - 'base_optimizer': base_optimizer.state_dict(), - 'arch_optimizer': arch_optimizer.state_dict(), - 'arch_genotypes': arch_genotypes, - 'discrepancies' : discrepancies, - }, model_base_path, logger) - if find_best: copy_checkpoint(model_base_path, model_best_path, logger) - last_info = save_checkpoint({ - 'epoch': epoch, - 'args' : deepcopy(args), - 'last_checkpoint': save_path, - }, logger.path('info'), logger) + # load checkpoint + if last_info.exists() or ( + args.resume is not None and osp.isfile(args.resume) + ): # automatically resume from previous checkpoint + if args.resume is not None and osp.isfile(args.resume): + resume_path = Path(args.resume) + elif last_info.exists(): + resume_path = last_info + else: + raise ValueError("Something is wrong.") + logger.log("=> loading checkpoint of the last-info '{:}' start".format(resume_path)) + checkpoint = torch.load(resume_path) + if "last_checkpoint" in checkpoint: + last_checkpoint_path = checkpoint["last_checkpoint"] + if not last_checkpoint_path.exists(): + logger.log("Does not find {:}, try another path".format(last_checkpoint_path)) + last_checkpoint_path = resume_path.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name + assert last_checkpoint_path.exists(), "can not find the checkpoint from {:}".format(last_checkpoint_path) + checkpoint = torch.load(last_checkpoint_path) + start_epoch = checkpoint["epoch"] + 1 + # for key, value in checkpoint['search_model'].items(): + # print('K {:} = Shape={:}'.format(key, value.shape)) + search_model.load_state_dict(checkpoint["search_model"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + base_optimizer.load_state_dict(checkpoint["base_optimizer"]) + arch_optimizer.load_state_dict(checkpoint["arch_optimizer"]) + valid_accuracies = checkpoint["valid_accuracies"] + arch_genotypes = checkpoint["arch_genotypes"] + discrepancies = checkpoint["discrepancies"] + max_bytes = checkpoint["max_bytes"] + logger.log( + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(resume_path, start_epoch) + ) + else: + logger.log("=> do not find the last-info file : {:} or resume : {:}".format(last_info, args.resume)) + start_epoch, valid_accuracies, arch_genotypes, discrepancies, max_bytes = 0, {"best": -1}, {}, {}, {} - # measure elapsed time - epoch_time.update(time.time() - start_time) - start_time = time.time() - + # main procedure + train_func, valid_func = get_procedures(args.procedure) + total_epoch = optim_config.epochs + optim_config.warmup + start_time, epoch_time = time.time(), AverageMeter() + for epoch in range(start_epoch, total_epoch): + scheduler.update(epoch, 0.0) + search_model.set_tau(args.gumbel_tau_max, args.gumbel_tau_min, epoch * 1.0 / total_epoch) + need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch), True)) + epoch_str = "epoch={:03d}/{:03d}".format(epoch, total_epoch) + LRs = scheduler.get_lr() + find_best = False - logger.log('') - logger.log('-'*100) - last_config_path = logger.path('log') / 'seed-{:}-last.config'.format(args.rand_seed) - configure2str(arch_genotypes['last'], str(last_config_path)) - logger.log('save the last config int {:} :\n{:}'.format(last_config_path, arch_genotypes['last'])) + logger.log( + "\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}, tau={:}, FLOP={:.2f}".format( + time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler, search_model.tau, MAX_FLOP + ) + ) - best_arch, valid_acc = arch_genotypes['best'], valid_accuracies['best'] - for key, config in arch_genotypes.items(): - if key == 'last': continue - FLOP_ratio = config['estimated_FLOP'] / MAX_FLOP - if abs(FLOP_ratio - args.FLOP_ratio) <= args.FLOP_tolerant: - if valid_acc <= valid_accuracies[key]: - best_arch, valid_acc = config, valid_accuracies[key] - print('Best-Arch : {:}\nRatio={:}, Valid-ACC={:}'.format(best_arch, best_arch['estimated_FLOP'] / MAX_FLOP, valid_acc)) - best_config_path = logger.path('log') / 'seed-{:}-best.config'.format(args.rand_seed) - configure2str(best_arch, str(best_config_path)) - logger.log('save the last config int {:} :\n{:}'.format(best_config_path, best_arch)) - logger.log('\n' + '-'*200) - logger.log('Finish training/validation in {:} with Max-GPU-Memory of {:.2f} GB, and save final checkpoint into {:}'.format(convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e9, logger.path('info'))) - logger.close() + # train for one epoch + train_base_loss, train_arch_loss, train_acc1, train_acc5 = train_func( + search_loader, + network, + criterion, + scheduler, + base_optimizer, + arch_optimizer, + optim_config, + { + "epoch-str": epoch_str, + "FLOP-exp": MAX_FLOP * args.FLOP_ratio, + "FLOP-weight": args.FLOP_weight, + "FLOP-tolerant": MAX_FLOP * args.FLOP_tolerant, + }, + args.print_freq, + logger, + ) + # log the results + logger.log( + "***{:s}*** TRAIN [{:}] base-loss = {:.6f}, arch-loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}".format( + time_string(), epoch_str, train_base_loss, train_arch_loss, train_acc1, train_acc5 + ) + ) + cur_FLOP, genotype = search_model.get_flop("genotype", model_config._asdict(), None) + arch_genotypes[epoch] = genotype + arch_genotypes["last"] = genotype + logger.log("[{:}] genotype : {:}".format(epoch_str, genotype)) + # save the configuration + configure2str(genotype, str(logger.path("log") / "seed-{:}-temp.config".format(args.rand_seed))) + arch_info, discrepancy = search_model.get_arch_info() + logger.log(arch_info) + discrepancies[epoch] = discrepancy + logger.log( + "[{:}] FLOP : {:.2f} MB, ratio : {:.4f}, Expected-ratio : {:.4f}, Discrepancy : {:.3f}".format( + epoch_str, cur_FLOP, cur_FLOP / MAX_FLOP, args.FLOP_ratio, np.mean(discrepancy) + ) + ) + + # if cur_FLOP/MAX_FLOP > args.FLOP_ratio: + # init_flop_weight = init_flop_weight * args.FLOP_decay + # else: + # init_flop_weight = init_flop_weight / args.FLOP_decay + + # evaluate the performance + if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): + logger.log("-" * 150) + valid_loss, valid_acc1, valid_acc5 = valid_func( + search_valid_loader, network, criterion, epoch_str, args.print_freq_eval, logger + ) + valid_accuracies[epoch] = valid_acc1 + logger.log( + "***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}".format( + time_string(), + epoch_str, + valid_loss, + valid_acc1, + valid_acc5, + valid_accuracies["best"], + 100 - valid_accuracies["best"], + ) + ) + if valid_acc1 > valid_accuracies["best"]: + valid_accuracies["best"] = valid_acc1 + arch_genotypes["best"] = genotype + find_best = True + logger.log( + "Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.".format( + epoch, valid_acc1, valid_acc5, 100 - valid_acc1, 100 - valid_acc5, model_best_path + ) + ) + # log the GPU memory usage + # num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0 + num_bytes = torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 + logger.log( + "[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]".format( + next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9 + ) + ) + max_bytes[epoch] = num_bytes + + # save checkpoint + save_path = save_checkpoint( + { + "epoch": epoch, + "args": deepcopy(args), + "max_bytes": deepcopy(max_bytes), + "valid_accuracies": deepcopy(valid_accuracies), + "model-config": model_config._asdict(), + "optim-config": optim_config._asdict(), + "search_model": search_model.state_dict(), + "scheduler": scheduler.state_dict(), + "base_optimizer": base_optimizer.state_dict(), + "arch_optimizer": arch_optimizer.state_dict(), + "arch_genotypes": arch_genotypes, + "discrepancies": discrepancies, + }, + model_base_path, + logger, + ) + if find_best: + copy_checkpoint(model_base_path, model_best_path, logger) + last_info = save_checkpoint( + { + "epoch": epoch, + "args": deepcopy(args), + "last_checkpoint": save_path, + }, + logger.path("info"), + logger, + ) + + # measure elapsed time + epoch_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("") + logger.log("-" * 100) + last_config_path = logger.path("log") / "seed-{:}-last.config".format(args.rand_seed) + configure2str(arch_genotypes["last"], str(last_config_path)) + logger.log("save the last config int {:} :\n{:}".format(last_config_path, arch_genotypes["last"])) + + best_arch, valid_acc = arch_genotypes["best"], valid_accuracies["best"] + for key, config in arch_genotypes.items(): + if key == "last": + continue + FLOP_ratio = config["estimated_FLOP"] / MAX_FLOP + if abs(FLOP_ratio - args.FLOP_ratio) <= args.FLOP_tolerant: + if valid_acc <= valid_accuracies[key]: + best_arch, valid_acc = config, valid_accuracies[key] + print( + "Best-Arch : {:}\nRatio={:}, Valid-ACC={:}".format(best_arch, best_arch["estimated_FLOP"] / MAX_FLOP, valid_acc) + ) + best_config_path = logger.path("log") / "seed-{:}-best.config".format(args.rand_seed) + configure2str(best_arch, str(best_config_path)) + logger.log("save the last config int {:} :\n{:}".format(best_config_path, best_arch)) + logger.log("\n" + "-" * 200) + logger.log( + "Finish training/validation in {:} with Max-GPU-Memory of {:.2f} GB, and save final checkpoint into {:}".format( + convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e9, logger.path("info") + ) + ) + logger.close() -if __name__ == '__main__': - args = obtain_args() - main(args) +if __name__ == "__main__": + args = obtain_args() + main(args) diff --git a/exps/show-dataset.py b/exps/show-dataset.py index 87510c3..2a79f33 100644 --- a/exps/show-dataset.py +++ b/exps/show-dataset.py @@ -7,41 +7,47 @@ ############################################################################## import os, sys, time, torch, random, argparse from typing import List, Text, Dict, Any -from PIL import ImageFile +from PIL import ImageFile + ImageFile.LOAD_TRUNCATED_IMAGES = True -from copy import deepcopy +from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from datasets import get_datasets from nats_bench import create def show_imagenet_16_120(dataset_dir=None): - if dataset_dir is None: - torch_home_dir = os.environ['TORCH_HOME'] if 'TORCH_HOME' in os.environ else os.path.join(os.environ['HOME'], '.torch') - dataset_dir = os.path.join(torch_home_dir, 'cifar.python', 'ImageNet16') - train_data, valid_data, xshape, class_num = get_datasets('ImageNet16-120', dataset_dir, -1) - split_info = load_config('configs/nas-benchmark/ImageNet16-120-split.txt', None, None) - print('=' * 10 + ' ImageNet-16-120 ' + '=' * 10) - print('Training Data: {:}'.format(train_data)) - print('Evaluation Data: {:}'.format(valid_data)) - print('Hold-out training: {:} images.'.format(len(split_info.train))) - print('Hold-out valid : {:} images.'.format(len(split_info.valid))) + if dataset_dir is None: + torch_home_dir = ( + os.environ["TORCH_HOME"] if "TORCH_HOME" in os.environ else os.path.join(os.environ["HOME"], ".torch") + ) + dataset_dir = os.path.join(torch_home_dir, "cifar.python", "ImageNet16") + train_data, valid_data, xshape, class_num = get_datasets("ImageNet16-120", dataset_dir, -1) + split_info = load_config("configs/nas-benchmark/ImageNet16-120-split.txt", None, None) + print("=" * 10 + " ImageNet-16-120 " + "=" * 10) + print("Training Data: {:}".format(train_data)) + print("Evaluation Data: {:}".format(valid_data)) + print("Hold-out training: {:} images.".format(len(split_info.train))) + print("Hold-out valid : {:} images.".format(len(split_info.valid))) -if __name__ == '__main__': - # show_imagenet_16_120() - api_nats_tss = create(None, 'tss', fast_mode=True, verbose=True) +if __name__ == "__main__": + # show_imagenet_16_120() + api_nats_tss = create(None, "tss", fast_mode=True, verbose=True) - valid_acc_12e = [] - test_acc_12e = [] - test_acc_200e = [] - for index in range(10000): - info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='12') - valid_acc_12e.append(info['valid-accuracy']) # the validation accuracy after training the model by 12 epochs - test_acc_12e.append(info['test-accuracy']) # the test accuracy after training the model by 12 epochs - info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='200') - test_acc_200e.append(info['test-accuracy']) # the test accuracy after training the model by 200 epochs (which I reported in the paper) + valid_acc_12e = [] + test_acc_12e = [] + test_acc_200e = [] + for index in range(10000): + info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="12") + valid_acc_12e.append(info["valid-accuracy"]) # the validation accuracy after training the model by 12 epochs + test_acc_12e.append(info["test-accuracy"]) # the test accuracy after training the model by 12 epochs + info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="200") + test_acc_200e.append( + info["test-accuracy"] + ) # the test accuracy after training the model by 200 epochs (which I reported in the paper) diff --git a/exps/trading/baselines.py b/exps/trading/baselines.py index 027a691..6228fd6 100644 --- a/exps/trading/baselines.py +++ b/exps/trading/baselines.py @@ -72,7 +72,7 @@ def retrieve_configs(): def main(xargs, exp_yaml): assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml) - pprint('Run {:}'.format(xargs.alg)) + pprint("Run {:}".format(xargs.alg)) with open(exp_yaml) as fp: config = yaml.safe_load(fp) config = update_market(config, xargs.market) @@ -87,7 +87,11 @@ def main(xargs, exp_yaml): for irun in range(xargs.times): run_exp( - config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), '{:}-{:}'.format(xargs.save_dir, xargs.market) + config.get("task"), + dataset, + xargs.alg, + "recorder-{:02d}-{:02d}".format(irun, xargs.times), + "{:}-{:}".format(xargs.save_dir, xargs.market), ) @@ -97,7 +101,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser("Baselines") parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.") - parser.add_argument("--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator.") + parser.add_argument( + "--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator." + ) parser.add_argument("--times", type=int, default=10, help="The repeated run times.") parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.") diff --git a/exps/trading/organize_results.py b/exps/trading/organize_results.py index 76ff9bb..7da7c76 100644 --- a/exps/trading/organize_results.py +++ b/exps/trading/organize_results.py @@ -179,4 +179,3 @@ if __name__ == "__main__": all_info_dict.append(info_dict) info_dict = QResult.merge_dict(all_info_dict) compare_results(info_dict["heads"], info_dict["values"], info_dict["names"], space=15, verbose=True, sort_key=True) -