From 63c8bb9bc8a387aed31e88e6e6573457ec1fc2bd Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 18 Mar 2021 16:02:55 +0800 Subject: [PATCH] Add int search space --- .github/workflows/test.yml | 2 + exps/KD-main.py | 100 +++++-- exps/NAS-Bench-201/check.py | 52 +++- exps/NAS-Bench-201/dist-setup.py | 4 +- exps/NAS-Bench-201/functions.py | 29 +- exps/NAS-Bench-201/main.py | 306 ++++++++++++++----- exps/NAS-Bench-201/show-best.py | 7 +- exps/NAS-Bench-201/statistics-v2.py | 287 +++++++++++++----- exps/NAS-Bench-201/statistics.py | 275 ++++++++++++++---- exps/NAS-Bench-201/test-correlation.py | 63 +++- exps/NAS-Bench-201/visualize.py | 293 +++++++++++++++---- exps/NATS-Bench/Analyze-time.py | 6 +- exps/NATS-Bench/draw-correlations.py | 45 ++- exps/NATS-Bench/draw-fig2_5.py | 93 ++++-- exps/NATS-Bench/draw-fig6.py | 79 +++-- exps/NATS-Bench/draw-fig7.py | 106 +++++-- exps/NATS-Bench/draw-fig8.py | 61 +++- exps/NATS-Bench/draw-ranks.py | 71 ++++- exps/NATS-Bench/draw-table.py | 45 ++- exps/NATS-Bench/main-sss.py | 216 +++++++++++--- exps/NATS-Bench/main-tss.py | 308 +++++++++++++++----- exps/NATS-Bench/sss-collect.py | 127 ++++++-- exps/NATS-Bench/sss-file-manager.py | 24 +- exps/NATS-Bench/test-nats-api.py | 12 +- exps/NATS-Bench/tss-collect-patcher.py | 43 ++- exps/NATS-Bench/tss-collect.py | 203 ++++++++++--- exps/NATS-Bench/tss-file-manager.py | 28 +- exps/NATS-algos/bohb.py | 93 ++++-- exps/NATS-algos/random_wo_share.py | 47 ++- exps/NATS-algos/regularized_ea.py | 74 ++++- exps/NATS-algos/reinforce.py | 73 ++++- exps/NATS-algos/search-cell.py | 323 +++++++++++++++++---- exps/NATS-algos/search-size.py | 217 +++++++++++--- exps/algos/BOHB.py | 133 +++++++-- exps/algos/DARTS-V1.py | 138 +++++++-- exps/algos/DARTS-V2.py | 179 +++++++++--- exps/algos/ENAS.py | 197 ++++++++++--- exps/algos/GDAS.py | 149 ++++++++-- exps/algos/RANDOM-NAS.py | 122 ++++++-- exps/algos/RANDOM.py | 68 ++++- exps/algos/R_EA.py | 116 ++++++-- exps/algos/SETN.py | 176 ++++++++--- exps/algos/reinforce.py | 77 +++-- exps/basic-eval.py | 57 +++- exps/basic-main.py | 111 +++++-- exps/experimental/example-nas-bench.py | 20 +- exps/experimental/test-nas-plot.py | 26 +- exps/experimental/test-ww-bench.py | 51 +++- exps/experimental/vis-nats-bench-algos.py | 67 ++++- exps/experimental/vis-nats-bench-ws.py | 56 +++- exps/experimental/visualize-nas-bench-x.py | 94 ++++-- exps/prepare.py | 13 +- exps/search-shape.py | 147 ++++++++-- exps/search-transformable.py | 165 ++++++++--- exps/show-dataset.py | 20 +- exps/trading/baselines.py | 31 +- exps/trading/organize_results.py | 45 ++- exps/trading/workflow_tt.py | 42 ++- lib/layers/drop.py | 304 +++++++++++-------- lib/layers/mlp.py | 45 +-- lib/layers/positional_embedding.py | 54 ++-- lib/layers/super_mlp.py | 63 ++-- lib/layers/super_module.py | 18 +- lib/layers/weight_init.py | 94 +++--- lib/spaces/__init__.py | 1 + lib/spaces/basic_space.py | 26 ++ tests/test_basic_space.py | 7 + 67 files changed, 5150 insertions(+), 1474 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9b482ae..f8a25a0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,7 +30,9 @@ jobs: python --version python -m black --version echo $PWD ; ls + python -m black ./exps -l 88 --check --diff --verbose python -m black ./tests -l 88 --check --diff --verbose + python -m black ./lib/layers -l 88 --check --diff --verbose python -m black ./lib/spaces -l 88 --check --diff --verbose python -m black ./lib/trade_models -l 88 --check --diff --verbose diff --git a/exps/KD-main.py b/exps/KD-main.py index d130b91..abfd538 100644 --- a/exps/KD-main.py +++ b/exps/KD-main.py @@ -30,18 +30,32 @@ def main(args): prepare_seed(args.rand_seed) logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) + train_data, valid_data, xshape, class_num = get_datasets( + args.dataset, args.data_path, args.cutout_length + ) train_loader = torch.utils.data.DataLoader( - train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True + train_data, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( - valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True + valid_data, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True, ) # get configures model_config = load_config(args.model_config, {"class_num": class_num}, logger) optim_config = load_config( args.optim_config, - {"class_num": class_num, "KD_alpha": args.KD_alpha, "KD_temperature": args.KD_temperature}, + { + "class_num": class_num, + "KD_alpha": args.KD_alpha, + "KD_temperature": args.KD_temperature, + }, logger, ) @@ -55,20 +69,32 @@ def main(args): logger.log("Teacher ====>>>>:\n{:}".format(teacher_base)) logger.log("model information : {:}".format(base_model.get_message())) logger.log("-" * 50) - logger.log("Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format(param, flop, flop / 1e3)) + logger.log( + "Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format( + param, flop, flop / 1e3 + ) + ) logger.log("-" * 50) logger.log("train_data : {:}".format(train_data)) logger.log("valid_data : {:}".format(valid_data)) - optimizer, scheduler, criterion = get_optim_scheduler(base_model.parameters(), optim_config) + optimizer, scheduler, criterion = get_optim_scheduler( + base_model.parameters(), optim_config + ) logger.log("optimizer : {:}".format(optimizer)) logger.log("scheduler : {:}".format(scheduler)) logger.log("criterion : {:}".format(criterion)) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(last_info) + ) last_info = torch.load(last_info) start_epoch = last_info["epoch"] + 1 checkpoint = torch.load(last_info["last_checkpoint"]) @@ -78,10 +104,14 @@ def main(args): valid_accuracies = checkpoint["valid_accuracies"] max_bytes = checkpoint["max_bytes"] logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + last_info, start_epoch + ) ) elif args.resume is not None: - assert Path(args.resume).exists(), "Can not find the resume file : {:}".format(args.resume) + assert Path(args.resume).exists(), "Can not find the resume file : {:}".format( + args.resume + ) checkpoint = torch.load(args.resume) start_epoch = checkpoint["epoch"] + 1 base_model.load_state_dict(checkpoint["base-model"]) @@ -89,9 +119,15 @@ def main(args): optimizer.load_state_dict(checkpoint["optimizer"]) valid_accuracies = checkpoint["valid_accuracies"] max_bytes = checkpoint["max_bytes"] - logger.log("=> loading checkpoint from '{:}' start with {:}-th epoch.".format(args.resume, start_epoch)) + logger.log( + "=> loading checkpoint from '{:}' start with {:}-th epoch.".format( + args.resume, start_epoch + ) + ) elif args.init_model is not None: - assert Path(args.init_model).exists(), "Can not find the initialization file : {:}".format(args.init_model) + assert Path( + args.init_model + ).exists(), "Can not find the initialization file : {:}".format(args.init_model) checkpoint = torch.load(args.init_model) base_model.load_state_dict(checkpoint["base-model"]) start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} @@ -108,7 +144,9 @@ def main(args): epoch_time = AverageMeter() for epoch in range(start_epoch, total_epoch): scheduler.update(epoch, 0.0) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.avg * (total_epoch - epoch), True) + ) epoch_str = "epoch={:03d}/{:03d}".format(epoch, total_epoch) LRs = scheduler.get_lr() find_best = False @@ -143,7 +181,14 @@ def main(args): if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): logger.log("-" * 150) valid_loss, valid_acc1, valid_acc5 = valid_func( - valid_loader, teacher, network, criterion, optim_config, epoch_str, args.print_freq_eval, logger + valid_loader, + teacher, + network, + criterion, + optim_config, + epoch_str, + args.print_freq_eval, + logger, ) valid_accuracies[epoch] = valid_acc1 logger.log( @@ -162,13 +207,24 @@ def main(args): find_best = True logger.log( "Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.".format( - epoch, valid_acc1, valid_acc5, 100 - valid_acc1, 100 - valid_acc5, model_best_path + epoch, + valid_acc1, + valid_acc5, + 100 - valid_acc1, + 100 - valid_acc5, + model_best_path, ) ) - num_bytes = torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 + num_bytes = ( + torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 + ) logger.log( "[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]".format( - next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9 + next(network.parameters()).device, + int(num_bytes), + num_bytes / 1e3, + num_bytes / 1e6, + num_bytes / 1e9, ) ) max_bytes[epoch] = num_bytes @@ -210,10 +266,16 @@ def main(args): start_time = time.time() logger.log("\n" + "-" * 200) - logger.log("||| Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format(param, flop, flop / 1e3)) + logger.log( + "||| Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format( + param, flop, flop / 1e3 + ) + ) logger.log( "Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}".format( - convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e6, logger.path("info") + convert_secs2time(epoch_time.sum, True), + max(v for k, v in max_bytes.items()) / 1e6, + logger.path("info"), ) ) logger.log("-" * 200 + "\n") diff --git a/exps/NAS-Bench-201/check.py b/exps/NAS-Bench-201/check.py index 4ef4153..f6929db 100644 --- a/exps/NAS-Bench-201/check.py +++ b/exps/NAS-Bench-201/check.py @@ -18,12 +18,16 @@ def check_files(save_dir, meta_file, basestr): meta_infos = torch.load(meta_file, map_location="cpu") meta_archs = meta_infos["archs"] meta_num_archs = meta_infos["total"] - assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format( - meta_num_archs, len(meta_archs) - ) + assert meta_num_archs == len( + meta_archs + ), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs)) sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) - print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs))) + print( + "{:} find {:} directories used to save checkpoints".format( + time_string(), len(sub_model_dirs) + ) + ) subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 num_seeds = defaultdict(lambda: 0) @@ -34,21 +38,29 @@ def check_files(save_dir, meta_file, basestr): for checkpoint in xcheckpoints: temp_names = checkpoint.name.split("-") assert ( - len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed" + len(temp_names) == 4 + and temp_names[0] == "arch" + and temp_names[2] == "seed" ), "invalid checkpoint name : {:}".format(checkpoint.name) arch_indexes.add(temp_names[1]) subdir2archs[sub_dir] = sorted(list(arch_indexes)) num_evaluated_arch += len(arch_indexes) # count number of seeds for each architecture for arch_index in arch_indexes: - num_seeds[len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))] += 1 + num_seeds[ + len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index)))) + ] += 1 print( "There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).".format( num_evaluated_arch, meta_num_archs, sum(k * v for k, v in num_seeds.items()) ) ) for key in sorted(list(num_seeds.keys())): - print("There are {:5d} architectures that are evaluated {:} times.".format(num_seeds[key], key)) + print( + "There are {:5d} architectures that are evaluated {:} times.".format( + num_seeds[key], key + ) + ) dir2ckps, dir2ckp_exists = dict(), dict() start_time, epoch_time = time.time(), AverageMeter() @@ -62,12 +74,14 @@ def check_files(save_dir, meta_file, basestr): numrs = defaultdict(lambda: 0) all_checkpoints, all_ckp_exists = [], [] for arch_index in arch_indexes: - checkpoints = ["arch-{:}-seed-{:04d}.pth".format(arch_index, seed) for seed in seeds] + checkpoints = [ + "arch-{:}-seed-{:04d}.pth".format(arch_index, seed) for seed in seeds + ] ckp_exists = [(sub_dir / x).exists() for x in checkpoints] arch_index = int(arch_index) - assert 0 <= arch_index < len(meta_archs), "invalid arch-index {:} (not found in meta_archs)".format( - arch_index - ) + assert ( + 0 <= arch_index < len(meta_archs) + ), "invalid arch-index {:} (not found in meta_archs)".format(arch_index) all_checkpoints += checkpoints all_ckp_exists += ckp_exists numrs[sum(ckp_exists)] += 1 @@ -76,7 +90,9 @@ def check_files(save_dir, meta_file, basestr): # measure time epoch_time.update(time.time() - start_time) start_time = time.time() - numrstr = ", ".join(["{:}: {:03d}".format(x, numrs[x]) for x in sorted(numrs.keys())]) + numrstr = ", ".join( + ["{:}: {:03d}".format(x, numrs[x]) for x in sorted(numrs.keys())] + ) print( "{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}".format( time_string(), @@ -95,7 +111,8 @@ def check_files(save_dir, meta_file, basestr): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="NAS Benchmark 201", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="NAS Benchmark 201", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--base_save_dir", @@ -104,9 +121,14 @@ if __name__ == "__main__": help="The base-name of folder to save checkpoints and log.", ) parser.add_argument( - "--meta_path", type=str, default="./output/NAS-BENCH-201-4/meta-node-4.pth", help="The meta file path." + "--meta_path", + type=str, + default="./output/NAS-BENCH-201-4/meta-node-4.pth", + help="The meta file path.", + ) + parser.add_argument( + "--base_str", type=str, default="C16-N5", help="The basic string." ) - parser.add_argument("--base_str", type=str, default="C16-N5", help="The basic string.") args = parser.parse_args() save_dir = Path(args.base_save_dir) diff --git a/exps/NAS-Bench-201/dist-setup.py b/exps/NAS-Bench-201/dist-setup.py index 0103f3f..9ec7dab 100644 --- a/exps/NAS-Bench-201/dist-setup.py +++ b/exps/NAS-Bench-201/dist-setup.py @@ -10,7 +10,9 @@ from setuptools import setup def read(fname="README.md"): - with open(os.path.join(os.path.dirname(__file__), fname), encoding="utf-8") as cfile: + with open( + os.path.join(os.path.dirname(__file__), fname), encoding="utf-8" + ) as cfile: return cfile.read() diff --git a/exps/NAS-Bench-201/functions.py b/exps/NAS-Bench-201/functions.py index d6eed59..5ac92bb 100644 --- a/exps/NAS-Bench-201/functions.py +++ b/exps/NAS-Bench-201/functions.py @@ -76,7 +76,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode): return losses.avg, top1.avg, top5.avg, batch_time.sum -def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger): +def evaluate_for_seed( + arch_config, config, arch, train_loader, valid_loaders, seed, logger +): prepare_seed(seed) # random seed net = get_cell_based_tiny_net( @@ -94,14 +96,29 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, se # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) flop, param = get_model_infos(net, config.xshape) logger.log("Network : {:}".format(net.get_message()), False) - logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed)) + logger.log( + "{:} Seed-------------------------- {:} --------------------------".format( + time_string(), seed + ) + ) logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) # train and valid optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config) network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() # start training - start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup - train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} + start_time, epoch_time, total_epoch = ( + time.time(), + AverageMeter(), + config.epochs + config.warmup, + ) + ( + train_losses, + train_acc1es, + train_acc5es, + valid_losses, + valid_acc1es, + valid_acc5es, + ) = ({}, {}, {}, {}, {}, {}) train_times, valid_times = {}, {} for epoch in range(total_epoch): scheduler.update(epoch, 0.0) @@ -126,7 +143,9 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, se # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True) + ) logger.log( "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]".format( time_string(), diff --git a/exps/NAS-Bench-201/main.py b/exps/NAS-Bench-201/main.py index b0dbf46..19a68e6 100644 --- a/exps/NAS-Bench-201/main.py +++ b/exps/NAS-Bench-201/main.py @@ -22,7 +22,9 @@ from models import CellStructure, CellArchitectures, get_search_spaces from functions import evaluate_for_seed -def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger): +def evaluate_all_datasets( + arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger +): machine_info, arch_config = get_machine_info(), deepcopy(arch_config) all_infos = {"info": machine_info} all_dataset_keys = [] @@ -36,27 +38,39 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c config_path = "configs/nas-benchmark/LESS.config" else: config_path = "configs/nas-benchmark/CIFAR.config" - split_info = load_config("configs/nas-benchmark/cifar-split.txt", None, None) + split_info = load_config( + "configs/nas-benchmark/cifar-split.txt", None, None + ) elif dataset.startswith("ImageNet16"): if use_less: config_path = "configs/nas-benchmark/LESS.config" else: config_path = "configs/nas-benchmark/ImageNet-16.config" - split_info = load_config("configs/nas-benchmark/{:}-split.txt".format(dataset), None, None) + split_info = load_config( + "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None + ) else: raise ValueError("invalid dataset : {:}".format(dataset)) - config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger) + config = load_config( + config_path, {"class_num": class_num, "xshape": xshape}, logger + ) # check whether use splited validation set if bool(split): assert dataset == "cifar10" ValLoaders = { "ori-test": torch.utils.data.DataLoader( - valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + valid_data, + batch_size=config.batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, ) } assert len(train_data) == len(split_info.train) + len( split_info.valid - ), "invalid length : {:} vs {:} + {:}".format(len(train_data), len(split_info.train), len(split_info.valid)) + ), "invalid length : {:} vs {:} + {:}".format( + len(train_data), len(split_info.train), len(split_info.valid) + ) train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 @@ -79,47 +93,67 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c else: # data loader train_loader = torch.utils.data.DataLoader( - train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + train_data, + batch_size=config.batch_size, + shuffle=True, + num_workers=workers, + pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( - valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + valid_data, + batch_size=config.batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, ) if dataset == "cifar10": ValLoaders = {"ori-test": valid_loader} elif dataset == "cifar100": - cifar100_splits = load_config("configs/nas-benchmark/cifar100-test-split.txt", None, None) + cifar100_splits = load_config( + "configs/nas-benchmark/cifar100-test-split.txt", None, None + ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + cifar100_splits.xvalid + ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + cifar100_splits.xtest + ), num_workers=workers, pin_memory=True, ), } elif dataset == "ImageNet16-120": - imagenet16_splits = load_config("configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None) + imagenet16_splits = load_config( + "configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None + ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + imagenet16_splits.xvalid + ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + imagenet16_splits.xtest + ), num_workers=workers, pin_memory=True, ), @@ -132,13 +166,24 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c dataset_key = dataset_key + "-valid" logger.log( "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( - dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size + dataset_key, + len(train_data), + len(valid_data), + len(train_loader), + len(valid_loader), + config.batch_size, ) ) - logger.log("Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)) + logger.log( + "Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config) + ) for key, value in ValLoaders.items(): - logger.log("Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))) - results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger) + logger.log( + "Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)) + ) + results = evaluate_for_seed( + arch_config, config, arch, train_loader, ValLoaders, seed, logger + ) all_infos[dataset_key] = results all_dataset_keys.append(dataset_key) all_infos["all_dataset_keys"] = all_dataset_keys @@ -146,7 +191,18 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c def main( - save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config + save_dir, + workers, + datasets, + xpaths, + splits, + use_less, + srange, + arch_index, + seeds, + cover_mode, + meta_info, + arch_config, ): assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True @@ -154,7 +210,9 @@ def main( torch.backends.cudnn.deterministic = True torch.set_num_threads(workers) - assert len(srange) == 2 and 0 <= srange[0] <= srange[1], "invalid srange : {:}".format(srange) + assert ( + len(srange) == 2 and 0 <= srange[0] <= srange[1] + ), "invalid srange : {:}".format(srange) if use_less: sub_dir = Path(save_dir) / "{:06d}-{:06d}-C{:}-N{:}-LESS".format( @@ -170,9 +228,9 @@ def main( assert srange[1] < meta_info["total"], "invalid range : {:}-{:} vs. {:}".format( srange[0], srange[1], meta_info["total"] ) - assert arch_index == -1 or srange[0] <= arch_index <= srange[1], "invalid range : {:} vs. {:} vs. {:}".format( - srange[0], arch_index, srange[1] - ) + assert ( + arch_index == -1 or srange[0] <= arch_index <= srange[1] + ), "invalid range : {:} vs. {:} vs. {:}".format(srange[0], arch_index, srange[1]) if arch_index == -1: to_evaluate_indexes = list(range(srange[0], srange[1] + 1)) else: @@ -200,7 +258,13 @@ def main( arch = all_archs[index] logger.log( "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}".format( - "-" * 15, i, len(to_evaluate_indexes), index, meta_info["total"], seeds, "-" * 15 + "-" * 15, + i, + len(to_evaluate_indexes), + index, + meta_info["total"], + seeds, + "-" * 15, ) ) # logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15)) @@ -212,10 +276,18 @@ def main( to_save_name = sub_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) if to_save_name.exists(): if cover_mode: - logger.log("Find existing file : {:}, remove it before evaluation".format(to_save_name)) + logger.log( + "Find existing file : {:}, remove it before evaluation".format( + to_save_name + ) + ) os.remove(str(to_save_name)) else: - logger.log("Find existing file : {:}, skip this evaluation".format(to_save_name)) + logger.log( + "Find existing file : {:}, skip this evaluation".format( + to_save_name + ) + ) has_continue = True continue results = evaluate_all_datasets( @@ -232,7 +304,13 @@ def main( torch.save(results, to_save_name) logger.log( "{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}".format( - "-" * 15, i, len(to_evaluate_indexes), index, meta_info["total"], seed, to_save_name + "-" * 15, + i, + len(to_evaluate_indexes), + index, + meta_info["total"], + seed, + to_save_name, ) ) # measure elapsed time @@ -242,7 +320,9 @@ def main( need_time = "Time Left: {:}".format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True) ) - logger.log("This arch costs : {:}".format(convert_secs2time(epoch_time.val, True))) + logger.log( + "This arch costs : {:}".format(convert_secs2time(epoch_time.val, True)) + ) logger.log("{:}".format("*" * 100)) logger.log( "{:} {:74s} {:}".format( @@ -258,7 +338,9 @@ def main( logger.close() -def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config): +def train_single_model( + save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config +): assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = True @@ -269,19 +351,32 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se Path(save_dir) / "specifics" / "{:}-{:}-{:}-{:}".format( - "LESS" if use_less else "FULL", model_str, arch_config["channel"], arch_config["num_cells"] + "LESS" if use_less else "FULL", + model_str, + arch_config["channel"], + arch_config["num_cells"], ) ) logger = Logger(str(save_dir), 0, False) if model_str in CellArchitectures: arch = CellArchitectures[model_str] - logger.log("The model string is found in pre-defined architecture dict : {:}".format(model_str)) + logger.log( + "The model string is found in pre-defined architecture dict : {:}".format( + model_str + ) + ) else: try: arch = CellStructure.str2structure(model_str) except: - raise ValueError("Invalid model string : {:}. It can not be found or parsed.".format(model_str)) - assert arch.check_valid_op(get_search_spaces("cell", "full")), "{:} has the invalid op.".format(arch) + raise ValueError( + "Invalid model string : {:}. It can not be found or parsed.".format( + model_str + ) + ) + assert arch.check_valid_op( + get_search_spaces("cell", "full") + ), "{:} has the invalid op.".format(arch) logger.log("Start train-evaluate {:}".format(arch.tostr())) logger.log("arch_config : {:}".format(arch_config)) @@ -294,27 +389,55 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se ) to_save_name = save_dir / "seed-{:04d}.pth".format(seed) if to_save_name.exists(): - logger.log("Find the existing file {:}, directly load!".format(to_save_name)) + logger.log( + "Find the existing file {:}, directly load!".format(to_save_name) + ) checkpoint = torch.load(to_save_name) else: - logger.log("Does not find the existing file {:}, train and evaluate!".format(to_save_name)) + logger.log( + "Does not find the existing file {:}, train and evaluate!".format( + to_save_name + ) + ) checkpoint = evaluate_all_datasets( - arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger + arch, + datasets, + xpaths, + splits, + use_less, + seed, + arch_config, + workers, + logger, ) torch.save(checkpoint, to_save_name) # log information logger.log("{:}".format(checkpoint["info"])) all_dataset_keys = checkpoint["all_dataset_keys"] for dataset_key in all_dataset_keys: - logger.log("\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15)) + logger.log( + "\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15) + ) dataset_info = checkpoint[dataset_key] # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] )) - logger.log("Flops = {:} MB, Params = {:} MB".format(dataset_info["flop"], dataset_info["param"])) + logger.log( + "Flops = {:} MB, Params = {:} MB".format( + dataset_info["flop"], dataset_info["param"] + ) + ) logger.log("config : {:}".format(dataset_info["config"])) - logger.log("Training State (finish) = {:}".format(dataset_info["finish-train"])) + logger.log( + "Training State (finish) = {:}".format(dataset_info["finish-train"]) + ) last_epoch = dataset_info["total_epoch"] - 1 - train_acc1es, train_acc5es = dataset_info["train_acc1es"], dataset_info["train_acc5es"] - valid_acc1es, valid_acc5es = dataset_info["valid_acc1es"], dataset_info["valid_acc5es"] + train_acc1es, train_acc5es = ( + dataset_info["train_acc1es"], + dataset_info["train_acc5es"], + ) + valid_acc1es, valid_acc5es = ( + dataset_info["valid_acc1es"], + dataset_info["valid_acc5es"], + ) logger.log( "Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format( train_acc1es[last_epoch], @@ -328,7 +451,9 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se # measure elapsed time seed_time.update(time.time() - start_time) start_time = time.time() - need_time = "Time Left: {:}".format(convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True) + ) logger.log( "\n<<<***>>> The {:02d}/{:02d}-th seed is {:} other procedures need {:}".format( _is, len(seeds), seed, need_time @@ -340,7 +465,11 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se def generate_meta_info(save_dir, max_node, divide=40): aa_nas_bench_ss = get_search_spaces("cell", "nas-bench-201") archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) - print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2))) + print( + "There are {:} archs vs {:}.".format( + len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2) + ) + ) random.seed(88) # please do not change this line for reproducibility random.shuffle(archs) @@ -352,10 +481,12 @@ def generate_meta_info(save_dir, max_node, divide=40): == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" ), "please check the 0-th architecture : {:}".format(archs[0]) assert ( - archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" + archs[9].tostr() + == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" ), "please check the 9-th architecture : {:}".format(archs[9]) assert ( - archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" + archs[123].tostr() + == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" ), "please check the 123-th architecture : {:}".format(archs[123]) total_arch = len(archs) @@ -374,11 +505,21 @@ def generate_meta_info(save_dir, max_node, divide=40): and valid_split[10] == 18 and valid_split[111] == 242 ), "{:} {:} {:} - {:} {:} {:}".format( - train_split[0], train_split[10], train_split[111], valid_split[0], valid_split[10], valid_split[111] + train_split[0], + train_split[10], + train_split[111], + valid_split[0], + valid_split[10], + valid_split[111], ) splits = {num: {"train": train_split, "valid": valid_split}} - info = {"archs": [x.tostr() for x in archs], "total": total_arch, "max_node": max_node, "splits": splits} + info = { + "archs": [x.tostr() for x in archs], + "total": total_arch, + "max_node": max_node, + "splits": splits, + } save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) @@ -404,7 +545,11 @@ def generate_meta_info(save_dir, max_node, divide=40): start, xend - 1 ) ) - print("save the training script into {:} and {:}".format(script_name_full, script_name_less)) + print( + "save the training script into {:} and {:}".format( + script_name_full, script_name_less + ) + ) full_file.close() less_file.close() @@ -425,29 +570,56 @@ if __name__ == "__main__": # mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()] # parser = argparse.ArgumentParser(description='Algorithm-Agnostic NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser( - description="NAS-Bench-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="NAS-Bench-201", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("--mode", type=str, required=True, help="The script mode.") - parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) parser.add_argument("--max_node", type=int, help="The maximum node in a cell.") # use for train the model - parser.add_argument("--workers", type=int, default=8, help="number of data loading workers (default: 2)") - parser.add_argument("--srange", type=int, nargs="+", help="The range of models to be evaluated") parser.add_argument( - "--arch_index", type=int, default=-1, help="The architecture index to be evaluated (cover mode)." + "--workers", + type=int, + default=8, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--srange", type=int, nargs="+", help="The range of models to be evaluated" + ) + parser.add_argument( + "--arch_index", + type=int, + default=-1, + help="The architecture index to be evaluated (cover mode).", ) parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.") - parser.add_argument("--xpaths", type=str, nargs="+", help="The root path for this dataset.") - parser.add_argument("--splits", type=int, nargs="+", help="The root path for this dataset.") - parser.add_argument("--use_less", type=int, default=0, choices=[0, 1], help="Using the less-training-epoch config.") - parser.add_argument("--seeds", type=int, nargs="+", help="The range of models to be evaluated") + parser.add_argument( + "--xpaths", type=str, nargs="+", help="The root path for this dataset." + ) + parser.add_argument( + "--splits", type=int, nargs="+", help="The root path for this dataset." + ) + parser.add_argument( + "--use_less", + type=int, + default=0, + choices=[0, 1], + help="Using the less-training-epoch config.", + ) + parser.add_argument( + "--seeds", type=int, nargs="+", help="The range of models to be evaluated" + ) parser.add_argument("--channel", type=int, help="The number of channels.") - parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument( + "--num_cells", type=int, help="The number of cells in one stage." + ) args = parser.parse_args() - assert args.mode in ["meta", "new", "cover"] or args.mode.startswith("specific-"), "invalid mode : {:}".format( - args.mode - ) + assert args.mode in ["meta", "new", "cover"] or args.mode.startswith( + "specific-" + ), "invalid mode : {:}".format(args.mode) if args.mode == "meta": generate_meta_info(args.save_dir, args.max_node) @@ -470,11 +642,15 @@ if __name__ == "__main__": assert meta_path.exists(), "{:} does not exist.".format(meta_path) meta_info = torch.load(meta_path) # check whether args is ok - assert len(args.srange) == 2 and args.srange[0] <= args.srange[1], "invalid length of srange args: {:}".format( - args.srange + assert ( + len(args.srange) == 2 and args.srange[0] <= args.srange[1] + ), "invalid length of srange args: {:}".format(args.srange) + assert len(args.seeds) > 0, "invalid length of seeds args: {:}".format( + args.seeds ) - assert len(args.seeds) > 0, "invalid length of seeds args: {:}".format(args.seeds) - assert len(args.datasets) == len(args.xpaths) == len(args.splits), "invalid infos : {:} vs {:} vs {:}".format( + assert ( + len(args.datasets) == len(args.xpaths) == len(args.splits) + ), "invalid infos : {:} vs {:} vs {:}".format( len(args.datasets), len(args.xpaths), len(args.splits) ) assert args.workers > 0, "invalid number of workers : {:}".format(args.workers) diff --git a/exps/NAS-Bench-201/show-best.py b/exps/NAS-Bench-201/show-best.py index 153d197..eb9e929 100644 --- a/exps/NAS-Bench-201/show-best.py +++ b/exps/NAS-Bench-201/show-best.py @@ -13,7 +13,12 @@ from nas_201_api import NASBench201API as API if __name__ == "__main__": parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") - parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.") + parser.add_argument( + "--api_path", + type=str, + default=None, + help="The path to the NAS-Bench-201 benchmark file.", + ) args = parser.parse_args() meta_file = Path(args.api_path) diff --git a/exps/NAS-Bench-201/statistics-v2.py b/exps/NAS-Bench-201/statistics-v2.py index dfeaa0b..d20a56d 100644 --- a/exps/NAS-Bench-201/statistics-v2.py +++ b/exps/NAS-Bench-201/statistics-v2.py @@ -19,7 +19,9 @@ from models import CellStructure, get_cell_based_tiny_net from nas_201_api import NASBench201API, ArchResults, ResultsCount from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders -api = NASBench201API("{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"])) +api = NASBench201API( + "{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"]) +) def create_result_count( @@ -55,33 +57,56 @@ def create_result_count( network.load_state_dict(xresult.get_net_param()) if "train_times" in results: # new version xresult.update_train_info( - results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"] + results["train_acc1es"], + results["train_acc5es"], + results["train_losses"], + results["train_times"], + ) + xresult.update_eval( + results["valid_acc1es"], results["valid_losses"], results["valid_times"] ) - xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"]) else: if dataset == "cifar10-valid": - xresult.update_OLD_eval("x-valid", results["valid_acc1es"], results["valid_losses"]) + xresult.update_OLD_eval( + "x-valid", results["valid_acc1es"], results["valid_losses"] + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda() ) - xresult.update_OLD_eval("ori-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_OLD_eval( + "ori-test", + {results["total_epoch"] - 1: top1}, + {results["total_epoch"] - 1: loss}, + ) xresult.update_latency(latencies) elif dataset == "cifar10": - xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + xresult.update_OLD_eval( + "ori-test", results["valid_acc1es"], results["valid_losses"] + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() ) xresult.update_latency(latencies) elif dataset == "cifar100" or dataset == "ImageNet16-120": - xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + xresult.update_OLD_eval( + "ori-test", results["valid_acc1es"], results["valid_losses"] + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda() ) - xresult.update_OLD_eval("x-valid", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_OLD_eval( + "x-valid", + {results["total_epoch"] - 1: top1}, + {results["total_epoch"] - 1: loss}, + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() ) - xresult.update_OLD_eval("x-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_OLD_eval( + "x-test", + {results["total_epoch"] - 1: top1}, + {results["total_epoch"] - 1: loss}, + ) xresult.update_latency(latencies) else: raise ValueError("invalid dataset name : {:}".format(dataset)) @@ -89,7 +114,11 @@ def create_result_count( def account_one_arch( - arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text], dataloader_dict: Dict[Text, Any] + arch_index: int, + arch_str: Text, + checkpoints: List[Text], + datasets: List[Text], + dataloader_dict: Dict[Text, Any], ) -> ArchResults: information = ArchResults(arch_index, arch_str) @@ -99,12 +128,18 @@ def account_one_arch( ok_dataset = 0 for dataset in datasets: if dataset not in checkpoint: - print("Can not find {:} in arch-{:} from {:}".format(dataset, arch_index, checkpoint_path)) + print( + "Can not find {:} in arch-{:} from {:}".format( + dataset, arch_index, checkpoint_path + ) + ) continue else: ok_dataset += 1 results = checkpoint[dataset] - assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( + assert results[ + "finish-train" + ], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( arch_index, used_seed, dataset, checkpoint_path ) arch_config = { @@ -114,17 +149,22 @@ def account_one_arch( "class_num": results["config"]["class_num"], } - xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict) + xresult = create_result_count( + used_seed, dataset, arch_config, results, dataloader_dict + ) information.update(dataset, int(used_seed), xresult) if ok_dataset == 0: raise ValueError("{:} does not find any data".format(checkpoint_path)) return information -def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch_info_less: ArchResults): +def correct_time_related_info( + arch_index: int, arch_info_full: ArchResults, arch_info_less: ArchResults +): # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth cifar010_latency = ( - api.get_latency(arch_index, "cifar10-valid", hp="200") + api.get_latency(arch_index, "cifar10", hp="200") + api.get_latency(arch_index, "cifar10-valid", hp="200") + + api.get_latency(arch_index, "cifar10", hp="200") ) / 2 arch_info_full.reset_latency("cifar10-valid", None, cifar010_latency) arch_info_full.reset_latency("cifar10", None, cifar010_latency) @@ -139,7 +179,9 @@ def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch arch_info_full.reset_latency("ImageNet16-120", None, image_latency) arch_info_less.reset_latency("ImageNet16-120", None, image_latency) - train_per_epoch_time = list(arch_info_less.query("cifar10-valid", 777).train_times.values()) + train_per_epoch_time = list( + arch_info_less.query("cifar10-valid", 777).train_times.values() + ) train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) eval_ori_test_time, eval_x_valid_time = [], [] for key, value in arch_info_less.query("cifar10-valid", 777).eval_times.items(): @@ -149,7 +191,9 @@ def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch eval_x_valid_time.append(value) else: raise ValueError("-- {:} --".format(key)) - eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time)) + eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float( + np.mean(eval_x_valid_time) + ) nums = { "ImageNet16-120-train": 151700, "ImageNet16-120-valid": 3000, @@ -162,36 +206,72 @@ def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch "cifar100-test": 10000, "cifar100-valid": 5000, } - eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (nums["cifar10-valid-valid"] + nums["cifar10-test"]) + eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / ( + nums["cifar10-valid-valid"] + nums["cifar10-test"] + ) for arch_info in [arch_info_less, arch_info_full]: arch_info.reset_pseudo_train_times( - "cifar10-valid", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-valid-train"] + "cifar10-valid", + None, + train_per_epoch_time + / nums["cifar10-valid-train"] + * nums["cifar10-valid-train"], ) arch_info.reset_pseudo_train_times( - "cifar10", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"] + "cifar10", + None, + train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"], ) arch_info.reset_pseudo_train_times( - "cifar100", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"] + "cifar100", + None, + train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"], ) arch_info.reset_pseudo_train_times( - "ImageNet16-120", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["ImageNet16-120-train"] + "ImageNet16-120", + None, + train_per_epoch_time + / nums["cifar10-valid-train"] + * nums["ImageNet16-120-train"], ) arch_info.reset_pseudo_eval_times( - "cifar10-valid", None, "x-valid", eval_per_sample * nums["cifar10-valid-valid"] - ) - arch_info.reset_pseudo_eval_times("cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"]) - arch_info.reset_pseudo_eval_times("cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"]) - arch_info.reset_pseudo_eval_times("cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"]) - arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"]) - arch_info.reset_pseudo_eval_times("cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"]) - arch_info.reset_pseudo_eval_times( - "ImageNet16-120", None, "x-valid", eval_per_sample * nums["ImageNet16-120-valid"] + "cifar10-valid", + None, + "x-valid", + eval_per_sample * nums["cifar10-valid-valid"], ) arch_info.reset_pseudo_eval_times( - "ImageNet16-120", None, "x-test", eval_per_sample * nums["ImageNet16-120-valid"] + "cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"] ) arch_info.reset_pseudo_eval_times( - "ImageNet16-120", None, "ori-test", eval_per_sample * nums["ImageNet16-120-test"] + "cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"] + ) + arch_info.reset_pseudo_eval_times( + "cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"] + ) + arch_info.reset_pseudo_eval_times( + "cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"] + ) + arch_info.reset_pseudo_eval_times( + "cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"] + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", + None, + "x-valid", + eval_per_sample * nums["ImageNet16-120-valid"], + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", + None, + "x-test", + eval_per_sample * nums["ImageNet16-120-valid"], + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", + None, + "ori-test", + eval_per_sample * nums["ImageNet16-120-test"], ) # arch_info_full.debug_test() # arch_info_less.debug_test() @@ -202,12 +282,16 @@ def simplify(save_dir, meta_file, basestr, target_dir): meta_infos = torch.load(meta_file, map_location="cpu") meta_archs = meta_infos["archs"] # a list of architecture strings meta_num_archs = meta_infos["total"] - assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format( - meta_num_archs, len(meta_archs) - ) + assert meta_num_archs == len( + meta_archs + ), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs)) sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) - print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs))) + print( + "{:} find {:} directories used to save checkpoints".format( + time_string(), len(sub_model_dirs) + ) + ) subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 num_seeds = defaultdict(lambda: 0) @@ -217,14 +301,18 @@ def simplify(save_dir, meta_file, basestr, target_dir): for checkpoint in xcheckpoints: temp_names = checkpoint.name.split("-") assert ( - len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed" + len(temp_names) == 4 + and temp_names[0] == "arch" + and temp_names[2] == "seed" ), "invalid checkpoint name : {:}".format(checkpoint.name) arch_indexes.add(temp_names[1]) subdir2archs[sub_dir] = sorted(list(arch_indexes)) num_evaluated_arch += len(arch_indexes) # count number of seeds for each architecture for arch_index in arch_indexes: - num_seeds[len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))] += 1 + num_seeds[ + len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index)))) + ] += 1 print( "{:} There are {:5d} architectures that have been evaluated ({:} in total).".format( time_string(), num_evaluated_arch, meta_num_archs @@ -232,7 +320,9 @@ def simplify(save_dir, meta_file, basestr, target_dir): ) for key in sorted(list(num_seeds.keys())): print( - "{:} There are {:5d} architectures that are evaluated {:} times.".format(time_string(), num_seeds[key], key) + "{:} There are {:5d} architectures that are evaluated {:} times.".format( + time_string(), num_seeds[key], key + ) ) dataloader_dict = get_nas_bench_loaders(6) @@ -243,8 +333,15 @@ def simplify(save_dir, meta_file, basestr, target_dir): if not to_save_allarc.exists(): to_save_allarc.mkdir(parents=True, exist_ok=True) - assert (save_dir / target_dir) in subdir2archs, "can not find {:}".format(target_dir) - arch2infos, datasets = {}, ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") + assert (save_dir / target_dir) in subdir2archs, "can not find {:}".format( + target_dir + ) + arch2infos, datasets = {}, ( + "cifar10-valid", + "cifar10", + "cifar100", + "ImageNet16-120", + ) evaluated_indexes = set() target_full_dir = save_dir / target_dir target_less_dir = save_dir / "{:}-LESS".format(target_dir) @@ -253,30 +350,46 @@ def simplify(save_dir, meta_file, basestr, target_dir): end_time = time.time() arch_time = AverageMeter() for idx, arch_index in enumerate(arch_indexes): - checkpoints = list(target_full_dir.glob("arch-{:}-seed-*.pth".format(arch_index))) + checkpoints = list( + target_full_dir.glob("arch-{:}-seed-*.pth".format(arch_index)) + ) ckps_less = list(target_less_dir.glob("arch-{:}-seed-*.pth".format(arch_index))) # create the arch info for each architecture try: arch_info_full = account_one_arch( - arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict + arch_index, + meta_archs[int(arch_index)], + checkpoints, + datasets, + dataloader_dict, ) arch_info_less = account_one_arch( - arch_index, meta_archs[int(arch_index)], ckps_less, datasets, dataloader_dict + arch_index, + meta_archs[int(arch_index)], + ckps_less, + datasets, + dataloader_dict, ) num_seeds[len(checkpoints)] += 1 except: print("Loading {:} failed, : {:}".format(arch_index, checkpoints)) continue - assert int(arch_index) not in evaluated_indexes, "conflict arch-index : {:}".format(arch_index) - assert 0 <= int(arch_index) < len(meta_archs), "invalid arch-index {:} (not found in meta_archs)".format( - arch_index - ) + assert ( + int(arch_index) not in evaluated_indexes + ), "conflict arch-index : {:}".format(arch_index) + assert ( + 0 <= int(arch_index) < len(meta_archs) + ), "invalid arch-index {:} (not found in meta_archs)".format(arch_index) arch_info = {"full": arch_info_full, "less": arch_info_less} evaluated_indexes.add(int(arch_index)) arch2infos[int(arch_index)] = arch_info # to correct the latency and training_time info. - arch_info_full, arch_info_less = correct_time_related_info(int(arch_index), arch_info_full, arch_info_less) - to_save_data = OrderedDict(full=arch_info_full.state_dict(), less=arch_info_less.state_dict()) + arch_info_full, arch_info_less = correct_time_related_info( + int(arch_index), arch_info_full, arch_info_less + ) + to_save_data = OrderedDict( + full=arch_info_full.state_dict(), less=arch_info_less.state_dict() + ) torch.save(to_save_data, to_save_allarc / "{:}-FULL.pth".format(arch_index)) arch_info["full"].clear_params() arch_info["less"].clear_params() @@ -284,14 +397,19 @@ def simplify(save_dir, meta_file, basestr, target_dir): # measure elapsed time arch_time.update(time.time() - end_time) end_time = time.time() - need_time = "{:}".format(convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1), True)) + need_time = "{:}".format( + convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1), True) + ) print( "{:} {:} [{:03d}/{:03d}] : {:} still need {:}".format( time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time ) ) # measure time - xstrs = ["{:}:{:03d}".format(key, num_seeds[key]) for key in sorted(list(num_seeds.keys()))] + xstrs = [ + "{:}:{:03d}".format(key, num_seeds[key]) + for key in sorted(list(num_seeds.keys())) + ] print("{:} {:} done : {:}".format(time_string(), target_dir, xstrs)) final_infos = { "meta_archs": meta_archs, @@ -303,7 +421,9 @@ def simplify(save_dir, meta_file, basestr, target_dir): save_file_name = to_save_simply / "{:}.pth".format(target_dir) torch.save(final_infos, save_file_name) print( - "Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name) + "Save {:} / {:} architecture results into {:}.".format( + len(evaluated_indexes), meta_num_archs, save_file_name + ) ) @@ -311,12 +431,16 @@ def merge_all(save_dir, meta_file, basestr): meta_infos = torch.load(meta_file, map_location="cpu") meta_archs = meta_infos["archs"] meta_num_archs = meta_infos["total"] - assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format( - meta_num_archs, len(meta_archs) - ) + assert meta_num_archs == len( + meta_archs + ), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs)) sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) - print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs))) + print( + "{:} find {:} directories used to save checkpoints".format( + time_string(), len(sub_model_dirs) + ) + ) for index, sub_dir in enumerate(sub_model_dirs): arch_info_files = sorted(list(sub_dir.glob("arch-*-seed-*.pth"))) print( @@ -330,11 +454,16 @@ def merge_all(save_dir, meta_file, basestr): ckp_path = sub_dir.parent / "simplifies" / "{:}.pth".format(sub_dir.name) if ckp_path.exists(): sub_ckps = torch.load(ckp_path, map_location="cpu") - assert sub_ckps["total_archs"] == meta_num_archs and sub_ckps["basestr"] == basestr + assert ( + sub_ckps["total_archs"] == meta_num_archs + and sub_ckps["basestr"] == basestr + ) xarch2infos = sub_ckps["arch2infos"] xevalindexs = sub_ckps["evaluated_indexes"] for eval_index in xevalindexs: - assert eval_index not in evaluated_indexes and eval_index not in arch2infos + assert ( + eval_index not in evaluated_indexes and eval_index not in arch2infos + ) # arch2infos[eval_index] = xarch2infos[eval_index].state_dict() arch2infos[eval_index] = { "full": xarch2infos[eval_index]["full"].state_dict(), @@ -351,7 +480,11 @@ def merge_all(save_dir, meta_file, basestr): # print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path)) evaluated_indexes = sorted(list(evaluated_indexes)) - print("Finally, there are {:} architectures that have been trained and evaluated.".format(len(evaluated_indexes))) + print( + "Finally, there are {:} architectures that have been trained and evaluated.".format( + len(evaluated_indexes) + ) + ) to_save_simply = save_dir / "simplifies" if not to_save_simply.exists(): @@ -365,16 +498,24 @@ def merge_all(save_dir, meta_file, basestr): save_file_name = to_save_simply / "{:}-final-infos.pth".format(basestr) torch.save(final_infos, save_file_name) print( - "Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name) + "Save {:} / {:} architecture results into {:}.".format( + len(evaluated_indexes), meta_num_archs, save_file_name + ) ) if __name__ == "__main__": parser = argparse.ArgumentParser( - description="NAS-BENCH-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="NAS-BENCH-201", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--mode", + type=str, + choices=["cal", "merge"], + help="The running mode for this script.", ) - parser.add_argument("--mode", type=str, choices=["cal", "merge"], help="The running mode for this script.") parser.add_argument( "--base_save_dir", type=str, @@ -382,16 +523,26 @@ if __name__ == "__main__": help="The base-name of folder to save checkpoints and log.", ) parser.add_argument("--target_dir", type=str, help="The target directory.") - parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.") - parser.add_argument("--channel", type=int, default=16, help="The number of channels.") - parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + parser.add_argument( + "--max_node", type=int, default=4, help="The maximum node in a cell." + ) + parser.add_argument( + "--channel", type=int, default=16, help="The number of channels." + ) + parser.add_argument( + "--num_cells", type=int, default=5, help="The number of cells in one stage." + ) args = parser.parse_args() save_dir = Path(args.base_save_dir) meta_path = save_dir / "meta-node-{:}.pth".format(args.max_node) assert save_dir.exists(), "invalid save dir path : {:}".format(save_dir) assert meta_path.exists(), "invalid saved meta path : {:}".format(meta_path) - print("start the statistics of our nas-benchmark from {:} using {:}.".format(save_dir, args.target_dir)) + print( + "start the statistics of our nas-benchmark from {:} using {:}.".format( + save_dir, args.target_dir + ) + ) basestr = "C{:}-N{:}".format(args.channel, args.num_cells) if args.mode == "cal": diff --git a/exps/NAS-Bench-201/statistics.py b/exps/NAS-Bench-201/statistics.py index 80985b9..14fbc8a 100644 --- a/exps/NAS-Bench-201/statistics.py +++ b/exps/NAS-Bench-201/statistics.py @@ -48,33 +48,56 @@ def create_result_count(used_seed, dataset, arch_config, results, dataloader_dic network.load_state_dict(xresult.get_net_param()) if "train_times" in results: # new version xresult.update_train_info( - results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"] + results["train_acc1es"], + results["train_acc5es"], + results["train_losses"], + results["train_times"], + ) + xresult.update_eval( + results["valid_acc1es"], results["valid_losses"], results["valid_times"] ) - xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"]) else: if dataset == "cifar10-valid": - xresult.update_OLD_eval("x-valid", results["valid_acc1es"], results["valid_losses"]) + xresult.update_OLD_eval( + "x-valid", results["valid_acc1es"], results["valid_losses"] + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda() ) - xresult.update_OLD_eval("ori-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_OLD_eval( + "ori-test", + {results["total_epoch"] - 1: top1}, + {results["total_epoch"] - 1: loss}, + ) xresult.update_latency(latencies) elif dataset == "cifar10": - xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + xresult.update_OLD_eval( + "ori-test", results["valid_acc1es"], results["valid_losses"] + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() ) xresult.update_latency(latencies) elif dataset == "cifar100" or dataset == "ImageNet16-120": - xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + xresult.update_OLD_eval( + "ori-test", results["valid_acc1es"], results["valid_losses"] + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda() ) - xresult.update_OLD_eval("x-valid", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_OLD_eval( + "x-valid", + {results["total_epoch"] - 1: top1}, + {results["total_epoch"] - 1: loss}, + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() ) - xresult.update_OLD_eval("x-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_OLD_eval( + "x-test", + {results["total_epoch"] - 1: top1}, + {results["total_epoch"] - 1: loss}, + ) xresult.update_latency(latencies) else: raise ValueError("invalid dataset name : {:}".format(dataset)) @@ -88,11 +111,15 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic checkpoint = torch.load(checkpoint_path, map_location="cpu") used_seed = checkpoint_path.name.split("-")[-1].split(".")[0] for dataset in datasets: - assert dataset in checkpoint, "Can not find {:} in arch-{:} from {:}".format( + assert ( + dataset in checkpoint + ), "Can not find {:} in arch-{:} from {:}".format( dataset, arch_index, checkpoint_path ) results = checkpoint[dataset] - assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( + assert results[ + "finish-train" + ], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( arch_index, used_seed, dataset, checkpoint_path ) arch_config = { @@ -102,7 +129,9 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic "class_num": results["config"]["class_num"], } - xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict) + xresult = create_result_count( + used_seed, dataset, arch_config, results, dataloader_dict + ) information.update(dataset, int(used_seed), xresult) return information @@ -118,14 +147,29 @@ def GET_DataLoaders(workers): cifar_config = load_config(cifar_config_path, None, None) print("{:} Create data-loader for all datasets".format(time_string())) print("-" * 200) - TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets("cifar10", str(torch_dir / "cifar.python"), -1) + TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets( + "cifar10", str(torch_dir / "cifar.python"), -1 + ) print( "original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format( len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num ) ) - cifar10_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None) - assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [ + cifar10_splits = load_config( + root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None + ) + assert cifar10_splits.train[:10] == [ + 0, + 5, + 7, + 11, + 13, + 15, + 16, + 17, + 20, + 24, + ] and cifar10_splits.valid[:10] == [ 1, 2, 3, @@ -141,7 +185,11 @@ def GET_DataLoaders(workers): temp_dataset.transform = VALID_CIFAR10.transform # data loader trainval_cifar10_loader = torch.utils.data.DataLoader( - TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + TRAIN_CIFAR10, + batch_size=cifar_config.batch_size, + shuffle=True, + num_workers=workers, + pin_memory=True, ) train_cifar10_loader = torch.utils.data.DataLoader( TRAIN_CIFAR10, @@ -158,7 +206,11 @@ def GET_DataLoaders(workers): pin_memory=True, ) test__cifar10_loader = torch.utils.data.DataLoader( - VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + VALID_CIFAR10, + batch_size=cifar_config.batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, ) print( "CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format( @@ -182,14 +234,29 @@ def GET_DataLoaders(workers): ) print("-" * 200) # CIFAR-100 - TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets("cifar100", str(torch_dir / "cifar.python"), -1) + TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets( + "cifar100", str(torch_dir / "cifar.python"), -1 + ) print( "original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num ) ) - cifar100_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None) - assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [ + cifar100_splits = load_config( + root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None + ) + assert cifar100_splits.xvalid[:10] == [ + 1, + 3, + 4, + 5, + 8, + 10, + 13, + 14, + 15, + 16, + ] and cifar100_splits.xtest[:10] == [ 0, 2, 6, @@ -202,7 +269,11 @@ def GET_DataLoaders(workers): 24, ] train_cifar100_loader = torch.utils.data.DataLoader( - TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + TRAIN_CIFAR100, + batch_size=cifar_config.batch_size, + shuffle=True, + num_workers=workers, + pin_memory=True, ) valid_cifar100_loader = torch.utils.data.DataLoader( VALID_CIFAR100, @@ -218,9 +289,15 @@ def GET_DataLoaders(workers): num_workers=workers, pin_memory=True, ) - print("CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader))) - print("CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))) - print("CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader))) + print( + "CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader)) + ) + print( + "CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader)) + ) + print( + "CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader)) + ) print("-" * 200) imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config" @@ -233,8 +310,23 @@ def GET_DataLoaders(workers): len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num ) ) - imagenet_splits = load_config(root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", None, None) - assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [ + imagenet_splits = load_config( + root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", + None, + None, + ) + assert imagenet_splits.xvalid[:10] == [ + 1, + 2, + 3, + 6, + 7, + 8, + 9, + 12, + 16, + 18, + ] and imagenet_splits.xtest[:10] == [ 0, 4, 5, @@ -304,12 +396,16 @@ def simplify(save_dir, meta_file, basestr, target_dir): meta_archs = meta_infos["archs"] # a list of architecture strings meta_num_archs = meta_infos["total"] meta_max_node = meta_infos["max_node"] - assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format( - meta_num_archs, len(meta_archs) - ) + assert meta_num_archs == len( + meta_archs + ), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs)) sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) - print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs))) + print( + "{:} find {:} directories used to save checkpoints".format( + time_string(), len(sub_model_dirs) + ) + ) subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 num_seeds = defaultdict(lambda: 0) @@ -319,14 +415,18 @@ def simplify(save_dir, meta_file, basestr, target_dir): for checkpoint in xcheckpoints: temp_names = checkpoint.name.split("-") assert ( - len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed" + len(temp_names) == 4 + and temp_names[0] == "arch" + and temp_names[2] == "seed" ), "invalid checkpoint name : {:}".format(checkpoint.name) arch_indexes.add(temp_names[1]) subdir2archs[sub_dir] = sorted(list(arch_indexes)) num_evaluated_arch += len(arch_indexes) # count number of seeds for each architecture for arch_index in arch_indexes: - num_seeds[len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))] += 1 + num_seeds[ + len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index)))) + ] += 1 print( "{:} There are {:5d} architectures that have been evaluated ({:} in total).".format( time_string(), num_evaluated_arch, meta_num_archs @@ -334,7 +434,9 @@ def simplify(save_dir, meta_file, basestr, target_dir): ) for key in sorted(list(num_seeds.keys())): print( - "{:} There are {:5d} architectures that are evaluated {:} times.".format(time_string(), num_seeds[key], key) + "{:} There are {:5d} architectures that are evaluated {:} times.".format( + time_string(), num_seeds[key], key + ) ) dataloader_dict = GET_DataLoaders(6) @@ -346,8 +448,15 @@ def simplify(save_dir, meta_file, basestr, target_dir): if not to_save_allarc.exists(): to_save_allarc.mkdir(parents=True, exist_ok=True) - assert (save_dir / target_dir) in subdir2archs, "can not find {:}".format(target_dir) - arch2infos, datasets = {}, ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") + assert (save_dir / target_dir) in subdir2archs, "can not find {:}".format( + target_dir + ) + arch2infos, datasets = {}, ( + "cifar10-valid", + "cifar10", + "cifar100", + "ImageNet16-120", + ) evaluated_indexes = set() target_directory = save_dir / target_dir target_less_dir = save_dir / "{:}-LESS".format(target_dir) @@ -356,24 +465,36 @@ def simplify(save_dir, meta_file, basestr, target_dir): end_time = time.time() arch_time = AverageMeter() for idx, arch_index in enumerate(arch_indexes): - checkpoints = list(target_directory.glob("arch-{:}-seed-*.pth".format(arch_index))) + checkpoints = list( + target_directory.glob("arch-{:}-seed-*.pth".format(arch_index)) + ) ckps_less = list(target_less_dir.glob("arch-{:}-seed-*.pth".format(arch_index))) # create the arch info for each architecture try: arch_info_full = account_one_arch( - arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict + arch_index, + meta_archs[int(arch_index)], + checkpoints, + datasets, + dataloader_dict, ) arch_info_less = account_one_arch( - arch_index, meta_archs[int(arch_index)], ckps_less, ["cifar10-valid"], dataloader_dict + arch_index, + meta_archs[int(arch_index)], + ckps_less, + ["cifar10-valid"], + dataloader_dict, ) num_seeds[len(checkpoints)] += 1 except: print("Loading {:} failed, : {:}".format(arch_index, checkpoints)) continue - assert int(arch_index) not in evaluated_indexes, "conflict arch-index : {:}".format(arch_index) - assert 0 <= int(arch_index) < len(meta_archs), "invalid arch-index {:} (not found in meta_archs)".format( - arch_index - ) + assert ( + int(arch_index) not in evaluated_indexes + ), "conflict arch-index : {:}".format(arch_index) + assert ( + 0 <= int(arch_index) < len(meta_archs) + ), "invalid arch-index {:} (not found in meta_archs)".format(arch_index) arch_info = {"full": arch_info_full, "less": arch_info_less} evaluated_indexes.add(int(arch_index)) arch2infos[int(arch_index)] = arch_info @@ -390,14 +511,19 @@ def simplify(save_dir, meta_file, basestr, target_dir): # measure elapsed time arch_time.update(time.time() - end_time) end_time = time.time() - need_time = "{:}".format(convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1), True)) + need_time = "{:}".format( + convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1), True) + ) print( "{:} {:} [{:03d}/{:03d}] : {:} still need {:}".format( time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time ) ) # measure time - xstrs = ["{:}:{:03d}".format(key, num_seeds[key]) for key in sorted(list(num_seeds.keys()))] + xstrs = [ + "{:}:{:03d}".format(key, num_seeds[key]) + for key in sorted(list(num_seeds.keys())) + ] print("{:} {:} done : {:}".format(time_string(), target_dir, xstrs)) final_infos = { "meta_archs": meta_archs, @@ -409,7 +535,9 @@ def simplify(save_dir, meta_file, basestr, target_dir): save_file_name = to_save_simply / "{:}.pth".format(target_dir) torch.save(final_infos, save_file_name) print( - "Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name) + "Save {:} / {:} architecture results into {:}.".format( + len(evaluated_indexes), meta_num_archs, save_file_name + ) ) @@ -418,12 +546,16 @@ def merge_all(save_dir, meta_file, basestr): meta_archs = meta_infos["archs"] meta_num_archs = meta_infos["total"] meta_max_node = meta_infos["max_node"] - assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format( - meta_num_archs, len(meta_archs) - ) + assert meta_num_archs == len( + meta_archs + ), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs)) sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) - print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs))) + print( + "{:} find {:} directories used to save checkpoints".format( + time_string(), len(sub_model_dirs) + ) + ) for index, sub_dir in enumerate(sub_model_dirs): arch_info_files = sorted(list(sub_dir.glob("arch-*-seed-*.pth"))) print( @@ -437,11 +569,16 @@ def merge_all(save_dir, meta_file, basestr): ckp_path = sub_dir.parent / "simplifies" / "{:}.pth".format(sub_dir.name) if ckp_path.exists(): sub_ckps = torch.load(ckp_path, map_location="cpu") - assert sub_ckps["total_archs"] == meta_num_archs and sub_ckps["basestr"] == basestr + assert ( + sub_ckps["total_archs"] == meta_num_archs + and sub_ckps["basestr"] == basestr + ) xarch2infos = sub_ckps["arch2infos"] xevalindexs = sub_ckps["evaluated_indexes"] for eval_index in xevalindexs: - assert eval_index not in evaluated_indexes and eval_index not in arch2infos + assert ( + eval_index not in evaluated_indexes and eval_index not in arch2infos + ) # arch2infos[eval_index] = xarch2infos[eval_index].state_dict() arch2infos[eval_index] = { "full": xarch2infos[eval_index]["full"].state_dict(), @@ -458,7 +595,11 @@ def merge_all(save_dir, meta_file, basestr): # print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path)) evaluated_indexes = sorted(list(evaluated_indexes)) - print("Finally, there are {:} architectures that have been trained and evaluated.".format(len(evaluated_indexes))) + print( + "Finally, there are {:} architectures that have been trained and evaluated.".format( + len(evaluated_indexes) + ) + ) to_save_simply = save_dir / "simplifies" if not to_save_simply.exists(): @@ -472,16 +613,24 @@ def merge_all(save_dir, meta_file, basestr): save_file_name = to_save_simply / "{:}-final-infos.pth".format(basestr) torch.save(final_infos, save_file_name) print( - "Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name) + "Save {:} / {:} architecture results into {:}.".format( + len(evaluated_indexes), meta_num_archs, save_file_name + ) ) if __name__ == "__main__": parser = argparse.ArgumentParser( - description="NAS-BENCH-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="NAS-BENCH-201", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--mode", + type=str, + choices=["cal", "merge"], + help="The running mode for this script.", ) - parser.add_argument("--mode", type=str, choices=["cal", "merge"], help="The running mode for this script.") parser.add_argument( "--base_save_dir", type=str, @@ -489,16 +638,26 @@ if __name__ == "__main__": help="The base-name of folder to save checkpoints and log.", ) parser.add_argument("--target_dir", type=str, help="The target directory.") - parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.") - parser.add_argument("--channel", type=int, default=16, help="The number of channels.") - parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + parser.add_argument( + "--max_node", type=int, default=4, help="The maximum node in a cell." + ) + parser.add_argument( + "--channel", type=int, default=16, help="The number of channels." + ) + parser.add_argument( + "--num_cells", type=int, default=5, help="The number of cells in one stage." + ) args = parser.parse_args() save_dir = Path(args.base_save_dir) meta_path = save_dir / "meta-node-{:}.pth".format(args.max_node) assert save_dir.exists(), "invalid save dir path : {:}".format(save_dir) assert meta_path.exists(), "invalid saved meta path : {:}".format(meta_path) - print("start the statistics of our nas-benchmark from {:} using {:}.".format(save_dir, args.target_dir)) + print( + "start the statistics of our nas-benchmark from {:} using {:}.".format( + save_dir, args.target_dir + ) + ) basestr = "C{:}-N{:}".format(args.channel, args.num_cells) if args.mode == "cal": diff --git a/exps/NAS-Bench-201/test-correlation.py b/exps/NAS-Bench-201/test-correlation.py index aaf2e14..31225ec 100644 --- a/exps/NAS-Bench-201/test-correlation.py +++ b/exps/NAS-Bench-201/test-correlation.py @@ -25,7 +25,11 @@ def check_unique_arch(meta_file): def get_unique_matrix(archs, consider_zero): UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] - print("{:} create unique-string ({:}/{:}) done".format(time_string(), len(set(UniquStrs)), len(UniquStrs))) + print( + "{:} create unique-string ({:}/{:}) done".format( + time_string(), len(set(UniquStrs)), len(UniquStrs) + ) + ) Unique2Index = dict() for index, xstr in enumerate(UniquStrs): if xstr not in Unique2Index: @@ -47,16 +51,32 @@ def check_unique_arch(meta_file): unique_num += 1 return sm_matrix, unique_ids, unique_num - print("There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs))) + print( + "There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs)) + ) sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None) - print("{:} There are {:} unique architectures (considering nothing).".format(time_string(), unique_num)) + print( + "{:} There are {:} unique architectures (considering nothing).".format( + time_string(), unique_num + ) + ) sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False) - print("{:} There are {:} unique architectures (not considering zero).".format(time_string(), unique_num)) + print( + "{:} There are {:} unique architectures (not considering zero).".format( + time_string(), unique_num + ) + ) sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True) - print("{:} There are {:} unique architectures (considering zero).".format(time_string(), unique_num)) + print( + "{:} There are {:} unique architectures (considering zero).".format( + time_string(), unique_num + ) + ) -def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False): +def check_cor_for_bandit( + meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False +): if isinstance(meta_file, API): api = meta_file else: @@ -69,7 +89,9 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n imagenet_test = [] imagenet_valid = [] for idx, arch in enumerate(api): - results = api.get_more_info(idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand) + results = api.get_more_info( + idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand + ) cifar10_currs.append(results["valid-accuracy"]) # --->>>>> results = api.get_more_info(idx, "cifar10-valid", None, False, is_rand) @@ -89,13 +111,23 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n cors = [] for basestr, xlist in zip( ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"], - [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test], + [ + cifar10_valid, + cifar10_test, + cifar100_valid, + cifar100_test, + imagenet_valid, + imagenet_test, + ], ): correlation = get_cor(cifar10_currs, xlist) if need_print: print( "With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}".format( - test_epoch, "012" if use_less_or_not else "200", basestr, correlation + test_epoch, + "012" if use_less_or_not else "200", + basestr, + correlation, ) ) cors.append(correlation) @@ -113,7 +145,11 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand): # xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] xstrs = ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"] correlations = np.array(corrs) - print("------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(test_epoch, "012" if use_less_or_not else "200")) + print( + "------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format( + test_epoch, "012" if use_less_or_not else "200" + ) + ) for idx, xstr in enumerate(xstrs): print( "{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}".format( @@ -135,7 +171,12 @@ if __name__ == "__main__": default="./output/search-cell-nas-bench-201/visuals", help="The base-name of folder to save checkpoints and log.", ) - parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.") + parser.add_argument( + "--api_path", + type=str, + default=None, + help="The path to the NAS-Bench-201 benchmark file.", + ) args = parser.parse_args() vis_save_dir = Path(args.save_dir) diff --git a/exps/NAS-Bench-201/visualize.py b/exps/NAS-Bench-201/visualize.py index 8e91a92..588c17c 100644 --- a/exps/NAS-Bench-201/visualize.py +++ b/exps/NAS-Bench-201/visualize.py @@ -47,15 +47,21 @@ def visualize_relative_ranking(vis_save_dir): print("{:} start to visualize relative ranking".format(time_string())) # maximum accuracy with ResNet-level params 11472 x_010_accs = [ - cifar010_info["test_accs"][i] if cifar010_info["params"][i] <= cifar010_info["params"][11472] else -1 + cifar010_info["test_accs"][i] + if cifar010_info["params"][i] <= cifar010_info["params"][11472] + else -1 for i in indexes ] x_100_accs = [ - cifar100_info["test_accs"][i] if cifar100_info["params"][i] <= cifar100_info["params"][11472] else -1 + cifar100_info["test_accs"][i] + if cifar100_info["params"][i] <= cifar100_info["params"][11472] + else -1 for i in indexes ] x_img_accs = [ - imagenet_info["test_accs"][i] if imagenet_info["params"][i] <= imagenet_info["params"][11472] else -1 + imagenet_info["test_accs"][i] + if imagenet_info["params"][i] <= imagenet_info["params"][11472] + else -1 for i in indexes ] @@ -79,8 +85,15 @@ def visualize_relative_ranking(vis_save_dir): plt.xlim(min(indexes), max(indexes)) plt.ylim(min(indexes), max(indexes)) # plt.ylabel('y').set_rotation(0) - plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize, rotation="vertical") - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize) + plt.yticks( + np.arange(min(indexes), max(indexes), max(indexes) // 6), + fontsize=LegendFontsize, + rotation="vertical", + ) + plt.xticks( + np.arange(min(indexes), max(indexes), max(indexes) // 6), + fontsize=LegendFontsize, + ) # ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100') # ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8, label='ImageNet-16-120') # ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10') @@ -113,7 +126,9 @@ def visualize_relative_ranking(vis_save_dir): ) fig = plt.figure(figsize=figsize) plt.axis("off") - h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={"size": sns_size}, fmt=".3f", linewidths=0.5) + h = sns.heatmap( + CoRelMatrix, annot=True, annot_kws={"size": sns_size}, fmt=".3f", linewidths=0.5 + ) save_path = (vis_save_dir / "co-relation-all.pdf").resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") print("{:} save into {:}".format(time_string(), save_path)) @@ -142,8 +157,16 @@ def visualize_relative_ranking(vis_save_dir): ) fig = plt.figure(figsize=figsize) plt.axis("off") - h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={"size": sns_size}, fmt=".3f", linewidths=0.5) - save_path = (vis_save_dir / "co-relation-top-{:}.pdf".format(len(selected_indexes))).resolve() + h = sns.heatmap( + CoRelMatrix, + annot=True, + annot_kws={"size": sns_size}, + fmt=".3f", + linewidths=0.5, + ) + save_path = ( + vis_save_dir / "co-relation-top-{:}.pdf".format(len(selected_indexes)) + ).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") print("{:} save into {:}".format(time_string(), save_path)) plt.close("all") @@ -155,7 +178,14 @@ def visualize_info(meta_file, dataset, vis_save_dir): if not cache_file_path.exists(): print("Do not find cache file : {:}".format(cache_file_path)) nas_bench = API(str(meta_file)) - params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], [] + params, flops, train_accs, valid_accs, test_accs, otest_accs = ( + [], + [], + [], + [], + [], + [], + ) for index in range(len(nas_bench)): info = nas_bench.query_by_index(index, use_12epochs_result=False) resx = info.get_comput_costs(dataset) @@ -239,7 +269,13 @@ def visualize_info(meta_file, dataset, vis_save_dir): plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) ax.scatter(params, valid_accs, marker="o", s=0.5, c="tab:blue") ax.scatter( - [resnet["params"]], [resnet["valid_acc"]], marker="*", s=resnet_scale, c="tab:orange", label="resnet", alpha=0.4 + [resnet["params"]], + [resnet["valid_acc"]], + marker="*", + s=resnet_scale, + c="tab:orange", + label="resnet", + alpha=0.4, ) plt.grid(zorder=0) ax.set_axisbelow(True) @@ -321,7 +357,10 @@ def visualize_info(meta_file, dataset, vis_save_dir): fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) plt.xlim(0, max(indexes)) - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize) + plt.xticks( + np.arange(min(indexes), max(indexes), max(indexes) // 5), + fontsize=LegendFontsize, + ) if dataset == "cifar10": plt.ylim(50, 100) plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) @@ -357,7 +396,11 @@ def visualize_info(meta_file, dataset, vis_save_dir): def visualize_rank_over_time(meta_file, vis_save_dir): print("\n" + "-" * 150) vis_save_dir.mkdir(parents=True, exist_ok=True) - print("{:} start to visualize rank-over-time into {:}".format(time_string(), vis_save_dir)) + print( + "{:} start to visualize rank-over-time into {:}".format( + time_string(), vis_save_dir + ) + ) cache_file_path = vis_save_dir / "rank-over-time-cache-info.pth" if not cache_file_path.exists(): print("Do not find cache file : {:}".format(cache_file_path)) @@ -434,17 +477,26 @@ def visualize_rank_over_time(meta_file, vis_save_dir): plt.xlim(min(indexes), max(indexes)) plt.ylim(min(indexes), max(indexes)) plt.yticks( - np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize, rotation="vertical" + np.arange(min(indexes), max(indexes), max(indexes) // 6), + fontsize=LegendFontsize, + rotation="vertical", + ) + plt.xticks( + np.arange(min(indexes), max(indexes), max(indexes) // 6), + fontsize=LegendFontsize, ) - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize) ax.scatter(indexes, valid_ord_lbls, marker="^", s=0.5, c="tab:green", alpha=0.8) ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) - ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-10 validation") + ax.scatter( + [-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-10 validation" + ) ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10 test") plt.grid(zorder=0) ax.set_axisbelow(True) plt.legend(loc="upper left", fontsize=LegendFontsize) - ax.set_xlabel("architecture ranking in the final test accuracy", fontsize=LabelSize) + ax.set_xlabel( + "architecture ranking in the final test accuracy", fontsize=LabelSize + ) ax.set_ylabel("architecture ranking in the validation set", fontsize=LabelSize) save_path = (vis_save_dir / "time-{:03d}.pdf".format(sepoch)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") @@ -464,7 +516,9 @@ def write_video(save_dir): # shape = (ximage.shape[1], ximage.shape[0]) shape = (1000, 1000) # writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape) - writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape) + writer = cv2.VideoWriter( + str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape + ) for idx, image in enumerate(images): ximage = cv2.imread(str(image)) _image = cv2.resize(ximage, shape) @@ -490,9 +544,13 @@ def plot_results_nas_v2(api, dataset_xset_a, dataset_xset_b, root, file_name, y_ accuracies = [] for x in all_indexes: info = api.arch2infos_full[x] - metrics = info.get_metrics(dataset_xset_a[0], dataset_xset_a[1], None, False) + metrics = info.get_metrics( + dataset_xset_a[0], dataset_xset_a[1], None, False + ) accuracies_A.append(metrics["accuracy"]) - metrics = info.get_metrics(dataset_xset_b[0], dataset_xset_b[1], None, False) + metrics = info.get_metrics( + dataset_xset_b[0], dataset_xset_b[1], None, False + ) accuracies_B.append(metrics["accuracy"]) accuracies.append((accuracies_A[-1], accuracies_B[-1])) if indexes is None: @@ -580,7 +638,14 @@ def plot_results_nas(api, dataset, xset, root, file_name, y_lims): plt.ylabel("The accuracy (%)", fontsize=LabelSize) for idx, legend in enumerate(legends): - plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle="-", label="{:}".format(legend), lw=2) + plt.plot( + indexes, + All_Accs[legend], + color=color_set[idx], + linestyle="-", + label="{:}".format(legend), + lw=2, + ) print( "{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}".format( legend, @@ -646,13 +711,19 @@ def just_show(api): return xresults for xkey in xpaths.keys(): - all_paths = ["{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]] + all_paths = [ + "{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey] + ] all_datas = [torch.load(xpath) for xpath in all_paths] accyss = [get_accs(xdatas) for xdatas in all_datas] accyss = np.array(accyss) print("\nxkey = {:}".format(xkey)) for i in range(accyss.shape[1]): - print("---->>>> {:.2f}$\\pm${:.2f}".format(accyss[:, i].mean(), accyss[:, i].std())) + print( + "---->>>> {:.2f}$\\pm${:.2f}".format( + accyss[:, i].mean(), accyss[:, i].std() + ) + ) print("\n{:}".format(get_accs(None, 11472))) # resnet pairs = [ @@ -665,10 +736,16 @@ def just_show(api): ] for dataset, metric_on_set in pairs: arch_index, highest_acc = api.find_best(dataset, metric_on_set) - print("[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}".format(dataset, metric_on_set, arch_index, highest_acc)) + print( + "[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}".format( + dataset, metric_on_set, arch_index, highest_acc + ) + ) -def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_lims, x_maxs): +def show_nas_sharing_w( + api, dataset, subset, vis_save_dir, sufix, file_name, y_lims, x_maxs +): color_set = ["r", "b", "g", "c", "m", "y", "k"] dpi, width, height = 300, 3400, 2600 LabelSize, LegendFontsize = 28, 28 @@ -685,12 +762,24 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_l plt.ylabel("The accuracy (%)", fontsize=LabelSize) xpaths = { - "RSPS": "output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/".format(sufix), - "DARTS-V1": "output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/".format(sufix), - "DARTS-V2": "output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/".format(sufix), - "GDAS": "output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/".format(sufix), - "SETN": "output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/".format(sufix), - "ENAS": "output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/".format(sufix), + "RSPS": "output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/".format( + sufix + ), + "DARTS-V1": "output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/".format( + sufix + ), + "DARTS-V2": "output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/".format( + sufix + ), + "GDAS": "output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/".format( + sufix + ), + "SETN": "output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/".format( + sufix + ), + "ENAS": "output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/".format( + sufix + ), } """ xseeds = {'RSPS' : [5349, 59613, 5983], @@ -713,16 +802,20 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_l def get_accs(xdata): epochs, xresults = xdata["epoch"], [] if -1 in xdata["genotypes"]: - metrics = api.arch2infos_full[api.query_index_by_arch(xdata["genotypes"][-1])].get_metrics( + metrics = api.arch2infos_full[ + api.query_index_by_arch(xdata["genotypes"][-1]) + ].get_metrics(dataset, subset, None, False) + else: + metrics = api.arch2infos_full[api.random()].get_metrics( dataset, subset, None, False ) - else: - metrics = api.arch2infos_full[api.random()].get_metrics(dataset, subset, None, False) xresults.append(metrics["accuracy"]) for iepoch in range(epochs): genotype = xdata["genotypes"][iepoch] index = api.query_index_by_arch(genotype) - metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False) + metrics = api.arch2infos_full[index].get_metrics( + dataset, subset, None, False + ) xresults.append(metrics["accuracy"]) return xresults @@ -735,7 +828,9 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_l for idx, method in enumerate(xxxstrs): xkey = method - all_paths = ["{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]] + all_paths = [ + "{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey] + ] all_datas = [torch.load(xpath, map_location="cpu") for xpath in all_paths] accyss = [get_accs(xdatas) for xdatas in all_datas] accyss = np.array(accyss) @@ -762,7 +857,9 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_l fig.savefig(str(save_path), dpi=dpi, bbox_inches="tight", format="pdf") -def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file_name, y_lims, x_maxs): +def show_nas_sharing_w_v2( + api, data_sub_a, data_sub_b, vis_save_dir, sufix, file_name, y_lims, x_maxs +): color_set = ["r", "b", "g", "c", "m", "y", "k"] dpi, width, height = 300, 3400, 2600 LabelSize, LegendFontsize = 28, 28 @@ -779,12 +876,24 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file plt.ylabel("The accuracy (%)", fontsize=LabelSize) xpaths = { - "RSPS": "output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/".format(sufix), - "DARTS-V1": "output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/".format(sufix), - "DARTS-V2": "output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/".format(sufix), - "GDAS": "output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/".format(sufix), - "SETN": "output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/".format(sufix), - "ENAS": "output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/".format(sufix), + "RSPS": "output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/".format( + sufix + ), + "DARTS-V1": "output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/".format( + sufix + ), + "DARTS-V2": "output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/".format( + sufix + ), + "GDAS": "output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/".format( + sufix + ), + "SETN": "output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/".format( + sufix + ), + "ENAS": "output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/".format( + sufix + ), } """ xseeds = {'RSPS' : [5349, 59613, 5983], @@ -807,16 +916,20 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file def get_accs(xdata, dataset, subset): epochs, xresults = xdata["epoch"], [] if -1 in xdata["genotypes"]: - metrics = api.arch2infos_full[api.query_index_by_arch(xdata["genotypes"][-1])].get_metrics( + metrics = api.arch2infos_full[ + api.query_index_by_arch(xdata["genotypes"][-1]) + ].get_metrics(dataset, subset, None, False) + else: + metrics = api.arch2infos_full[api.random()].get_metrics( dataset, subset, None, False ) - else: - metrics = api.arch2infos_full[api.random()].get_metrics(dataset, subset, None, False) xresults.append(metrics["accuracy"]) for iepoch in range(epochs): genotype = xdata["genotypes"][iepoch] index = api.query_index_by_arch(genotype) - metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False) + metrics = api.arch2infos_full[index].get_metrics( + dataset, subset, None, False + ) xresults.append(metrics["accuracy"]) return xresults @@ -829,10 +942,16 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file for idx, method in enumerate(xxxstrs): xkey = method - all_paths = ["{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]] + all_paths = [ + "{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey] + ] all_datas = [torch.load(xpath, map_location="cpu") for xpath in all_paths] - accyss_A = np.array([get_accs(xdatas, data_sub_a[0], data_sub_a[1]) for xdatas in all_datas]) - accyss_B = np.array([get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas]) + accyss_A = np.array( + [get_accs(xdatas, data_sub_a[0], data_sub_a[1]) for xdatas in all_datas] + ) + accyss_B = np.array( + [get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas] + ) epochs = list(range(accyss_A.shape[1])) for j, accyss in enumerate([accyss_A, accyss_B]): if x_maxs == 50: @@ -859,7 +978,9 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file ) setname = data_sub_a if j == 0 else data_sub_b print( - "{:} -- {:} ---- {:.2f}$\\pm${:.2f}".format(method, setname, accyss[:, -1].mean(), accyss[:, -1].std()) + "{:} -- {:} ---- {:.2f}$\\pm${:.2f}".format( + method, setname, accyss[:, -1].mean(), accyss[:, -1].std() + ) ) # plt.legend(loc=4, fontsize=LegendFontsize) plt.legend(loc=0, fontsize=LegendFontsize) @@ -871,7 +992,10 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file def show_reinforce(api, root, dataset, xset, file_name, y_lims): print("root-path={:}, dataset={:}, xset={:}".format(root, dataset, xset)) LRs = ["0.01", "0.02", "0.1", "0.2", "0.5"] - checkpoints = ["./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth".format(x) for x in LRs] + checkpoints = [ + "./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth".format(x) + for x in LRs + ] acc_lr_dict, indexes = {}, None for lr, checkpoint in zip(LRs, checkpoints): all_indexes, accuracies = torch.load(checkpoint, map_location="cpu"), [] @@ -882,7 +1006,11 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims): if indexes is None: indexes = list(range(len(accuracies))) acc_lr_dict[lr] = np.array(sorted(accuracies)) - print("LR={:.3f}, mean={:}, std={:}".format(float(lr), acc_lr_dict[lr].mean(), acc_lr_dict[lr].std())) + print( + "LR={:.3f}, mean={:}, std={:}".format( + float(lr), acc_lr_dict[lr].mean(), acc_lr_dict[lr].std() + ) + ) color_set = ["r", "b", "g", "c", "m", "y", "k"] dpi, width, height = 300, 3400, 2600 @@ -903,7 +1031,15 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims): legend = "LR={:.2f}".format(float(LR)) # color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.' color, linestyle = color_set[idx], "-" - plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8) + plt.plot( + indexes, + acc_lr_dict[LR], + color=color, + linestyle=linestyle, + label=legend, + lw=2, + alpha=0.8, + ) print( "{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}".format( legend, @@ -922,7 +1058,10 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims): def show_rea(api, root, dataset, xset, file_name, y_lims): print("root-path={:}, dataset={:}, xset={:}".format(root, dataset, xset)) SSs = [3, 5, 10] - checkpoints = ["./output/search-cell-nas-bench-201/R-EA-cifar10-SS{:}/results.pth".format(x) for x in SSs] + checkpoints = [ + "./output/search-cell-nas-bench-201/R-EA-cifar10-SS{:}/results.pth".format(x) + for x in SSs + ] acc_ss_dict, indexes = {}, None for ss, checkpoint in zip(SSs, checkpoints): all_indexes, accuracies = torch.load(checkpoint, map_location="cpu"), [] @@ -933,7 +1072,11 @@ def show_rea(api, root, dataset, xset, file_name, y_lims): if indexes is None: indexes = list(range(len(accuracies))) acc_ss_dict[ss] = np.array(sorted(accuracies)) - print("Sample-Size={:2d}, mean={:}, std={:}".format(ss, acc_ss_dict[ss].mean(), acc_ss_dict[ss].std())) + print( + "Sample-Size={:2d}, mean={:}, std={:}".format( + ss, acc_ss_dict[ss].mean(), acc_ss_dict[ss].std() + ) + ) color_set = ["r", "b", "g", "c", "m", "y", "k"] dpi, width, height = 300, 3400, 2600 @@ -954,7 +1097,15 @@ def show_rea(api, root, dataset, xset, file_name, y_lims): legend = "sample-size={:2d}".format(ss) # color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.' color, linestyle = color_set[idx], "-" - plt.plot(indexes, acc_ss_dict[ss], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8) + plt.plot( + indexes, + acc_ss_dict[ss], + color=color, + linestyle=linestyle, + label=legend, + lw=2, + alpha=0.8, + ) print( "{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}".format( legend, @@ -973,7 +1124,8 @@ def show_rea(api, root, dataset, xset, file_name, y_lims): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="NAS-Bench-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="NAS-Bench-201", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--save_dir", @@ -981,7 +1133,12 @@ if __name__ == "__main__": default="./output/search-cell-nas-bench-201/visuals", help="The base-name of folder to save checkpoints and log.", ) - parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.") + parser.add_argument( + "--api_path", + type=str, + default=None, + help="The path to the NAS-Bench-201 benchmark file.", + ) args = parser.parse_args() vis_save_dir = Path(args.save_dir) @@ -1066,9 +1223,25 @@ if __name__ == "__main__": ) show_nas_sharing_w( - api, "cifar10-valid", "x-valid", vis_save_dir, "BN0", "BN0-XX-CIFAR010-VALID.pdf", (0, 100, 10), 250 + api, + "cifar10-valid", + "x-valid", + vis_save_dir, + "BN0", + "BN0-XX-CIFAR010-VALID.pdf", + (0, 100, 10), + 250, + ) + show_nas_sharing_w( + api, + "cifar10", + "ori-test", + vis_save_dir, + "BN0", + "BN0-XX-CIFAR010-TEST.pdf", + (0, 100, 10), + 250, ) - show_nas_sharing_w(api, "cifar10", "ori-test", vis_save_dir, "BN0", "BN0-XX-CIFAR010-TEST.pdf", (0, 100, 10), 250) """ for x_maxs in [50, 250]: show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) diff --git a/exps/NATS-Bench/Analyze-time.py b/exps/NATS-Bench/Analyze-time.py index 4db9e05..46c0e73 100644 --- a/exps/NATS-Bench/Analyze-time.py +++ b/exps/NATS-Bench/Analyze-time.py @@ -30,7 +30,11 @@ def show_time(api, epoch=12): all_cifar10_time += cifar10_time all_cifar100_time += cifar100_time all_imagenet_time += imagenet_time - print("The total training time for CIFAR-10 (held-out train set) is {:} seconds".format(all_cifar10_time)) + print( + "The total training time for CIFAR-10 (held-out train set) is {:} seconds".format( + all_cifar10_time + ) + ) print( "The total training time for CIFAR-100 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format( all_cifar100_time, all_cifar100_time / all_cifar10_time diff --git a/exps/NATS-Bench/draw-correlations.py b/exps/NATS-Bench/draw-correlations.py index 500b088..25b578e 100644 --- a/exps/NATS-Bench/draw-correlations.py +++ b/exps/NATS-Bench/draw-correlations.py @@ -30,15 +30,28 @@ from log_utils import time_string def get_valid_test_acc(api, arch, dataset): is_size_space = api.search_space_name == "size" if dataset == "cifar10": - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) test_acc = xinfo["test-accuracy"] - xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, + dataset="cifar10-valid", + hp=90 if is_size_space else 200, + is_random=False, + ) valid_acc = xinfo["valid-accuracy"] else: - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) valid_acc = xinfo["valid-accuracy"] test_acc = xinfo["test-accuracy"] - return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc) + return ( + valid_acc, + test_acc, + "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc), + ) def compute_kendalltau(vectori, vectorj): @@ -61,9 +74,17 @@ if __name__ == "__main__": formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log." + "--save_dir", + type=str, + default="output/vis-nas-bench/nas-algos", + help="Folder to save checkpoints and log.", + ) + parser.add_argument( + "--search_space", + type=str, + choices=["tss", "sss"], + help="Choose the search space.", ) - parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") args = parser.parse_args() save_dir = Path(args.save_dir) @@ -77,9 +98,17 @@ if __name__ == "__main__": scores_1.append(valid_acc) scores_2.append(test_acc) correlation = compute_kendalltau(scores_1, scores_2) - print("The kendall tau correlation of {:} samples : {:}".format(len(indexes), correlation)) + print( + "The kendall tau correlation of {:} samples : {:}".format( + len(indexes), correlation + ) + ) correlation = compute_spearmanr(scores_1, scores_2) - print("The spearmanr correlation of {:} samples : {:}".format(len(indexes), correlation)) + print( + "The spearmanr correlation of {:} samples : {:}".format( + len(indexes), correlation + ) + ) # scores_1 = ['{:.2f}'.format(x) for x in scores_1] # scores_2 = ['{:.2f}'.format(x) for x in scores_2] # print(', '.join(scores_1)) diff --git a/exps/NATS-Bench/draw-fig2_5.py b/exps/NATS-Bench/draw-fig2_5.py index 771d2e0..f9eb8fd 100644 --- a/exps/NATS-Bench/draw-fig2_5.py +++ b/exps/NATS-Bench/draw-fig2_5.py @@ -35,9 +35,15 @@ def visualize_relative_info(api, vis_save_dir, indicator): # print ('{:} start to visualize {:} information'.format(time_string(), api)) vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) - cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) - imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar10", indicator + ) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar100", indicator + ) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "ImageNet16-120", indicator + ) cifar010_info = torch.load(cifar010_cache_path) cifar100_info = torch.load(cifar100_cache_path) imagenet_info = torch.load(imagenet_cache_path) @@ -65,8 +71,15 @@ def visualize_relative_info(api, vis_save_dir, indicator): plt.xlim(min(indexes), max(indexes)) plt.ylim(min(indexes), max(indexes)) # plt.ylabel('y').set_rotation(30) - plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical") - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize) + plt.yticks( + np.arange(min(indexes), max(indexes), max(indexes) // 3), + fontsize=LegendFontsize, + rotation="vertical", + ) + plt.xticks( + np.arange(min(indexes), max(indexes), max(indexes) // 5), + fontsize=LegendFontsize, + ) ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8) ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) @@ -102,7 +115,9 @@ def visualize_sss_info(api, dataset, vis_save_dir): train_accs.append(info["train-accuracy"]) test_accs.append(info["test-accuracy"]) if dataset == "cifar10": - info = api.get_more_info(index, "cifar10-valid", hp="90", is_random=False) + info = api.get_more_info( + index, "cifar10-valid", hp="90", is_random=False + ) valid_accs.append(info["valid-accuracy"]) else: valid_accs.append(info["valid-accuracy"]) @@ -263,7 +278,9 @@ def visualize_tss_info(api, dataset, vis_save_dir): train_accs.append(info["train-accuracy"]) test_accs.append(info["test-accuracy"]) if dataset == "cifar10": - info = api.get_more_info(index, "cifar10-valid", hp="200", is_random=False) + info = api.get_more_info( + index, "cifar10-valid", hp="200", is_random=False + ) valid_accs.append(info["valid-accuracy"]) else: valid_accs.append(info["valid-accuracy"]) @@ -288,7 +305,9 @@ def visualize_tss_info(api, dataset, vis_save_dir): ) print("{:} collect data done.".format(time_string())) - resnet = ["|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"] + resnet = [ + "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|" + ] resnet_indexes = [api.query_index_by_arch(x) for x in resnet] largest_indexes = [ api.query_index_by_arch( @@ -415,9 +434,15 @@ def visualize_rank_info(api, vis_save_dir, indicator): # print ('{:} start to visualize {:} information'.format(time_string(), api)) vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) - cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) - imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar10", indicator + ) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar100", indicator + ) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "ImageNet16-120", indicator + ) cifar010_info = torch.load(cifar010_cache_path) cifar100_info = torch.load(cifar100_cache_path) imagenet_info = torch.load(imagenet_cache_path) @@ -452,8 +477,17 @@ def visualize_rank_info(api, vis_save_dir, indicator): ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8) ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) - ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)) - ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name)) + ax.scatter( + [-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name) + ) + ax.scatter( + [-1], + [-1], + marker="o", + s=100, + c="tab:blue", + label="{:} validation".format(name), + ) ax.legend(loc=4, fontsize=LegendFontsize) ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize) ax.set_ylabel("architecture ranking", fontsize=LabelSize) @@ -465,9 +499,13 @@ def visualize_rank_info(api, vis_save_dir, indicator): labels = get_labels(imagenet_info) plot_ax(labels, ax3, "ImageNet-16-120") - save_path = (vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)).resolve() + save_path = ( + vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator) + ).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") - save_path = (vis_save_dir / "{:}-same-relative-rank.png".format(indicator)).resolve() + save_path = ( + vis_save_dir / "{:}-same-relative-rank.png".format(indicator) + ).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") print("{:} save into {:}".format(time_string(), save_path)) plt.close("all") @@ -496,9 +534,15 @@ def visualize_all_rank_info(api, vis_save_dir, indicator): # print ('{:} start to visualize {:} information'.format(time_string(), api)) vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) - cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) - imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar10", indicator + ) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar100", indicator + ) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "ImageNet16-120", indicator + ) cifar010_info = torch.load(cifar010_cache_path) cifar100_info = torch.load(cifar100_cache_path) imagenet_info = torch.load(imagenet_cache_path) @@ -564,7 +608,9 @@ def visualize_all_rank_info(api, vis_save_dir, indicator): yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], ) ax1.set_title("Correlation coefficient over ALL candidates") - ax2.set_title("Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)) + ax2.set_title( + "Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar) + ) save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") print("{:} save into {:}".format(time_string(), save_path)) @@ -572,9 +618,14 @@ def visualize_all_rank_info(api, vis_save_dir, indicator): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument( - "--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log." + "--save_dir", + type=str, + default="output/vis-nas-bench", + help="Folder to save checkpoints and log.", ) # use for train the model args = parser.parse_args() diff --git a/exps/NATS-Bench/draw-fig6.py b/exps/NATS-Bench/draw-fig6.py index 12ca998..8b00ad8 100644 --- a/exps/NATS-Bench/draw-fig6.py +++ b/exps/NATS-Bench/draw-fig6.py @@ -43,7 +43,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): for alg, path in alg2path.items(): data = torch.load(path) for index, info in data.items(): - info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])] + info["time_w_arch"] = [ + (x, y) for x, y in zip(info["all_total_times"], info["all_archs"]) + ] for j, arch in enumerate(info["all_archs"]): assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( alg, search_space, dataset, index, j @@ -58,12 +60,16 @@ def query_performance(api, data, dataset, ticket): time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket)) time_a, arch_a = time_w_arch[0] time_b, arch_b = time_w_arch[1] - info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + info_a = api.get_more_info( + arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) + info_b = api.get_more_info( + arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"] - interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (ticket - time_a) / ( - time_b - time_a - ) * accuracy_b + interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + ( + ticket - time_a + ) / (time_b - time_a) * accuracy_b results.append(interplate) # return sum(results) / len(results) return np.mean(results), np.std(results) @@ -74,12 +80,21 @@ def show_valid_test(api, data, dataset): for i, info in data.items(): time, arch = info["time_w_arch"][-1] if dataset == "cifar10": - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) test_accs.append(xinfo["test-accuracy"]) - xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, + dataset="cifar10-valid", + hp=90 if is_size_space else 200, + is_random=False, + ) valid_accs.append(xinfo["valid-accuracy"]) else: - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) valid_accs.append(xinfo["valid-accuracy"]) test_accs.append(xinfo["test-accuracy"]) valid_str = "{:.2f}$\pm${:.2f}".format(np.mean(valid_accs), np.std(valid_accs)) @@ -114,7 +129,11 @@ x_axis_s = { ("ImageNet16-120", "sss"): 600, } -name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} +name2label = { + "cifar10": "CIFAR-10", + "cifar100": "CIFAR-100", + "ImageNet16-120": "ImageNet-16-120", +} def visualize_curve(api, vis_save_dir, search_space): @@ -130,10 +149,14 @@ def visualize_curve(api, vis_save_dir, search_space): alg2data = fetch_data(search_space=search_space, dataset=dataset) alg2accuracies = OrderedDict() total_tickets = 150 - time_tickets = [float(i) / total_tickets * int(max_time) for i in range(total_tickets)] + time_tickets = [ + float(i) / total_tickets * int(max_time) for i in range(total_tickets) + ] colors = ["b", "g", "c", "m", "y"] ax.set_xlim(0, x_axis_s[(xdataset, search_space)]) - ax.set_ylim(y_min_s[(xdataset, search_space)], y_max_s[(xdataset, search_space)]) + ax.set_ylim( + y_min_s[(xdataset, search_space)], y_max_s[(xdataset, search_space)] + ) for idx, (alg, data) in enumerate(alg2data.items()): accuracies = [] for ticket in time_tickets: @@ -142,13 +165,25 @@ def visualize_curve(api, vis_save_dir, search_space): valid_str, test_str = show_valid_test(api, data, xdataset) # print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std)) print( - "{:} plot alg : {:10s} | validation = {:} | test = {:}".format(time_string(), alg, valid_str, test_str) + "{:} plot alg : {:10s} | validation = {:} | test = {:}".format( + time_string(), alg, valid_str, test_str + ) ) alg2accuracies[alg] = accuracies - ax.plot([x / 100 for x in time_tickets], accuracies, c=colors[idx], label="{:}".format(alg)) + ax.plot( + [x / 100 for x in time_tickets], + accuracies, + c=colors[idx], + label="{:}".format(alg), + ) ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize) - ax.set_ylabel("Test accuracy on {:}".format(name2label[xdataset]), fontsize=LabelSize) - ax.set_title("Searching results on {:}".format(name2label[xdataset]), fontsize=LabelSize + 4) + ax.set_ylabel( + "Test accuracy on {:}".format(name2label[xdataset]), fontsize=LabelSize + ) + ax.set_title( + "Searching results on {:}".format(name2label[xdataset]), + fontsize=LabelSize + 4, + ) ax.legend(loc=4, fontsize=LegendFontsize) fig, axs = plt.subplots(1, 3, figsize=figsize) @@ -174,9 +209,17 @@ if __name__ == "__main__": formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log." + "--save_dir", + type=str, + default="output/vis-nas-bench/nas-algos", + help="Folder to save checkpoints and log.", + ) + parser.add_argument( + "--search_space", + type=str, + choices=["tss", "sss"], + help="Choose the search space.", ) - parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") args = parser.parse_args() save_dir = Path(args.save_dir) diff --git a/exps/NATS-Bench/draw-fig7.py b/exps/NATS-Bench/draw-fig7.py index 77c01fe..0a98537 100644 --- a/exps/NATS-Bench/draw-fig7.py +++ b/exps/NATS-Bench/draw-fig7.py @@ -31,18 +31,33 @@ from log_utils import time_string def get_valid_test_acc(api, arch, dataset): is_size_space = api.search_space_name == "size" if dataset == "cifar10": - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) test_acc = xinfo["test-accuracy"] - xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, + dataset="cifar10-valid", + hp=90 if is_size_space else 200, + is_random=False, + ) valid_acc = xinfo["valid-accuracy"] else: - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) valid_acc = xinfo["valid-accuracy"] test_acc = xinfo["test-accuracy"] - return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc) + return ( + valid_acc, + test_acc, + "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc), + ) -def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"): +def fetch_data( + root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3" +): ss_dir = "{:}-{:}".format(root_dir, search_space) alg2name, alg2path = OrderedDict(), OrderedDict() seeds = [777, 888, 999] @@ -55,8 +70,12 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suf alg2name["ENAS"] = "enas-affine0_BN0-None" alg2name["SETN"] = "setn-affine0_BN0-None" else: - alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(suffix) - alg2name["masking + Gumbel-Softmax"] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix) + alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format( + suffix + ) + alg2name[ + "masking + Gumbel-Softmax" + ] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix) alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix) for alg, name in alg2name.items(): alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth") @@ -72,7 +91,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suf continue data = torch.load(xpath, map_location=torch.device("cpu")) try: - data = torch.load(data["last_checkpoint"], map_location=torch.device("cpu")) + data = torch.load( + data["last_checkpoint"], map_location=torch.device("cpu") + ) except: xpath = str(data["last_checkpoint"]).split("E100-") if len(xpath) == 2 and os.path.isfile(xpath[0] + xpath[1]): @@ -82,7 +103,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suf elif "tunas" in str(data["last_checkpoint"]): xpath = str(data["last_checkpoint"]).replace("tunas", "mask_rl") else: - raise ValueError("Invalid path: {:}".format(data["last_checkpoint"])) + raise ValueError( + "Invalid path: {:}".format(data["last_checkpoint"]) + ) data = torch.load(xpath, map_location=torch.device("cpu")) alg2data[alg].append(data["genotypes"]) print("This algorithm : {:} has {:} valid ckps.".format(alg, ok_num)) @@ -108,9 +131,18 @@ y_max_s = { ("ImageNet16-120", "sss"): 46, } -name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} +name2label = { + "cifar10": "CIFAR-10", + "cifar100": "CIFAR-100", + "ImageNet16-120": "ImageNet-16-120", +} -name2suffix = {("sss", "warm"): "-WARM0.3", ("sss", "none"): "-WARMNone", ("tss", "none"): None, ("tss", None): None} +name2suffix = { + ("sss", "warm"): "-WARM0.3", + ("sss", "none"): "-WARMNone", + ("tss", "none"): None, + ("tss", None): None, +} def visualize_curve(api, vis_save_dir, search_space, suffix): @@ -123,7 +155,11 @@ def visualize_curve(api, vis_save_dir, search_space, suffix): def sub_plot_fn(ax, dataset): print("{:} plot {:10s}".format(time_string(), dataset)) - alg2data = fetch_data(search_space=search_space, dataset=dataset, suffix=name2suffix[(search_space, suffix)]) + alg2data = fetch_data( + search_space=search_space, + dataset=dataset, + suffix=name2suffix[(search_space, suffix)], + ) alg2accuracies = OrderedDict() epochs = 100 colors = ["b", "g", "c", "m", "y", "r"] @@ -135,10 +171,17 @@ def visualize_curve(api, vis_save_dir, search_space, suffix): try: structures, accs = [_[iepoch - 1] for _ in data], [] except: - raise ValueError("This alg {:} on {:} has invalid checkpoints.".format(alg, dataset)) + raise ValueError( + "This alg {:} on {:} has invalid checkpoints.".format( + alg, dataset + ) + ) for structure in structures: info = api.get_more_info( - structure, dataset=dataset, hp=90 if api.search_space_name == "size" else 200, is_random=False + structure, + dataset=dataset, + hp=90 if api.search_space_name == "size" else 200, + is_random=False, ) accs.append(info["test-accuracy"]) accuracies.append(sum(accs) / len(accs)) @@ -146,17 +189,31 @@ def visualize_curve(api, vis_save_dir, search_space, suffix): alg2accuracies[alg] = accuracies ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg)) ax.set_xlabel("The searching epoch", fontsize=LabelSize) - ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize) - ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4) + ax.set_ylabel( + "Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize + ) + ax.set_title( + "Searching results on {:}".format(name2label[dataset]), + fontsize=LabelSize + 4, + ) structures, valid_accs, test_accs = [_[epochs - 1] for _ in data], [], [] - print("{:} plot alg : {:} -- final {:} architectures.".format(time_string(), alg, len(structures))) + print( + "{:} plot alg : {:} -- final {:} architectures.".format( + time_string(), alg, len(structures) + ) + ) for arch in structures: valid_acc, test_acc, _ = get_valid_test_acc(api, arch, dataset) test_accs.append(test_acc) valid_accs.append(valid_acc) print( "{:} plot alg : {:} -- validation: {:.2f}$\pm${:.2f} -- test: {:.2f}$\pm${:.2f}".format( - time_string(), alg, np.mean(valid_accs), np.std(valid_accs), np.mean(test_accs), np.std(test_accs) + time_string(), + alg, + np.mean(valid_accs), + np.std(valid_accs), + np.mean(test_accs), + np.std(test_accs), ) ) ax.legend(loc=4, fontsize=LegendFontsize) @@ -166,16 +223,23 @@ def visualize_curve(api, vis_save_dir, search_space, suffix): for dataset, ax in zip(datasets, axs): sub_plot_fn(ax, dataset) print("sub-plot {:} on {:} done.".format(dataset, search_space)) - save_path = (vis_save_dir / "{:}-ws-{:}-curve.png".format(search_space, suffix)).resolve() + save_path = ( + vis_save_dir / "{:}-ws-{:}-curve.png".format(search_space, suffix) + ).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") print("{:} save into {:}".format(time_string(), save_path)) plt.close("all") if __name__ == "__main__": - parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument( - "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log." + "--save_dir", + type=str, + default="output/vis-nas-bench/nas-algos", + help="Folder to save checkpoints and log.", ) args = parser.parse_args() diff --git a/exps/NATS-Bench/draw-fig8.py b/exps/NATS-Bench/draw-fig8.py index bc81e19..8579aa5 100644 --- a/exps/NATS-Bench/draw-fig8.py +++ b/exps/NATS-Bench/draw-fig8.py @@ -28,7 +28,9 @@ from nats_bench import create from log_utils import time_string -plt.rcParams.update({"text.usetex": True, "font.family": "sans-serif", "font.sans-serif": ["Helvetica"]}) +plt.rcParams.update( + {"text.usetex": True, "font.family": "sans-serif", "font.sans-serif": ["Helvetica"]} +) ## for Palatino and other serif fonts use: plt.rcParams.update( { @@ -57,16 +59,22 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): raise ValueError("Unkonwn search space: {:}".format(search_space)) alg2all[r"REA ($\mathcal{H}^{0}$)"] = dict( - path=os.path.join(ss_dir, dataset + suffixes[0], "R-EA-SS3", "results.pth"), color="b", linestyle="-" + path=os.path.join(ss_dir, dataset + suffixes[0], "R-EA-SS3", "results.pth"), + color="b", + linestyle="-", ) alg2all[r"REA ({:})".format(hp)] = dict( - path=os.path.join(ss_dir, dataset + suffixes[1], "R-EA-SS3", "results.pth"), color="b", linestyle="--" + path=os.path.join(ss_dir, dataset + suffixes[1], "R-EA-SS3", "results.pth"), + color="b", + linestyle="--", ) for alg, xdata in alg2all.items(): data = torch.load(xdata["path"]) for index, info in data.items(): - info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])] + info["time_w_arch"] = [ + (x, y) for x, y in zip(info["all_total_times"], info["all_archs"]) + ] for j, arch in enumerate(info["all_archs"]): assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( alg, search_space, dataset, index, j @@ -81,12 +89,16 @@ def query_performance(api, data, dataset, ticket): time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket)) time_a, arch_a = time_w_arch[0] time_b, arch_b = time_w_arch[1] - info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + info_a = api.get_more_info( + arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) + info_b = api.get_more_info( + arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"] - interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (ticket - time_a) / ( - time_b - time_a - ) * accuracy_b + interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + ( + ticket - time_a + ) / (time_b - time_a) * accuracy_b results.append(interplate) # return sum(results) / len(results) return np.mean(results), np.std(results) @@ -119,7 +131,11 @@ x_axis_s = { ("ImageNet16-120", "sss"): 600, } -name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} +name2label = { + "cifar10": "CIFAR-10", + "cifar100": "CIFAR-100", + "ImageNet16-120": "ImageNet-16-120", +} spaces2latex = { "tss": r"$\mathcal{S}_{t}$", @@ -149,7 +165,9 @@ def visualize_curve(api_dict, vis_save_dir): alg2data = fetch_data(search_space=search_space, dataset=dataset) alg2accuracies = OrderedDict() total_tickets = 200 - time_tickets = [float(i) / total_tickets * int(max_time) for i in range(total_tickets)] + time_tickets = [ + float(i) / total_tickets * int(max_time) for i in range(total_tickets) + ] ax.set_xlim(0, x_axis_s[(dataset, search_space)]) ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) for tick in ax.get_xticklabels(): @@ -162,16 +180,29 @@ def visualize_curve(api_dict, vis_save_dir): accuracies = [] for ticket in time_tickets: # import pdb; pdb.set_trace() - accuracy, accuracy_std = query_performance(api_dict[search_space], xdata["data"], dataset, ticket) + accuracy, accuracy_std = query_performance( + api_dict[search_space], xdata["data"], dataset, ticket + ) accuracies.append(accuracy) # print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std)) - print("{:} plot alg : {:10s} on {:}".format(time_string(), alg, search_space)) + print( + "{:} plot alg : {:10s} on {:}".format(time_string(), alg, search_space) + ) alg2accuracies[alg] = accuracies - ax.plot(time_tickets, accuracies, c=xdata["color"], linestyle=xdata["linestyle"], label="{:}".format(alg)) + ax.plot( + time_tickets, + accuracies, + c=xdata["color"], + linestyle=xdata["linestyle"], + label="{:}".format(alg), + ) ax.set_xlabel("Estimated wall-clock time", fontsize=LabelSize) ax.set_ylabel("Test accuracy", fontsize=LabelSize) ax.set_title( - r"Results on {:} over {:}".format(name2label[dataset], spaces2latex[search_space]), fontsize=LabelSize + r"Results on {:} over {:}".format( + name2label[dataset], spaces2latex[search_space] + ), + fontsize=LabelSize, ) ax.legend(loc=4, fontsize=LegendFontsize) diff --git a/exps/NATS-Bench/draw-ranks.py b/exps/NATS-Bench/draw-ranks.py index d3a8c02..76ee5fb 100644 --- a/exps/NATS-Bench/draw-ranks.py +++ b/exps/NATS-Bench/draw-ranks.py @@ -30,12 +30,20 @@ from models import get_cell_based_tiny_net from nats_bench import create -name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} +name2label = { + "cifar10": "CIFAR-10", + "cifar100": "CIFAR-100", + "ImageNet16-120": "ImageNet-16-120", +} def visualize_relative_info(vis_save_dir, search_space, indicator, topk): vis_save_dir = vis_save_dir.resolve() - print("{:} start to visualize {:} with top-{:} information".format(time_string(), search_space, topk)) + print( + "{:} start to visualize {:} with top-{:} information".format( + time_string(), search_space, topk + ) + ) vis_save_dir.mkdir(parents=True, exist_ok=True) cache_file_path = vis_save_dir / "cache-{:}-info.pth".format(search_space) datasets = ["cifar10", "cifar100", "ImageNet16-120"] @@ -46,8 +54,12 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk): all_info = OrderedDict() for dataset in datasets: info_less = api.get_more_info(index, dataset, hp="12", is_random=False) - info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False) - all_info[dataset] = dict(less=info_less["test-accuracy"], more=info_more["test-accuracy"]) + info_more = api.get_more_info( + index, dataset, hp=api.full_train_epochs, is_random=False + ) + all_info[dataset] = dict( + less=info_less["test-accuracy"], more=info_more["test-accuracy"] + ) all_infos[index] = all_info torch.save(all_infos, cache_file_path) print("{:} save all cache data into {:}".format(time_string(), cache_file_path)) @@ -80,12 +92,18 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk): for idx in selected_indexes: standard_scores.append( api.get_more_info( - idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=False + idx, + dataset, + hp=api.full_train_epochs if indicator == "more" else "12", + is_random=False, )["test-accuracy"] ) random_scores.append( api.get_more_info( - idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=True + idx, + dataset, + hp=api.full_train_epochs if indicator == "more" else "12", + is_random=True, )["test-accuracy"] ) indexes = list(range(len(selected_indexes))) @@ -105,11 +123,28 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk): ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) ax.scatter(indexes, random_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) - ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Average Over Multi-Trials") - ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="Randomly Selected Trial") + ax.scatter( + [-1], + [-1], + marker="o", + s=100, + c="tab:blue", + label="Average Over Multi-Trials", + ) + ax.scatter( + [-1], + [-1], + marker="^", + s=100, + c="tab:green", + label="Randomly Selected Trial", + ) coef, p = scipy.stats.kendalltau(standard_scores, random_scores) - ax.set_xlabel("architecture ranking in {:}".format(name2label[dataset]), fontsize=LabelSize) + ax.set_xlabel( + "architecture ranking in {:}".format(name2label[dataset]), + fontsize=LabelSize, + ) if dataset == "cifar10": ax.set_ylabel("architecture ranking", fontsize=LabelSize) ax.legend(loc=4, fontsize=LegendFontsize) @@ -117,17 +152,27 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk): for dataset, ax in zip(datasets, axs): rank_coef = sub_plot_fn(ax, dataset, indicator) - print("sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format(dataset, search_space, rank_coef)) + print( + "sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format( + dataset, search_space, rank_coef + ) + ) - save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk)).resolve() + save_path = ( + vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk) + ).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") - save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk)).resolve() + save_path = ( + vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk) + ).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") print("Save into {:}".format(save_path)) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument( "--save_dir", type=str, diff --git a/exps/NATS-Bench/draw-table.py b/exps/NATS-Bench/draw-table.py index 8a52a4f..d90ec06 100644 --- a/exps/NATS-Bench/draw-table.py +++ b/exps/NATS-Bench/draw-table.py @@ -42,7 +42,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): for alg, path in alg2path.items(): data = torch.load(path) for index, info in data.items(): - info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])] + info["time_w_arch"] = [ + (x, y) for x, y in zip(info["all_total_times"], info["all_archs"]) + ] for j, arch in enumerate(info["all_archs"]): assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( alg, search_space, dataset, index, j @@ -54,15 +56,28 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): def get_valid_test_acc(api, arch, dataset): is_size_space = api.search_space_name == "size" if dataset == "cifar10": - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) test_acc = xinfo["test-accuracy"] - xinfo = api.get_more_info(arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, + dataset="cifar10-valid", + hp=90 if is_size_space else 200, + is_random=False, + ) valid_acc = xinfo["valid-accuracy"] else: - xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + xinfo = api.get_more_info( + arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) valid_acc = xinfo["valid-accuracy"] test_acc = xinfo["test-accuracy"] - return valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc) + return ( + valid_acc, + test_acc, + "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc), + ) def show_valid_test(api, arch): @@ -84,8 +99,16 @@ def find_best_valid(api, dataset): best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0] print("-" * 50 + "{:10s}".format(dataset) + "-" * 50) - print("Best ({:}) architecture on validation: {:}".format(best_valid_index, api[best_valid_index])) - print("Best ({:}) architecture on test: {:}".format(best_test_index, api[best_test_index])) + print( + "Best ({:}) architecture on validation: {:}".format( + best_valid_index, api[best_valid_index] + ) + ) + print( + "Best ({:}) architecture on test: {:}".format( + best_test_index, api[best_test_index] + ) + ) _, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset) print("using validation ::: {:}".format(perf_str)) _, _, perf_str = get_valid_test_acc(api, best_test_index, dataset) @@ -130,10 +153,14 @@ def show_multi_trial(search_space): v_acc, t_acc = query_performance(api, x, xdataset, float(max_time)) valid_accs.append(v_acc) test_accs.append(t_acc) - valid_str = "{:.2f}$\pm${:.2f}".format(np.mean(valid_accs), np.std(valid_accs)) + valid_str = "{:.2f}$\pm${:.2f}".format( + np.mean(valid_accs), np.std(valid_accs) + ) test_str = "{:.2f}$\pm${:.2f}".format(np.mean(test_accs), np.std(test_accs)) print( - "{:} plot alg : {:10s} | validation = {:} | test = {:}".format(time_string(), alg, valid_str, test_str) + "{:} plot alg : {:10s} | validation = {:} | test = {:}".format( + time_string(), alg, valid_str, test_str + ) ) if search_space == "tss": diff --git a/exps/NATS-Bench/main-sss.py b/exps/NATS-Bench/main-sss.py index dcbd579..e63c076 100644 --- a/exps/NATS-Bench/main-sss.py +++ b/exps/NATS-Bench/main-sss.py @@ -51,23 +51,35 @@ def evaluate_all_datasets( train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) # load the configuration if dataset == "cifar10" or dataset == "cifar100": - split_info = load_config("configs/nas-benchmark/cifar-split.txt", None, None) + split_info = load_config( + "configs/nas-benchmark/cifar-split.txt", None, None + ) elif dataset.startswith("ImageNet16"): - split_info = load_config("configs/nas-benchmark/{:}-split.txt".format(dataset), None, None) + split_info = load_config( + "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None + ) else: raise ValueError("invalid dataset : {:}".format(dataset)) - config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) + config = load_config( + config_path, dict(class_num=class_num, xshape=xshape), logger + ) # check whether use the splitted validation set if bool(split): assert dataset == "cifar10" ValLoaders = { "ori-test": torch.utils.data.DataLoader( - valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + valid_data, + batch_size=config.batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, ) } assert len(train_data) == len(split_info.train) + len( split_info.valid - ), "invalid length : {:} vs {:} + {:}".format(len(train_data), len(split_info.train), len(split_info.valid)) + ), "invalid length : {:} vs {:} + {:}".format( + len(train_data), len(split_info.train), len(split_info.valid) + ) train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 @@ -90,47 +102,67 @@ def evaluate_all_datasets( else: # data loader train_loader = torch.utils.data.DataLoader( - train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + train_data, + batch_size=config.batch_size, + shuffle=True, + num_workers=workers, + pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( - valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + valid_data, + batch_size=config.batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, ) if dataset == "cifar10": ValLoaders = {"ori-test": valid_loader} elif dataset == "cifar100": - cifar100_splits = load_config("configs/nas-benchmark/cifar100-test-split.txt", None, None) + cifar100_splits = load_config( + "configs/nas-benchmark/cifar100-test-split.txt", None, None + ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + cifar100_splits.xvalid + ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + cifar100_splits.xtest + ), num_workers=workers, pin_memory=True, ), } elif dataset == "ImageNet16-120": - imagenet16_splits = load_config("configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None) + imagenet16_splits = load_config( + "configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None + ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + imagenet16_splits.xvalid + ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + imagenet16_splits.xtest + ), num_workers=workers, pin_memory=True, ), @@ -143,19 +175,36 @@ def evaluate_all_datasets( dataset_key = dataset_key + "-valid" logger.log( "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( - dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size + dataset_key, + len(train_data), + len(valid_data), + len(train_loader), + len(valid_loader), + config.batch_size, ) ) - logger.log("Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)) + logger.log( + "Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config) + ) for key, value in ValLoaders.items(): - logger.log("Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))) + logger.log( + "Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)) + ) # arch-index= 9930, arch=|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2| # this genotype is the architecture with the highest accuracy on CIFAR-100 validation set genotype = "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|" arch_config = dict2config( - dict(name="infer.shape.tiny", channels=channels, genotype=genotype, num_classes=class_num), None + dict( + name="infer.shape.tiny", + channels=channels, + genotype=genotype, + num_classes=class_num, + ), + None, + ) + results = bench_evaluate_for_seed( + arch_config, config, train_loader, ValLoaders, seed, logger ) - results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger) all_infos[dataset_key] = results all_dataset_keys.append(dataset_key) all_infos["all_dataset_keys"] = all_dataset_keys @@ -183,8 +232,12 @@ def main( logger.log("xargs : cover_mode = {:}".format(cover_mode)) logger.log("-" * 100) logger.log( - "Start evaluating range =: {:06d} - {:06d}".format(min(to_evaluate_indexes), max(to_evaluate_indexes)) - + "({:} in total) / {:06d} with cover-mode={:}".format(len(to_evaluate_indexes), len(nets), cover_mode) + "Start evaluating range =: {:06d} - {:06d}".format( + min(to_evaluate_indexes), max(to_evaluate_indexes) + ) + + "({:} in total) / {:06d} with cover-mode={:}".format( + len(to_evaluate_indexes), len(nets), cover_mode + ) ) for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): logger.log( @@ -199,7 +252,13 @@ def main( channelstr = nets[index] logger.log( "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}".format( - time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, "-" * 15 + time_string(), + i, + len(to_evaluate_indexes), + index, + len(nets), + seeds, + "-" * 15, ) ) logger.log("{:} {:} {:}".format("-" * 15, channelstr, "-" * 15)) @@ -210,17 +269,33 @@ def main( to_save_name = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) if to_save_name.exists(): if cover_mode: - logger.log("Find existing file : {:}, remove it before evaluation".format(to_save_name)) + logger.log( + "Find existing file : {:}, remove it before evaluation".format( + to_save_name + ) + ) os.remove(str(to_save_name)) else: - logger.log("Find existing file : {:}, skip this evaluation".format(to_save_name)) + logger.log( + "Find existing file : {:}, skip this evaluation".format( + to_save_name + ) + ) has_continue = True continue - results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger) + results = evaluate_all_datasets( + channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger + ) torch.save(results, to_save_name) logger.log( "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}".format( - time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, to_save_name + time_string(), + i, + len(to_evaluate_indexes), + index, + len(nets), + seeds, + to_save_name, ) ) # measure elapsed time @@ -230,7 +305,9 @@ def main( need_time = "Time Left: {:}".format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True) ) - logger.log("This arch costs : {:}".format(convert_secs2time(epoch_time.val, True))) + logger.log( + "This arch costs : {:}".format(convert_secs2time(epoch_time.val, True)) + ) logger.log("{:}".format("*" * 100)) logger.log( "{:} {:74s} {:}".format( @@ -277,16 +354,24 @@ def filter_indexes(xlist, mode, save_dir, seeds): SLURM_PROCID, SLURM_NTASKS = "SLURM_PROCID", "SLURM_NTASKS" if SLURM_PROCID in os.environ and SLURM_NTASKS in os.environ: # run on the slurm proc_id, ntasks = int(os.environ[SLURM_PROCID]), int(os.environ[SLURM_NTASKS]) - assert 0 <= proc_id < ntasks, "invalid proc_id {:} vs ntasks {:}".format(proc_id, ntasks) - scales = [int(float(i) / ntasks * len(all_indexes)) for i in range(ntasks)] + [len(all_indexes)] + assert 0 <= proc_id < ntasks, "invalid proc_id {:} vs ntasks {:}".format( + proc_id, ntasks + ) + scales = [int(float(i) / ntasks * len(all_indexes)) for i in range(ntasks)] + [ + len(all_indexes) + ] per_job = [] for i in range(ntasks): - xs, xe = min(max(scales[i], 0), len(all_indexes) - 1), min(max(scales[i + 1] - 1, 0), len(all_indexes) - 1) + xs, xe = min(max(scales[i], 0), len(all_indexes) - 1), min( + max(scales[i + 1] - 1, 0), len(all_indexes) - 1 + ) per_job.append((xs, xe)) for i, srange in enumerate(per_job): print(" -->> {:2d}/{:02d} : {:}".format(i, ntasks, srange)) current_range = per_job[proc_id] - all_indexes = [all_indexes[i] for i in range(current_range[0], current_range[1] + 1)] + all_indexes = [ + all_indexes[i] for i in range(current_range[0], current_range[1] + 1) + ] # set the device id device = proc_id % torch.cuda.device_count() torch.cuda.set_device(device) @@ -301,30 +386,67 @@ def filter_indexes(xlist, mode, save_dir, seeds): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="NATS-Bench (size search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="NATS-Bench (size search space)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--mode", type=str, required=True, choices=["new", "cover"], help="The script mode.") parser.add_argument( - "--save_dir", type=str, default="output/NATS-Bench-size", help="Folder to save checkpoints and log." + "--mode", + type=str, + required=True, + choices=["new", "cover"], + help="The script mode.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="output/NATS-Bench-size", + help="Folder to save checkpoints and log.", + ) + parser.add_argument( + "--candidateC", + type=int, + nargs="+", + default=[8, 16, 24, 32, 40, 48, 56, 64], + help=".", + ) + parser.add_argument( + "--num_layers", type=int, default=5, help="The number of layers in a network." ) - parser.add_argument("--candidateC", type=int, nargs="+", default=[8, 16, 24, 32, 40, 48, 56, 64], help=".") - parser.add_argument("--num_layers", type=int, default=5, help="The number of layers in a network.") parser.add_argument("--check_N", type=int, default=32768, help="For safety.") # use for train the model - parser.add_argument("--workers", type=int, default=8, help="The number of data loading workers (default: 2)") - parser.add_argument("--srange", type=str, required=True, help="The range of models to be evaluated") - parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.") - parser.add_argument("--xpaths", type=str, nargs="+", help="The root path for this dataset.") - parser.add_argument("--splits", type=int, nargs="+", help="The root path for this dataset.") parser.add_argument( - "--hyper", type=str, default="12", choices=["01", "12", "90"], help="The tag for hyper-parameters." + "--workers", + type=int, + default=8, + help="The number of data loading workers (default: 2)", + ) + parser.add_argument( + "--srange", type=str, required=True, help="The range of models to be evaluated" + ) + parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.") + parser.add_argument( + "--xpaths", type=str, nargs="+", help="The root path for this dataset." + ) + parser.add_argument( + "--splits", type=int, nargs="+", help="The root path for this dataset." + ) + parser.add_argument( + "--hyper", + type=str, + default="12", + choices=["01", "12", "90"], + help="The tag for hyper-parameters.", + ) + parser.add_argument( + "--seeds", type=int, nargs="+", help="The range of models to be evaluated" ) - parser.add_argument("--seeds", type=int, nargs="+", help="The range of models to be evaluated") args = parser.parse_args() nets = traverse_net(args.candidateC, args.num_layers) if len(nets) != args.check_N: - raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) + raise ValueError( + "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) + ) opt_config = "./configs/nas-benchmark/hyper-opts/{:}E.config".format(args.hyper) if not os.path.isfile(opt_config): @@ -337,12 +459,16 @@ if __name__ == "__main__": raise ValueError("invalid length of seeds args: {:}".format(args.seeds)) if not (len(args.datasets) == len(args.xpaths) == len(args.splits)): raise ValueError( - "invalid infos : {:} vs {:} vs {:}".format(len(args.datasets), len(args.xpaths), len(args.splits)) + "invalid infos : {:} vs {:} vs {:}".format( + len(args.datasets), len(args.xpaths), len(args.splits) + ) ) if args.workers <= 0: raise ValueError("invalid number of workers : {:}".format(args.workers)) - target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds) + target_indexes = filter_indexes( + to_evaluate_indexes, args.mode, save_dir, args.seeds + ) assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True diff --git a/exps/NATS-Bench/main-tss.py b/exps/NATS-Bench/main-tss.py index d6e9231..b5ca5fb 100644 --- a/exps/NATS-Bench/main-tss.py +++ b/exps/NATS-Bench/main-tss.py @@ -57,23 +57,35 @@ def evaluate_all_datasets( train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) # load the configuration if dataset == "cifar10" or dataset == "cifar100": - split_info = load_config("configs/nas-benchmark/cifar-split.txt", None, None) + split_info = load_config( + "configs/nas-benchmark/cifar-split.txt", None, None + ) elif dataset.startswith("ImageNet16"): - split_info = load_config("configs/nas-benchmark/{:}-split.txt".format(dataset), None, None) + split_info = load_config( + "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None + ) else: raise ValueError("invalid dataset : {:}".format(dataset)) - config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) + config = load_config( + config_path, dict(class_num=class_num, xshape=xshape), logger + ) # check whether use splited validation set if bool(split): assert dataset == "cifar10" ValLoaders = { "ori-test": torch.utils.data.DataLoader( - valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + valid_data, + batch_size=config.batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, ) } assert len(train_data) == len(split_info.train) + len( split_info.valid - ), "invalid length : {:} vs {:} + {:}".format(len(train_data), len(split_info.train), len(split_info.valid)) + ), "invalid length : {:} vs {:} + {:}".format( + len(train_data), len(split_info.train), len(split_info.valid) + ) train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 @@ -96,47 +108,67 @@ def evaluate_all_datasets( else: # data loader train_loader = torch.utils.data.DataLoader( - train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True + train_data, + batch_size=config.batch_size, + shuffle=True, + num_workers=workers, + pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( - valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True + valid_data, + batch_size=config.batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, ) if dataset == "cifar10": ValLoaders = {"ori-test": valid_loader} elif dataset == "cifar100": - cifar100_splits = load_config("configs/nas-benchmark/cifar100-test-split.txt", None, None) + cifar100_splits = load_config( + "configs/nas-benchmark/cifar100-test-split.txt", None, None + ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + cifar100_splits.xvalid + ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + cifar100_splits.xtest + ), num_workers=workers, pin_memory=True, ), } elif dataset == "ImageNet16-120": - imagenet16_splits = load_config("configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None) + imagenet16_splits = load_config( + "configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None + ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + imagenet16_splits.xvalid + ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest), + sampler=torch.utils.data.sampler.SubsetRandomSampler( + imagenet16_splits.xtest + ), num_workers=workers, pin_memory=True, ), @@ -149,12 +181,21 @@ def evaluate_all_datasets( dataset_key = dataset_key + "-valid" logger.log( "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( - dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size + dataset_key, + len(train_data), + len(valid_data), + len(train_loader), + len(valid_loader), + config.batch_size, ) ) - logger.log("Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)) + logger.log( + "Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config) + ) for key, value in ValLoaders.items(): - logger.log("Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))) + logger.log( + "Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)) + ) arch_config = dict2config( dict( name="infer.tiny", @@ -165,7 +206,9 @@ def evaluate_all_datasets( ), None, ) - results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger) + results = bench_evaluate_for_seed( + arch_config, config, train_loader, ValLoaders, seed, logger + ) all_infos[dataset_key] = results all_dataset_keys.append(dataset_key) all_infos["all_dataset_keys"] = all_dataset_keys @@ -194,8 +237,12 @@ def main( logger.log("xargs : cover_mode = {:}".format(cover_mode)) logger.log("-" * 100) logger.log( - "Start evaluating range =: {:06d} - {:06d}".format(min(to_evaluate_indexes), max(to_evaluate_indexes)) - + "({:} in total) / {:06d} with cover-mode={:}".format(len(to_evaluate_indexes), len(nets), cover_mode) + "Start evaluating range =: {:06d} - {:06d}".format( + min(to_evaluate_indexes), max(to_evaluate_indexes) + ) + + "({:} in total) / {:06d} with cover-mode={:}".format( + len(to_evaluate_indexes), len(nets), cover_mode + ) ) for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): logger.log( @@ -210,7 +257,13 @@ def main( arch = nets[index] logger.log( "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}".format( - time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, "-" * 15 + time_string(), + i, + len(to_evaluate_indexes), + index, + len(nets), + seeds, + "-" * 15, ) ) logger.log("{:} {:} {:}".format("-" * 15, arch, "-" * 15)) @@ -221,10 +274,18 @@ def main( to_save_name = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) if to_save_name.exists(): if cover_mode: - logger.log("Find existing file : {:}, remove it before evaluation".format(to_save_name)) + logger.log( + "Find existing file : {:}, remove it before evaluation".format( + to_save_name + ) + ) os.remove(str(to_save_name)) else: - logger.log("Find existing file : {:}, skip this evaluation".format(to_save_name)) + logger.log( + "Find existing file : {:}, skip this evaluation".format( + to_save_name + ) + ) has_continue = True continue results = evaluate_all_datasets( @@ -241,7 +302,13 @@ def main( torch.save(results, to_save_name) logger.log( "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}".format( - time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, to_save_name + time_string(), + i, + len(to_evaluate_indexes), + index, + len(nets), + seeds, + to_save_name, ) ) # measure elapsed time @@ -251,7 +318,9 @@ def main( need_time = "Time Left: {:}".format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True) ) - logger.log("This arch costs : {:}".format(convert_secs2time(epoch_time.val, True))) + logger.log( + "This arch costs : {:}".format(convert_secs2time(epoch_time.val, True)) + ) logger.log("{:}".format("*" * 100)) logger.log( "{:} {:74s} {:}".format( @@ -267,7 +336,9 @@ def main( logger.close() -def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config): +def train_single_model( + save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config +): assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = True @@ -278,19 +349,32 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se Path(save_dir) / "specifics" / "{:}-{:}-{:}-{:}".format( - "LESS" if use_less else "FULL", model_str, arch_config["channel"], arch_config["num_cells"] + "LESS" if use_less else "FULL", + model_str, + arch_config["channel"], + arch_config["num_cells"], ) ) logger = Logger(str(save_dir), 0, False) if model_str in CellArchitectures: arch = CellArchitectures[model_str] - logger.log("The model string is found in pre-defined architecture dict : {:}".format(model_str)) + logger.log( + "The model string is found in pre-defined architecture dict : {:}".format( + model_str + ) + ) else: try: arch = CellStructure.str2structure(model_str) except: - raise ValueError("Invalid model string : {:}. It can not be found or parsed.".format(model_str)) - assert arch.check_valid_op(get_search_spaces("cell", "full")), "{:} has the invalid op.".format(arch) + raise ValueError( + "Invalid model string : {:}. It can not be found or parsed.".format( + model_str + ) + ) + assert arch.check_valid_op( + get_search_spaces("cell", "full") + ), "{:} has the invalid op.".format(arch) logger.log("Start train-evaluate {:}".format(arch.tostr())) logger.log("arch_config : {:}".format(arch_config)) @@ -303,27 +387,55 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se ) to_save_name = save_dir / "seed-{:04d}.pth".format(seed) if to_save_name.exists(): - logger.log("Find the existing file {:}, directly load!".format(to_save_name)) + logger.log( + "Find the existing file {:}, directly load!".format(to_save_name) + ) checkpoint = torch.load(to_save_name) else: - logger.log("Does not find the existing file {:}, train and evaluate!".format(to_save_name)) + logger.log( + "Does not find the existing file {:}, train and evaluate!".format( + to_save_name + ) + ) checkpoint = evaluate_all_datasets( - arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger + arch, + datasets, + xpaths, + splits, + use_less, + seed, + arch_config, + workers, + logger, ) torch.save(checkpoint, to_save_name) # log information logger.log("{:}".format(checkpoint["info"])) all_dataset_keys = checkpoint["all_dataset_keys"] for dataset_key in all_dataset_keys: - logger.log("\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15)) + logger.log( + "\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15) + ) dataset_info = checkpoint[dataset_key] # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] )) - logger.log("Flops = {:} MB, Params = {:} MB".format(dataset_info["flop"], dataset_info["param"])) + logger.log( + "Flops = {:} MB, Params = {:} MB".format( + dataset_info["flop"], dataset_info["param"] + ) + ) logger.log("config : {:}".format(dataset_info["config"])) - logger.log("Training State (finish) = {:}".format(dataset_info["finish-train"])) + logger.log( + "Training State (finish) = {:}".format(dataset_info["finish-train"]) + ) last_epoch = dataset_info["total_epoch"] - 1 - train_acc1es, train_acc5es = dataset_info["train_acc1es"], dataset_info["train_acc5es"] - valid_acc1es, valid_acc5es = dataset_info["valid_acc1es"], dataset_info["valid_acc5es"] + train_acc1es, train_acc5es = ( + dataset_info["train_acc1es"], + dataset_info["train_acc5es"], + ) + valid_acc1es, valid_acc5es = ( + dataset_info["valid_acc1es"], + dataset_info["valid_acc5es"], + ) logger.log( "Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format( train_acc1es[last_epoch], @@ -337,7 +449,9 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se # measure elapsed time seed_time.update(time.time() - start_time) start_time = time.time() - need_time = "Time Left: {:}".format(convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True) + ) logger.log( "\n<<<***>>> The {:02d}/{:02d}-th seed is {:} other procedures need {:}".format( _is, len(seeds), seed, need_time @@ -349,7 +463,11 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se def generate_meta_info(save_dir, max_node, divide=40): aa_nas_bench_ss = get_search_spaces("cell", "nas-bench-201") archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) - print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2))) + print( + "There are {:} archs vs {:}.".format( + len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2) + ) + ) random.seed(88) # please do not change this line for reproducibility random.shuffle(archs) @@ -361,10 +479,12 @@ def generate_meta_info(save_dir, max_node, divide=40): == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" ), "please check the 0-th architecture : {:}".format(archs[0]) assert ( - archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" + archs[9].tostr() + == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" ), "please check the 9-th architecture : {:}".format(archs[9]) assert ( - archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" + archs[123].tostr() + == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" ), "please check the 123-th architecture : {:}".format(archs[123]) total_arch = len(archs) @@ -383,11 +503,21 @@ def generate_meta_info(save_dir, max_node, divide=40): and valid_split[10] == 18 and valid_split[111] == 242 ), "{:} {:} {:} - {:} {:} {:}".format( - train_split[0], train_split[10], train_split[111], valid_split[0], valid_split[10], valid_split[111] + train_split[0], + train_split[10], + train_split[111], + valid_split[0], + valid_split[10], + valid_split[111], ) splits = {num: {"train": train_split, "valid": valid_split}} - info = {"archs": [x.tostr() for x in archs], "total": total_arch, "max_node": max_node, "splits": splits} + info = { + "archs": [x.tostr() for x in archs], + "total": total_arch, + "max_node": max_node, + "splits": splits, + } save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) @@ -400,7 +530,11 @@ def generate_meta_info(save_dir, max_node, divide=40): def traverse_net(max_node): aa_nas_bench_ss = get_search_spaces("cell", "nats-bench") archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) - print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2))) + print( + "There are {:} archs vs {:}.".format( + len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2) + ) + ) random.seed(88) # please do not change this line for reproducibility random.shuffle(archs) @@ -409,10 +543,12 @@ def traverse_net(max_node): == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" ), "please check the 0-th architecture : {:}".format(archs[0]) assert ( - archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" + archs[9].tostr() + == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" ), "please check the 9-th architecture : {:}".format(archs[9]) assert ( - archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" + archs[123].tostr() + == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" ), "please check the 123-th architecture : {:}".format(archs[123]) return [x.tostr() for x in archs] @@ -439,32 +575,62 @@ def filter_indexes(xlist, mode, save_dir, seeds): if __name__ == "__main__": # mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()] parser = argparse.ArgumentParser( - description="NATS-Bench (topology search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="NATS-Bench (topology search space)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("--mode", type=str, required=True, help="The script mode.") parser.add_argument( - "--save_dir", type=str, default="output/NATS-Bench-topology", help="Folder to save checkpoints and log." + "--save_dir", + type=str, + default="output/NATS-Bench-topology", + help="Folder to save checkpoints and log.", ) - parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell (please do not change it).") - # use for train the model - parser.add_argument("--workers", type=int, default=8, help="number of data loading workers (default: 2)") - parser.add_argument("--srange", type=str, required=True, help="The range of models to be evaluated") - parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.") - parser.add_argument("--xpaths", type=str, nargs="+", help="The root path for this dataset.") - parser.add_argument("--splits", type=int, nargs="+", help="The root path for this dataset.") parser.add_argument( - "--hyper", type=str, default="12", choices=["01", "12", "200"], help="The tag for hyper-parameters." + "--max_node", + type=int, + default=4, + help="The maximum node in a cell (please do not change it).", + ) + # use for train the model + parser.add_argument( + "--workers", + type=int, + default=8, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--srange", type=str, required=True, help="The range of models to be evaluated" + ) + parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.") + parser.add_argument( + "--xpaths", type=str, nargs="+", help="The root path for this dataset." + ) + parser.add_argument( + "--splits", type=int, nargs="+", help="The root path for this dataset." + ) + parser.add_argument( + "--hyper", + type=str, + default="12", + choices=["01", "12", "200"], + help="The tag for hyper-parameters.", ) - parser.add_argument("--seeds", type=int, nargs="+", help="The range of models to be evaluated") - parser.add_argument("--channel", type=int, default=16, help="The number of channels.") - parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + parser.add_argument( + "--seeds", type=int, nargs="+", help="The range of models to be evaluated" + ) + parser.add_argument( + "--channel", type=int, default=16, help="The number of channels." + ) + parser.add_argument( + "--num_cells", type=int, default=5, help="The number of cells in one stage." + ) parser.add_argument("--check_N", type=int, default=15625, help="For safety.") args = parser.parse_args() - assert args.mode in ["meta", "new", "cover"] or args.mode.startswith("specific-"), "invalid mode : {:}".format( - args.mode - ) + assert args.mode in ["meta", "new", "cover"] or args.mode.startswith( + "specific-" + ), "invalid mode : {:}".format(args.mode) if args.mode == "meta": generate_meta_info(args.save_dir, args.max_node) @@ -485,7 +651,9 @@ if __name__ == "__main__": else: nets = traverse_net(args.max_node) if len(nets) != args.check_N: - raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) + raise ValueError( + "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) + ) opt_config = "./configs/nas-benchmark/hyper-opts/{:}E.config".format(args.hyper) if not os.path.isfile(opt_config): raise ValueError("{:} is not a file.".format(opt_config)) @@ -496,12 +664,16 @@ if __name__ == "__main__": raise ValueError("invalid length of seeds args: {:}".format(args.seeds)) if not (len(args.datasets) == len(args.xpaths) == len(args.splits)): raise ValueError( - "invalid infos : {:} vs {:} vs {:}".format(len(args.datasets), len(args.xpaths), len(args.splits)) + "invalid infos : {:} vs {:} vs {:}".format( + len(args.datasets), len(args.xpaths), len(args.splits) + ) ) if args.workers <= 0: raise ValueError("invalid number of workers : {:}".format(args.workers)) - target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds) + target_indexes = filter_indexes( + to_evaluate_indexes, args.mode, save_dir, args.seeds + ) assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True @@ -519,5 +691,9 @@ if __name__ == "__main__": opt_config, target_indexes, args.mode == "cover", - {"name": "infer.tiny", "channel": args.channel, "num_cells": args.num_cells}, + { + "name": "infer.tiny", + "channel": args.channel, + "num_cells": args.num_cells, + }, ) diff --git a/exps/NATS-Bench/sss-collect.py b/exps/NATS-Bench/sss-collect.py index dc8f650..b5ab2d4 100644 --- a/exps/NATS-Bench/sss-collect.py +++ b/exps/NATS-Bench/sss-collect.py @@ -31,24 +31,34 @@ from utils import get_md5_file NATS_SSS_BASE_NAME = "NATS-sss-v1_0" # 2020.08.28 -def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text]) -> ArchResults: +def account_one_arch( + arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text] +) -> ArchResults: information = ArchResults(arch_index, arch_str) for checkpoint_path in checkpoints: try: checkpoint = torch.load(checkpoint_path, map_location="cpu") except: - raise ValueError("This checkpoint failed to be loaded : {:}".format(checkpoint_path)) + raise ValueError( + "This checkpoint failed to be loaded : {:}".format(checkpoint_path) + ) used_seed = checkpoint_path.name.split("-")[-1].split(".")[0] ok_dataset = 0 for dataset in datasets: if dataset not in checkpoint: - print("Can not find {:} in arch-{:} from {:}".format(dataset, arch_index, checkpoint_path)) + print( + "Can not find {:} in arch-{:} from {:}".format( + dataset, arch_index, checkpoint_path + ) + ) continue else: ok_dataset += 1 results = checkpoint[dataset] - assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( + assert results[ + "finish-train" + ], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( arch_index, used_seed, dataset, checkpoint_path ) arch_config = { @@ -71,13 +81,20 @@ def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], d None, ) xresult.update_train_info( - results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"] + results["train_acc1es"], + results["train_acc5es"], + results["train_losses"], + results["train_times"], + ) + xresult.update_eval( + results["valid_acc1es"], results["valid_losses"], results["valid_times"] ) - xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"]) information.update(dataset, int(used_seed), xresult) if ok_dataset < len(datasets): raise ValueError( - "{:} does find enought data : {:} vs {:}".format(checkpoint_path, ok_dataset, len(datasets)) + "{:} does find enought data : {:} vs {:}".format( + checkpoint_path, ok_dataset, len(datasets) + ) ) return information @@ -107,7 +124,9 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]): arch_info.reset_latency("ImageNet16-120", None, image_latency) # CIFAR10 VALID - train_per_epoch_time = list(hp2info["01"].query("cifar10-valid", 777).train_times.values()) + train_per_epoch_time = list( + hp2info["01"].query("cifar10-valid", 777).train_times.values() + ) train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) eval_ori_test_time, eval_x_valid_time = [], [] for key, value in hp2info["01"].query("cifar10-valid", 777).eval_times.items(): @@ -121,11 +140,17 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]): eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time) for hp, arch_info in hp2info.items(): arch_info.reset_pseudo_train_times("cifar10-valid", None, train_per_epoch_time) - arch_info.reset_pseudo_eval_times("cifar10-valid", None, "x-valid", eval_x_valid_time) - arch_info.reset_pseudo_eval_times("cifar10-valid", None, "ori-test", eval_ori_test_time) + arch_info.reset_pseudo_eval_times( + "cifar10-valid", None, "x-valid", eval_x_valid_time + ) + arch_info.reset_pseudo_eval_times( + "cifar10-valid", None, "ori-test", eval_ori_test_time + ) # CIFAR10 - train_per_epoch_time = list(hp2info["01"].query("cifar10", 777).train_times.values()) + train_per_epoch_time = list( + hp2info["01"].query("cifar10", 777).train_times.values() + ) train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) eval_ori_test_time = [] for key, value in hp2info["01"].query("cifar10", 777).eval_times.items(): @@ -136,10 +161,14 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]): eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) for hp, arch_info in hp2info.items(): arch_info.reset_pseudo_train_times("cifar10", None, train_per_epoch_time) - arch_info.reset_pseudo_eval_times("cifar10", None, "ori-test", eval_ori_test_time) + arch_info.reset_pseudo_eval_times( + "cifar10", None, "ori-test", eval_ori_test_time + ) # CIFAR100 - train_per_epoch_time = list(hp2info["01"].query("cifar100", 777).train_times.values()) + train_per_epoch_time = list( + hp2info["01"].query("cifar100", 777).train_times.values() + ) train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], [] for key, value in hp2info["01"].query("cifar100", 777).eval_times.items(): @@ -156,12 +185,18 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]): eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time) for hp, arch_info in hp2info.items(): arch_info.reset_pseudo_train_times("cifar100", None, train_per_epoch_time) - arch_info.reset_pseudo_eval_times("cifar100", None, "x-valid", eval_x_valid_time) + arch_info.reset_pseudo_eval_times( + "cifar100", None, "x-valid", eval_x_valid_time + ) arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_x_test_time) - arch_info.reset_pseudo_eval_times("cifar100", None, "ori-test", eval_ori_test_time) + arch_info.reset_pseudo_eval_times( + "cifar100", None, "ori-test", eval_ori_test_time + ) # ImageNet16-120 - train_per_epoch_time = list(hp2info["01"].query("ImageNet16-120", 777).train_times.values()) + train_per_epoch_time = list( + hp2info["01"].query("ImageNet16-120", 777).train_times.values() + ) train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], [] for key, value in hp2info["01"].query("ImageNet16-120", 777).eval_times.items(): @@ -178,9 +213,15 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]): eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time) for hp, arch_info in hp2info.items(): arch_info.reset_pseudo_train_times("ImageNet16-120", None, train_per_epoch_time) - arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "x-valid", eval_x_valid_time) - arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "x-test", eval_x_test_time) - arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "ori-test", eval_ori_test_time) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", None, "x-valid", eval_x_valid_time + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", None, "x-test", eval_x_test_time + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", None, "ori-test", eval_ori_test_time + ) return hp2info @@ -200,7 +241,9 @@ def simplify(save_dir, save_name, nets, total): seeds.add(seed) nums.append(len(xlist)) print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist))) - assert len(nets) == total == max(nums), "there are some missed files : {:} vs {:}".format(max(nums), total) + assert ( + len(nets) == total == max(nums) + ), "there are some missed files : {:} vs {:}".format(max(nums), total) print("{:} start simplify the checkpoint.".format(time_string())) datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") @@ -225,7 +268,10 @@ def simplify(save_dir, save_name, nets, total): for hp in hps: sub_save_dir = save_dir / "raw-data-{:}".format(hp) - ckps = [sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) for seed in seeds] + ckps = [ + sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) + for seed in seeds + ] ckps = [x for x in ckps if x.exists()] if len(ckps) == 0: raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp)) @@ -238,21 +284,31 @@ def simplify(save_dir, save_name, nets, total): hp2info["01"].clear_params() # to save some spaces... to_save_data = OrderedDict( - {"01": hp2info["01"].state_dict(), "12": hp2info["12"].state_dict(), "90": hp2info["90"].state_dict()} + { + "01": hp2info["01"].state_dict(), + "12": hp2info["12"].state_dict(), + "90": hp2info["90"].state_dict(), + } ) pickle_save(to_save_data, str(full_save_path)) for hp in hps: hp2info[hp].clear_params() to_save_data = OrderedDict( - {"01": hp2info["01"].state_dict(), "12": hp2info["12"].state_dict(), "90": hp2info["90"].state_dict()} + { + "01": hp2info["01"].state_dict(), + "12": hp2info["12"].state_dict(), + "90": hp2info["90"].state_dict(), + } ) pickle_save(to_save_data, str(simple_save_path)) arch2infos[index] = to_save_data # measure elapsed time arch_time.update(time.time() - end_time) end_time = time.time() - need_time = "{:}".format(convert_secs2time(arch_time.avg * (total - index - 1), True)) + need_time = "{:}".format( + convert_secs2time(arch_time.avg * (total - index - 1), True) + ) # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) print("{:} {:} done.".format(time_string(), save_name)) final_infos = { @@ -297,7 +353,8 @@ def traverse_net(candidates: List[int], N: int): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="NATS-Bench (size search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="NATS-Bench (size search space)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--base_save_dir", @@ -305,15 +362,27 @@ if __name__ == "__main__": default="./output/NATS-Bench-size", help="The base-name of folder to save checkpoints and log.", ) - parser.add_argument("--candidateC", type=int, nargs="+", default=[8, 16, 24, 32, 40, 48, 56, 64], help=".") - parser.add_argument("--num_layers", type=int, default=5, help="The number of layers in a network.") + parser.add_argument( + "--candidateC", + type=int, + nargs="+", + default=[8, 16, 24, 32, 40, 48, 56, 64], + help=".", + ) + parser.add_argument( + "--num_layers", type=int, default=5, help="The number of layers in a network." + ) parser.add_argument("--check_N", type=int, default=32768, help="For safety.") - parser.add_argument("--save_name", type=str, default="process", help="The save directory.") + parser.add_argument( + "--save_name", type=str, default="process", help="The save directory." + ) args = parser.parse_args() nets = traverse_net(args.candidateC, args.num_layers) if len(nets) != args.check_N: - raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) + raise ValueError( + "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) + ) save_dir = Path(args.base_save_dir) simplify(save_dir, args.save_name, nets, args.check_N) diff --git a/exps/NATS-Bench/sss-file-manager.py b/exps/NATS-Bench/sss-file-manager.py index 292a2ff..ce17652 100644 --- a/exps/NATS-Bench/sss-file-manager.py +++ b/exps/NATS-Bench/sss-file-manager.py @@ -54,7 +54,11 @@ def copy_data(source_dir, target_dir, meta_path): target_path = os.path.join(target_dir, file_name) if os.path.exists(source_path): s2t[source_path] = target_path - print("Map from {:} to {:}, find {:} missed ckps.".format(source_dir, target_dir, len(s2t))) + print( + "Map from {:} to {:}, find {:} missed ckps.".format( + source_dir, target_dir, len(s2t) + ) + ) for s, t in s2t.items(): copyfile(s, t) @@ -64,9 +68,18 @@ if __name__ == "__main__": description="NATS-Bench (size search space) file manager.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--mode", type=str, required=True, choices=["check", "copy"], help="The script mode.") parser.add_argument( - "--save_dir", type=str, default="output/NATS-Bench-size", help="Folder to save checkpoints and log." + "--mode", + type=str, + required=True, + choices=["check", "copy"], + help="The script mode.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="output/NATS-Bench-size", + help="Folder to save checkpoints and log.", ) parser.add_argument("--check_N", type=int, default=32768, help="For safety.") # use for train the model @@ -76,7 +89,10 @@ if __name__ == "__main__": for config in possible_configs: cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N) - torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), "{:}/meta-{:}.pth".format(args.save_dir, config)) + torch.save( + dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), + "{:}/meta-{:}.pth".format(args.save_dir, config), + ) elif args.mode == "copy": for config in possible_configs: cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) diff --git a/exps/NATS-Bench/test-nats-api.py b/exps/NATS-Bench/test-nats-api.py index 75ce7fd..3547ebb 100644 --- a/exps/NATS-Bench/test-nats-api.py +++ b/exps/NATS-Bench/test-nats-api.py @@ -91,14 +91,22 @@ if __name__ == "__main__": for fast_mode in [True, False]: for verbose in [True, False]: api_nats_tss = create(None, "tss", fast_mode=fast_mode, verbose=True) - print("{:} create with fast_mode={:} and verbose={:}".format(time_string(), fast_mode, verbose)) + print( + "{:} create with fast_mode={:} and verbose={:}".format( + time_string(), fast_mode, verbose + ) + ) test_api(api_nats_tss, False) del api_nats_tss gc.collect() for fast_mode in [True, False]: for verbose in [True, False]: - print("{:} create with fast_mode={:} and verbose={:}".format(time_string(), fast_mode, verbose)) + print( + "{:} create with fast_mode={:} and verbose={:}".format( + time_string(), fast_mode, verbose + ) + ) api_nats_sss = create(None, "size", fast_mode=fast_mode, verbose=True) print("{:} --->>> {:}".format(time_string(), api_nats_sss)) test_api(api_nats_sss, True) diff --git a/exps/NATS-Bench/tss-collect-patcher.py b/exps/NATS-Bench/tss-collect-patcher.py index 5b4b13f..6895aa9 100644 --- a/exps/NATS-Bench/tss-collect-patcher.py +++ b/exps/NATS-Bench/tss-collect-patcher.py @@ -50,7 +50,9 @@ def simplify(save_dir, save_name, nets, total, sup_config): seeds.add(seed) nums.append(len(xlist)) print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist))) - assert len(nets) == total == max(nums), "there are some missed files : {:} vs {:}".format(max(nums), total) + assert ( + len(nets) == total == max(nums) + ), "there are some missed files : {:} vs {:}".format(max(nums), total) print("{:} start simplify the checkpoint.".format(time_string())) datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") @@ -78,7 +80,9 @@ def simplify(save_dir, save_name, nets, total, sup_config): # measure elapsed time arch_time.update(time.time() - end_time) end_time = time.time() - need_time = "{:}".format(convert_secs2time(arch_time.avg * (total - index - 1), True)) + need_time = "{:}".format( + convert_secs2time(arch_time.avg * (total - index - 1), True) + ) # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) print("{:} {:} done.".format(time_string(), save_name)) final_infos = { @@ -108,7 +112,11 @@ def simplify(save_dir, save_name, nets, total, sup_config): def traverse_net(max_node): aa_nas_bench_ss = get_search_spaces("cell", "nats-bench") archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) - print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2))) + print( + "There are {:} archs vs {:}.".format( + len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2) + ) + ) random.seed(88) # please do not change this line for reproducibility random.shuffle(archs) @@ -117,10 +125,12 @@ def traverse_net(max_node): == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" ), "please check the 0-th architecture : {:}".format(archs[0]) assert ( - archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" + archs[9].tostr() + == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" ), "please check the 9-th architecture : {:}".format(archs[9]) assert ( - archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" + archs[123].tostr() + == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" ), "please check the 123-th architecture : {:}".format(archs[123]) return [x.tostr() for x in archs] @@ -128,7 +138,8 @@ def traverse_net(max_node): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="NATS-Bench (topology search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="NATS-Bench (topology search space)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--base_save_dir", @@ -136,16 +147,26 @@ if __name__ == "__main__": default="./output/NATS-Bench-topology", help="The base-name of folder to save checkpoints and log.", ) - parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.") - parser.add_argument("--channel", type=int, default=16, help="The number of channels.") - parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + parser.add_argument( + "--max_node", type=int, default=4, help="The maximum node in a cell." + ) + parser.add_argument( + "--channel", type=int, default=16, help="The number of channels." + ) + parser.add_argument( + "--num_cells", type=int, default=5, help="The number of cells in one stage." + ) parser.add_argument("--check_N", type=int, default=15625, help="For safety.") - parser.add_argument("--save_name", type=str, default="process", help="The save directory.") + parser.add_argument( + "--save_name", type=str, default="process", help="The save directory." + ) args = parser.parse_args() nets = traverse_net(args.max_node) if len(nets) != args.check_N: - raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) + raise ValueError( + "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) + ) save_dir = Path(args.base_save_dir) simplify( diff --git a/exps/NATS-Bench/tss-collect.py b/exps/NATS-Bench/tss-collect.py index 51017df..aee2a6b 100644 --- a/exps/NATS-Bench/tss-collect.py +++ b/exps/NATS-Bench/tss-collect.py @@ -32,7 +32,9 @@ from utils import get_md5_file from nas_201_api import NASBench201API -api = NASBench201API("{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"])) +api = NASBench201API( + "{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"]) +) NATS_TSS_BASE_NAME = "NATS-tss-v1_0" # 2020.08.28 @@ -68,35 +70,58 @@ def create_result_count( ) if "train_times" in results: # new version xresult.update_train_info( - results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"] + results["train_acc1es"], + results["train_acc5es"], + results["train_losses"], + results["train_times"], + ) + xresult.update_eval( + results["valid_acc1es"], results["valid_losses"], results["valid_times"] ) - xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"]) else: network = get_cell_based_tiny_net(net_config) network.load_state_dict(xresult.get_net_param()) if dataset == "cifar10-valid": - xresult.update_OLD_eval("x-valid", results["valid_acc1es"], results["valid_losses"]) + xresult.update_OLD_eval( + "x-valid", results["valid_acc1es"], results["valid_losses"] + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda() ) - xresult.update_OLD_eval("ori-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_OLD_eval( + "ori-test", + {results["total_epoch"] - 1: top1}, + {results["total_epoch"] - 1: loss}, + ) xresult.update_latency(latencies) elif dataset == "cifar10": - xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + xresult.update_OLD_eval( + "ori-test", results["valid_acc1es"], results["valid_losses"] + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() ) xresult.update_latency(latencies) elif dataset == "cifar100" or dataset == "ImageNet16-120": - xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) + xresult.update_OLD_eval( + "ori-test", results["valid_acc1es"], results["valid_losses"] + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda() ) - xresult.update_OLD_eval("x-valid", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_OLD_eval( + "x-valid", + {results["total_epoch"] - 1: top1}, + {results["total_epoch"] - 1: loss}, + ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() ) - xresult.update_OLD_eval("x-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) + xresult.update_OLD_eval( + "x-test", + {results["total_epoch"] - 1: top1}, + {results["total_epoch"] - 1: loss}, + ) xresult.update_latency(latencies) else: raise ValueError("invalid dataset name : {:}".format(dataset)) @@ -112,12 +137,18 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic ok_dataset = 0 for dataset in datasets: if dataset not in checkpoint: - print("Can not find {:} in arch-{:} from {:}".format(dataset, arch_index, checkpoint_path)) + print( + "Can not find {:} in arch-{:} from {:}".format( + dataset, arch_index, checkpoint_path + ) + ) continue else: ok_dataset += 1 results = checkpoint[dataset] - assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( + assert results[ + "finish-train" + ], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( arch_index, used_seed, dataset, checkpoint_path ) arch_config = { @@ -127,7 +158,9 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic "class_num": results["config"]["class_num"], } - xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict) + xresult = create_result_count( + used_seed, dataset, arch_config, results, dataloader_dict + ) information.update(dataset, int(used_seed), xresult) if ok_dataset == 0: raise ValueError("{:} does not find any data".format(checkpoint_path)) @@ -137,7 +170,8 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResults]): # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth cifar010_latency = ( - api.get_latency(arch_index, "cifar10-valid", hp="200") + api.get_latency(arch_index, "cifar10", hp="200") + api.get_latency(arch_index, "cifar10-valid", hp="200") + + api.get_latency(arch_index, "cifar10", hp="200") ) / 2 cifar100_latency = api.get_latency(arch_index, "cifar100", hp="200") image_latency = api.get_latency(arch_index, "ImageNet16-120", hp="200") @@ -147,7 +181,9 @@ def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResult arch_info.reset_latency("cifar100", None, cifar100_latency) arch_info.reset_latency("ImageNet16-120", None, image_latency) - train_per_epoch_time = list(arch_infos["12"].query("cifar10-valid", 777).train_times.values()) + train_per_epoch_time = list( + arch_infos["12"].query("cifar10-valid", 777).train_times.values() + ) train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) eval_ori_test_time, eval_x_valid_time = [], [] for key, value in arch_infos["12"].query("cifar10-valid", 777).eval_times.items(): @@ -157,7 +193,9 @@ def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResult eval_x_valid_time.append(value) else: raise ValueError("-- {:} --".format(key)) - eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time)) + eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float( + np.mean(eval_x_valid_time) + ) nums = { "ImageNet16-120-train": 151700, "ImageNet16-120-valid": 3000, @@ -170,36 +208,72 @@ def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResult "cifar100-test": 10000, "cifar100-valid": 5000, } - eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (nums["cifar10-valid-valid"] + nums["cifar10-test"]) + eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / ( + nums["cifar10-valid-valid"] + nums["cifar10-test"] + ) for hp, arch_info in arch_infos.items(): arch_info.reset_pseudo_train_times( - "cifar10-valid", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-valid-train"] + "cifar10-valid", + None, + train_per_epoch_time + / nums["cifar10-valid-train"] + * nums["cifar10-valid-train"], ) arch_info.reset_pseudo_train_times( - "cifar10", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"] + "cifar10", + None, + train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"], ) arch_info.reset_pseudo_train_times( - "cifar100", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"] + "cifar100", + None, + train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"], ) arch_info.reset_pseudo_train_times( - "ImageNet16-120", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["ImageNet16-120-train"] + "ImageNet16-120", + None, + train_per_epoch_time + / nums["cifar10-valid-train"] + * nums["ImageNet16-120-train"], ) arch_info.reset_pseudo_eval_times( - "cifar10-valid", None, "x-valid", eval_per_sample * nums["cifar10-valid-valid"] - ) - arch_info.reset_pseudo_eval_times("cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"]) - arch_info.reset_pseudo_eval_times("cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"]) - arch_info.reset_pseudo_eval_times("cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"]) - arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"]) - arch_info.reset_pseudo_eval_times("cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"]) - arch_info.reset_pseudo_eval_times( - "ImageNet16-120", None, "x-valid", eval_per_sample * nums["ImageNet16-120-valid"] + "cifar10-valid", + None, + "x-valid", + eval_per_sample * nums["cifar10-valid-valid"], ) arch_info.reset_pseudo_eval_times( - "ImageNet16-120", None, "x-test", eval_per_sample * nums["ImageNet16-120-valid"] + "cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"] ) arch_info.reset_pseudo_eval_times( - "ImageNet16-120", None, "ori-test", eval_per_sample * nums["ImageNet16-120-test"] + "cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"] + ) + arch_info.reset_pseudo_eval_times( + "cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"] + ) + arch_info.reset_pseudo_eval_times( + "cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"] + ) + arch_info.reset_pseudo_eval_times( + "cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"] + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", + None, + "x-valid", + eval_per_sample * nums["ImageNet16-120-valid"], + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", + None, + "x-test", + eval_per_sample * nums["ImageNet16-120-valid"], + ) + arch_info.reset_pseudo_eval_times( + "ImageNet16-120", + None, + "ori-test", + eval_per_sample * nums["ImageNet16-120-test"], ) return arch_infos @@ -220,7 +294,9 @@ def simplify(save_dir, save_name, nets, total, sup_config): seeds.add(seed) nums.append(len(xlist)) print(" [seed={:}] there are {:} checkpoints.".format(seed, len(xlist))) - assert len(nets) == total == max(nums), "there are some missed files : {:} vs {:}".format(max(nums), total) + assert ( + len(nets) == total == max(nums) + ), "there are some missed files : {:} vs {:}".format(max(nums), total) print("{:} start simplify the checkpoint.".format(time_string())) datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") @@ -236,7 +312,12 @@ def simplify(save_dir, save_name, nets, total, sup_config): arch2infos, evaluated_indexes = dict(), set() end_time, arch_time = time.time(), AverageMeter() # save the meta information - temp_final_infos = {"meta_archs": nets, "total_archs": total, "arch2infos": None, "evaluated_indexes": set()} + temp_final_infos = { + "meta_archs": nets, + "total_archs": total, + "arch2infos": None, + "evaluated_indexes": set(), + } pickle_save(temp_final_infos, str(full_save_dir / "meta.pickle")) pickle_save(temp_final_infos, str(simple_save_dir / "meta.pickle")) @@ -248,29 +329,40 @@ def simplify(save_dir, save_name, nets, total, sup_config): simple_save_path = simple_save_dir / "{:06d}.pickle".format(index) for hp in hps: sub_save_dir = save_dir / "raw-data-{:}".format(hp) - ckps = [sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) for seed in seeds] + ckps = [ + sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) + for seed in seeds + ] ckps = [x for x in ckps if x.exists()] if len(ckps) == 0: raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp)) - arch_info = account_one_arch(index, arch_str, ckps, datasets, dataloader_dict) + arch_info = account_one_arch( + index, arch_str, ckps, datasets, dataloader_dict + ) hp2info[hp] = arch_info hp2info = correct_time_related_info(index, hp2info) evaluated_indexes.add(index) - to_save_data = OrderedDict({"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()}) + to_save_data = OrderedDict( + {"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()} + ) pickle_save(to_save_data, str(full_save_path)) for hp in hps: hp2info[hp].clear_params() - to_save_data = OrderedDict({"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()}) + to_save_data = OrderedDict( + {"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()} + ) pickle_save(to_save_data, str(simple_save_path)) arch2infos[index] = to_save_data # measure elapsed time arch_time.update(time.time() - end_time) end_time = time.time() - need_time = "{:}".format(convert_secs2time(arch_time.avg * (total - index - 1), True)) + need_time = "{:}".format( + convert_secs2time(arch_time.avg * (total - index - 1), True) + ) # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) print("{:} {:} done.".format(time_string(), save_name)) final_infos = { @@ -303,7 +395,11 @@ def simplify(save_dir, save_name, nets, total, sup_config): def traverse_net(max_node): aa_nas_bench_ss = get_search_spaces("cell", "nats-bench") archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) - print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2))) + print( + "There are {:} archs vs {:}.".format( + len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2) + ) + ) random.seed(88) # please do not change this line for reproducibility random.shuffle(archs) @@ -312,10 +408,12 @@ def traverse_net(max_node): == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" ), "please check the 0-th architecture : {:}".format(archs[0]) assert ( - archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" + archs[9].tostr() + == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" ), "please check the 9-th architecture : {:}".format(archs[9]) assert ( - archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" + archs[123].tostr() + == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" ), "please check the 123-th architecture : {:}".format(archs[123]) return [x.tostr() for x in archs] @@ -323,7 +421,8 @@ def traverse_net(max_node): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="NATS-Bench (topology search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="NATS-Bench (topology search space)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--base_save_dir", @@ -331,16 +430,26 @@ if __name__ == "__main__": default="./output/NATS-Bench-topology", help="The base-name of folder to save checkpoints and log.", ) - parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.") - parser.add_argument("--channel", type=int, default=16, help="The number of channels.") - parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + parser.add_argument( + "--max_node", type=int, default=4, help="The maximum node in a cell." + ) + parser.add_argument( + "--channel", type=int, default=16, help="The number of channels." + ) + parser.add_argument( + "--num_cells", type=int, default=5, help="The number of cells in one stage." + ) parser.add_argument("--check_N", type=int, default=15625, help="For safety.") - parser.add_argument("--save_name", type=str, default="process", help="The save directory.") + parser.add_argument( + "--save_name", type=str, default="process", help="The save directory." + ) args = parser.parse_args() nets = traverse_net(args.max_node) if len(nets) != args.check_N: - raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) + raise ValueError( + "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) + ) save_dir = Path(args.base_save_dir) simplify( diff --git a/exps/NATS-Bench/tss-file-manager.py b/exps/NATS-Bench/tss-file-manager.py index 10c94eb..6038c7b 100644 --- a/exps/NATS-Bench/tss-file-manager.py +++ b/exps/NATS-Bench/tss-file-manager.py @@ -53,7 +53,11 @@ def copy_data(source_dir, target_dir, meta_path): target_path = os.path.join(target_dir, file_name) if os.path.exists(source_path): s2t[source_path] = target_path - print("Map from {:} to {:}, find {:} missed ckps.".format(source_dir, target_dir, len(s2t))) + print( + "Map from {:} to {:}, find {:} missed ckps.".format( + source_dir, target_dir, len(s2t) + ) + ) for s, t in s2t.items(): copyfile(s, t) @@ -63,9 +67,18 @@ if __name__ == "__main__": description="NATS-Bench (topology search space) file manager.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--mode", type=str, required=True, choices=["check", "copy"], help="The script mode.") parser.add_argument( - "--save_dir", type=str, default="output/NATS-Bench-topology", help="Folder to save checkpoints and log." + "--mode", + type=str, + required=True, + choices=["check", "copy"], + help="The script mode.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="output/NATS-Bench-topology", + help="Folder to save checkpoints and log.", ) parser.add_argument("--check_N", type=int, default=15625, help="For safety.") # use for train the model @@ -75,8 +88,13 @@ if __name__ == "__main__": if args.mode == "check": for config, possible_seeds in zip(possible_configs, possible_seedss): cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) - seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds) - torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), "{:}/meta-{:}.pth".format(args.save_dir, config)) + seed2ckps, miss2ckps = obtain_valid_ckp( + cur_save_dir, args.check_N, possible_seeds + ) + torch.save( + dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), + "{:}/meta-{:}.pth".format(args.save_dir, config), + ) elif args.mode == "copy": for config in possible_configs: cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) diff --git a/exps/NATS-algos/bohb.py b/exps/NATS-algos/bohb.py index e7840b1..e06f574 100644 --- a/exps/NATS-algos/bohb.py +++ b/exps/NATS-algos/bohb.py @@ -36,7 +36,9 @@ def get_topology_config_space(search_space, max_nodes=4): for i in range(1, max_nodes): for j in range(i): node_str = "{:}<-{:}".format(i, j) - cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) + cs.add_hyperparameter( + ConfigSpace.CategoricalHyperparameter(node_str, search_space) + ) return cs @@ -44,7 +46,9 @@ def get_size_config_space(search_space): cs = ConfigSpace.ConfigurationSpace() for ilayer in range(search_space["numbers"]): node_str = "layer-{:}".format(ilayer) - cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space["candidates"])) + cs.add_hyperparameter( + ConfigSpace.CategoricalHyperparameter(node_str, search_space["candidates"]) + ) return cs @@ -159,8 +163,14 @@ def main(xargs, api): current_best_index.append(api.query_index_by_arch(arch)) best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1] - logger.log("Best found configuration: {:} within {:.3f} s".format(best_arch, workers[0].total_times[-1])) - info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90") + logger.log( + "Best found configuration: {:} within {:.3f} s".format( + best_arch, workers[0].total_times[-1] + ) + ) + info = api.query_info_str_by_arch( + best_arch, "200" if xargs.search_space == "tss" else "90" + ) logger.log("{:}".format(info)) logger.log("-" * 100) logger.close() @@ -169,7 +179,9 @@ def main(xargs, api): if __name__ == "__main__": - parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale") + parser = argparse.ArgumentParser( + "BOHB: Robust and Efficient Hyperparameter Optimization at Scale" + ) parser.add_argument( "--dataset", type=str, @@ -177,35 +189,80 @@ if __name__ == "__main__": help="Choose between Cifar10/100 and ImageNet-16.", ) # general arg - parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") parser.add_argument( - "--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)." + "--search_space", + type=str, + choices=["tss", "sss"], + help="Choose the search space.", + ) + parser.add_argument( + "--time_budget", + type=int, + default=20000, + help="The total time cost budge for searching (in seconds).", + ) + parser.add_argument( + "--loops_if_rand", type=int, default=500, help="The total runs for evaluation." ) - parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.") # BOHB parser.add_argument( - "--strategy", default="sampling", type=str, nargs="?", help="optimization strategy for the acquisition function" - ) - parser.add_argument("--min_bandwidth", default=0.3, type=float, nargs="?", help="minimum bandwidth for KDE") - parser.add_argument( - "--num_samples", default=64, type=int, nargs="?", help="number of samples for the acquisition function" + "--strategy", + default="sampling", + type=str, + nargs="?", + help="optimization strategy for the acquisition function", ) parser.add_argument( - "--random_fraction", default=0.33, type=float, nargs="?", help="fraction of random configurations" + "--min_bandwidth", + default=0.3, + type=float, + nargs="?", + help="minimum bandwidth for KDE", ) - parser.add_argument("--bandwidth_factor", default=3, type=int, nargs="?", help="factor multiplied to the bandwidth") parser.add_argument( - "--n_iters", default=300, type=int, nargs="?", help="number of iterations for optimization method" + "--num_samples", + default=64, + type=int, + nargs="?", + help="number of samples for the acquisition function", + ) + parser.add_argument( + "--random_fraction", + default=0.33, + type=float, + nargs="?", + help="fraction of random configurations", + ) + parser.add_argument( + "--bandwidth_factor", + default=3, + type=int, + nargs="?", + help="factor multiplied to the bandwidth", + ) + parser.add_argument( + "--n_iters", + default=300, + type=int, + nargs="?", + help="number of iterations for optimization method", ) # log - parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") + parser.add_argument( + "--save_dir", + type=str, + default="./output/search", + help="Folder to save checkpoints and log.", + ) parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") args = parser.parse_args() api = create(None, args.search_space, fast_mode=False, verbose=False) args.save_dir = os.path.join( - "{:}-{:}".format(args.save_dir, args.search_space), "{:}-T{:}".format(args.dataset, args.time_budget), "BOHB" + "{:}-{:}".format(args.save_dir, args.search_space), + "{:}-T{:}".format(args.dataset, args.time_budget), + "BOHB", ) print("save-dir : {:}".format(args.save_dir)) diff --git a/exps/NATS-algos/random_wo_share.py b/exps/NATS-algos/random_wo_share.py index 5c1d3ad..773ec0c 100644 --- a/exps/NATS-algos/random_wo_share.py +++ b/exps/NATS-algos/random_wo_share.py @@ -19,7 +19,13 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_search_spaces @@ -45,12 +51,16 @@ def main(xargs, api): current_best_index = [] while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget: arch = random_arch() - accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, hp="12") + accuracy, _, _, total_cost = api.simulate_train_eval( + arch, xargs.dataset, hp="12" + ) total_time_cost.append(total_cost) history.append(arch) if best_arch is None or best_acc < accuracy: best_acc, best_arch = accuracy, arch - logger.log("[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy)) + logger.log( + "[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy) + ) current_best_index.append(api.query_index_by_arch(best_arch)) logger.log( "{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.".format( @@ -58,7 +68,9 @@ def main(xargs, api): ) ) - info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90") + info = api.query_info_str_by_arch( + best_arch, "200" if xargs.search_space == "tss" else "90" + ) logger.log("{:}".format(info)) logger.log("-" * 100) logger.close() @@ -73,21 +85,38 @@ if __name__ == "__main__": choices=["cifar10", "cifar100", "ImageNet16-120"], help="Choose between Cifar10/100 and ImageNet-16.", ) - parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") + parser.add_argument( + "--search_space", + type=str, + choices=["tss", "sss"], + help="Choose the search space.", + ) parser.add_argument( - "--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)." + "--time_budget", + type=int, + default=20000, + help="The total time cost budge for searching (in seconds).", + ) + parser.add_argument( + "--loops_if_rand", type=int, default=500, help="The total runs for evaluation." ) - parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.") # log - parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") + parser.add_argument( + "--save_dir", + type=str, + default="./output/search", + help="Folder to save checkpoints and log.", + ) parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") args = parser.parse_args() api = create(None, args.search_space, fast_mode=True, verbose=False) args.save_dir = os.path.join( - "{:}-{:}".format(args.save_dir, args.search_space), "{:}-T{:}".format(args.dataset, args.time_budget), "RANDOM" + "{:}-{:}".format(args.save_dir, args.search_space), + "{:}-T{:}".format(args.dataset, args.time_budget), + "RANDOM", ) print("save-dir : {:}".format(args.save_dir)) diff --git a/exps/NATS-algos/regularized_ea.py b/exps/NATS-algos/regularized_ea.py index 86c21e3..6499c8a 100644 --- a/exps/NATS-algos/regularized_ea.py +++ b/exps/NATS-algos/regularized_ea.py @@ -23,7 +23,13 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import CellStructure, get_search_spaces @@ -103,7 +109,15 @@ def mutate_size_func(info): def regularized_evolution( - cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset + cycles, + population_size, + sample_size, + time_budget, + random_arch, + mutate_arch, + api, + use_proxy, + dataset, ): """Algorithm for regularized evolution (i.e. aging evolution). @@ -122,7 +136,10 @@ def regularized_evolution( """ population = collections.deque() api.reset_time() - history, total_time_cost = [], [] # Not used by the algorithm, only used to report results. + history, total_time_cost = ( + [], + [], + ) # Not used by the algorithm, only used to report results. current_best_index = [] # Initialize the population with random models. while len(population) < population_size: @@ -135,7 +152,9 @@ def regularized_evolution( population.append(model) history.append((model.accuracy, model.arch)) total_time_cost.append(total_cost) - current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1])) + current_best_index.append( + api.query_index_by_arch(max(history, key=lambda x: x[0])[1]) + ) # Carry out evolution in cycles. Each cycle produces a model and removes another. while total_time_cost[-1] < time_budget: @@ -160,7 +179,9 @@ def regularized_evolution( # Append the info population.append(child) history.append((child.accuracy, child.arch)) - current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1])) + current_best_index.append( + api.query_index_by_arch(max(history, key=lambda x: x[0])[1]) + ) total_time_cost.append(total_cost) # Remove the oldest model. @@ -183,7 +204,10 @@ def main(xargs, api): x_start_time = time.time() logger.log("{:} use api : {:}".format(time_string(), api)) - logger.log("-" * 30 + " start searching with the time budget of {:} s".format(xargs.time_budget)) + logger.log( + "-" * 30 + + " start searching with the time budget of {:} s".format(xargs.time_budget) + ) history, current_best_index, total_times = regularized_evolution( xargs.ea_cycles, xargs.ea_population, @@ -203,7 +227,9 @@ def main(xargs, api): best_arch = max(history, key=lambda x: x[0])[1] logger.log("{:} best arch is {:}".format(time_string(), best_arch)) - info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90") + info = api.query_info_str_by_arch( + best_arch, "200" if xargs.search_space == "tss" else "90" + ) logger.log("{:}".format(info)) logger.log("-" * 100) logger.close() @@ -218,19 +244,39 @@ if __name__ == "__main__": choices=["cifar10", "cifar100", "ImageNet16-120"], help="Choose between Cifar10/100 and ImageNet-16.", ) - parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") + parser.add_argument( + "--search_space", + type=str, + choices=["tss", "sss"], + help="Choose the search space.", + ) # hyperparameters for REA parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.") parser.add_argument("--ea_population", type=int, help="The population size in EA.") parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.") parser.add_argument( - "--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)." + "--time_budget", + type=int, + default=20000, + help="The total time cost budge for searching (in seconds).", + ) + parser.add_argument( + "--use_proxy", + type=int, + default=1, + help="Whether to use the proxy (H0) task or not.", ) - parser.add_argument("--use_proxy", type=int, default=1, help="Whether to use the proxy (H0) task or not.") # - parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.") + parser.add_argument( + "--loops_if_rand", type=int, default=500, help="The total runs for evaluation." + ) # log - parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") + parser.add_argument( + "--save_dir", + type=str, + default="./output/search", + help="Folder to save checkpoints and log.", + ) parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") args = parser.parse_args() @@ -238,7 +284,9 @@ if __name__ == "__main__": args.save_dir = os.path.join( "{:}-{:}".format(args.save_dir, args.search_space), - "{:}-T{:}{:}".format(args.dataset, args.time_budget, "" if args.use_proxy > 0 else "-FULL"), + "{:}-T{:}{:}".format( + args.dataset, args.time_budget, "" if args.use_proxy > 0 else "-FULL" + ), "R-EA-SS{:}".format(args.ea_sample_size), ) print("save-dir : {:}".format(args.save_dir)) diff --git a/exps/NATS-algos/reinforce.py b/exps/NATS-algos/reinforce.py index 1280ff8..b6505e7 100644 --- a/exps/NATS-algos/reinforce.py +++ b/exps/NATS-algos/reinforce.py @@ -23,7 +23,13 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import CellStructure, get_search_spaces @@ -40,7 +46,9 @@ class PolicyTopology(nn.Module): for j in range(i): node_str = "{:}<-{:}".format(i, j) self.edge2index[node_str] = len(self.edge2index) - self.arch_parameters = nn.Parameter(1e-3 * torch.randn(len(self.edge2index), len(search_space))) + self.arch_parameters = nn.Parameter( + 1e-3 * torch.randn(len(self.edge2index), len(search_space)) + ) def generate_arch(self, actions): genotypes = [] @@ -76,7 +84,9 @@ class PolicySize(nn.Module): super(PolicySize, self).__init__() self.candidates = search_space["candidates"] self.numbers = search_space["numbers"] - self.arch_parameters = nn.Parameter(1e-3 * torch.randn(self.numbers, len(self.candidates))) + self.arch_parameters = nn.Parameter( + 1e-3 * torch.randn(self.numbers, len(self.candidates)) + ) def generate_arch(self, actions): channels = [str(self.candidates[i]) for i in actions] @@ -103,7 +113,9 @@ class ExponentialMovingAverage(object): self._momentum = momentum def update(self, value): - self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value + self._numerator = ( + self._momentum * self._numerator + (1 - self._momentum) * value + ) self._denominator = self._momentum * self._denominator + (1 - self._momentum) def value(self): @@ -143,14 +155,18 @@ def main(xargs, api): # REINFORCE x_start_time = time.time() - logger.log("Will start searching with time budget of {:} s.".format(xargs.time_budget)) + logger.log( + "Will start searching with time budget of {:} s.".format(xargs.time_budget) + ) total_steps, total_costs, trace = 0, [], [] current_best_index = [] while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget: start_time = time.time() log_prob, action = select_action(policy) arch = policy.generate_arch(action) - reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, hp="12") + reward, _, _, current_total_cost = api.simulate_train_eval( + arch, xargs.dataset, hp="12" + ) trace.append((reward, arch)) total_costs.append(current_total_cost) @@ -168,7 +184,9 @@ def main(xargs, api): ) ) # to analyze - current_best_index.append(api.query_index_by_arch(max(trace, key=lambda x: x[0])[1])) + current_best_index.append( + api.query_index_by_arch(max(trace, key=lambda x: x[0])[1]) + ) # best_arch = policy.genotype() # first version best_arch = max(trace, key=lambda x: x[0])[1] logger.log( @@ -176,7 +194,9 @@ def main(xargs, api): total_steps, total_costs[-1], time.time() - x_start_time ) ) - info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90") + info = api.query_info_str_by_arch( + best_arch, "200" if xargs.search_space == "tss" else "90" + ) logger.log("{:}".format(info)) logger.log("-" * 100) logger.close() @@ -193,17 +213,38 @@ if __name__ == "__main__": choices=["cifar10", "cifar100", "ImageNet16-120"], help="Choose between Cifar10/100 and ImageNet-16.", ) - parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") - parser.add_argument("--learning_rate", type=float, help="The learning rate for REINFORCE.") - parser.add_argument("--EMA_momentum", type=float, default=0.9, help="The momentum value for EMA.") parser.add_argument( - "--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)." + "--search_space", + type=str, + choices=["tss", "sss"], + help="Choose the search space.", ) - parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.") - # log - parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") parser.add_argument( - "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + "--learning_rate", type=float, help="The learning rate for REINFORCE." + ) + parser.add_argument( + "--EMA_momentum", type=float, default=0.9, help="The momentum value for EMA." + ) + parser.add_argument( + "--time_budget", + type=int, + default=20000, + help="The total time cost budge for searching (in seconds).", + ) + parser.add_argument( + "--loops_if_rand", type=int, default=500, help="The total runs for evaluation." + ) + # log + parser.add_argument( + "--save_dir", + type=str, + default="./output/search", + help="Folder to save checkpoints and log.", + ) + parser.add_argument( + "--arch_nas_dataset", + type=str, + help="The path to load the architecture dataset (tiny-nas-benchmark).", ) parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") diff --git a/exps/NATS-algos/search-cell.py b/exps/NATS-algos/search-cell.py index fd51703..dc21021 100644 --- a/exps/NATS-algos/search-cell.py +++ b/exps/NATS-algos/search-cell.py @@ -37,7 +37,13 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import count_parameters_in_MB, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_cell_based_tiny_net, get_search_spaces @@ -49,7 +55,9 @@ def _concat(xs): return torch.cat([x.view(-1) for x in xs]) -def _hessian_vector_product(vector, network, criterion, base_inputs, base_targets, r=1e-2): +def _hessian_vector_product( + vector, network, criterion, base_inputs, base_targets, r=1e-2 +): R = r / _concat(vector).norm() for p, v in zip(network.weights, vector): p.data.add_(R, v) @@ -68,7 +76,15 @@ def _hessian_vector_product(vector, network, criterion, base_inputs, base_target return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)] -def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets): +def backward_step_unrolled( + network, + criterion, + base_inputs, + base_targets, + w_optimizer, + arch_inputs, + arch_targets, +): # _compute_unrolled_model _, logits = network(base_inputs) loss = criterion(logits, base_targets) @@ -80,7 +96,9 @@ def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_opti with torch.no_grad(): theta = _concat(network.weights) try: - moment = _concat(w_optimizer.state[v]["momentum_buffer"] for v in network.weights) + moment = _concat( + w_optimizer.state[v]["momentum_buffer"] for v in network.weights + ) moment = moment.mul_(momentum) except: moment = torch.zeros_like(theta) @@ -105,7 +123,9 @@ def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_opti dalpha = unrolled_model.arch_parameters.grad vector = [v.grad.data for v in unrolled_model.weights] - [implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets) + [implicit_grads] = _hessian_vector_product( + vector, network, criterion, base_inputs, base_targets + ) dalpha.data.sub_(LR, implicit_grads.data) @@ -116,13 +136,26 @@ def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_opti return unrolled_loss.detach(), unrolled_logits.detach() -def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, algo, logger): +def search_func( + xloader, + network, + criterion, + scheduler, + w_optimizer, + a_optimizer, + epoch_str, + print_freq, + algo, + logger, +): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() end = time.time() network.train() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( + xloader + ): scheduler.update(None, 1.0 * step / len(xloader)) base_inputs = base_inputs.cuda(non_blocking=True) arch_inputs = arch_inputs.cuda(non_blocking=True) @@ -155,7 +188,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer base_loss.backward() w_optimizer.step() # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_prec1, base_prec5 = obtain_accuracy( + logits.data, base_targets.data, topk=(1, 5) + ) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) @@ -174,7 +209,13 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer network.zero_grad() if algo == "darts-v2": arch_loss, logits = backward_step_unrolled( - network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets + network, + criterion, + base_inputs, + base_targets, + w_optimizer, + arch_inputs, + arch_targets, ) a_optimizer.step() elif algo == "random" or algo == "enas": @@ -187,7 +228,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer arch_loss.backward() a_optimizer.step() # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -197,7 +240,11 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Sstr = ( + "*SEARCH* " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) @@ -208,14 +255,31 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer loss=arch_losses, top1=arch_top1, top5=arch_top5 ) logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) - return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg + return ( + base_losses.avg, + base_top1.avg, + base_top5.avg, + arch_losses.avg, + arch_top1.avg, + arch_top5.avg, + ) -def train_controller(xloader, network, criterion, optimizer, prev_baseline, epoch_str, print_freq, logger): +def train_controller( + xloader, network, criterion, optimizer, prev_baseline, epoch_str, print_freq, logger +): # config. (containing some necessary arg) # baseline: The baseline score (i.e. average val_acc) from the previous epoch data_time, batch_time = AverageMeter(), AverageMeter() - GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = ( + ( + GradnormMeter, + LossMeter, + ValAccMeter, + EntropyMeter, + BaselineMeter, + RewardMeter, + xend, + ) = ( AverageMeter(), AverageMeter(), AverageMeter(), @@ -255,7 +319,9 @@ def train_controller(xloader, network, criterion, optimizer, prev_baseline, epoc if prev_baseline is None: baseline = val_top1 else: - baseline = prev_baseline - (1 - controller_bl_dec) * (prev_baseline - reward) + baseline = prev_baseline - (1 - controller_bl_dec) * ( + prev_baseline - reward + ) loss = -1 * log_prob * (reward - baseline) @@ -274,7 +340,9 @@ def train_controller(xloader, network, criterion, optimizer, prev_baseline, epoc batch_time.update(time.time() - xend) xend = time.time() if (step + 1) % controller_num_aggregate == 0: - grad_norm = torch.nn.utils.clip_grad_norm_(network.controller.parameters(), 5.0) + grad_norm = torch.nn.utils.clip_grad_norm_( + network.controller.parameters(), 5.0 + ) GradnormMeter.update(grad_norm) optimizer.step() network.controller.zero_grad() @@ -283,13 +351,18 @@ def train_controller(xloader, network, criterion, optimizer, prev_baseline, epoc Sstr = ( "*Train-Controller* " + time_string() - + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, controller_train_steps * controller_num_aggregate) + + " [{:}][{:03d}/{:03d}]".format( + epoch_str, step, controller_train_steps * controller_num_aggregate + ) ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})".format( - loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter + loss=LossMeter, + top1=ValAccMeter, + reward=RewardMeter, + basel=BaselineMeter, ) Estr = "Entropy={:.4f} ({:.4f})".format(EntropyMeter.val, EntropyMeter.avg) logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Estr) @@ -323,7 +396,9 @@ def get_best_arch(xloader, network, n_samples, algo): loader_iter = iter(xloader) inputs, targets = next(loader_iter) _, logits = network(inputs.cuda(non_blocking=True)) - val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) + val_top1, val_top5 = obtain_accuracy( + logits.cpu().data, targets.data, topk=(1, 5) + ) valid_accs.append(val_top1.item()) best_idx = np.argmax(valid_accs) best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] @@ -344,7 +419,9 @@ def valid_func(xloader, network, criterion, algo, logger): _, logits = network(arch_inputs.cuda(non_blocking=True)) arch_loss = criterion(logits, arch_targets) # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -363,11 +440,17 @@ def main(xargs): prepare_seed(xargs.rand_seed) logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) if xargs.overwite_epochs is None: extra_info = {"class_num": class_num, "xshape": xshape} else: - extra_info = {"class_num": class_num, "xshape": xshape, "epochs": xargs.overwite_epochs} + extra_info = { + "class_num": class_num, + "xshape": xshape, + "epochs": xargs.overwite_epochs, + } config = load_config(xargs.config_path, extra_info, logger) search_loader, train_loader, valid_loader = get_nas_search_loaders( train_data, @@ -405,7 +488,9 @@ def main(xargs): search_model.set_algo(xargs.algo) logger.log("{:}".format(search_model)) - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) + w_optimizer, w_scheduler, criterion = get_optim_scheduler( + search_model.weights, config + ) a_optimizer = torch.optim.Adam( search_model.alphas, lr=xargs.arch_learning_rate, @@ -426,13 +511,23 @@ def main(xargs): api = None logger.log("{:} create API = {:} done".format(time_string(), api)) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(last_info) + ) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) @@ -444,11 +539,17 @@ def main(xargs): w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + last_info, start_epoch + ) ) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: network.return_topK(1, True)[0]} + start_epoch, valid_accuracies, genotypes = ( + 0, + {"best": -1}, + {-1: network.return_topK(1, True)[0]}, + ) baseline = None # start training @@ -460,15 +561,35 @@ def main(xargs): ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.val * (total_epoch - epoch), True) + ) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) - logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min(w_scheduler.get_lr()))) + logger.log( + "\n[Search the {:}-th epoch] {:}, LR={:}".format( + epoch_str, need_time, min(w_scheduler.get_lr()) + ) + ) network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate) if xargs.algo == "gdas": - network.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) - logger.log("[RESET tau as : {:} and drop_path as {:}]".format(network.tau, network.drop_path)) - search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 = search_func( + network.set_tau( + xargs.tau_max + - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1) + ) + logger.log( + "[RESET tau as : {:} and drop_path as {:}]".format( + network.tau, network.drop_path + ) + ) + ( + search_w_loss, + search_w_top1, + search_w_top5, + search_a_loss, + search_a_top1, + search_a_top5, + ) = search_func( search_loader, network, criterion, @@ -493,7 +614,14 @@ def main(xargs): ) if xargs.algo == "enas": ctl_loss, ctl_acc, baseline, ctl_reward = train_controller( - valid_loader, network, criterion, a_optimizer, baseline, epoch_str, xargs.print_freq, logger + valid_loader, + network, + criterion, + a_optimizer, + baseline, + epoch_str, + xargs.print_freq, + logger, ) logger.log( "[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}".format( @@ -501,7 +629,9 @@ def main(xargs): ) ) - genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo) + genotype, temp_accuracy = get_best_arch( + valid_loader, network, xargs.eval_candidate_num, xargs.algo + ) if xargs.algo == "setn" or xargs.algo == "enas": network.set_cal_mode("dynamic", genotype) elif xargs.algo == "gdas": @@ -512,8 +642,14 @@ def main(xargs): network.set_cal_mode("urs", None) else: raise ValueError("Invalid algorithm name : {:}".format(xargs.algo)) - logger.log("[{:}] - [get_best_arch] : {:} -> {:}".format(epoch_str, genotype, temp_accuracy)) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger) + logger.log( + "[{:}] - [get_best_arch] : {:} -> {:}".format( + epoch_str, genotype, temp_accuracy + ) + ) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( + valid_loader, network, criterion, xargs.algo, logger + ) logger.log( "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format( epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype @@ -522,7 +658,9 @@ def main(xargs): valid_accuracies[epoch] = valid_a_top1 genotypes[epoch] = genotype - logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + logger.log( + "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) + ) # save checkpoint save_path = save_checkpoint( { @@ -558,7 +696,9 @@ def main(xargs): # the final post procedure : count the time start_time = time.time() - genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo) + genotype, temp_accuracy = get_best_arch( + valid_loader, network, xargs.eval_candidate_num, xargs.algo + ) if xargs.algo == "setn" or xargs.algo == "enas": network.set_cal_mode("dynamic", genotype) elif xargs.algo == "gdas": @@ -571,8 +711,14 @@ def main(xargs): raise ValueError("Invalid algorithm name : {:}".format(xargs.algo)) search_time.update(time.time() - start_time) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger) - logger.log("Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(genotype, valid_a_top1)) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( + valid_loader, network, criterion, xargs.algo, logger + ) + logger.log( + "Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format( + genotype, valid_a_top1 + ) + ) logger.log("\n" + "-" * 100) # check the performance from the architecture dataset @@ -595,7 +741,13 @@ if __name__ == "__main__": choices=["cifar10", "cifar100", "ImageNet16-120"], help="Choose between Cifar10/100 and ImageNet-16.", ) - parser.add_argument("--search_space", type=str, default="tss", choices=["tss"], help="The search space name.") + parser.add_argument( + "--search_space", + type=str, + default="tss", + choices=["tss"], + help="The search space name.", + ) parser.add_argument( "--algo", type=str, @@ -603,18 +755,35 @@ if __name__ == "__main__": help="The search space name.", ) parser.add_argument( - "--use_api", type=int, default=1, choices=[0, 1], help="Whether use API or not (which will cost much memory)." + "--use_api", + type=int, + default=1, + choices=[0, 1], + help="Whether use API or not (which will cost much memory).", ) # FOR GDAS - parser.add_argument("--tau_min", type=float, default=0.1, help="The minimum tau for Gumbel Softmax.") - parser.add_argument("--tau_max", type=float, default=10, help="The maximum tau for Gumbel Softmax.") + parser.add_argument( + "--tau_min", type=float, default=0.1, help="The minimum tau for Gumbel Softmax." + ) + parser.add_argument( + "--tau_max", type=float, default=10, help="The maximum tau for Gumbel Softmax." + ) # channels and number-of-cells - parser.add_argument("--max_nodes", type=int, default=4, help="The maximum number of nodes.") - parser.add_argument("--channel", type=int, default=16, help="The number of channels.") - parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") + parser.add_argument( + "--max_nodes", type=int, default=4, help="The maximum number of nodes." + ) + parser.add_argument( + "--channel", type=int, default=16, help="The number of channels." + ) + parser.add_argument( + "--num_cells", type=int, default=5, help="The number of cells in one stage." + ) # parser.add_argument( - "--eval_candidate_num", type=int, default=100, help="The number of selected architectures to evaluate." + "--eval_candidate_num", + type=int, + default=100, + help="The number of selected architectures to evaluate.", ) # parser.add_argument( @@ -625,7 +794,11 @@ if __name__ == "__main__": help="Whether use track_running_stats or not in the BN layer.", ) parser.add_argument( - "--affine", type=int, default=0, choices=[0, 1], help="Whether use affine=True or False in the BN layer." + "--affine", + type=int, + default=0, + choices=[0, 1], + help="Whether use affine=True or False in the BN layer.", ) parser.add_argument( "--config_path", @@ -634,17 +807,43 @@ if __name__ == "__main__": help="The path of configuration.", ) parser.add_argument( - "--overwite_epochs", type=int, help="The number of epochs to overwrite that value in config files." + "--overwite_epochs", + type=int, + help="The number of epochs to overwrite that value in config files.", ) # architecture leraning rate - parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") - parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") - parser.add_argument("--arch_eps", type=float, default=1e-8, help="weight decay for arch encoding") + parser.add_argument( + "--arch_learning_rate", + type=float, + default=3e-4, + help="learning rate for arch encoding", + ) + parser.add_argument( + "--arch_weight_decay", + type=float, + default=1e-3, + help="weight decay for arch encoding", + ) + parser.add_argument( + "--arch_eps", type=float, default=1e-8, help="weight decay for arch encoding" + ) parser.add_argument("--drop_path_rate", type=float, help="The drop path rate.") # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") - parser.add_argument("--print_freq", type=int, default=200, help="print frequency (default: 200)") + parser.add_argument( + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./output/search", + help="Folder to save checkpoints and log.", + ) + parser.add_argument( + "--print_freq", type=int, default=200, help="print frequency (default: 200)" + ) parser.add_argument("--rand_seed", type=int, help="manual seed") args = parser.parse_args() if args.rand_seed is None or args.rand_seed < 0: @@ -653,14 +852,20 @@ if __name__ == "__main__": args.save_dir = os.path.join( "{:}-{:}".format(args.save_dir, args.search_space), args.dataset, - "{:}-affine{:}_BN{:}-{:}".format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate), + "{:}-affine{:}_BN{:}-{:}".format( + args.algo, args.affine, args.track_running_stats, args.drop_path_rate + ), ) else: args.save_dir = os.path.join( "{:}-{:}".format(args.save_dir, args.search_space), args.dataset, "{:}-affine{:}_BN{:}-E{:}-{:}".format( - args.algo, args.affine, args.track_running_stats, args.overwite_epochs, args.drop_path_rate + args.algo, + args.affine, + args.track_running_stats, + args.overwite_epochs, + args.drop_path_rate, ), ) diff --git a/exps/NATS-algos/search-size.py b/exps/NATS-algos/search-size.py index 1c978cb..c24db9d 100644 --- a/exps/NATS-algos/search-size.py +++ b/exps/NATS-algos/search-size.py @@ -38,7 +38,13 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import count_parameters_in_MB, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_cell_based_tiny_net, get_search_spaces @@ -55,7 +61,9 @@ class ExponentialMovingAverage(object): self._momentum = momentum def update(self, value): - self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value + self._numerator = ( + self._momentum * self._numerator + (1 - self._momentum) * value + ) self._denominator = self._momentum * self._denominator + (1 - self._momentum) @property @@ -85,7 +93,9 @@ def search_func( arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() end = time.time() network.train() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( + xloader + ): scheduler.update(None, 1.0 * step / len(xloader)) base_inputs = base_inputs.cuda(non_blocking=True) arch_inputs = arch_inputs.cuda(non_blocking=True) @@ -101,7 +111,9 @@ def search_func( base_loss.backward() w_optimizer.step() # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_prec1, base_prec5 = obtain_accuracy( + logits.data, base_targets.data, topk=(1, 5) + ) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) @@ -110,7 +122,9 @@ def search_func( network.zero_grad() a_optimizer.zero_grad() _, logits, log_probs = network(arch_inputs) - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) if algo == "mask_rl": with torch.no_grad(): RL_BASELINE_EMA.update(arch_prec1.item()) @@ -134,7 +148,11 @@ def search_func( end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Sstr = ( + "*SEARCH* " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) @@ -145,7 +163,14 @@ def search_func( loss=arch_losses, top1=arch_top1, top5=arch_top5 ) logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) - return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg + return ( + base_losses.avg, + base_top1.avg, + base_top5.avg, + arch_losses.avg, + arch_top1.avg, + arch_top5.avg, + ) def valid_func(xloader, network, criterion, logger): @@ -162,7 +187,9 @@ def valid_func(xloader, network, criterion, logger): _, logits, _ = network(arch_inputs.cuda(non_blocking=True)) arch_loss = criterion(logits, arch_targets) # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -181,11 +208,17 @@ def main(xargs): prepare_seed(xargs.rand_seed) logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) if xargs.overwite_epochs is None: extra_info = {"class_num": class_num, "xshape": xshape} else: - extra_info = {"class_num": class_num, "xshape": xshape, "epochs": xargs.overwite_epochs} + extra_info = { + "class_num": class_num, + "xshape": xshape, + "epochs": xargs.overwite_epochs, + } config = load_config(xargs.config_path, extra_info, logger) search_loader, train_loader, valid_loader = get_nas_search_loaders( train_data, @@ -223,7 +256,9 @@ def main(xargs): search_model.set_algo(xargs.algo) logger.log("{:}".format(search_model)) - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) + w_optimizer, w_scheduler, criterion = get_optim_scheduler( + search_model.weights, config + ) a_optimizer = torch.optim.Adam( search_model.alphas, lr=xargs.arch_learning_rate, @@ -244,13 +279,23 @@ def main(xargs): api = None logger.log("{:} create API = {:} done".format(time_string(), api)) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(last_info) + ) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) @@ -261,7 +306,9 @@ def main(xargs): w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + last_info, start_epoch + ) ) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) @@ -276,26 +323,47 @@ def main(xargs): ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.val * (total_epoch - epoch), True) + ) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) - if xargs.warmup_ratio is None or xargs.warmup_ratio <= float(epoch) / total_epoch: + if ( + xargs.warmup_ratio is None + or xargs.warmup_ratio <= float(epoch) / total_epoch + ): enable_controller = True network.set_warmup_ratio(None) else: enable_controller = False - network.set_warmup_ratio(1.0 - float(epoch) / total_epoch / xargs.warmup_ratio) + network.set_warmup_ratio( + 1.0 - float(epoch) / total_epoch / xargs.warmup_ratio + ) logger.log( "\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}".format( - epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller + epoch_str, + need_time, + min(w_scheduler.get_lr()), + network.warmup_ratio, + enable_controller, ) ) if xargs.algo == "mask_gumbel" or xargs.algo == "tas": - network.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) + network.set_tau( + xargs.tau_max + - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1) + ) logger.log("[RESET tau as : {:}]".format(network.tau)) - search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 = search_func( + ( + search_w_loss, + search_w_top1, + search_w_top5, + search_a_loss, + search_a_top1, + search_a_top5, + ) = search_func( search_loader, network, criterion, @@ -322,7 +390,9 @@ def main(xargs): genotype = network.genotype logger.log("[{:}] - [get_best_arch] : {:}".format(epoch_str, genotype)) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, logger) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( + valid_loader, network, criterion, logger + ) logger.log( "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format( epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype @@ -331,7 +401,9 @@ def main(xargs): valid_accuracies[epoch] = valid_a_top1 genotypes[epoch] = genotype - logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + logger.log( + "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) + ) # save checkpoint save_path = save_checkpoint( { @@ -369,8 +441,14 @@ def main(xargs): genotype = network.genotype search_time.update(time.time() - start_time) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, logger) - logger.log("Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(genotype, valid_a_top1)) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( + valid_loader, network, criterion, logger + ) + logger.log( + "Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format( + genotype, valid_a_top1 + ) + ) logger.log("\n" + "-" * 100) # check the performance from the architecture dataset @@ -393,8 +471,19 @@ if __name__ == "__main__": choices=["cifar10", "cifar100", "ImageNet16-120"], help="Choose between Cifar10/100 and ImageNet-16.", ) - parser.add_argument("--search_space", type=str, default="sss", choices=["sss"], help="The search space name.") - parser.add_argument("--algo", type=str, choices=["tas", "mask_gumbel", "mask_rl"], help="The search space name.") + parser.add_argument( + "--search_space", + type=str, + default="sss", + choices=["sss"], + help="The search space name.", + ) + parser.add_argument( + "--algo", + type=str, + choices=["tas", "mask_gumbel", "mask_rl"], + help="The search space name.", + ) parser.add_argument( "--genotype", type=str, @@ -402,13 +491,23 @@ if __name__ == "__main__": help="The genotype.", ) parser.add_argument( - "--use_api", type=int, default=1, choices=[0, 1], help="Whether use API or not (which will cost much memory)." + "--use_api", + type=int, + default=1, + choices=[0, 1], + help="Whether use API or not (which will cost much memory).", ) # FOR GDAS - parser.add_argument("--tau_min", type=float, default=0.1, help="The minimum tau for Gumbel Softmax.") - parser.add_argument("--tau_max", type=float, default=10, help="The maximum tau for Gumbel Softmax.") + parser.add_argument( + "--tau_min", type=float, default=0.1, help="The minimum tau for Gumbel Softmax." + ) + parser.add_argument( + "--tau_max", type=float, default=10, help="The maximum tau for Gumbel Softmax." + ) # FOR ALL - parser.add_argument("--warmup_ratio", type=float, help="The warmup ratio, if None, not use warmup.") + parser.add_argument( + "--warmup_ratio", type=float, help="The warmup ratio, if None, not use warmup." + ) # parser.add_argument( "--track_running_stats", @@ -418,7 +517,11 @@ if __name__ == "__main__": help="Whether use track_running_stats or not in the BN layer.", ) parser.add_argument( - "--affine", type=int, default=0, choices=[0, 1], help="Whether use affine=True or False in the BN layer." + "--affine", + type=int, + default=0, + choices=[0, 1], + help="Whether use affine=True or False in the BN layer.", ) parser.add_argument( "--config_path", @@ -427,25 +530,57 @@ if __name__ == "__main__": help="The path of configuration.", ) parser.add_argument( - "--overwite_epochs", type=int, help="The number of epochs to overwrite that value in config files." + "--overwite_epochs", + type=int, + help="The number of epochs to overwrite that value in config files.", ) # architecture leraning rate - parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") - parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") - parser.add_argument("--arch_eps", type=float, default=1e-8, help="weight decay for arch encoding") + parser.add_argument( + "--arch_learning_rate", + type=float, + default=3e-4, + help="learning rate for arch encoding", + ) + parser.add_argument( + "--arch_weight_decay", + type=float, + default=1e-3, + help="weight decay for arch encoding", + ) + parser.add_argument( + "--arch_eps", type=float, default=1e-8, help="weight decay for arch encoding" + ) # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.") - parser.add_argument("--print_freq", type=int, default=200, help="print frequency (default: 200)") + parser.add_argument( + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./output/search", + help="Folder to save checkpoints and log.", + ) + parser.add_argument( + "--print_freq", type=int, default=200, help="print frequency (default: 200)" + ) parser.add_argument("--rand_seed", type=int, help="manual seed") args = parser.parse_args() if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) dirname = "{:}-affine{:}_BN{:}-AWD{:}-WARM{:}".format( - args.algo, args.affine, args.track_running_stats, args.arch_weight_decay, args.warmup_ratio + args.algo, + args.affine, + args.track_running_stats, + args.arch_weight_decay, + args.warmup_ratio, ) if args.overwite_epochs is not None: dirname = dirname + "-E{:}".format(args.overwite_epochs) - args.save_dir = os.path.join("{:}-{:}".format(args.save_dir, args.search_space), args.dataset, dirname) + args.save_dir = os.path.join( + "{:}-{:}".format(args.save_dir, args.search_space), args.dataset, dirname + ) main(args) diff --git a/exps/algos/BOHB.py b/exps/algos/BOHB.py index 9d3f0a7..f6e3f49 100644 --- a/exps/algos/BOHB.py +++ b/exps/algos/BOHB.py @@ -35,7 +35,9 @@ def get_configuration_space(max_nodes, search_space): for i in range(1, max_nodes): for j in range(i): node_str = "{:}<-{:}".format(i, j) - cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) + cs.add_hyperparameter( + ConfigSpace.CategoricalHyperparameter(node_str, search_space) + ) return cs @@ -55,7 +57,15 @@ def config2structure_func(max_nodes): class MyWorker(Worker): - def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs): + def __init__( + self, + *args, + convert_func=None, + dataname=None, + nas_bench=None, + time_budget=None, + **kwargs + ): super().__init__(*args, **kwargs) self.convert_func = convert_func self._dataname = dataname @@ -70,7 +80,9 @@ class MyWorker(Worker): assert len(self.seen_archs) > 0 best_index, best_acc = -1, None for arch_index in self.seen_archs: - info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp="200", is_random=True) + info = self._nas_bench.get_more_info( + arch_index, self._dataname, None, hp="200", is_random=True + ) vacc = info["valid-accuracy"] if best_acc is None or best_acc < vacc: best_acc = vacc @@ -82,7 +94,9 @@ class MyWorker(Worker): start_time = time.time() structure = self.convert_func(config) arch_index = self._nas_bench.query_index_by_arch(structure) - info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp="200", is_random=True) + info = self._nas_bench.get_more_info( + arch_index, self._dataname, None, hp="200", is_random=True + ) cur_time = info["train-all-time"] + info["valid-per-time"] cur_vacc = info["valid-accuracy"] self.real_cost_time += time.time() - start_time @@ -101,7 +115,11 @@ class MyWorker(Worker): self.is_end = True return { "loss": 100, - "info": {"seen-arch": len(self.seen_archs), "sim-test-time": self.sim_cost_time, "current-arch": None}, + "info": { + "seen-arch": len(self.seen_archs), + "sim-test-time": self.sim_cost_time, + "current-arch": None, + }, } @@ -119,13 +137,17 @@ def main(xargs, nas_bench): else: dataname = xargs.dataset if xargs.data_path is not None: - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) split_Fpath = "configs/nas-benchmark/cifar-split.txt" cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log("Load split file from {:}".format(split_Fpath)) config_path = "configs/nas-benchmark/algos/R-EA.config" - config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger) + config = load_config( + config_path, {"class_num": class_num, "xshape": xshape}, logger + ) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform @@ -152,7 +174,11 @@ def main(xargs, nas_bench): ) ) logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader} + extra_info = { + "config": config, + "train_loader": train_loader, + "valid_loader": valid_loader, + } else: config_path = "configs/nas-benchmark/algos/R-EA.config" config = load_config(config_path, None, logger) @@ -213,7 +239,11 @@ def main(xargs, nas_bench): id2config = results.get_id2config_mapping() incumbent = results.get_incumbent_id() - logger.log("Best found configuration: {:} within {:.3f} s".format(id2config[incumbent]["config"], real_cost_time)) + logger.log( + "Best found configuration: {:} within {:.3f} s".format( + id2config[incumbent]["config"], real_cost_time + ) + ) best_arch = config2structure(id2config[incumbent]["config"]) info = nas_bench.query_by_arch(best_arch, "200") @@ -223,13 +253,19 @@ def main(xargs, nas_bench): logger.log("{:}".format(info)) logger.log("-" * 100) - logger.log("workers : {:.1f}s with {:} archs".format(workers[0].time_budget, len(workers[0].seen_archs))) + logger.log( + "workers : {:.1f}s with {:} archs".format( + workers[0].time_budget, len(workers[0].seen_archs) + ) + ) logger.close() return logger.log_dir, nas_bench.query_index_by_arch(best_arch), real_cost_time if __name__ == "__main__": - parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale") + parser = argparse.ArgumentParser( + "BOHB: Robust and Efficient Hyperparameter Optimization at Scale" + ) parser.add_argument("--data_path", type=str, help="Path to dataset") parser.add_argument( "--dataset", @@ -241,28 +277,71 @@ if __name__ == "__main__": parser.add_argument("--search_space_name", type=str, help="The search space name.") parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") parser.add_argument("--channel", type=int, help="The number of channels.") - parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") - parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).") + parser.add_argument( + "--num_cells", type=int, help="The number of cells in one stage." + ) + parser.add_argument( + "--time_budget", + type=int, + help="The total time cost budge for searching (in seconds).", + ) # BOHB parser.add_argument( - "--strategy", default="sampling", type=str, nargs="?", help="optimization strategy for the acquisition function" - ) - parser.add_argument("--min_bandwidth", default=0.3, type=float, nargs="?", help="minimum bandwidth for KDE") - parser.add_argument( - "--num_samples", default=64, type=int, nargs="?", help="number of samples for the acquisition function" + "--strategy", + default="sampling", + type=str, + nargs="?", + help="optimization strategy for the acquisition function", ) parser.add_argument( - "--random_fraction", default=0.33, type=float, nargs="?", help="fraction of random configurations" + "--min_bandwidth", + default=0.3, + type=float, + nargs="?", + help="minimum bandwidth for KDE", ) - parser.add_argument("--bandwidth_factor", default=3, type=int, nargs="?", help="factor multiplied to the bandwidth") parser.add_argument( - "--n_iters", default=100, type=int, nargs="?", help="number of iterations for optimization method" + "--num_samples", + default=64, + type=int, + nargs="?", + help="number of samples for the acquisition function", + ) + parser.add_argument( + "--random_fraction", + default=0.33, + type=float, + nargs="?", + help="fraction of random configurations", + ) + parser.add_argument( + "--bandwidth_factor", + default=3, + type=int, + nargs="?", + help="factor multiplied to the bandwidth", + ) + parser.add_argument( + "--n_iters", + default=100, + type=int, + nargs="?", + help="number of iterations for optimization method", ) # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") parser.add_argument( - "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) + parser.add_argument( + "--arch_nas_dataset", + type=str, + help="The path to load the architecture dataset (tiny-nas-benchmark).", ) parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") parser.add_argument("--rand_seed", type=int, help="manual seed") @@ -271,7 +350,11 @@ if __name__ == "__main__": if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): nas_bench = None else: - print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset)) + print( + "{:} build NAS-Benchmark-API from {:}".format( + time_string(), args.arch_nas_dataset + ) + ) nas_bench = API(args.arch_nas_dataset) if args.rand_seed < 0: save_dir, all_indexes, num, all_times = None, [], 500, [] diff --git a/exps/algos/DARTS-V1.py b/exps/algos/DARTS-V1.py index 44df9ce..b1b409f 100644 --- a/exps/algos/DARTS-V1.py +++ b/exps/algos/DARTS-V1.py @@ -13,7 +13,13 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_cell_based_tiny_net, get_search_spaces @@ -21,14 +27,25 @@ from nas_201_api import NASBench201API as API def search_func( - xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger, gradient_clip + xloader, + network, + criterion, + scheduler, + w_optimizer, + a_optimizer, + epoch_str, + print_freq, + logger, + gradient_clip, ): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() network.train() end = time.time() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( + xloader + ): scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) @@ -44,7 +61,9 @@ def search_func( torch.nn.utils.clip_grad_norm_(network.parameters(), gradient_clip) w_optimizer.step() # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_prec1, base_prec5 = obtain_accuracy( + logits.data, base_targets.data, topk=(1, 5) + ) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) @@ -56,7 +75,9 @@ def search_func( arch_loss.backward() a_optimizer.step() # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -66,7 +87,11 @@ def search_func( end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Sstr = ( + "*SEARCH* " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) @@ -94,7 +119,9 @@ def valid_func(xloader, network, criterion): _, logits = network(arch_inputs) arch_loss = criterion(logits, arch_targets) # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -113,11 +140,20 @@ def main(xargs): prepare_seed(xargs.rand_seed) logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) # config_path = 'configs/nas-benchmark/algos/DARTS.config' - config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + config = load_config( + xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger + ) search_loader, _, valid_loader = get_nas_search_loaders( - train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers + train_data, + valid_data, + xargs.dataset, + "configs/nas-benchmark/", + config.batch_size, + xargs.workers, ) logger.log( "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( @@ -155,9 +191,14 @@ def main(xargs): search_model = get_cell_based_tiny_net(model_config) logger.log("search-model :\n{:}".format(search_model)) - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) + w_optimizer, w_scheduler, criterion = get_optim_scheduler( + search_model.get_weights(), config + ) a_optimizer = torch.optim.Adam( - search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay + search_model.get_alphas(), + lr=xargs.arch_learning_rate, + betas=(0.5, 0.999), + weight_decay=xargs.arch_weight_decay, ) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("a-optimizer : {:}".format(a_optimizer)) @@ -172,11 +213,17 @@ def main(xargs): api = API(xargs.arch_nas_dataset) logger.log("{:} create API = {:} done".format(time_string(), api)) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(last_info) + ) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) @@ -187,11 +234,17 @@ def main(xargs): w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + last_info, start_epoch + ) ) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: search_model.genotype()} + start_epoch, valid_accuracies, genotypes = ( + 0, + {"best": -1}, + {-1: search_model.genotype()}, + ) # start training start_time, search_time, epoch_time, total_epoch = ( @@ -202,9 +255,15 @@ def main(xargs): ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.val * (total_epoch - epoch), True) + ) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) - logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min(w_scheduler.get_lr()))) + logger.log( + "\n[Search the {:}-th epoch] {:}, LR={:}".format( + epoch_str, need_time, min(w_scheduler.get_lr()) + ) + ) search_w_loss, search_w_top1, search_w_top5 = search_func( search_loader, @@ -224,7 +283,9 @@ def main(xargs): epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum ) ) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( + valid_loader, network, criterion + ) logger.log( "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( epoch_str, valid_a_loss, valid_a_top1, valid_a_top5 @@ -240,7 +301,9 @@ def main(xargs): find_best = False genotypes[epoch] = search_model.genotype() - logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + logger.log( + "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) + ) # save checkpoint save_path = save_checkpoint( { @@ -305,7 +368,9 @@ if __name__ == "__main__": parser.add_argument("--search_space_name", type=str, help="The search space name.") parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") parser.add_argument("--channel", type=int, help="The number of channels.") - parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument( + "--num_cells", type=int, help="The number of cells in one stage." + ) parser.add_argument( "--track_running_stats", type=int, @@ -320,13 +385,32 @@ if __name__ == "__main__": ) parser.add_argument("--gradient_clip", type=float, default=5, help="") # architecture leraning rate - parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") - parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") - # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") parser.add_argument( - "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (nas-benchmark)." + "--arch_learning_rate", + type=float, + default=3e-4, + help="learning rate for arch encoding", + ) + parser.add_argument( + "--arch_weight_decay", + type=float, + default=1e-3, + help="weight decay for arch encoding", + ) + # log + parser.add_argument( + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) + parser.add_argument( + "--arch_nas_dataset", + type=str, + help="The path to load the architecture dataset (nas-benchmark).", ) parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") parser.add_argument("--rand_seed", type=int, help="manual seed") diff --git a/exps/algos/DARTS-V2.py b/exps/algos/DARTS-V2.py index 802bbdd..6739ffc 100644 --- a/exps/algos/DARTS-V2.py +++ b/exps/algos/DARTS-V2.py @@ -15,7 +15,13 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_cell_based_tiny_net, get_search_spaces @@ -26,7 +32,9 @@ def _concat(xs): return torch.cat([x.view(-1) for x in xs]) -def _hessian_vector_product(vector, network, criterion, base_inputs, base_targets, r=1e-2): +def _hessian_vector_product( + vector, network, criterion, base_inputs, base_targets, r=1e-2 +): R = r / _concat(vector).norm() for p, v in zip(network.module.get_weights(), vector): p.data.add_(R, v) @@ -45,7 +53,15 @@ def _hessian_vector_product(vector, network, criterion, base_inputs, base_target return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)] -def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets): +def backward_step_unrolled( + network, + criterion, + base_inputs, + base_targets, + w_optimizer, + arch_inputs, + arch_targets, +): # _compute_unrolled_model _, logits = network(base_inputs) loss = criterion(logits, base_targets) @@ -57,11 +73,17 @@ def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_opti with torch.no_grad(): theta = _concat(network.module.get_weights()) try: - moment = _concat(w_optimizer.state[v]["momentum_buffer"] for v in network.module.get_weights()) + moment = _concat( + w_optimizer.state[v]["momentum_buffer"] + for v in network.module.get_weights() + ) moment = moment.mul_(momentum) except: moment = torch.zeros_like(theta) - dtheta = _concat(torch.autograd.grad(loss, network.module.get_weights())) + WD * theta + dtheta = ( + _concat(torch.autograd.grad(loss, network.module.get_weights())) + + WD * theta + ) params = theta.sub(LR, moment + dtheta) unrolled_model = deepcopy(network) model_dict = unrolled_model.state_dict() @@ -82,7 +104,9 @@ def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_opti dalpha = unrolled_model.module.arch_parameters.grad vector = [v.grad.data for v in unrolled_model.module.get_weights()] - [implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets) + [implicit_grads] = _hessian_vector_product( + vector, network, criterion, base_inputs, base_targets + ) dalpha.data.sub_(LR, implicit_grads.data) @@ -93,13 +117,25 @@ def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_opti return unrolled_loss.detach(), unrolled_logits.detach() -def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): +def search_func( + xloader, + network, + criterion, + scheduler, + w_optimizer, + a_optimizer, + epoch_str, + print_freq, + logger, +): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() network.train() end = time.time() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( + xloader + ): scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) @@ -109,11 +145,19 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer # update the architecture-weight a_optimizer.zero_grad() arch_loss, arch_logits = backward_step_unrolled( - network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets + network, + criterion, + base_inputs, + base_targets, + w_optimizer, + arch_inputs, + arch_targets, ) a_optimizer.step() # record - arch_prec1, arch_prec5 = obtain_accuracy(arch_logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + arch_logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -126,7 +170,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer torch.nn.utils.clip_grad_norm_(network.parameters(), 5) w_optimizer.step() # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_prec1, base_prec5 = obtain_accuracy( + logits.data, base_targets.data, topk=(1, 5) + ) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) @@ -136,7 +182,11 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Sstr = ( + "*SEARCH* " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) @@ -164,7 +214,9 @@ def valid_func(xloader, network, criterion): _, logits = network(arch_inputs) arch_loss = criterion(logits, arch_targets) # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -183,10 +235,19 @@ def main(xargs): prepare_seed(xargs.rand_seed) logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) + config = load_config( + xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger + ) search_loader, _, valid_loader = get_nas_search_loaders( - train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers + train_data, + valid_data, + xargs.dataset, + "configs/nas-benchmark/", + config.batch_size, + xargs.workers, ) logger.log( "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( @@ -212,9 +273,14 @@ def main(xargs): search_model = get_cell_based_tiny_net(model_config) logger.log("search-model :\n{:}".format(search_model)) - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) + w_optimizer, w_scheduler, criterion = get_optim_scheduler( + search_model.get_weights(), config + ) a_optimizer = torch.optim.Adam( - search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay + search_model.get_alphas(), + lr=xargs.arch_learning_rate, + betas=(0.5, 0.999), + weight_decay=xargs.arch_weight_decay, ) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("a-optimizer : {:}".format(a_optimizer)) @@ -229,11 +295,17 @@ def main(xargs): api = API(xargs.arch_nas_dataset) logger.log("{:} create API = {:} done".format(time_string(), api)) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(last_info) + ) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) @@ -244,11 +316,17 @@ def main(xargs): w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + last_info, start_epoch + ) ) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: search_model.genotype()} + start_epoch, valid_accuracies, genotypes = ( + 0, + {"best": -1}, + {-1: search_model.genotype()}, + ) # start training start_time, search_time, epoch_time, total_epoch = ( @@ -259,10 +337,16 @@ def main(xargs): ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.val * (total_epoch - epoch), True) + ) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) min_LR = min(w_scheduler.get_lr()) - logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min_LR)) + logger.log( + "\n[Search the {:}-th epoch] {:}, LR={:}".format( + epoch_str, need_time, min_LR + ) + ) search_w_loss, search_w_top1, search_w_top5 = search_func( search_loader, @@ -281,7 +365,9 @@ def main(xargs): epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum ) ) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( + valid_loader, network, criterion + ) logger.log( "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( epoch_str, valid_a_loss, valid_a_top1, valid_a_top5 @@ -297,7 +383,9 @@ def main(xargs): find_best = False genotypes[epoch] = search_model.genotype() - logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + logger.log( + "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) + ) # save checkpoint save_path = save_checkpoint( { @@ -331,7 +419,9 @@ def main(xargs): copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log( - "arch-parameters :\n{:}".format(nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu()) + "arch-parameters :\n{:}".format( + nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() + ) ) if api is not None: logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) @@ -365,7 +455,9 @@ if __name__ == "__main__": parser.add_argument("--search_space_name", type=str, help="The search space name.") parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") parser.add_argument("--channel", type=int, help="The number of channels.") - parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument( + "--num_cells", type=int, help="The number of cells in one stage." + ) parser.add_argument( "--track_running_stats", type=int, @@ -373,13 +465,32 @@ if __name__ == "__main__": help="Whether use track_running_stats or not in the BN layer.", ) # architecture leraning rate - parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") - parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") - # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") parser.add_argument( - "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + "--arch_learning_rate", + type=float, + default=3e-4, + help="learning rate for arch encoding", + ) + parser.add_argument( + "--arch_weight_decay", + type=float, + default=1e-3, + help="weight decay for arch encoding", + ) + # log + parser.add_argument( + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) + parser.add_argument( + "--arch_nas_dataset", + type=str, + help="The path to load the architecture dataset (tiny-nas-benchmark).", ) parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") parser.add_argument("--rand_seed", type=int, help="manual seed") diff --git a/exps/algos/ENAS.py b/exps/algos/ENAS.py index 2a850bc..7b8dd1d 100644 --- a/exps/algos/ENAS.py +++ b/exps/algos/ENAS.py @@ -15,16 +15,37 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_cell_based_tiny_net, get_search_spaces from nas_201_api import NASBench201API as API -def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger): +def train_shared_cnn( + xloader, + shared_cnn, + controller, + criterion, + scheduler, + optimizer, + epoch_str, + print_freq, + logger, +): data_time, batch_time = AverageMeter(), AverageMeter() - losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter(), time.time() + losses, top1s, top5s, xend = ( + AverageMeter(), + AverageMeter(), + AverageMeter(), + time.time(), + ) shared_cnn.train() controller.eval() @@ -56,7 +77,11 @@ def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, opti xend = time.time() if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = "*Train-Shared-CNN* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Sstr = ( + "*Train-Shared-CNN* " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) @@ -67,11 +92,29 @@ def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, opti return losses.avg, top1s.avg, top5s.avg -def train_controller(xloader, shared_cnn, controller, criterion, optimizer, config, epoch_str, print_freq, logger): +def train_controller( + xloader, + shared_cnn, + controller, + criterion, + optimizer, + config, + epoch_str, + print_freq, + logger, +): # config. (containing some necessary arg) # baseline: The baseline score (i.e. average val_acc) from the previous epoch data_time, batch_time = AverageMeter(), AverageMeter() - GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = ( + ( + GradnormMeter, + LossMeter, + ValAccMeter, + EntropyMeter, + BaselineMeter, + RewardMeter, + xend, + ) = ( AverageMeter(), AverageMeter(), AverageMeter(), @@ -106,7 +149,9 @@ def train_controller(xloader, shared_cnn, controller, criterion, optimizer, conf if config.baseline is None: baseline = val_top1 else: - baseline = config.baseline - (1 - config.ctl_bl_dec) * (config.baseline - reward) + baseline = config.baseline - (1 - config.ctl_bl_dec) * ( + config.baseline - reward + ) loss = -1 * log_prob * (reward - baseline) @@ -134,18 +179,29 @@ def train_controller(xloader, shared_cnn, controller, criterion, optimizer, conf Sstr = ( "*Train-Controller* " + time_string() - + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre) + + " [{:}][{:03d}/{:03d}]".format( + epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre + ) ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})".format( - loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter + loss=LossMeter, + top1=ValAccMeter, + reward=RewardMeter, + basel=BaselineMeter, ) Estr = "Entropy={:.4f} ({:.4f})".format(EntropyMeter.val, EntropyMeter.avg) logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Estr) - return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg, baseline.item() + return ( + LossMeter.avg, + ValAccMeter.avg, + BaselineMeter.avg, + RewardMeter.avg, + baseline.item(), + ) def get_best_arch(controller, shared_cnn, xloader, n_samples=10): @@ -164,7 +220,9 @@ def get_best_arch(controller, shared_cnn, xloader, n_samples=10): _, _, sampled_arch = controller() arch = shared_cnn.module.update_arch(sampled_arch) _, logits = shared_cnn(inputs) - val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) + val_top1, val_top5 = obtain_accuracy( + logits.cpu().data, targets.data, topk=(1, 5) + ) archs.append(arch) valid_accs.append(val_top1.item()) @@ -188,7 +246,9 @@ def valid_func(xloader, network, criterion): _, logits = network(arch_inputs) arch_loss = criterion(logits, arch_targets) # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -207,11 +267,20 @@ def main(xargs): prepare_seed(xargs.rand_seed) logger = prepare_logger(args) - train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + train_data, test_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) logger.log("use config from : {:}".format(xargs.config_path)) - config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + config = load_config( + xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger + ) _, train_loader, valid_loader = get_nas_search_loaders( - train_data, test_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers + train_data, + test_data, + xargs.dataset, + "configs/nas-benchmark/", + config.batch_size, + xargs.workers, ) # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform) @@ -242,9 +311,14 @@ def main(xargs): shared_cnn = get_cell_based_tiny_net(model_config) controller = shared_cnn.create_controller() - w_optimizer, w_scheduler, criterion = get_optim_scheduler(shared_cnn.parameters(), config) + w_optimizer, w_scheduler, criterion = get_optim_scheduler( + shared_cnn.parameters(), config + ) a_optimizer = torch.optim.Adam( - controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps + controller.parameters(), + lr=config.controller_lr, + betas=config.controller_betas, + eps=config.controller_eps, ) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("a-optimizer : {:}".format(a_optimizer)) @@ -259,12 +333,22 @@ def main(xargs): else: api = API(xargs.arch_nas_dataset) logger.log("{:} create API = {:} done".format(time_string(), api)) - shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda() + shared_cnn, controller, criterion = ( + torch.nn.DataParallel(shared_cnn).cuda(), + controller.cuda(), + criterion.cuda(), + ) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(last_info) + ) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) @@ -277,7 +361,9 @@ def main(xargs): w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + last_info, start_epoch + ) ) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) @@ -292,7 +378,9 @@ def main(xargs): ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.val * (total_epoch - epoch), True) + ) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) logger.log( "\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}".format( @@ -339,7 +427,13 @@ def main(xargs): search_time.update(time.time() - start_time) logger.log( "[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s".format( - epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum + epoch_str, + ctl_loss, + ctl_acc, + ctl_baseline, + ctl_reward, + baseline, + search_time.sum, ) ) best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) @@ -356,7 +450,9 @@ def main(xargs): else: find_best = False - logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + logger.log( + "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) + ) # save checkpoint save_path = save_checkpoint( { @@ -397,18 +493,32 @@ def main(xargs): start_time = time.time() logger.log("\n" + "-" * 100) - logger.log("During searching, the best architecture is {:}".format(genotypes["best"])) + logger.log( + "During searching, the best architecture is {:}".format(genotypes["best"]) + ) logger.log("Its accuracy is {:.2f}%".format(valid_accuracies["best"])) - logger.log("Randomly select {:} architectures and select the best.".format(xargs.controller_num_samples)) + logger.log( + "Randomly select {:} architectures and select the best.".format( + xargs.controller_num_samples + ) + ) start_time = time.time() - final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) + final_arch, _ = get_best_arch( + controller, shared_cnn, valid_loader, xargs.controller_num_samples + ) search_time.update(time.time() - start_time) shared_cnn.module.update_arch(final_arch) final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) logger.log("The Selected Final Architecture : {:}".format(final_arch)) - logger.log("Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%".format(final_loss, final_top1, final_top5)) logger.log( - "ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(total_epoch, search_time.sum, final_arch) + "Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%".format( + final_loss, final_top1, final_top5 + ) + ) + logger.log( + "ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( + total_epoch, search_time.sum, final_arch + ) ) if api is not None: logger.log("{:}".format(api.query_by_arch(final_arch))) @@ -434,18 +544,35 @@ if __name__ == "__main__": parser.add_argument("--search_space_name", type=str, help="The search space name.") parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") parser.add_argument("--channel", type=int, help="The number of channels.") - parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") - parser.add_argument("--config_path", type=str, help="The config file to train ENAS.") + parser.add_argument( + "--num_cells", type=int, help="The number of cells in one stage." + ) + parser.add_argument( + "--config_path", type=str, help="The config file to train ENAS." + ) parser.add_argument("--controller_train_steps", type=int, help=".") parser.add_argument("--controller_num_aggregate", type=int, help=".") - parser.add_argument("--controller_entropy_weight", type=float, help="The weight for the entropy of the controller.") + parser.add_argument( + "--controller_entropy_weight", + type=float, + help="The weight for the entropy of the controller.", + ) parser.add_argument("--controller_bl_dec", type=float, help=".") parser.add_argument("--controller_num_samples", type=int, help=".") # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") parser.add_argument( - "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (nas-benchmark)." + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) + parser.add_argument( + "--arch_nas_dataset", + type=str, + help="The path to load the architecture dataset (nas-benchmark).", ) parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") parser.add_argument("--rand_seed", type=int, help="manual seed") diff --git a/exps/algos/GDAS.py b/exps/algos/GDAS.py index 758d052..bc760f3 100644 --- a/exps/algos/GDAS.py +++ b/exps/algos/GDAS.py @@ -13,20 +13,38 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_cell_based_tiny_net, get_search_spaces from nas_201_api import NASBench201API as API -def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): +def search_func( + xloader, + network, + criterion, + scheduler, + w_optimizer, + a_optimizer, + epoch_str, + print_freq, + logger, +): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() network.train() end = time.time() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( + xloader + ): scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) @@ -41,7 +59,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer torch.nn.utils.clip_grad_norm_(network.parameters(), 5) w_optimizer.step() # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_prec1, base_prec5 = obtain_accuracy( + logits.data, base_targets.data, topk=(1, 5) + ) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) @@ -53,7 +73,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer arch_loss.backward() a_optimizer.step() # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -63,7 +85,11 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Sstr = ( + "*SEARCH* " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) @@ -74,7 +100,14 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer loss=arch_losses, top1=arch_top1, top5=arch_top5 ) logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) - return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg + return ( + base_losses.avg, + base_top1.avg, + base_top5.avg, + arch_losses.avg, + arch_top1.avg, + arch_top5.avg, + ) def main(xargs): @@ -86,11 +119,20 @@ def main(xargs): prepare_seed(xargs.rand_seed) logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) # config_path = 'configs/nas-benchmark/algos/GDAS.config' - config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + config = load_config( + xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger + ) search_loader, _, valid_loader = get_nas_search_loaders( - train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers + train_data, + valid_data, + xargs.dataset, + "configs/nas-benchmark/", + config.batch_size, + xargs.workers, ) logger.log( "||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}".format( @@ -129,9 +171,14 @@ def main(xargs): logger.log("search-model :\n{:}".format(search_model)) logger.log("model-config : {:}".format(model_config)) - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) + w_optimizer, w_scheduler, criterion = get_optim_scheduler( + search_model.get_weights(), config + ) a_optimizer = torch.optim.Adam( - search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay + search_model.get_alphas(), + lr=xargs.arch_learning_rate, + betas=(0.5, 0.999), + weight_decay=xargs.arch_weight_decay, ) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("a-optimizer : {:}".format(a_optimizer)) @@ -146,11 +193,17 @@ def main(xargs): api = API(xargs.arch_nas_dataset) logger.log("{:} create API = {:} done".format(time_string(), api)) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(last_info) + ) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) @@ -161,11 +214,17 @@ def main(xargs): w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + last_info, start_epoch + ) ) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) - start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: search_model.genotype()} + start_epoch, valid_accuracies, genotypes = ( + 0, + {"best": -1}, + {-1: search_model.genotype()}, + ) # start training start_time, search_time, epoch_time, total_epoch = ( @@ -176,16 +235,27 @@ def main(xargs): ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.val * (total_epoch - epoch), True) + ) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) - search_model.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) + search_model.set_tau( + xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1) + ) logger.log( "\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}".format( epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr()) ) ) - search_w_loss, search_w_top1, search_w_top5, valid_a_loss, valid_a_top1, valid_a_top5 = search_func( + ( + search_w_loss, + search_w_top1, + search_w_top5, + valid_a_loss, + valid_a_top1, + valid_a_top5, + ) = search_func( search_loader, network, criterion, @@ -217,7 +287,9 @@ def main(xargs): find_best = False genotypes[epoch] = search_model.genotype() - logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + logger.log( + "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) + ) # save checkpoint save_path = save_checkpoint( { @@ -282,29 +354,52 @@ if __name__ == "__main__": parser.add_argument("--search_space_name", type=str, help="The search space name.") parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") parser.add_argument("--channel", type=int, help="The number of channels.") - parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument( + "--num_cells", type=int, help="The number of cells in one stage." + ) parser.add_argument( "--track_running_stats", type=int, choices=[0, 1], help="Whether use track_running_stats or not in the BN layer.", ) - parser.add_argument("--config_path", type=str, help="The path of the configuration.") + parser.add_argument( + "--config_path", type=str, help="The path of the configuration." + ) parser.add_argument( "--model_config", type=str, help="The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.", ) # architecture leraning rate - parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") - parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") + parser.add_argument( + "--arch_learning_rate", + type=float, + default=3e-4, + help="learning rate for arch encoding", + ) + parser.add_argument( + "--arch_weight_decay", + type=float, + default=1e-3, + help="weight decay for arch encoding", + ) parser.add_argument("--tau_min", type=float, help="The minimum tau for Gumbel") parser.add_argument("--tau_max", type=float, help="The maximum tau for Gumbel") # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") parser.add_argument( - "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) + parser.add_argument( + "--arch_nas_dataset", + type=str, + help="The path to load the architecture dataset (tiny-nas-benchmark).", ) parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") parser.add_argument("--rand_seed", type=int, help="manual seed") diff --git a/exps/algos/RANDOM-NAS.py b/exps/algos/RANDOM-NAS.py index d6b8a3a..90ffd44 100644 --- a/exps/algos/RANDOM-NAS.py +++ b/exps/algos/RANDOM-NAS.py @@ -15,19 +15,29 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_cell_based_tiny_net, get_search_spaces from nas_201_api import NASBench201API as API -def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger): +def search_func( + xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger +): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() network.train() end = time.time() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( + xloader + ): scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) @@ -43,7 +53,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, nn.utils.clip_grad_norm_(network.parameters(), 5) w_optimizer.step() # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_prec1, base_prec5 = obtain_accuracy( + logits.data, base_targets.data, topk=(1, 5) + ) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) @@ -53,7 +65,11 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Sstr = ( + "*SEARCH* " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) @@ -80,7 +96,9 @@ def valid_func(xloader, network, criterion): _, logits = network(arch_inputs) arch_loss = criterion(logits, arch_targets) # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -105,7 +123,9 @@ def search_find_best(xloader, network, n_samples): inputs, targets = next(loader_iter) _, logits = network(inputs) - val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) + val_top1, val_top5 = obtain_accuracy( + logits.cpu().data, targets.data, topk=(1, 5) + ) archs.append(arch) valid_accs.append(val_top1.item()) @@ -124,8 +144,12 @@ def main(xargs): prepare_seed(xargs.rand_seed) logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) + config = load_config( + xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger + ) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, @@ -157,7 +181,9 @@ def main(xargs): ) search_model = get_cell_based_tiny_net(model_config) - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config) + w_optimizer, w_scheduler, criterion = get_optim_scheduler( + search_model.parameters(), config + ) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("w-scheduler : {:}".format(w_scheduler)) logger.log("criterion : {:}".format(criterion)) @@ -167,11 +193,17 @@ def main(xargs): api = API(xargs.arch_nas_dataset) logger.log("{:} create API = {:} done".format(time_string(), api)) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(last_info) + ) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) @@ -181,7 +213,9 @@ def main(xargs): w_scheduler.load_state_dict(checkpoint["w_scheduler"]) w_optimizer.load_state_dict(checkpoint["w_optimizer"]) logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + last_info, start_epoch + ) ) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) @@ -196,13 +230,26 @@ def main(xargs): ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.val * (total_epoch - epoch), True) + ) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) - logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min(w_scheduler.get_lr()))) + logger.log( + "\n[Search the {:}-th epoch] {:}, LR={:}".format( + epoch_str, need_time, min(w_scheduler.get_lr()) + ) + ) # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) search_w_loss, search_w_top1, search_w_top5 = search_func( - search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger + search_loader, + network, + criterion, + w_scheduler, + w_optimizer, + epoch_str, + xargs.print_freq, + logger, ) search_time.update(time.time() - start_time) logger.log( @@ -210,14 +257,22 @@ def main(xargs): epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum ) ) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( + valid_loader, network, criterion + ) logger.log( "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( epoch_str, valid_a_loss, valid_a_top1, valid_a_top5 ) ) - cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num) - logger.log("[{:}] find-the-best : {:}, accuracy@1={:.2f}%".format(epoch_str, cur_arch, cur_valid_acc)) + cur_arch, cur_valid_acc = search_find_best( + valid_loader, network, xargs.select_num + ) + logger.log( + "[{:}] find-the-best : {:}, accuracy@1={:.2f}%".format( + epoch_str, cur_arch, cur_valid_acc + ) + ) genotypes[epoch] = cur_arch # check the best accuracy valid_accuracies[epoch] = valid_a_top1 @@ -289,11 +344,19 @@ if __name__ == "__main__": ) # channels and number-of-cells parser.add_argument("--search_space_name", type=str, help="The search space name.") - parser.add_argument("--config_path", type=str, help="The path to the configuration.") + parser.add_argument( + "--config_path", type=str, help="The path to the configuration." + ) parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") parser.add_argument("--channel", type=int, help="The number of channels.") - parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") - parser.add_argument("--select_num", type=int, help="The number of selected architectures to evaluate.") + parser.add_argument( + "--num_cells", type=int, help="The number of cells in one stage." + ) + parser.add_argument( + "--select_num", + type=int, + help="The number of selected architectures to evaluate.", + ) parser.add_argument( "--track_running_stats", type=int, @@ -301,10 +364,19 @@ if __name__ == "__main__": help="Whether use track_running_stats or not in the BN layer.", ) # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") parser.add_argument( - "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) + parser.add_argument( + "--arch_nas_dataset", + type=str, + help="The path to load the architecture dataset (tiny-nas-benchmark).", ) parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") parser.add_argument("--rand_seed", type=int, help="manual seed") diff --git a/exps/algos/RANDOM.py b/exps/algos/RANDOM.py index 150f6ec..8bf9b92 100644 --- a/exps/algos/RANDOM.py +++ b/exps/algos/RANDOM.py @@ -13,7 +13,13 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_search_spaces @@ -35,13 +41,17 @@ def main(xargs, nas_bench): else: dataname = xargs.dataset if xargs.data_path is not None: - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) split_Fpath = "configs/nas-benchmark/cifar-split.txt" cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log("Load split file from {:}".format(split_Fpath)) config_path = "configs/nas-benchmark/algos/R-EA.config" - config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger) + config = load_config( + config_path, {"class_num": class_num, "xshape": xshape}, logger + ) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform @@ -68,7 +78,11 @@ def main(xargs, nas_bench): ) ) logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader} + extra_info = { + "config": config, + "train_loader": train_loader, + "valid_loader": valid_loader, + } else: config_path = "configs/nas-benchmark/algos/R-EA.config" config = load_config(config_path, None, logger) @@ -91,10 +105,17 @@ def main(xargs, nas_bench): history.append(arch) if best_arch is None or best_acc < accuracy: best_acc, best_arch = accuracy, arch - logger.log("[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy)) + logger.log( + "[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy) + ) logger.log( "{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).".format( - time_string(), best_arch, best_acc, len(history), total_time_cost, time.time() - x_start_time + time_string(), + best_arch, + best_acc, + len(history), + total_time_cost, + time.time() - x_start_time, ) ) @@ -121,14 +142,29 @@ if __name__ == "__main__": parser.add_argument("--search_space_name", type=str, help="The search space name.") parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") parser.add_argument("--channel", type=int, help="The number of channels.") - parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") - # parser.add_argument('--random_num', type=int, help='The number of random selected architectures.') - parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).") - # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") parser.add_argument( - "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + "--num_cells", type=int, help="The number of cells in one stage." + ) + # parser.add_argument('--random_num', type=int, help='The number of random selected architectures.') + parser.add_argument( + "--time_budget", + type=int, + help="The total time cost budge for searching (in seconds).", + ) + # log + parser.add_argument( + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) + parser.add_argument( + "--arch_nas_dataset", + type=str, + help="The path to load the architecture dataset (tiny-nas-benchmark).", ) parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") parser.add_argument("--rand_seed", type=int, help="manual seed") @@ -137,7 +173,11 @@ if __name__ == "__main__": if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): nas_bench = None else: - print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset)) + print( + "{:} build NAS-Benchmark-API from {:}".format( + time_string(), args.arch_nas_dataset + ) + ) nas_bench = API(args.arch_nas_dataset) if args.rand_seed < 0: save_dir, all_indexes, num = None, [], 500 diff --git a/exps/algos/R_EA.py b/exps/algos/R_EA.py index d6920cf..fdf79a1 100644 --- a/exps/algos/R_EA.py +++ b/exps/algos/R_EA.py @@ -15,7 +15,13 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from nas_201_api import NASBench201API as API @@ -38,13 +44,20 @@ class Model(object): # In this case, the LR schedular is converged. # For use_012_epoch_training = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure. # -def train_and_eval(arch, nas_bench, extra_info, dataname="cifar10-valid", use_012_epoch_training=True): +def train_and_eval( + arch, nas_bench, extra_info, dataname="cifar10-valid", use_012_epoch_training=True +): if use_012_epoch_training and nas_bench is not None: arch_index = nas_bench.query_index_by_arch(arch) assert arch_index >= 0, "can not find this arch : {:}".format(arch) - info = nas_bench.get_more_info(arch_index, dataname, iepoch=None, hp="12", is_random=True) - valid_acc, time_cost = info["valid-accuracy"], info["train-all-time"] + info["valid-per-time"] + info = nas_bench.get_more_info( + arch_index, dataname, iepoch=None, hp="12", is_random=True + ) + valid_acc, time_cost = ( + info["valid-accuracy"], + info["train-all-time"] + info["valid-per-time"], + ) # _, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs elif not use_012_epoch_training and nas_bench is not None: # Please contact me if you want to use the following logic, because it has some potential issues. @@ -52,7 +65,9 @@ def train_and_eval(arch, nas_bench, extra_info, dataname="cifar10-valid", use_01 # It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details) arch_index, nepoch = nas_bench.query_index_by_arch(arch), 25 assert arch_index >= 0, "can not find this arch : {:}".format(arch) - xoinfo = nas_bench.get_more_info(arch_index, "cifar10-valid", iepoch=None, hp="12") + xoinfo = nas_bench.get_more_info( + arch_index, "cifar10-valid", iepoch=None, hp="12" + ) xocost = nas_bench.get_cost_info(arch_index, "cifar10-valid", hp="200") info = nas_bench.get_more_info( arch_index, dataname, nepoch, hp="200", is_random=True @@ -85,9 +100,15 @@ def train_and_eval(arch, nas_bench, extra_info, dataname="cifar10-valid", use_01 * cost["latency"] ) try: - valid_acc, time_cost = info["valid-accuracy"], estimated_train_cost + estimated_valid_cost + valid_acc, time_cost = ( + info["valid-accuracy"], + estimated_train_cost + estimated_valid_cost, + ) except: - valid_acc, time_cost = info["valtest-accuracy"], estimated_train_cost + estimated_valid_cost + valid_acc, time_cost = ( + info["valtest-accuracy"], + estimated_train_cost + estimated_valid_cost, + ) else: # train a model from scratch. raise ValueError("NOT IMPLEMENT YET") @@ -131,7 +152,15 @@ def mutate_arch_func(op_names): def regularized_evolution( - cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info, dataname + cycles, + population_size, + sample_size, + time_budget, + random_arch, + mutate_arch, + nas_bench, + extra_info, + dataname, ): """Algorithm for regularized evolution (i.e. aging evolution). @@ -149,13 +178,18 @@ def regularized_evolution( during the evolution experiment. """ population = collections.deque() - history, total_time_cost = [], 0 # Not used by the algorithm, only used to report results. + history, total_time_cost = ( + [], + 0, + ) # Not used by the algorithm, only used to report results. # Initialize the population with random models. while len(population) < population_size: model = Model() model.arch = random_arch() - model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info, dataname) + model.accuracy, time_cost = train_and_eval( + model.arch, nas_bench, extra_info, dataname + ) population.append(model) history.append(model) total_time_cost += time_cost @@ -180,7 +214,9 @@ def regularized_evolution( child = Model() child.arch = mutate_arch(parent.arch) total_time_cost += time.time() - start_time - child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info, dataname) + child.accuracy, time_cost = train_and_eval( + child.arch, nas_bench, extra_info, dataname + ) if total_time_cost + time_cost > time_budget: # return return history, total_time_cost else: @@ -207,13 +243,17 @@ def main(xargs, nas_bench): else: dataname = xargs.dataset if xargs.data_path is not None: - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) split_Fpath = "configs/nas-benchmark/cifar-split.txt" cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log("Load split file from {:}".format(split_Fpath)) config_path = "configs/nas-benchmark/algos/R-EA.config" - config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger) + config = load_config( + config_path, {"class_num": class_num, "xshape": xshape}, logger + ) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform @@ -240,7 +280,11 @@ def main(xargs, nas_bench): ) ) logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader} + extra_info = { + "config": config, + "train_loader": train_loader, + "valid_loader": valid_loader, + } else: config_path = "configs/nas-benchmark/algos/R-EA.config" config = load_config(config_path, None, logger) @@ -253,7 +297,10 @@ def main(xargs, nas_bench): # x =random_arch() ; y = mutate_arch(x) x_start_time = time.time() logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench)) - logger.log("-" * 30 + " start searching with the time budget of {:} s".format(xargs.time_budget)) + logger.log( + "-" * 30 + + " start searching with the time budget of {:} s".format(xargs.time_budget) + ) history, total_cost = regularized_evolution( xargs.ea_cycles, xargs.ea_population, @@ -297,17 +344,36 @@ if __name__ == "__main__": parser.add_argument("--search_space_name", type=str, help="The search space name.") parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") parser.add_argument("--channel", type=int, help="The number of channels.") - parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") + parser.add_argument( + "--num_cells", type=int, help="The number of cells in one stage." + ) parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.") parser.add_argument("--ea_population", type=int, help="The population size in EA.") parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.") - parser.add_argument("--ea_fast_by_api", type=int, help="Use our API to speed up the experiments or not.") - parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).") - # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") parser.add_argument( - "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + "--ea_fast_by_api", + type=int, + help="Use our API to speed up the experiments or not.", + ) + parser.add_argument( + "--time_budget", + type=int, + help="The total time cost budge for searching (in seconds).", + ) + # log + parser.add_argument( + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) + parser.add_argument( + "--arch_nas_dataset", + type=str, + help="The path to load the architecture dataset (tiny-nas-benchmark).", ) parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") @@ -318,7 +384,11 @@ if __name__ == "__main__": if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): nas_bench = None else: - print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset)) + print( + "{:} build NAS-Benchmark-API from {:}".format( + time_string(), args.arch_nas_dataset + ) + ) nas_bench = API(args.arch_nas_dataset) if args.rand_seed < 0: save_dir, all_indexes, num = None, [], 500 diff --git a/exps/algos/SETN.py b/exps/algos/SETN.py index c6e23bc..3e048b9 100644 --- a/exps/algos/SETN.py +++ b/exps/algos/SETN.py @@ -15,20 +15,38 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, get_nas_search_loaders -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_cell_based_tiny_net, get_search_spaces from nas_201_api import NASBench201API as API -def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): +def search_func( + xloader, + network, + criterion, + scheduler, + w_optimizer, + a_optimizer, + epoch_str, + print_freq, + logger, +): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() end = time.time() network.train() - for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): + for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( + xloader + ): scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) @@ -45,7 +63,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer base_loss.backward() w_optimizer.step() # record - base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) + base_prec1, base_prec5 = obtain_accuracy( + logits.data, base_targets.data, topk=(1, 5) + ) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) @@ -58,7 +78,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer arch_loss.backward() a_optimizer.step() # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -68,7 +90,11 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): - Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + Sstr = ( + "*SEARCH* " + + time_string() + + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) + ) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time ) @@ -81,7 +107,14 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) # print (nn.functional.softmax(network.module.arch_parameters, dim=-1)) # print (network.module.arch_parameters) - return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg + return ( + base_losses.avg, + base_top1.avg, + base_top5.avg, + arch_losses.avg, + arch_top1.avg, + arch_top5.avg, + ) def get_best_arch(xloader, network, n_samples): @@ -99,7 +132,9 @@ def get_best_arch(xloader, network, n_samples): inputs, targets = next(loader_iter) _, logits = network(inputs) - val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) + val_top1, val_top5 = obtain_accuracy( + logits.cpu().data, targets.data, topk=(1, 5) + ) valid_accs.append(val_top1.item()) @@ -122,7 +157,9 @@ def valid_func(xloader, network, criterion): _, logits = network(arch_inputs) arch_loss = criterion(logits, arch_targets) # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) + arch_prec1, arch_prec5 = obtain_accuracy( + logits.data, arch_targets.data, topk=(1, 5) + ) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) @@ -141,8 +178,12 @@ def main(xargs): prepare_seed(xargs.rand_seed) logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) + config = load_config( + xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger + ) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, @@ -187,9 +228,14 @@ def main(xargs): logger.log("search space : {:}".format(search_space)) search_model = get_cell_based_tiny_net(model_config) - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) + w_optimizer, w_scheduler, criterion = get_optim_scheduler( + search_model.get_weights(), config + ) a_optimizer = torch.optim.Adam( - search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay + search_model.get_alphas(), + lr=xargs.arch_learning_rate, + betas=(0.5, 0.999), + weight_decay=xargs.arch_weight_decay, ) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("a-optimizer : {:}".format(a_optimizer)) @@ -204,11 +250,17 @@ def main(xargs): api = API(xargs.arch_nas_dataset) logger.log("{:} create API = {:} done".format(time_string(), api)) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(last_info) + ) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) @@ -219,7 +271,9 @@ def main(xargs): w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + last_info, start_epoch + ) ) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) @@ -235,11 +289,24 @@ def main(xargs): ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.val * (total_epoch - epoch), True) + ) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) - logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min(w_scheduler.get_lr()))) + logger.log( + "\n[Search the {:}-th epoch] {:}, LR={:}".format( + epoch_str, need_time, min(w_scheduler.get_lr()) + ) + ) - search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 = search_func( + ( + search_w_loss, + search_w_top1, + search_w_top5, + search_a_loss, + search_a_top1, + search_a_top5, + ) = search_func( search_loader, network, criterion, @@ -264,7 +331,9 @@ def main(xargs): genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) network.module.set_cal_mode("dynamic", genotype) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( + valid_loader, network, criterion + ) logger.log( "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format( epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype @@ -283,7 +352,9 @@ def main(xargs): valid_accuracies[epoch] = valid_a_top1 genotypes[epoch] = genotype - logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])) + logger.log( + "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) + ) # save checkpoint save_path = save_checkpoint( { @@ -321,12 +392,22 @@ def main(xargs): genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) network.module.set_cal_mode("dynamic", genotype) - valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion) - logger.log("Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(genotype, valid_a_top1)) + valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( + valid_loader, network, criterion + ) + logger.log( + "Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format( + genotype, valid_a_top1 + ) + ) logger.log("\n" + "-" * 100) # check the performance from the architecture dataset - logger.log("SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(total_epoch, search_time.sum, genotype)) + logger.log( + "SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( + total_epoch, search_time.sum, genotype + ) + ) if api is not None: logger.log("{:}".format(api.query_by_arch(genotype, "200"))) logger.close() @@ -345,23 +426,50 @@ if __name__ == "__main__": parser.add_argument("--search_space_name", type=str, help="The search space name.") parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") parser.add_argument("--channel", type=int, help="The number of channels.") - parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") - parser.add_argument("--select_num", type=int, help="The number of selected architectures to evaluate.") + parser.add_argument( + "--num_cells", type=int, help="The number of cells in one stage." + ) + parser.add_argument( + "--select_num", + type=int, + help="The number of selected architectures to evaluate.", + ) parser.add_argument( "--track_running_stats", type=int, choices=[0, 1], help="Whether use track_running_stats or not in the BN layer.", ) - parser.add_argument("--config_path", type=str, help="The path of the configuration.") - # architecture leraning rate - parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding") - parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding") - # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") parser.add_argument( - "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + "--config_path", type=str, help="The path of the configuration." + ) + # architecture leraning rate + parser.add_argument( + "--arch_learning_rate", + type=float, + default=3e-4, + help="learning rate for arch encoding", + ) + parser.add_argument( + "--arch_weight_decay", + type=float, + default=1e-3, + help="weight decay for arch encoding", + ) + # log + parser.add_argument( + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) + parser.add_argument( + "--arch_nas_dataset", + type=str, + help="The path to load the architecture dataset (tiny-nas-benchmark).", ) parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") parser.add_argument("--rand_seed", type=int, help="manual seed") diff --git a/exps/algos/reinforce.py b/exps/algos/reinforce.py index 2e8f629..0122aaa 100644 --- a/exps/algos/reinforce.py +++ b/exps/algos/reinforce.py @@ -16,7 +16,13 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, + get_optim_scheduler, +) from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from nas_201_api import NASBench201API as API @@ -34,7 +40,9 @@ class Policy(nn.Module): for j in range(i): node_str = "{:}<-{:}".format(i, j) self.edge2index[node_str] = len(self.edge2index) - self.arch_parameters = nn.Parameter(1e-3 * torch.randn(len(self.edge2index), len(search_space))) + self.arch_parameters = nn.Parameter( + 1e-3 * torch.randn(len(self.edge2index), len(search_space)) + ) def generate_arch(self, actions): genotypes = [] @@ -74,7 +82,9 @@ class ExponentialMovingAverage(object): self._momentum = momentum def update(self, value): - self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value + self._numerator = ( + self._momentum * self._numerator + (1 - self._momentum) * value + ) self._denominator = self._momentum * self._denominator + (1 - self._momentum) def value(self): @@ -104,13 +114,17 @@ def main(xargs, nas_bench): else: dataname = xargs.dataset if xargs.data_path is not None: - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) split_Fpath = "configs/nas-benchmark/cifar-split.txt" cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log("Load split file from {:}".format(split_Fpath)) config_path = "configs/nas-benchmark/algos/R-EA.config" - config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger) + config = load_config( + config_path, {"class_num": class_num, "xshape": xshape}, logger + ) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform @@ -137,7 +151,11 @@ def main(xargs, nas_bench): ) ) logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) - extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader} + extra_info = { + "config": config, + "train_loader": train_loader, + "valid_loader": valid_loader, + } else: config_path = "configs/nas-benchmark/algos/R-EA.config" config = load_config(config_path, None, logger) @@ -160,7 +178,9 @@ def main(xargs, nas_bench): # REINFORCE # attempts = 0 x_start_time = time.time() - logger.log("Will start searching with time budget of {:} s.".format(xargs.time_budget)) + logger.log( + "Will start searching with time budget of {:} s.".format(xargs.time_budget) + ) total_steps, total_costs, trace = 0, 0, [] # for istep in range(xargs.RL_steps): while total_costs < xargs.time_budget: @@ -222,16 +242,35 @@ if __name__ == "__main__": parser.add_argument("--search_space_name", type=str, help="The search space name.") parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") parser.add_argument("--channel", type=int, help="The number of channels.") - parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.") - parser.add_argument("--learning_rate", type=float, help="The learning rate for REINFORCE.") - # parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.') - parser.add_argument("--EMA_momentum", type=float, help="The momentum value for EMA.") - parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).") - # log - parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)") - parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.") parser.add_argument( - "--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)." + "--num_cells", type=int, help="The number of cells in one stage." + ) + parser.add_argument( + "--learning_rate", type=float, help="The learning rate for REINFORCE." + ) + # parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.') + parser.add_argument( + "--EMA_momentum", type=float, help="The momentum value for EMA." + ) + parser.add_argument( + "--time_budget", + type=int, + help="The total time cost budge for searching (in seconds).", + ) + # log + parser.add_argument( + "--workers", + type=int, + default=2, + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) + parser.add_argument( + "--arch_nas_dataset", + type=str, + help="The path to load the architecture dataset (tiny-nas-benchmark).", ) parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") @@ -240,7 +279,11 @@ if __name__ == "__main__": if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): nas_bench = None else: - print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset)) + print( + "{:} build NAS-Benchmark-API from {:}".format( + time_string(), args.arch_nas_dataset + ) + ) nas_bench = API(args.arch_nas_dataset) if args.rand_seed < 0: save_dir, all_indexes, num = None, [], 500 diff --git a/exps/basic-eval.py b/exps/basic-eval.py index f0e2509..fdac9d1 100644 --- a/exps/basic-eval.py +++ b/exps/basic-eval.py @@ -24,14 +24,24 @@ assert torch.cuda.is_available(), "torch.cuda is not available" def main(args): - assert os.path.isdir(args.data_path), "invalid data-path : {:}".format(args.data_path) - assert os.path.isfile(args.checkpoint), "invalid checkpoint : {:}".format(args.checkpoint) + assert os.path.isdir(args.data_path), "invalid data-path : {:}".format( + args.data_path + ) + assert os.path.isfile(args.checkpoint), "invalid checkpoint : {:}".format( + args.checkpoint + ) checkpoint = torch.load(args.checkpoint) xargs = checkpoint["args"] - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, args.data_path, xargs.cutout_length) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, args.data_path, xargs.cutout_length + ) valid_loader = torch.utils.data.DataLoader( - valid_data, batch_size=xargs.batch_size, shuffle=False, num_workers=xargs.workers, pin_memory=True + valid_data, + batch_size=xargs.batch_size, + shuffle=False, + num_workers=xargs.workers, + pin_memory=True, ) logger = PrintLogger() @@ -41,7 +51,11 @@ def main(args): logger.log("model ====>>>>:\n{:}".format(base_model)) logger.log("model information : {:}".format(base_model.get_message())) logger.log("-" * 50) - logger.log("Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format(param, flop, flop / 1e3)) + logger.log( + "Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format( + param, flop, flop / 1e3 + ) + ) logger.log("-" * 50) logger.log("valid_data : {:}".format(valid_data)) optim_config = dict2config(checkpoint["optim-config"], logger) @@ -54,23 +68,44 @@ def main(args): try: valid_loss, valid_acc1, valid_acc5 = valid_func( - valid_loader, network, criterion, optim_config, "pure-evaluation", xargs.print_freq_eval, logger + valid_loader, + network, + criterion, + optim_config, + "pure-evaluation", + xargs.print_freq_eval, + logger, ) except: _, valid_func = get_procedures("basic") valid_loss, valid_acc1, valid_acc5 = valid_func( - valid_loader, network, criterion, optim_config, "pure-evaluation", xargs.print_freq_eval, logger + valid_loader, + network, + criterion, + optim_config, + "pure-evaluation", + xargs.print_freq_eval, + logger, ) num_bytes = torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 logger.log( "***{:s}*** EVALUATION loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f}, error@1 = {:.2f}, error@5 = {:.2f}".format( - time_string(), valid_loss, valid_acc1, valid_acc5, 100 - valid_acc1, 100 - valid_acc5 + time_string(), + valid_loss, + valid_acc1, + valid_acc5, + 100 - valid_acc1, + 100 - valid_acc5, ) ) logger.log( "[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]".format( - next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9 + next(network.parameters()).device, + int(num_bytes), + num_bytes / 1e3, + num_bytes / 1e6, + num_bytes / 1e9, ) ) logger.close() @@ -79,6 +114,8 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser("Evaluate-CNN") parser.add_argument("--data_path", type=str, help="Path to dataset.") - parser.add_argument("--checkpoint", type=str, help="Choose between Cifar10/100 and ImageNet.") + parser.add_argument( + "--checkpoint", type=str, help="Choose between Cifar10/100 and ImageNet." + ) args = parser.parse_args() main(args) diff --git a/exps/basic-main.py b/exps/basic-main.py index 82f756d..cc3b6d3 100644 --- a/exps/basic-main.py +++ b/exps/basic-main.py @@ -31,12 +31,22 @@ def main(args): prepare_seed(args.rand_seed) logger = prepare_logger(args) - train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) + train_data, valid_data, xshape, class_num = get_datasets( + args.dataset, args.data_path, args.cutout_length + ) train_loader = torch.utils.data.DataLoader( - train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True + train_data, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( - valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True + valid_data, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True, ) # get configures model_config = load_config(args.model_config, {"class_num": class_num}, logger) @@ -54,26 +64,44 @@ def main(args): logger.log("model ====>>>>:\n{:}".format(base_model)) logger.log("model information : {:}".format(base_model.get_message())) logger.log("-" * 50) - logger.log("Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format(param, flop, flop / 1e3)) + logger.log( + "Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format( + param, flop, flop / 1e3 + ) + ) logger.log("-" * 50) logger.log("train_data : {:}".format(train_data)) logger.log("valid_data : {:}".format(valid_data)) - optimizer, scheduler, criterion = get_optim_scheduler(base_model.parameters(), optim_config) + optimizer, scheduler, criterion = get_optim_scheduler( + base_model.parameters(), optim_config + ) logger.log("optimizer : {:}".format(optimizer)) logger.log("scheduler : {:}".format(scheduler)) logger.log("criterion : {:}".format(criterion)) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint - logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(last_info) + ) last_infox = torch.load(last_info) start_epoch = last_infox["epoch"] + 1 last_checkpoint_path = last_infox["last_checkpoint"] if not last_checkpoint_path.exists(): - logger.log("Does not find {:}, try another path".format(last_checkpoint_path)) - last_checkpoint_path = last_info.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name + logger.log( + "Does not find {:}, try another path".format(last_checkpoint_path) + ) + last_checkpoint_path = ( + last_info.parent + / last_checkpoint_path.parent.name + / last_checkpoint_path.name + ) checkpoint = torch.load(last_checkpoint_path) base_model.load_state_dict(checkpoint["base-model"]) scheduler.load_state_dict(checkpoint["scheduler"]) @@ -81,10 +109,14 @@ def main(args): valid_accuracies = checkpoint["valid_accuracies"] max_bytes = checkpoint["max_bytes"] logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + last_info, start_epoch + ) ) elif args.resume is not None: - assert Path(args.resume).exists(), "Can not find the resume file : {:}".format(args.resume) + assert Path(args.resume).exists(), "Can not find the resume file : {:}".format( + args.resume + ) checkpoint = torch.load(args.resume) start_epoch = checkpoint["epoch"] + 1 base_model.load_state_dict(checkpoint["base-model"]) @@ -92,9 +124,15 @@ def main(args): optimizer.load_state_dict(checkpoint["optimizer"]) valid_accuracies = checkpoint["valid_accuracies"] max_bytes = checkpoint["max_bytes"] - logger.log("=> loading checkpoint from '{:}' start with {:}-th epoch.".format(args.resume, start_epoch)) + logger.log( + "=> loading checkpoint from '{:}' start with {:}-th epoch.".format( + args.resume, start_epoch + ) + ) elif args.init_model is not None: - assert Path(args.init_model).exists(), "Can not find the initialization file : {:}".format(args.init_model) + assert Path( + args.init_model + ).exists(), "Can not find the initialization file : {:}".format(args.init_model) checkpoint = torch.load(args.init_model) base_model.load_state_dict(checkpoint["base-model"]) start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {} @@ -111,13 +149,17 @@ def main(args): epoch_time = AverageMeter() for epoch in range(start_epoch, total_epoch): scheduler.update(epoch, 0.0) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch), True)) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.avg * (total_epoch - epoch), True) + ) epoch_str = "epoch={:03d}/{:03d}".format(epoch, total_epoch) LRs = scheduler.get_lr() find_best = False # set-up drop-out ratio if hasattr(base_model, "update_drop_path"): - base_model.update_drop_path(model_config.drop_path_prob * epoch / total_epoch) + base_model.update_drop_path( + model_config.drop_path_prob * epoch / total_epoch + ) logger.log( "\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}".format( time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler @@ -126,7 +168,15 @@ def main(args): # train for one epoch train_loss, train_acc1, train_acc5 = train_func( - train_loader, network, criterion, scheduler, optimizer, optim_config, epoch_str, args.print_freq, logger + train_loader, + network, + criterion, + scheduler, + optimizer, + optim_config, + epoch_str, + args.print_freq, + logger, ) # log the results logger.log( @@ -139,7 +189,13 @@ def main(args): if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): logger.log("-" * 150) valid_loss, valid_acc1, valid_acc5 = valid_func( - valid_loader, network, criterion, optim_config, epoch_str, args.print_freq_eval, logger + valid_loader, + network, + criterion, + optim_config, + epoch_str, + args.print_freq_eval, + logger, ) valid_accuracies[epoch] = valid_acc1 logger.log( @@ -158,13 +214,24 @@ def main(args): find_best = True logger.log( "Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.".format( - epoch, valid_acc1, valid_acc5, 100 - valid_acc1, 100 - valid_acc5, model_best_path + epoch, + valid_acc1, + valid_acc5, + 100 - valid_acc1, + 100 - valid_acc5, + model_best_path, ) ) - num_bytes = torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 + num_bytes = ( + torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 + ) logger.log( "[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]".format( - next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9 + next(network.parameters()).device, + int(num_bytes), + num_bytes / 1e3, + num_bytes / 1e6, + num_bytes / 1e9, ) ) max_bytes[epoch] = num_bytes @@ -208,7 +275,9 @@ def main(args): logger.log("\n" + "-" * 200) logger.log( "Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}".format( - convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e6, logger.path("info") + convert_secs2time(epoch_time.sum, True), + max(v for k, v in max_bytes.items()) / 1e6, + logger.path("info"), ) ) logger.log("-" * 200 + "\n") diff --git a/exps/experimental/example-nas-bench.py b/exps/experimental/example-nas-bench.py index 79fbde2..aae09ab 100644 --- a/exps/experimental/example-nas-bench.py +++ b/exps/experimental/example-nas-bench.py @@ -28,15 +28,25 @@ from utils import weight_watcher if __name__ == "__main__": parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") parser.add_argument( - "--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file and weight dir." + "--api_path", + type=str, + default=None, + help="The path to the NAS-Bench-201 benchmark file and weight dir.", + ) + parser.add_argument( + "--archive_path", + type=str, + default=None, + help="The path to the NAS-Bench-201 weight dir.", ) - parser.add_argument("--archive_path", type=str, default=None, help="The path to the NAS-Bench-201 weight dir.") args = parser.parse_args() meta_file = Path(args.api_path) weight_dir = Path(args.archive_path) assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) - assert weight_dir.exists() and weight_dir.is_dir(), "invalid path for weight dir : {:}".format(weight_dir) + assert ( + weight_dir.exists() and weight_dir.is_dir() + ), "invalid path for weight dir : {:}".format(weight_dir) api = NASBench201API(meta_file, verbose=True) @@ -46,7 +56,9 @@ if __name__ == "__main__": data = "cifar10" # query the info from CIFAR-10 config = api.get_net_config(arch_index, data) net = get_cell_based_tiny_net(config) - meta_info = api.query_meta_info_by_index(arch_index, hp="200") # all info about this architecture + meta_info = api.query_meta_info_by_index( + arch_index, hp="200" + ) # all info about this architecture params = meta_info.get_net_param(data, 888) net.load_state_dict(params) diff --git a/exps/experimental/test-nas-plot.py b/exps/experimental/test-nas-plot.py index 7a129d1..f46f186 100644 --- a/exps/experimental/test-nas-plot.py +++ b/exps/experimental/test-nas-plot.py @@ -69,7 +69,13 @@ def plot(filename): for xin in range(i): op_i = random.randint(0, len(OPS) - 1) # g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i]) - g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i]) + g.edge( + str(xin), + str(i), + label=OPS[op_i], + color=COLORS[op_i], + fillcolor=COLORS[op_i], + ) # import pdb; pdb.set_trace() g.render(filename, cleanup=True, view=False) @@ -88,7 +94,9 @@ def test_auto_grad(): net = Net(10) inputs = torch.rand(256, 10) loss = net(inputs) - first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True) + first_order_grads = torch.autograd.grad( + loss, net.parameters(), retain_graph=True, create_graph=True + ) first_order_grads = torch.cat([x.view(-1) for x in first_order_grads]) second_order_grads = [] for grads in first_order_grads: @@ -108,9 +116,15 @@ def test_one_shot_model(ckpath, use_train): print("ckpath : {:}".format(ckpath)) ckp = torch.load(ckpath) xargs = ckp["args"] - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) + train_data, valid_data, xshape, class_num = get_datasets( + xargs.dataset, xargs.data_path, -1 + ) # config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) - config = load_config("./configs/nas-benchmark/algos/DARTS.config", {"class_num": class_num, "xshape": xshape}, None) + config = load_config( + "./configs/nas-benchmark/algos/DARTS.config", + {"class_num": class_num, "xshape": xshape}, + None, + ) if xargs.dataset == "cifar10": cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None) xvalid_data = deepcopy(train_data) @@ -142,7 +156,9 @@ def test_one_shot_model(ckpath, use_train): search_model.load_state_dict(ckp["search_model"]) search_model = search_model.cuda() api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth") - archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train) + archs, probs, accuracies = evaluate_one_shot( + search_model, valid_loader, api, use_train + ) if __name__ == "__main__": diff --git a/exps/experimental/test-ww-bench.py b/exps/experimental/test-ww-bench.py index 58d9ffc..21d7668 100644 --- a/exps/experimental/test-ww-bench.py +++ b/exps/experimental/test-ww-bench.py @@ -53,8 +53,12 @@ def evaluate(api, weight_dir, data: str): # compute the weight watcher results config = api.get_net_config(arch_index, data) net = get_cell_based_tiny_net(config) - meta_info = api.query_meta_info_by_index(arch_index, hp="200" if api.search_space_name == "topology" else "90") - params = meta_info.get_net_param(data, 888 if api.search_space_name == "topology" else 777) + meta_info = api.query_meta_info_by_index( + arch_index, hp="200" if api.search_space_name == "topology" else "90" + ) + params = meta_info.get_net_param( + data, 888 if api.search_space_name == "topology" else 777 + ) with torch.no_grad(): net.load_state_dict(params) _, summary = weight_watcher.analyze(net, alphas=False) @@ -73,7 +77,10 @@ def evaluate(api, weight_dir, data: str): norms.append(cur_norm) # query the accuracy info = meta_info.get_metrics( - data, "ori-test", iepoch=None, is_random=888 if api.search_space_name == "topology" else 777 + data, + "ori-test", + iepoch=None, + is_random=888 if api.search_space_name == "topology" else 777, ) accuracies.append(info["accuracy"]) del net, meta_info @@ -98,7 +105,11 @@ def main(search_space, meta_file: str, weight_dir, save_dir, xdata): for hp in hps: nums = api.statistics(data, hp=hp) total = sum([k * v for k, v in nums.items()]) - print("Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format(hp, data, total, nums)) + print( + "Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format( + hp, data, total, nums + ) + ) print(time_string() + " " + "=" * 50) norms, accuracies = evaluate(api, weight_dir, xdata) @@ -120,8 +131,15 @@ def main(search_space, meta_file: str, weight_dir, save_dir, xdata): plt.xlim(min(indexes), max(indexes)) plt.ylim(min(indexes), max(indexes)) # plt.ylabel('y').set_rotation(30) - plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical") - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize) + plt.yticks( + np.arange(min(indexes), max(indexes), max(indexes) // 3), + fontsize=LegendFontsize, + rotation="vertical", + ) + plt.xticks( + np.arange(min(indexes), max(indexes), max(indexes) // 5), + fontsize=LegendFontsize, + ) ax.scatter(indexes, labels, marker="*", s=0.5, c="tab:red", alpha=0.8) ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Test accuracy") @@ -129,7 +147,9 @@ def main(search_space, meta_file: str, weight_dir, save_dir, xdata): plt.grid(zorder=0) ax.set_axisbelow(True) plt.legend(loc=0, fontsize=LegendFontsize) - ax.set_xlabel("architecture ranking sorted by the test accuracy ", fontsize=LabelSize) + ax.set_xlabel( + "architecture ranking sorted by the test accuracy ", fontsize=LabelSize + ) ax.set_ylabel("architecture ranking computed by weight watcher", fontsize=LabelSize) save_path = (save_dir / "{:}-{:}-test-ww.pdf".format(search_space, xdata)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") @@ -148,9 +168,18 @@ if __name__ == "__main__": default="./output/vis-nas-bench/", help="The base-name of folder to save checkpoints and log.", ) - parser.add_argument("--search_space", type=str, default=None, choices=["tss", "sss"], help="The search space.") parser.add_argument( - "--base_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file and weight dir." + "--search_space", + type=str, + default=None, + choices=["tss", "sss"], + help="The search space.", + ) + parser.add_argument( + "--base_path", + type=str, + default=None, + help="The path to the NAS-Bench-201 benchmark file and weight dir.", ) parser.add_argument("--dataset", type=str, default=None, help=".") args = parser.parse_args() @@ -160,6 +189,8 @@ if __name__ == "__main__": meta_file = Path(args.base_path + ".pth") weight_dir = Path(args.base_path + "-full") assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) - assert weight_dir.exists() and weight_dir.is_dir(), "invalid path for weight dir : {:}".format(weight_dir) + assert ( + weight_dir.exists() and weight_dir.is_dir() + ), "invalid path for weight dir : {:}".format(weight_dir) main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset) diff --git a/exps/experimental/vis-nats-bench-algos.py b/exps/experimental/vis-nats-bench-algos.py index 8e53f98..2d9ea17 100644 --- a/exps/experimental/vis-nats-bench-algos.py +++ b/exps/experimental/vis-nats-bench-algos.py @@ -42,7 +42,9 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): for alg, path in alg2path.items(): data = torch.load(path) for index, info in data.items(): - info["time_w_arch"] = [(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])] + info["time_w_arch"] = [ + (x, y) for x, y in zip(info["all_total_times"], info["all_archs"]) + ] for j, arch in enumerate(info["all_archs"]): assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( alg, search_space, dataset, index, j @@ -57,12 +59,16 @@ def query_performance(api, data, dataset, ticket): time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket)) time_a, arch_a = time_w_arch[0] time_b, arch_b = time_w_arch[1] - info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) - info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False) + info_a = api.get_more_info( + arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) + info_b = api.get_more_info( + arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False + ) accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"] - interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (ticket - time_a) / ( - time_b - time_a - ) * accuracy_b + interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + ( + ticket - time_a + ) / (time_b - time_a) * accuracy_b results.append(interplate) return sum(results) / len(results) @@ -85,7 +91,11 @@ y_max_s = { ("ImageNet16-120", "sss"): 46, } -name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} +name2label = { + "cifar10": "CIFAR-10", + "cifar100": "CIFAR-100", + "ImageNet16-120": "ImageNet-16-120", +} def visualize_curve(api, vis_save_dir, search_space, max_time): @@ -100,7 +110,9 @@ def visualize_curve(api, vis_save_dir, search_space, max_time): alg2data = fetch_data(search_space=search_space, dataset=dataset) alg2accuracies = OrderedDict() total_tickets = 150 - time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)] + time_tickets = [ + float(i) / total_tickets * max_time for i in range(total_tickets) + ] colors = ["b", "g", "c", "m", "y"] ax.set_xlim(0, 200) ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) @@ -111,10 +123,20 @@ def visualize_curve(api, vis_save_dir, search_space, max_time): accuracy = query_performance(api, data, dataset, ticket) accuracies.append(accuracy) alg2accuracies[alg] = accuracies - ax.plot([x / 100 for x in time_tickets], accuracies, c=colors[idx], label="{:}".format(alg)) + ax.plot( + [x / 100 for x in time_tickets], + accuracies, + c=colors[idx], + label="{:}".format(alg), + ) ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize) - ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize) - ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4) + ax.set_ylabel( + "Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize + ) + ax.set_title( + "Searching results on {:}".format(name2label[dataset]), + fontsize=LabelSize + 4, + ) ax.legend(loc=4, fontsize=LegendFontsize) fig, axs = plt.subplots(1, 3, figsize=figsize) @@ -129,12 +151,25 @@ def visualize_curve(api, vis_save_dir, search_space, max_time): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument( - "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log." + parser = argparse.ArgumentParser( + description="NAS-Bench-X", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--save_dir", + type=str, + default="output/vis-nas-bench/nas-algos", + help="Folder to save checkpoints and log.", + ) + parser.add_argument( + "--search_space", + type=str, + choices=["tss", "sss"], + help="Choose the search space.", + ) + parser.add_argument( + "--max_time", type=float, default=20000, help="The maximum time budget." ) - parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.") - parser.add_argument("--max_time", type=float, default=20000, help="The maximum time budget.") args = parser.parse_args() save_dir = Path(args.save_dir) diff --git a/exps/experimental/vis-nats-bench-ws.py b/exps/experimental/vis-nats-bench-ws.py index c3715b3..47500e1 100644 --- a/exps/experimental/vis-nats-bench-ws.py +++ b/exps/experimental/vis-nats-bench-ws.py @@ -29,7 +29,9 @@ from log_utils import time_string # def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'): -def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"): +def fetch_data( + root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3" +): ss_dir = "{:}-{:}".format(root_dir, search_space) alg2name, alg2path = OrderedDict(), OrderedDict() seeds = [777, 888, 999] @@ -45,8 +47,12 @@ def fetch_data(root_dir="./output/search", search_space="tss", dataset=None, suf # alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) - alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(suffix) - alg2name["masking + Gumbel-Softmax"] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix) + alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format( + suffix + ) + alg2name[ + "masking + Gumbel-Softmax" + ] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix) alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix) for alg, name in alg2name.items(): alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth") @@ -86,7 +92,11 @@ y_max_s = { ("ImageNet16-120", "sss"): 46, } -name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} +name2label = { + "cifar10": "CIFAR-10", + "cifar100": "CIFAR-100", + "ImageNet16-120": "ImageNet-16-120", +} def visualize_curve(api, vis_save_dir, search_space): @@ -111,10 +121,17 @@ def visualize_curve(api, vis_save_dir, search_space): try: structures, accs = [_[iepoch - 1] for _ in data], [] except: - raise ValueError("This alg {:} on {:} has invalid checkpoints.".format(alg, dataset)) + raise ValueError( + "This alg {:} on {:} has invalid checkpoints.".format( + alg, dataset + ) + ) for structure in structures: info = api.get_more_info( - structure, dataset=dataset, hp=90 if api.search_space_name == "size" else 200, is_random=False + structure, + dataset=dataset, + hp=90 if api.search_space_name == "size" else 200, + is_random=False, ) accs.append(info["test-accuracy"]) accuracies.append(sum(accs) / len(accs)) @@ -122,8 +139,13 @@ def visualize_curve(api, vis_save_dir, search_space): alg2accuracies[alg] = accuracies ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg)) ax.set_xlabel("The searching epoch", fontsize=LabelSize) - ax.set_ylabel("Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize) - ax.set_title("Searching results on {:}".format(name2label[dataset]), fontsize=LabelSize + 4) + ax.set_ylabel( + "Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize + ) + ax.set_title( + "Searching results on {:}".format(name2label[dataset]), + fontsize=LabelSize + 4, + ) ax.legend(loc=4, fontsize=LegendFontsize) fig, axs = plt.subplots(1, 3, figsize=figsize) @@ -138,12 +160,22 @@ def visualize_curve(api, vis_save_dir, search_space): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument( - "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log." + parser = argparse.ArgumentParser( + description="NAS-Bench-X", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - "--search_space", type=str, default="tss", choices=["tss", "sss"], help="Choose the search space." + "--save_dir", + type=str, + default="output/vis-nas-bench/nas-algos", + help="Folder to save checkpoints and log.", + ) + parser.add_argument( + "--search_space", + type=str, + default="tss", + choices=["tss", "sss"], + help="Choose the search space.", ) args = parser.parse_args() diff --git a/exps/experimental/visualize-nas-bench-x.py b/exps/experimental/visualize-nas-bench-x.py index a8e4786..505e2bb 100644 --- a/exps/experimental/visualize-nas-bench-x.py +++ b/exps/experimental/visualize-nas-bench-x.py @@ -33,9 +33,15 @@ def visualize_info(api, vis_save_dir, indicator): # print ('{:} start to visualize {:} information'.format(time_string(), api)) vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) - cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) - imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar10", indicator + ) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar100", indicator + ) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "ImageNet16-120", indicator + ) cifar010_info = torch.load(cifar010_cache_path) cifar100_info = torch.load(cifar100_cache_path) imagenet_info = torch.load(imagenet_cache_path) @@ -63,8 +69,15 @@ def visualize_info(api, vis_save_dir, indicator): plt.xlim(min(indexes), max(indexes)) plt.ylim(min(indexes), max(indexes)) # plt.ylabel('y').set_rotation(30) - plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical") - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize) + plt.yticks( + np.arange(min(indexes), max(indexes), max(indexes) // 3), + fontsize=LegendFontsize, + rotation="vertical", + ) + plt.xticks( + np.arange(min(indexes), max(indexes), max(indexes) // 5), + fontsize=LegendFontsize, + ) ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8) ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) @@ -100,7 +113,9 @@ def visualize_sss_info(api, dataset, vis_save_dir): train_accs.append(info["train-accuracy"]) test_accs.append(info["test-accuracy"]) if dataset == "cifar10": - info = api.get_more_info(index, "cifar10-valid", hp="90", is_random=False) + info = api.get_more_info( + index, "cifar10-valid", hp="90", is_random=False + ) valid_accs.append(info["valid-accuracy"]) else: valid_accs.append(info["valid-accuracy"]) @@ -272,7 +287,9 @@ def visualize_tss_info(api, dataset, vis_save_dir): train_accs.append(info["train-accuracy"]) test_accs.append(info["test-accuracy"]) if dataset == "cifar10": - info = api.get_more_info(index, "cifar10-valid", hp="200", is_random=False) + info = api.get_more_info( + index, "cifar10-valid", hp="200", is_random=False + ) valid_accs.append(info["valid-accuracy"]) else: valid_accs.append(info["valid-accuracy"]) @@ -297,7 +314,9 @@ def visualize_tss_info(api, dataset, vis_save_dir): ) print("{:} collect data done.".format(time_string())) - resnet = ["|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"] + resnet = [ + "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|" + ] resnet_indexes = [api.query_index_by_arch(x) for x in resnet] largest_indexes = [ api.query_index_by_arch( @@ -429,9 +448,15 @@ def visualize_rank_info(api, vis_save_dir, indicator): # print ('{:} start to visualize {:} information'.format(time_string(), api)) vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) - cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) - imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar10", indicator + ) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar100", indicator + ) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "ImageNet16-120", indicator + ) cifar010_info = torch.load(cifar010_cache_path) cifar100_info = torch.load(cifar100_cache_path) imagenet_info = torch.load(imagenet_cache_path) @@ -466,8 +491,17 @@ def visualize_rank_info(api, vis_save_dir, indicator): ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8) ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) - ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)) - ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name)) + ax.scatter( + [-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name) + ) + ax.scatter( + [-1], + [-1], + marker="o", + s=100, + c="tab:blue", + label="{:} validation".format(name), + ) ax.legend(loc=4, fontsize=LegendFontsize) ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize) ax.set_ylabel("architecture ranking", fontsize=LabelSize) @@ -479,9 +513,13 @@ def visualize_rank_info(api, vis_save_dir, indicator): labels = get_labels(imagenet_info) plot_ax(labels, ax3, "ImageNet-16-120") - save_path = (vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)).resolve() + save_path = ( + vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator) + ).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") - save_path = (vis_save_dir / "{:}-same-relative-rank.png".format(indicator)).resolve() + save_path = ( + vis_save_dir / "{:}-same-relative-rank.png".format(indicator) + ).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") print("{:} save into {:}".format(time_string(), save_path)) plt.close("all") @@ -502,9 +540,15 @@ def visualize_all_rank_info(api, vis_save_dir, indicator): # print ('{:} start to visualize {:} information'.format(time_string(), api)) vis_save_dir.mkdir(parents=True, exist_ok=True) - cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) - cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) - imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) + cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar10", indicator + ) + cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "cifar100", indicator + ) + imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( + "ImageNet16-120", indicator + ) cifar010_info = torch.load(cifar010_cache_path) cifar100_info = torch.load(cifar100_cache_path) imagenet_info = torch.load(imagenet_cache_path) @@ -570,7 +614,9 @@ def visualize_all_rank_info(api, vis_save_dir, indicator): yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], ) ax1.set_title("Correlation coefficient over ALL candidates") - ax2.set_title("Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)) + ax2.set_title( + "Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar) + ) save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") print("{:} save into {:}".format(time_string(), save_path)) @@ -578,9 +624,15 @@ def visualize_all_rank_info(api, vis_save_dir, indicator): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + description="NAS-Bench-X", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( - "--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log." + "--save_dir", + type=str, + default="output/vis-nas-bench", + help="Folder to save checkpoints and log.", ) # use for train the model args = parser.parse_args() diff --git a/exps/prepare.py b/exps/prepare.py index 8e66f85..6ac50a4 100644 --- a/exps/prepare.py +++ b/exps/prepare.py @@ -16,7 +16,8 @@ lib_dir = (Path(__file__).parent / ".." / "lib").resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) parser = argparse.ArgumentParser( - description="Prepare splits for searching", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="Prepare splits for searching", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("--name", type=str, help="The dataset name.") parser.add_argument("--root", type=str, help="The directory to the dataset.") @@ -73,7 +74,15 @@ def main(): for index in valid: class2numV[targets[index]] += 1 class2numT, class2numV = dict(class2numT), dict(class2numV) - torch.save({"train": train, "valid": valid, "class2numTrain": class2numT, "class2numValid": class2numV}, save_path) + torch.save( + { + "train": train, + "valid": valid, + "class2numTrain": class2numT, + "class2numValid": class2numV, + }, + save_path, + ) print("-" * 80) diff --git a/exps/search-shape.py b/exps/search-shape.py index ed52ba7..2511752 100644 --- a/exps/search-shape.py +++ b/exps/search-shape.py @@ -14,7 +14,11 @@ lib_dir = (Path(__file__).parent / ".." / "lib").resolve() print("lib_dir : {:}".format(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from config_utils import load_config, configure2str, obtain_search_single_args as obtain_args +from config_utils import ( + load_config, + configure2str, + obtain_search_single_args as obtain_args, +) from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint from procedures import get_optim_scheduler, get_procedures from datasets import get_datasets, SearchDataset @@ -34,10 +38,16 @@ def main(args): logger = prepare_logger(args) # prepare dataset - train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) + train_data, valid_data, xshape, class_num = get_datasets( + args.dataset, args.data_path, args.cutout_length + ) # train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader( - valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True + valid_data, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True, ) split_file_path = Path(args.split_path) @@ -48,9 +58,9 @@ def main(args): assert ( len(set(train_split).intersection(set(valid_split))) == 0 ), "There should be 0 element that belongs to both train and valid" - assert len(train_split) + len(valid_split) == len(train_data), "{:} + {:} vs {:}".format( - len(train_split), len(valid_split), len(train_data) - ) + assert len(train_split) + len(valid_split) == len( + train_data + ), "{:} + {:} vs {:}".format(len(train_split), len(valid_split), len(train_data)) search_dataset = SearchDataset(args.dataset, train_data, train_split, valid_split) search_train_loader = torch.utils.data.DataLoader( @@ -76,12 +86,18 @@ def main(args): sampler=None, ) # get configures - model_config = load_config(args.model_config, {"class_num": class_num, "search_mode": args.search_shape}, logger) + model_config = load_config( + args.model_config, + {"class_num": class_num, "search_mode": args.search_shape}, + logger, + ) # obtain the model search_model = obtain_search_model(model_config) MAX_FLOP, param = get_model_infos(search_model, xshape) - optim_config = load_config(args.optim_config, {"class_num": class_num, "FLOP": MAX_FLOP}, logger) + optim_config = load_config( + args.optim_config, {"class_num": class_num, "FLOP": MAX_FLOP}, logger + ) logger.log("Model Information : {:}".format(search_model.get_message())) logger.log("MAX_FLOP = {:} M".format(MAX_FLOP)) logger.log("Params = {:} M".format(param)) @@ -89,7 +105,9 @@ def main(args): logger.log("search-data: {:}".format(search_dataset)) logger.log("search_train_loader : {:} samples".format(len(train_split))) logger.log("search_valid_loader : {:} samples".format(len(valid_split))) - base_optimizer, scheduler, criterion = get_optim_scheduler(search_model.base_parameters(), optim_config) + base_optimizer, scheduler, criterion = get_optim_scheduler( + search_model.base_parameters(), optim_config + ) arch_optimizer = torch.optim.Adam( search_model.arch_parameters(), lr=optim_config.arch_LR, @@ -101,7 +119,11 @@ def main(args): logger.log("scheduler : {:}".format(scheduler)) logger.log("criterion : {:}".format(criterion)) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() # load checkpoint @@ -114,14 +136,24 @@ def main(args): resume_path = last_info else: raise ValueError("Something is wrong.") - logger.log("=> loading checkpoint of the last-info '{:}' start".format(resume_path)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(resume_path) + ) checkpoint = torch.load(resume_path) if "last_checkpoint" in checkpoint: last_checkpoint_path = checkpoint["last_checkpoint"] if not last_checkpoint_path.exists(): - logger.log("Does not find {:}, try another path".format(last_checkpoint_path)) - last_checkpoint_path = resume_path.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name - assert last_checkpoint_path.exists(), "can not find the checkpoint from {:}".format(last_checkpoint_path) + logger.log( + "Does not find {:}, try another path".format(last_checkpoint_path) + ) + last_checkpoint_path = ( + resume_path.parent + / last_checkpoint_path.parent.name + / last_checkpoint_path.name + ) + assert ( + last_checkpoint_path.exists() + ), "can not find the checkpoint from {:}".format(last_checkpoint_path) checkpoint = torch.load(last_checkpoint_path) start_epoch = checkpoint["epoch"] + 1 search_model.load_state_dict(checkpoint["search_model"]) @@ -132,11 +164,22 @@ def main(args): arch_genotypes = checkpoint["arch_genotypes"] discrepancies = checkpoint["discrepancies"] logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(resume_path, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + resume_path, start_epoch + ) ) else: - logger.log("=> do not find the last-info file : {:} or resume : {:}".format(last_info, args.resume)) - start_epoch, valid_accuracies, arch_genotypes, discrepancies = 0, {"best": -1}, {}, {} + logger.log( + "=> do not find the last-info file : {:} or resume : {:}".format( + last_info, args.resume + ) + ) + start_epoch, valid_accuracies, arch_genotypes, discrepancies = ( + 0, + {"best": -1}, + {}, + {}, + ) # main procedure train_func, valid_func = get_procedures(args.procedure) @@ -144,15 +187,26 @@ def main(args): start_time, epoch_time = time.time(), AverageMeter() for epoch in range(start_epoch, total_epoch): scheduler.update(epoch, 0.0) - search_model.set_tau(args.gumbel_tau_max, args.gumbel_tau_min, epoch * 1.0 / total_epoch) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch), True)) + search_model.set_tau( + args.gumbel_tau_max, args.gumbel_tau_min, epoch * 1.0 / total_epoch + ) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.avg * (total_epoch - epoch), True) + ) epoch_str = "epoch={:03d}/{:03d}".format(epoch, total_epoch) LRs = scheduler.get_lr() find_best = False logger.log( "\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}, tau={:}, FLOP={:.2f}".format( - time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler, search_model.tau, MAX_FLOP + time_string(), + epoch_str, + need_time, + min(LRs), + max(LRs), + scheduler, + search_model.tau, + MAX_FLOP, ) ) @@ -177,10 +231,17 @@ def main(args): # log the results logger.log( "***{:s}*** TRAIN [{:}] base-loss = {:.6f}, arch-loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}".format( - time_string(), epoch_str, train_base_loss, train_arch_loss, train_acc1, train_acc5 + time_string(), + epoch_str, + train_base_loss, + train_arch_loss, + train_acc1, + train_acc5, ) ) - cur_FLOP, genotype = search_model.get_flop("genotype", model_config._asdict(), None) + cur_FLOP, genotype = search_model.get_flop( + "genotype", model_config._asdict(), None + ) arch_genotypes[epoch] = genotype arch_genotypes["last"] = genotype logger.log("[{:}] genotype : {:}".format(epoch_str, genotype)) @@ -189,7 +250,11 @@ def main(args): discrepancies[epoch] = discrepancy logger.log( "[{:}] FLOP : {:.2f} MB, ratio : {:.4f}, Expected-ratio : {:.4f}, Discrepancy : {:.3f}".format( - epoch_str, cur_FLOP, cur_FLOP / MAX_FLOP, args.FLOP_ratio, np.mean(discrepancy) + epoch_str, + cur_FLOP, + cur_FLOP / MAX_FLOP, + args.FLOP_ratio, + np.mean(discrepancy), ) ) @@ -202,7 +267,12 @@ def main(args): if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): logger.log("-" * 150) valid_loss, valid_acc1, valid_acc5 = valid_func( - search_valid_loader, network, criterion, epoch_str, args.print_freq_eval, logger + search_valid_loader, + network, + criterion, + epoch_str, + args.print_freq_eval, + logger, ) valid_accuracies[epoch] = valid_acc1 logger.log( @@ -222,7 +292,12 @@ def main(args): find_best = True logger.log( "Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.".format( - epoch, valid_acc1, valid_acc5, 100 - valid_acc1, 100 - valid_acc5, model_best_path + epoch, + valid_acc1, + valid_acc5, + 100 - valid_acc1, + 100 - valid_acc5, + model_best_path, ) ) @@ -262,9 +337,15 @@ def main(args): logger.log("") logger.log("-" * 100) - last_config_path = logger.path("log") / "seed-{:}-last.config".format(args.rand_seed) + last_config_path = logger.path("log") / "seed-{:}-last.config".format( + args.rand_seed + ) configure2str(arch_genotypes["last"], str(last_config_path)) - logger.log("save the last config int {:} :\n{:}".format(last_config_path, arch_genotypes["last"])) + logger.log( + "save the last config int {:} :\n{:}".format( + last_config_path, arch_genotypes["last"] + ) + ) best_arch, valid_acc = arch_genotypes["best"], valid_accuracies["best"] for key, config in arch_genotypes.items(): @@ -275,11 +356,17 @@ def main(args): if valid_acc < valid_accuracies[key]: best_arch, valid_acc = config, valid_accuracies[key] print( - "Best-Arch : {:}\nRatio={:}, Valid-ACC={:}".format(best_arch, best_arch["estimated_FLOP"] / MAX_FLOP, valid_acc) + "Best-Arch : {:}\nRatio={:}, Valid-ACC={:}".format( + best_arch, best_arch["estimated_FLOP"] / MAX_FLOP, valid_acc + ) + ) + best_config_path = logger.path("log") / "seed-{:}-best.config".format( + args.rand_seed ) - best_config_path = logger.path("log") / "seed-{:}-best.config".format(args.rand_seed) configure2str(best_arch, str(best_config_path)) - logger.log("save the last config int {:} :\n{:}".format(best_config_path, best_arch)) + logger.log( + "save the last config int {:} :\n{:}".format(best_config_path, best_arch) + ) logger.log("\n" + "-" * 200) logger.log( "Finish training/validation in {:}, and save final checkpoint into {:}".format( diff --git a/exps/search-transformable.py b/exps/search-transformable.py index 9b81a1f..f22f346 100644 --- a/exps/search-transformable.py +++ b/exps/search-transformable.py @@ -35,10 +35,16 @@ def main(args): logger = prepare_logger(args) # prepare dataset - train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) + train_data, valid_data, xshape, class_num = get_datasets( + args.dataset, args.data_path, args.cutout_length + ) # train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader( - valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True + valid_data, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True, ) split_file_path = Path(args.split_path) @@ -49,9 +55,9 @@ def main(args): assert ( len(set(train_split).intersection(set(valid_split))) == 0 ), "There should be 0 element that belongs to both train and valid" - assert len(train_split) + len(valid_split) == len(train_data), "{:} + {:} vs {:}".format( - len(train_split), len(valid_split), len(train_data) - ) + assert len(train_split) + len(valid_split) == len( + train_data + ), "{:} + {:} vs {:}".format(len(train_split), len(valid_split), len(train_data)) search_dataset = SearchDataset(args.dataset, train_data, train_split, valid_split) search_train_loader = torch.utils.data.DataLoader( @@ -78,18 +84,26 @@ def main(args): ) # get configures if args.ablation_num_select is None or args.ablation_num_select <= 0: - model_config = load_config(args.model_config, {"class_num": class_num, "search_mode": "shape"}, logger) + model_config = load_config( + args.model_config, {"class_num": class_num, "search_mode": "shape"}, logger + ) else: model_config = load_config( args.model_config, - {"class_num": class_num, "search_mode": "ablation", "num_random_select": args.ablation_num_select}, + { + "class_num": class_num, + "search_mode": "ablation", + "num_random_select": args.ablation_num_select, + }, logger, ) # obtain the model search_model = obtain_search_model(model_config) MAX_FLOP, param = get_model_infos(search_model, xshape) - optim_config = load_config(args.optim_config, {"class_num": class_num, "FLOP": MAX_FLOP}, logger) + optim_config = load_config( + args.optim_config, {"class_num": class_num, "FLOP": MAX_FLOP}, logger + ) logger.log("Model Information : {:}".format(search_model.get_message())) logger.log("MAX_FLOP = {:} M".format(MAX_FLOP)) logger.log("Params = {:} M".format(param)) @@ -97,7 +111,9 @@ def main(args): logger.log("search-data: {:}".format(search_dataset)) logger.log("search_train_loader : {:} samples".format(len(train_split))) logger.log("search_valid_loader : {:} samples".format(len(valid_split))) - base_optimizer, scheduler, criterion = get_optim_scheduler(search_model.base_parameters(), optim_config) + base_optimizer, scheduler, criterion = get_optim_scheduler( + search_model.base_parameters(), optim_config + ) arch_optimizer = torch.optim.Adam( search_model.arch_parameters(optim_config.arch_LR), lr=optim_config.arch_LR, @@ -109,7 +125,11 @@ def main(args): logger.log("scheduler : {:}".format(scheduler)) logger.log("criterion : {:}".format(criterion)) - last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best") + last_info, model_base_path, model_best_path = ( + logger.path("info"), + logger.path("model"), + logger.path("best"), + ) network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() # load checkpoint @@ -122,14 +142,24 @@ def main(args): resume_path = last_info else: raise ValueError("Something is wrong.") - logger.log("=> loading checkpoint of the last-info '{:}' start".format(resume_path)) + logger.log( + "=> loading checkpoint of the last-info '{:}' start".format(resume_path) + ) checkpoint = torch.load(resume_path) if "last_checkpoint" in checkpoint: last_checkpoint_path = checkpoint["last_checkpoint"] if not last_checkpoint_path.exists(): - logger.log("Does not find {:}, try another path".format(last_checkpoint_path)) - last_checkpoint_path = resume_path.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name - assert last_checkpoint_path.exists(), "can not find the checkpoint from {:}".format(last_checkpoint_path) + logger.log( + "Does not find {:}, try another path".format(last_checkpoint_path) + ) + last_checkpoint_path = ( + resume_path.parent + / last_checkpoint_path.parent.name + / last_checkpoint_path.name + ) + assert ( + last_checkpoint_path.exists() + ), "can not find the checkpoint from {:}".format(last_checkpoint_path) checkpoint = torch.load(last_checkpoint_path) start_epoch = checkpoint["epoch"] + 1 # for key, value in checkpoint['search_model'].items(): @@ -143,11 +173,23 @@ def main(args): discrepancies = checkpoint["discrepancies"] max_bytes = checkpoint["max_bytes"] logger.log( - "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(resume_path, start_epoch) + "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( + resume_path, start_epoch + ) ) else: - logger.log("=> do not find the last-info file : {:} or resume : {:}".format(last_info, args.resume)) - start_epoch, valid_accuracies, arch_genotypes, discrepancies, max_bytes = 0, {"best": -1}, {}, {}, {} + logger.log( + "=> do not find the last-info file : {:} or resume : {:}".format( + last_info, args.resume + ) + ) + start_epoch, valid_accuracies, arch_genotypes, discrepancies, max_bytes = ( + 0, + {"best": -1}, + {}, + {}, + {}, + ) # main procedure train_func, valid_func = get_procedures(args.procedure) @@ -155,15 +197,26 @@ def main(args): start_time, epoch_time = time.time(), AverageMeter() for epoch in range(start_epoch, total_epoch): scheduler.update(epoch, 0.0) - search_model.set_tau(args.gumbel_tau_max, args.gumbel_tau_min, epoch * 1.0 / total_epoch) - need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch), True)) + search_model.set_tau( + args.gumbel_tau_max, args.gumbel_tau_min, epoch * 1.0 / total_epoch + ) + need_time = "Time Left: {:}".format( + convert_secs2time(epoch_time.avg * (total_epoch - epoch), True) + ) epoch_str = "epoch={:03d}/{:03d}".format(epoch, total_epoch) LRs = scheduler.get_lr() find_best = False logger.log( "\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}, tau={:}, FLOP={:.2f}".format( - time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler, search_model.tau, MAX_FLOP + time_string(), + epoch_str, + need_time, + min(LRs), + max(LRs), + scheduler, + search_model.tau, + MAX_FLOP, ) ) @@ -188,21 +241,35 @@ def main(args): # log the results logger.log( "***{:s}*** TRAIN [{:}] base-loss = {:.6f}, arch-loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}".format( - time_string(), epoch_str, train_base_loss, train_arch_loss, train_acc1, train_acc5 + time_string(), + epoch_str, + train_base_loss, + train_arch_loss, + train_acc1, + train_acc5, ) ) - cur_FLOP, genotype = search_model.get_flop("genotype", model_config._asdict(), None) + cur_FLOP, genotype = search_model.get_flop( + "genotype", model_config._asdict(), None + ) arch_genotypes[epoch] = genotype arch_genotypes["last"] = genotype logger.log("[{:}] genotype : {:}".format(epoch_str, genotype)) # save the configuration - configure2str(genotype, str(logger.path("log") / "seed-{:}-temp.config".format(args.rand_seed))) + configure2str( + genotype, + str(logger.path("log") / "seed-{:}-temp.config".format(args.rand_seed)), + ) arch_info, discrepancy = search_model.get_arch_info() logger.log(arch_info) discrepancies[epoch] = discrepancy logger.log( "[{:}] FLOP : {:.2f} MB, ratio : {:.4f}, Expected-ratio : {:.4f}, Discrepancy : {:.3f}".format( - epoch_str, cur_FLOP, cur_FLOP / MAX_FLOP, args.FLOP_ratio, np.mean(discrepancy) + epoch_str, + cur_FLOP, + cur_FLOP / MAX_FLOP, + args.FLOP_ratio, + np.mean(discrepancy), ) ) @@ -215,7 +282,12 @@ def main(args): if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): logger.log("-" * 150) valid_loss, valid_acc1, valid_acc5 = valid_func( - search_valid_loader, network, criterion, epoch_str, args.print_freq_eval, logger + search_valid_loader, + network, + criterion, + epoch_str, + args.print_freq_eval, + logger, ) valid_accuracies[epoch] = valid_acc1 logger.log( @@ -235,15 +307,26 @@ def main(args): find_best = True logger.log( "Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.".format( - epoch, valid_acc1, valid_acc5, 100 - valid_acc1, 100 - valid_acc5, model_best_path + epoch, + valid_acc1, + valid_acc5, + 100 - valid_acc1, + 100 - valid_acc5, + model_best_path, ) ) # log the GPU memory usage # num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0 - num_bytes = torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 + num_bytes = ( + torch.cuda.max_memory_cached(next(network.parameters()).device) * 1.0 + ) logger.log( "[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]".format( - next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9 + next(network.parameters()).device, + int(num_bytes), + num_bytes / 1e3, + num_bytes / 1e6, + num_bytes / 1e9, ) ) max_bytes[epoch] = num_bytes @@ -285,9 +368,15 @@ def main(args): logger.log("") logger.log("-" * 100) - last_config_path = logger.path("log") / "seed-{:}-last.config".format(args.rand_seed) + last_config_path = logger.path("log") / "seed-{:}-last.config".format( + args.rand_seed + ) configure2str(arch_genotypes["last"], str(last_config_path)) - logger.log("save the last config int {:} :\n{:}".format(last_config_path, arch_genotypes["last"])) + logger.log( + "save the last config int {:} :\n{:}".format( + last_config_path, arch_genotypes["last"] + ) + ) best_arch, valid_acc = arch_genotypes["best"], valid_accuracies["best"] for key, config in arch_genotypes.items(): @@ -298,15 +387,23 @@ def main(args): if valid_acc <= valid_accuracies[key]: best_arch, valid_acc = config, valid_accuracies[key] print( - "Best-Arch : {:}\nRatio={:}, Valid-ACC={:}".format(best_arch, best_arch["estimated_FLOP"] / MAX_FLOP, valid_acc) + "Best-Arch : {:}\nRatio={:}, Valid-ACC={:}".format( + best_arch, best_arch["estimated_FLOP"] / MAX_FLOP, valid_acc + ) + ) + best_config_path = logger.path("log") / "seed-{:}-best.config".format( + args.rand_seed ) - best_config_path = logger.path("log") / "seed-{:}-best.config".format(args.rand_seed) configure2str(best_arch, str(best_config_path)) - logger.log("save the last config int {:} :\n{:}".format(best_config_path, best_arch)) + logger.log( + "save the last config int {:} :\n{:}".format(best_config_path, best_arch) + ) logger.log("\n" + "-" * 200) logger.log( "Finish training/validation in {:} with Max-GPU-Memory of {:.2f} GB, and save final checkpoint into {:}".format( - convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e9, logger.path("info") + convert_secs2time(epoch_time.sum, True), + max(v for k, v in max_bytes.items()) / 1e9, + logger.path("info"), ) ) logger.close() diff --git a/exps/show-dataset.py b/exps/show-dataset.py index 2a79f33..ee76831 100644 --- a/exps/show-dataset.py +++ b/exps/show-dataset.py @@ -24,11 +24,17 @@ from nats_bench import create def show_imagenet_16_120(dataset_dir=None): if dataset_dir is None: torch_home_dir = ( - os.environ["TORCH_HOME"] if "TORCH_HOME" in os.environ else os.path.join(os.environ["HOME"], ".torch") + os.environ["TORCH_HOME"] + if "TORCH_HOME" in os.environ + else os.path.join(os.environ["HOME"], ".torch") ) dataset_dir = os.path.join(torch_home_dir, "cifar.python", "ImageNet16") - train_data, valid_data, xshape, class_num = get_datasets("ImageNet16-120", dataset_dir, -1) - split_info = load_config("configs/nas-benchmark/ImageNet16-120-split.txt", None, None) + train_data, valid_data, xshape, class_num = get_datasets( + "ImageNet16-120", dataset_dir, -1 + ) + split_info = load_config( + "configs/nas-benchmark/ImageNet16-120-split.txt", None, None + ) print("=" * 10 + " ImageNet-16-120 " + "=" * 10) print("Training Data: {:}".format(train_data)) print("Evaluation Data: {:}".format(valid_data)) @@ -45,8 +51,12 @@ if __name__ == "__main__": test_acc_200e = [] for index in range(10000): info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="12") - valid_acc_12e.append(info["valid-accuracy"]) # the validation accuracy after training the model by 12 epochs - test_acc_12e.append(info["test-accuracy"]) # the test accuracy after training the model by 12 epochs + valid_acc_12e.append( + info["valid-accuracy"] + ) # the validation accuracy after training the model by 12 epochs + test_acc_12e.append( + info["test-accuracy"] + ) # the test accuracy after training the model by 12 epochs info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="200") test_acc_200e.append( info["test-accuracy"] diff --git a/exps/trading/baselines.py b/exps/trading/baselines.py index 6228fd6..21efdb5 100644 --- a/exps/trading/baselines.py +++ b/exps/trading/baselines.py @@ -65,7 +65,11 @@ def retrieve_configs(): path = config_dir / name assert path.exists(), "{:} does not exist.".format(path) alg2paths[alg] = str(path) - print("The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format(idx, len(alg2names), alg, path)) + print( + "The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format( + idx, len(alg2names), alg, path + ) + ) return alg2paths @@ -100,13 +104,30 @@ if __name__ == "__main__": alg2paths = retrieve_configs() parser = argparse.ArgumentParser("Baselines") - parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.") parser.add_argument( - "--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator." + "--save_dir", + type=str, + default="./outputs/qlib-baselines", + help="The checkpoint directory.", + ) + parser.add_argument( + "--market", + type=str, + default="all", + choices=["csi100", "csi300", "all"], + help="The market indicator.", ) parser.add_argument("--times", type=int, default=10, help="The repeated run times.") - parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") - parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.") + parser.add_argument( + "--gpu", type=int, default=0, help="The GPU ID used for train / test." + ) + parser.add_argument( + "--alg", + type=str, + choices=list(alg2paths.keys()), + required=True, + help="The algorithm name.", + ) args = parser.parse_args() main(args, alg2paths[args.alg]) diff --git a/exps/trading/organize_results.py b/exps/trading/organize_results.py index 5901a47..b29615d 100644 --- a/exps/trading/organize_results.py +++ b/exps/trading/organize_results.py @@ -55,7 +55,13 @@ class QResult: new_dict[xkey] = values return new_dict - def info(self, keys: List[Text], separate: Text = "& ", space: int = 25, verbose: bool = True): + def info( + self, + keys: List[Text], + separate: Text = "& ", + space: int = 25, + verbose: bool = True, + ): avaliable_keys = [] for key in keys: if key not in self.result: @@ -89,7 +95,10 @@ def compare_results(heads, values, names, space=10, verbose=True, sort_key=False if verbose: print(info_str_dict["head"]) if sort_key: - lines = sorted(list(zip(values, info_str_dict["lines"])), key=lambda x: float(x[0].split(" ")[0])) + lines = sorted( + list(zip(values, info_str_dict["lines"])), + key=lambda x: float(x[0].split(" ")[0]), + ) lines = [x[1] for x in lines] else: lines = info_str_dict["lines"] @@ -136,7 +145,11 @@ def query_info(save_dir, verbose): if verbose: print( "====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format( - idx + 1, len(experiments), experiment.name, len(recorders), len(recorders) + not_finished + idx + 1, + len(experiments), + experiment.name, + len(recorders), + len(recorders) + not_finished, ) ) result = QResult() @@ -149,7 +162,9 @@ def query_info(save_dir, verbose): head_strs.append(head_str) value_strs.append(value_str) names.append(experiment.name) - info_str_dict = compare_results(head_strs, value_strs, names, space=10, verbose=verbose) + info_str_dict = compare_results( + head_strs, value_strs, names, space=10, verbose=verbose + ) info_value_dict = dict(heads=head_strs, values=value_strs, names=names) return info_str_dict, info_value_dict @@ -169,9 +184,18 @@ if __name__ == "__main__": raise argparse.ArgumentTypeError("Boolean value expected.") parser.add_argument( - "--save_dir", type=str, nargs="+", default=["./outputs/qlib-baselines"], help="The checkpoint directory." + "--save_dir", + type=str, + nargs="+", + default=["./outputs/qlib-baselines"], + help="The checkpoint directory.", + ) + parser.add_argument( + "--verbose", + type=str2bool, + default=False, + help="Print detailed log information or not.", ) - parser.add_argument("--verbose", type=str2bool, default=False, help="Print detailed log information or not.") args = parser.parse_args() print("Show results of {:}".format(args.save_dir)) @@ -184,4 +208,11 @@ if __name__ == "__main__": _, info_dict = query_info(save_dir, args.verbose) all_info_dict.append(info_dict) info_dict = QResult.merge_dict(all_info_dict) - compare_results(info_dict["heads"], info_dict["values"], info_dict["names"], space=10, verbose=True, sort_key=True) + compare_results( + info_dict["heads"], + info_dict["values"], + info_dict["names"], + space=10, + verbose=True, + sort_key=True, + ) diff --git a/exps/trading/workflow_tt.py b/exps/trading/workflow_tt.py index 08be891..3df31ba 100644 --- a/exps/trading/workflow_tt.py +++ b/exps/trading/workflow_tt.py @@ -39,7 +39,10 @@ def main(xargs): "fit_end_time": "2014-12-31", "instruments": xargs.market, "infer_processors": [ - {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": True}}, + { + "class": "RobustZScoreNorm", + "kwargs": {"fields_group": "feature", "clip_outlier": True}, + }, {"class": "Fillna", "kwargs": {"fields_group": "feature"}}, ], "learn_processors": [ @@ -90,7 +93,11 @@ def main(xargs): } record_config = [ - {"class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict()}, + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + "kwargs": dict(), + }, { "class": "SigAnaRecord", "module_path": "qlib.workflow.record_temp", @@ -111,18 +118,37 @@ def main(xargs): for irun in range(xargs.times): xmodel_config = model_config.copy() xmodel_config = update_gpu(xmodel_config, xargs.gpu) - task_config = dict(model=xmodel_config, dataset=dataset_config, record=record_config) + task_config = dict( + model=xmodel_config, dataset=dataset_config, record=record_config + ) - run_exp(task_config, dataset, xargs.name, "recorder-{:02d}-{:02d}".format(irun, xargs.times), save_dir) + run_exp( + task_config, + dataset, + xargs.name, + "recorder-{:02d}-{:02d}".format(irun, xargs.times), + save_dir, + ) if __name__ == "__main__": parser = argparse.ArgumentParser("Vanilla Transformable Transformer") - parser.add_argument("--save_dir", type=str, default="./outputs/vtt-runs", help="The checkpoint directory.") - parser.add_argument("--name", type=str, default="Transformer", help="The experiment name.") + parser.add_argument( + "--save_dir", + type=str, + default="./outputs/vtt-runs", + help="The checkpoint directory.", + ) + parser.add_argument( + "--name", type=str, default="Transformer", help="The experiment name." + ) parser.add_argument("--times", type=int, default=10, help="The repeated run times.") - parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") - parser.add_argument("--market", type=str, default="all", help="The market indicator.") + parser.add_argument( + "--gpu", type=int, default=0, help="The GPU ID used for train / test." + ) + parser.add_argument( + "--market", type=str, default="all", help="The market indicator." + ) args = parser.parse_args() main(args) diff --git a/lib/layers/drop.py b/lib/layers/drop.py index f9ebef7..9be8ab0 100644 --- a/lib/layers/drop.py +++ b/lib/layers/drop.py @@ -21,149 +21,209 @@ import torch.nn.functional as F def drop_block_2d( - x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, - with_noise: bool = False, inplace: bool = False, batchwise: bool = False): - """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + x, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = False, +): + """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf - DropBlock with an experimental gaussian noise option. This layer has been tested on a few training - runs with success, but needs further validation and possibly optimization for lower runtime impact. - """ - B, C, H, W = x.shape - total_size = W * H - clipped_block_size = min(block_size, min(W, H)) - # seed_drop_rate, the gamma parameter - gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( - (W - block_size + 1) * (H - block_size + 1)) + DropBlock with an experimental gaussian noise option. This layer has been tested on a few training + runs with success, but needs further validation and possibly optimization for lower runtime impact. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + # seed_drop_rate, the gamma parameter + gamma = ( + gamma_scale + * drop_prob + * total_size + / clipped_block_size ** 2 + / ((W - block_size + 1) * (H - block_size + 1)) + ) - # Forces the block to be inside the feature map. - w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) - valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ - ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) - valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid( + torch.arange(W).to(x.device), torch.arange(H).to(x.device) + ) + valid_block = ( + (w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2) + ) & ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) - if batchwise: - # one mask for whole batch, quite a bit faster - uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) - else: - uniform_noise = torch.rand_like(x) - block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) - block_mask = -F.max_pool2d( - -block_mask, - kernel_size=clipped_block_size, # block_size, - stride=1, - padding=clipped_block_size // 2) - - if with_noise: - normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) - if inplace: - x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + if batchwise: + # one mask for whole batch, quite a bit faster + uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) else: - x = x * block_mask + normal_noise * (1 - block_mask) - else: - normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) - if inplace: - x.mul_(block_mask * normalize_scale) + uniform_noise = torch.rand_like(x) + block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2, + ) + + if with_noise: + normal_noise = ( + torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) + if batchwise + else torch.randn_like(x) + ) + if inplace: + x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + else: + x = x * block_mask + normal_noise * (1 - block_mask) else: - x = x * block_mask * normalize_scale - return x + normalize_scale = ( + block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7) + ).to(x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x def drop_block_fast_2d( - x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, - gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): - """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + x: torch.Tensor, + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = False, +): + """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf - DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid - block mask at edges. - """ - B, C, H, W = x.shape - total_size = W * H - clipped_block_size = min(block_size, min(W, H)) - gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( - (W - block_size + 1) * (H - block_size + 1)) + DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid + block mask at edges. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + gamma = ( + gamma_scale + * drop_prob + * total_size + / clipped_block_size ** 2 + / ((W - block_size + 1) * (H - block_size + 1)) + ) - if batchwise: - # one mask for whole batch, quite a bit faster - block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma - else: - # mask per batch element - block_mask = torch.rand_like(x) < gamma - block_mask = F.max_pool2d( - block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) - - if with_noise: - normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) - if inplace: - x.mul_(1. - block_mask).add_(normal_noise * block_mask) + if batchwise: + # one mask for whole batch, quite a bit faster + block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma else: - x = x * (1. - block_mask) + normal_noise * block_mask - else: - block_mask = 1 - block_mask - normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) - if inplace: - x.mul_(block_mask * normalize_scale) + # mask per batch element + block_mask = torch.rand_like(x) < gamma + block_mask = F.max_pool2d( + block_mask.to(x.dtype), + kernel_size=clipped_block_size, + stride=1, + padding=clipped_block_size // 2, + ) + + if with_noise: + normal_noise = ( + torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) + if batchwise + else torch.randn_like(x) + ) + if inplace: + x.mul_(1.0 - block_mask).add_(normal_noise * block_mask) + else: + x = x * (1.0 - block_mask) + normal_noise * block_mask else: - x = x * block_mask * normalize_scale - return x + block_mask = 1 - block_mask + normalize_scale = ( + block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7) + ).to(dtype=x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x class DropBlock2d(nn.Module): - """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf - """ - def __init__(self, - drop_prob=0.1, - block_size=7, - gamma_scale=1.0, - with_noise=False, - inplace=False, - batchwise=False, - fast=True): - super(DropBlock2d, self).__init__() - self.drop_prob = drop_prob - self.gamma_scale = gamma_scale - self.block_size = block_size - self.with_noise = with_noise - self.inplace = inplace - self.batchwise = batchwise - self.fast = fast # FIXME finish comparisons of fast vs not + """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf""" - def forward(self, x): - if not self.training or not self.drop_prob: - return x - if self.fast: - return drop_block_fast_2d( - x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) - else: - return drop_block_2d( - x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + def __init__( + self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False, + inplace=False, + batchwise=False, + fast=True, + ): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + self.inplace = inplace + self.batchwise = batchwise + self.fast = fast # FIXME finish comparisons of fast vs not + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + if self.fast: + return drop_block_fast_2d( + x, + self.drop_prob, + self.block_size, + self.gamma_scale, + self.with_noise, + self.inplace, + self.batchwise, + ) + else: + return drop_block_2d( + x, + self.drop_prob, + self.block_size, + self.gamma_scale, + self.with_noise, + self.inplace, + self.batchwise, + ) -def drop_path(x, drop_prob: float = 0., training: bool = False): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for - changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use - 'survival rate' as the argument. + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. - """ - if drop_prob == 0. or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - """ - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/lib/layers/mlp.py b/lib/layers/mlp.py index ffd3f50..00635a8 100644 --- a/lib/layers/mlp.py +++ b/lib/layers/mlp.py @@ -1,24 +1,29 @@ import torch.nn as nn from typing import Optional -class MLP(nn.Module): - # MLP: FC -> Activation -> Drop -> FC -> Drop - def __init__(self, in_features, hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer=nn.GELU, - drop: Optional[float] = None): - super(MLP, self).__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop or 0) - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x +class MLP(nn.Module): + # MLP: FC -> Activation -> Drop -> FC -> Drop + def __init__( + self, + in_features, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer=nn.GELU, + drop: Optional[float] = None, + ): + super(MLP, self).__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop or 0) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/lib/layers/positional_embedding.py b/lib/layers/positional_embedding.py index 67b6c46..f40f8c9 100644 --- a/lib/layers/positional_embedding.py +++ b/lib/layers/positional_embedding.py @@ -5,31 +5,31 @@ import torch import torch.nn as nn import math -class PositionalEncoder(nn.Module): - # Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - # https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65 - def __init__(self, d_model, max_seq_len, dropout=0.1): - super(PositionalEncoder, self).__init__() - self.d_model = d_model - # create constant 'pe' matrix with values dependant on - # pos and i - pe = torch.zeros(max_seq_len, d_model) - for pos in range(max_seq_len): - for i in range(0, d_model): - div = 10000 ** ((i // 2) * 2 / d_model) - value = pos / div - if i % 2 == 0: - pe[pos, i] = math.sin(value) - else: - pe[pos, i] = math.cos(value) - pe = pe.unsqueeze(0) - self.dropout = nn.Dropout(p=dropout) - self.register_buffer('pe', pe) - - - def forward(self, x): - batch, seq, fdim = x.shape[:3] - embeddings = self.pe[:, :seq, :fdim] - outs = self.dropout(x + embeddings) - return outs +class PositionalEncoder(nn.Module): + # Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + # https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65 + + def __init__(self, d_model, max_seq_len, dropout=0.1): + super(PositionalEncoder, self).__init__() + self.d_model = d_model + # create constant 'pe' matrix with values dependant on + # pos and i + pe = torch.zeros(max_seq_len, d_model) + for pos in range(max_seq_len): + for i in range(0, d_model): + div = 10000 ** ((i // 2) * 2 / d_model) + value = pos / div + if i % 2 == 0: + pe[pos, i] = math.sin(value) + else: + pe[pos, i] = math.cos(value) + pe = pe.unsqueeze(0) + self.dropout = nn.Dropout(p=dropout) + self.register_buffer("pe", pe) + + def forward(self, x): + batch, seq, fdim = x.shape[:3] + embeddings = self.pe[:, :seq, :fdim] + outs = self.dropout(x + embeddings) + return outs diff --git a/lib/layers/super_mlp.py b/lib/layers/super_mlp.py index bf48941..3f25ee8 100644 --- a/lib/layers/super_mlp.py +++ b/lib/layers/super_mlp.py @@ -2,24 +2,22 @@ import torch.nn as nn from torch.nn.parameter import Parameter from typing import Optional +from layers.super_module import SuperModule +from layers.super_module import SuperModule -class Linear(nn.Module): - """Applies a linear transformation to the incoming data: :math:`y = xA^T + b` - """ - __constants__ = ['in_features', 'out_features'] - in_features: int - out_features: int - weight: Tensor + +class SuperLinear(SuperModule): + """Applies a linear transformation to the incoming data: :math:`y = xA^T + b`""" def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: - super(Linear, self).__init__() + super(SuperLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: @@ -33,28 +31,33 @@ class Linear(nn.Module): return F.linear(input, self.weight, self.bias) def extra_repr(self) -> str: - return 'in_features={}, out_features={}, bias={}'.format( + return "in_features={:}, out_features={:}, bias={:}".format( self.in_features, self.out_features, self.bias is not None ) -class SuperMLP(nn.Module): - # MLP: FC -> Activation -> Drop -> FC -> Drop - def __init__(self, in_features, hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer=nn.GELU, - drop: Optional[float] = None): - super(MLP, self).__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop or 0) - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x +class SuperMLP(nn.Module): + # MLP: FC -> Activation -> Drop -> FC -> Drop + def __init__( + self, + in_features, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer=nn.GELU, + drop: Optional[float] = None, + ): + super(MLP, self).__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop or 0) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/lib/layers/super_module.py b/lib/layers/super_module.py index a5d3968..7aa7c1a 100644 --- a/lib/layers/super_module.py +++ b/lib/layers/super_module.py @@ -1,7 +1,17 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # +##################################################### + +import abc import torch.nn as nn -class SuperModule(nn.Module): - def __init__(self): - super(SuperModule, self).__init__() - +class SuperModule(abc.ABCMeta, nn.Module): + """This class equips the nn.Module class with the ability to apply AutoDL.""" + + def __init__(self): + super(SuperModule, self).__init__() + + @abc.abstractmethod + def abstract_search_space(self): + raise NotImplementedError diff --git a/lib/layers/weight_init.py b/lib/layers/weight_init.py index a70904b..478a462 100644 --- a/lib/layers/weight_init.py +++ b/lib/layers/weight_init.py @@ -5,57 +5,59 @@ import warnings def _no_grad_trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2) + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - # type: (Tensor, float, float, float, float) -> Tensor - r"""Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \leq \text{mean} \leq b`. - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - Examples: - >>> w = torch.empty(3, 5) - >>> nn.init.trunc_normal_(w) - """ - return _no_grad_trunc_normal_(tensor, mean, std, a, b) +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/lib/spaces/__init__.py b/lib/spaces/__init__.py index cb7034a..a6e0807 100644 --- a/lib/spaces/__init__.py +++ b/lib/spaces/__init__.py @@ -6,5 +6,6 @@ from .basic_space import Categorical from .basic_space import Continuous +from .basic_space import Integer from .basic_op import has_categorical from .basic_op import has_continuous diff --git a/lib/spaces/basic_space.py b/lib/spaces/basic_space.py index 61fcad0..35fdc35 100644 --- a/lib/spaces/basic_space.py +++ b/lib/spaces/basic_space.py @@ -92,6 +92,32 @@ class Categorical(Space): return sample +class Integer(Categorical): + """A space contains the integer values.""" + + def __init__(self, lower: int, upper: int, default: Optional[int] = None): + if not isinstance(lower, int) or not isinstance(upper, int): + raise ValueError( + "The lower [{:}] and uppwer [{:}] must be int.".format(lower, upper) + ) + data = list(range(lower, upper + 1)) + self._raw_lower = lower + self._raw_upper = upper + self._raw_default = default + if default is not None and (default < lower or default > upper): + raise ValueError("The default value [{:}] is out of range.".format(default)) + default = data.index(default) + super(Integer, self).__init__(*data, default=default) + + def __repr__(self): + return "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( + name=self.__class__.__name__, + lower=self._raw_lower, + upper=self._raw_upper, + default=self._raw_default, + ) + + np_float_types = (np.float16, np.float32, np.float64) np_int_types = ( np.uint8, diff --git a/tests/test_basic_space.py b/tests/test_basic_space.py index ae79df5..9ea74c7 100644 --- a/tests/test_basic_space.py +++ b/tests/test_basic_space.py @@ -13,6 +13,7 @@ if str(lib_dir) not in sys.path: from spaces import Categorical from spaces import Continuous +from spaces import Integer class TestBasicSpace(unittest.TestCase): @@ -26,6 +27,12 @@ class TestBasicSpace(unittest.TestCase): "Categorical(candidates=[1, 2, 3, 4], default_index=None)", str(space) ) + def test_integer(self): + space = Integer(lower=1, upper=4) + for i in range(4): + self.assertEqual(space[i], i + 1) + self.assertEqual("Integer(lower=1, upper=4, default=None)", str(space)) + def test_continuous(self): random.seed(999) space = Continuous(0, 1)