Add int search space

This commit is contained in:
D-X-Y
2021-03-18 16:02:55 +08:00
parent ece6ac5f41
commit 63c8bb9bc8
67 changed files with 5150 additions and 1474 deletions

View File

@@ -18,12 +18,16 @@ def check_files(save_dir, meta_file, basestr):
meta_infos = torch.load(meta_file, map_location="cpu")
meta_archs = meta_infos["archs"]
meta_num_archs = meta_infos["total"]
assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format(
meta_num_archs, len(meta_archs)
)
assert meta_num_archs == len(
meta_archs
), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr))))
print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs)))
print(
"{:} find {:} directories used to save checkpoints".format(
time_string(), len(sub_model_dirs)
)
)
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
num_seeds = defaultdict(lambda: 0)
@@ -34,21 +38,29 @@ def check_files(save_dir, meta_file, basestr):
for checkpoint in xcheckpoints:
temp_names = checkpoint.name.split("-")
assert (
len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed"
len(temp_names) == 4
and temp_names[0] == "arch"
and temp_names[2] == "seed"
), "invalid checkpoint name : {:}".format(checkpoint.name)
arch_indexes.add(temp_names[1])
subdir2archs[sub_dir] = sorted(list(arch_indexes))
num_evaluated_arch += len(arch_indexes)
# count number of seeds for each architecture
for arch_index in arch_indexes:
num_seeds[len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))] += 1
num_seeds[
len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))
] += 1
print(
"There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).".format(
num_evaluated_arch, meta_num_archs, sum(k * v for k, v in num_seeds.items())
)
)
for key in sorted(list(num_seeds.keys())):
print("There are {:5d} architectures that are evaluated {:} times.".format(num_seeds[key], key))
print(
"There are {:5d} architectures that are evaluated {:} times.".format(
num_seeds[key], key
)
)
dir2ckps, dir2ckp_exists = dict(), dict()
start_time, epoch_time = time.time(), AverageMeter()
@@ -62,12 +74,14 @@ def check_files(save_dir, meta_file, basestr):
numrs = defaultdict(lambda: 0)
all_checkpoints, all_ckp_exists = [], []
for arch_index in arch_indexes:
checkpoints = ["arch-{:}-seed-{:04d}.pth".format(arch_index, seed) for seed in seeds]
checkpoints = [
"arch-{:}-seed-{:04d}.pth".format(arch_index, seed) for seed in seeds
]
ckp_exists = [(sub_dir / x).exists() for x in checkpoints]
arch_index = int(arch_index)
assert 0 <= arch_index < len(meta_archs), "invalid arch-index {:} (not found in meta_archs)".format(
arch_index
)
assert (
0 <= arch_index < len(meta_archs)
), "invalid arch-index {:} (not found in meta_archs)".format(arch_index)
all_checkpoints += checkpoints
all_ckp_exists += ckp_exists
numrs[sum(ckp_exists)] += 1
@@ -76,7 +90,9 @@ def check_files(save_dir, meta_file, basestr):
# measure time
epoch_time.update(time.time() - start_time)
start_time = time.time()
numrstr = ", ".join(["{:}: {:03d}".format(x, numrs[x]) for x in sorted(numrs.keys())])
numrstr = ", ".join(
["{:}: {:03d}".format(x, numrs[x]) for x in sorted(numrs.keys())]
)
print(
"{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}".format(
time_string(),
@@ -95,7 +111,8 @@ def check_files(save_dir, meta_file, basestr):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NAS Benchmark 201", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="NAS Benchmark 201",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--base_save_dir",
@@ -104,9 +121,14 @@ if __name__ == "__main__":
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument(
"--meta_path", type=str, default="./output/NAS-BENCH-201-4/meta-node-4.pth", help="The meta file path."
"--meta_path",
type=str,
default="./output/NAS-BENCH-201-4/meta-node-4.pth",
help="The meta file path.",
)
parser.add_argument(
"--base_str", type=str, default="C16-N5", help="The basic string."
)
parser.add_argument("--base_str", type=str, default="C16-N5", help="The basic string.")
args = parser.parse_args()
save_dir = Path(args.base_save_dir)

View File

@@ -10,7 +10,9 @@ from setuptools import setup
def read(fname="README.md"):
with open(os.path.join(os.path.dirname(__file__), fname), encoding="utf-8") as cfile:
with open(
os.path.join(os.path.dirname(__file__), fname), encoding="utf-8"
) as cfile:
return cfile.read()

View File

@@ -76,7 +76,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode):
return losses.avg, top1.avg, top5.avg, batch_time.sum
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger):
def evaluate_for_seed(
arch_config, config, arch, train_loader, valid_loaders, seed, logger
):
prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(
@@ -94,14 +96,29 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, se
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
flop, param = get_model_infos(net, config.xshape)
logger.log("Network : {:}".format(net.get_message()), False)
logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed))
logger.log(
"{:} Seed-------------------------- {:} --------------------------".format(
time_string(), seed
)
)
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
# train and valid
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config)
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
# start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
start_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
config.epochs + config.warmup,
)
(
train_losses,
train_acc1es,
train_acc5es,
valid_losses,
valid_acc1es,
valid_acc5es,
) = ({}, {}, {}, {}, {}, {})
train_times, valid_times = {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
@@ -126,7 +143,9 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, se
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True))
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)
)
logger.log(
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]".format(
time_string(),

View File

@@ -22,7 +22,9 @@ from models import CellStructure, CellArchitectures, get_search_spaces
from functions import evaluate_for_seed
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger):
def evaluate_all_datasets(
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
):
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
all_infos = {"info": machine_info}
all_dataset_keys = []
@@ -36,27 +38,39 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c
config_path = "configs/nas-benchmark/LESS.config"
else:
config_path = "configs/nas-benchmark/CIFAR.config"
split_info = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
split_info = load_config(
"configs/nas-benchmark/cifar-split.txt", None, None
)
elif dataset.startswith("ImageNet16"):
if use_less:
config_path = "configs/nas-benchmark/LESS.config"
else:
config_path = "configs/nas-benchmark/ImageNet-16.config"
split_info = load_config("configs/nas-benchmark/{:}-split.txt".format(dataset), None, None)
split_info = load_config(
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
)
else:
raise ValueError("invalid dataset : {:}".format(dataset))
config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger)
config = load_config(
config_path, {"class_num": class_num, "xshape": xshape}, logger
)
# check whether use splited validation set
if bool(split):
assert dataset == "cifar10"
ValLoaders = {
"ori-test": torch.utils.data.DataLoader(
valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True
valid_data,
batch_size=config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
}
assert len(train_data) == len(split_info.train) + len(
split_info.valid
), "invalid length : {:} vs {:} + {:}".format(len(train_data), len(split_info.train), len(split_info.valid))
), "invalid length : {:} vs {:} + {:}".format(
len(train_data), len(split_info.train), len(split_info.valid)
)
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
@@ -79,47 +93,67 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c
else:
# data loader
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
train_data,
batch_size=config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True
valid_data,
batch_size=config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
if dataset == "cifar10":
ValLoaders = {"ori-test": valid_loader}
elif dataset == "cifar100":
cifar100_splits = load_config("configs/nas-benchmark/cifar100-test-split.txt", None, None)
cifar100_splits = load_config(
"configs/nas-benchmark/cifar100-test-split.txt", None, None
)
ValLoaders = {
"ori-test": valid_loader,
"x-valid": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_splits.xvalid
),
num_workers=workers,
pin_memory=True,
),
"x-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_splits.xtest
),
num_workers=workers,
pin_memory=True,
),
}
elif dataset == "ImageNet16-120":
imagenet16_splits = load_config("configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None)
imagenet16_splits = load_config(
"configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None
)
ValLoaders = {
"ori-test": valid_loader,
"x-valid": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
imagenet16_splits.xvalid
),
num_workers=workers,
pin_memory=True,
),
"x-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest),
sampler=torch.utils.data.sampler.SubsetRandomSampler(
imagenet16_splits.xtest
),
num_workers=workers,
pin_memory=True,
),
@@ -132,13 +166,24 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c
dataset_key = dataset_key + "-valid"
logger.log(
"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size
dataset_key,
len(train_data),
len(valid_data),
len(train_loader),
len(valid_loader),
config.batch_size,
)
)
logger.log("Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config))
logger.log(
"Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)
)
for key, value in ValLoaders.items():
logger.log("Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)))
results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger)
logger.log(
"Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))
)
results = evaluate_for_seed(
arch_config, config, arch, train_loader, ValLoaders, seed, logger
)
all_infos[dataset_key] = results
all_dataset_keys.append(dataset_key)
all_infos["all_dataset_keys"] = all_dataset_keys
@@ -146,7 +191,18 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c
def main(
save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config
save_dir,
workers,
datasets,
xpaths,
splits,
use_less,
srange,
arch_index,
seeds,
cover_mode,
meta_info,
arch_config,
):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
@@ -154,7 +210,9 @@ def main(
torch.backends.cudnn.deterministic = True
torch.set_num_threads(workers)
assert len(srange) == 2 and 0 <= srange[0] <= srange[1], "invalid srange : {:}".format(srange)
assert (
len(srange) == 2 and 0 <= srange[0] <= srange[1]
), "invalid srange : {:}".format(srange)
if use_less:
sub_dir = Path(save_dir) / "{:06d}-{:06d}-C{:}-N{:}-LESS".format(
@@ -170,9 +228,9 @@ def main(
assert srange[1] < meta_info["total"], "invalid range : {:}-{:} vs. {:}".format(
srange[0], srange[1], meta_info["total"]
)
assert arch_index == -1 or srange[0] <= arch_index <= srange[1], "invalid range : {:} vs. {:} vs. {:}".format(
srange[0], arch_index, srange[1]
)
assert (
arch_index == -1 or srange[0] <= arch_index <= srange[1]
), "invalid range : {:} vs. {:} vs. {:}".format(srange[0], arch_index, srange[1])
if arch_index == -1:
to_evaluate_indexes = list(range(srange[0], srange[1] + 1))
else:
@@ -200,7 +258,13 @@ def main(
arch = all_archs[index]
logger.log(
"\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}".format(
"-" * 15, i, len(to_evaluate_indexes), index, meta_info["total"], seeds, "-" * 15
"-" * 15,
i,
len(to_evaluate_indexes),
index,
meta_info["total"],
seeds,
"-" * 15,
)
)
# logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15))
@@ -212,10 +276,18 @@ def main(
to_save_name = sub_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed)
if to_save_name.exists():
if cover_mode:
logger.log("Find existing file : {:}, remove it before evaluation".format(to_save_name))
logger.log(
"Find existing file : {:}, remove it before evaluation".format(
to_save_name
)
)
os.remove(str(to_save_name))
else:
logger.log("Find existing file : {:}, skip this evaluation".format(to_save_name))
logger.log(
"Find existing file : {:}, skip this evaluation".format(
to_save_name
)
)
has_continue = True
continue
results = evaluate_all_datasets(
@@ -232,7 +304,13 @@ def main(
torch.save(results, to_save_name)
logger.log(
"{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}".format(
"-" * 15, i, len(to_evaluate_indexes), index, meta_info["total"], seed, to_save_name
"-" * 15,
i,
len(to_evaluate_indexes),
index,
meta_info["total"],
seed,
to_save_name,
)
)
# measure elapsed time
@@ -242,7 +320,9 @@ def main(
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True)
)
logger.log("This arch costs : {:}".format(convert_secs2time(epoch_time.val, True)))
logger.log(
"This arch costs : {:}".format(convert_secs2time(epoch_time.val, True))
)
logger.log("{:}".format("*" * 100))
logger.log(
"{:} {:74s} {:}".format(
@@ -258,7 +338,9 @@ def main(
logger.close()
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config):
def train_single_model(
save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config
):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
@@ -269,19 +351,32 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se
Path(save_dir)
/ "specifics"
/ "{:}-{:}-{:}-{:}".format(
"LESS" if use_less else "FULL", model_str, arch_config["channel"], arch_config["num_cells"]
"LESS" if use_less else "FULL",
model_str,
arch_config["channel"],
arch_config["num_cells"],
)
)
logger = Logger(str(save_dir), 0, False)
if model_str in CellArchitectures:
arch = CellArchitectures[model_str]
logger.log("The model string is found in pre-defined architecture dict : {:}".format(model_str))
logger.log(
"The model string is found in pre-defined architecture dict : {:}".format(
model_str
)
)
else:
try:
arch = CellStructure.str2structure(model_str)
except:
raise ValueError("Invalid model string : {:}. It can not be found or parsed.".format(model_str))
assert arch.check_valid_op(get_search_spaces("cell", "full")), "{:} has the invalid op.".format(arch)
raise ValueError(
"Invalid model string : {:}. It can not be found or parsed.".format(
model_str
)
)
assert arch.check_valid_op(
get_search_spaces("cell", "full")
), "{:} has the invalid op.".format(arch)
logger.log("Start train-evaluate {:}".format(arch.tostr()))
logger.log("arch_config : {:}".format(arch_config))
@@ -294,27 +389,55 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se
)
to_save_name = save_dir / "seed-{:04d}.pth".format(seed)
if to_save_name.exists():
logger.log("Find the existing file {:}, directly load!".format(to_save_name))
logger.log(
"Find the existing file {:}, directly load!".format(to_save_name)
)
checkpoint = torch.load(to_save_name)
else:
logger.log("Does not find the existing file {:}, train and evaluate!".format(to_save_name))
logger.log(
"Does not find the existing file {:}, train and evaluate!".format(
to_save_name
)
)
checkpoint = evaluate_all_datasets(
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
arch,
datasets,
xpaths,
splits,
use_less,
seed,
arch_config,
workers,
logger,
)
torch.save(checkpoint, to_save_name)
# log information
logger.log("{:}".format(checkpoint["info"]))
all_dataset_keys = checkpoint["all_dataset_keys"]
for dataset_key in all_dataset_keys:
logger.log("\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15))
logger.log(
"\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15)
)
dataset_info = checkpoint[dataset_key]
# logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
logger.log("Flops = {:} MB, Params = {:} MB".format(dataset_info["flop"], dataset_info["param"]))
logger.log(
"Flops = {:} MB, Params = {:} MB".format(
dataset_info["flop"], dataset_info["param"]
)
)
logger.log("config : {:}".format(dataset_info["config"]))
logger.log("Training State (finish) = {:}".format(dataset_info["finish-train"]))
logger.log(
"Training State (finish) = {:}".format(dataset_info["finish-train"])
)
last_epoch = dataset_info["total_epoch"] - 1
train_acc1es, train_acc5es = dataset_info["train_acc1es"], dataset_info["train_acc5es"]
valid_acc1es, valid_acc5es = dataset_info["valid_acc1es"], dataset_info["valid_acc5es"]
train_acc1es, train_acc5es = (
dataset_info["train_acc1es"],
dataset_info["train_acc5es"],
)
valid_acc1es, valid_acc5es = (
dataset_info["valid_acc1es"],
dataset_info["valid_acc5es"],
)
logger.log(
"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format(
train_acc1es[last_epoch],
@@ -328,7 +451,9 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se
# measure elapsed time
seed_time.update(time.time() - start_time)
start_time = time.time()
need_time = "Time Left: {:}".format(convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True))
need_time = "Time Left: {:}".format(
convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)
)
logger.log(
"\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}".format(
_is, len(seeds), seed, need_time
@@ -340,7 +465,11 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se
def generate_meta_info(save_dir, max_node, divide=40):
aa_nas_bench_ss = get_search_spaces("cell", "nas-bench-201")
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)))
print(
"There are {:} archs vs {:}.".format(
len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2)
)
)
random.seed(88) # please do not change this line for reproducibility
random.shuffle(archs)
@@ -352,10 +481,12 @@ def generate_meta_info(save_dir, max_node, divide=40):
== "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|"
), "please check the 0-th architecture : {:}".format(archs[0])
assert (
archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
archs[9].tostr()
== "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|"
), "please check the 9-th architecture : {:}".format(archs[9])
assert (
archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
archs[123].tostr()
== "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|"
), "please check the 123-th architecture : {:}".format(archs[123])
total_arch = len(archs)
@@ -374,11 +505,21 @@ def generate_meta_info(save_dir, max_node, divide=40):
and valid_split[10] == 18
and valid_split[111] == 242
), "{:} {:} {:} - {:} {:} {:}".format(
train_split[0], train_split[10], train_split[111], valid_split[0], valid_split[10], valid_split[111]
train_split[0],
train_split[10],
train_split[111],
valid_split[0],
valid_split[10],
valid_split[111],
)
splits = {num: {"train": train_split, "valid": valid_split}}
info = {"archs": [x.tostr() for x in archs], "total": total_arch, "max_node": max_node, "splits": splits}
info = {
"archs": [x.tostr() for x in archs],
"total": total_arch,
"max_node": max_node,
"splits": splits,
}
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
@@ -404,7 +545,11 @@ def generate_meta_info(save_dir, max_node, divide=40):
start, xend - 1
)
)
print("save the training script into {:} and {:}".format(script_name_full, script_name_less))
print(
"save the training script into {:} and {:}".format(
script_name_full, script_name_less
)
)
full_file.close()
less_file.close()
@@ -425,29 +570,56 @@ if __name__ == "__main__":
# mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()]
# parser = argparse.ArgumentParser(description='Algorithm-Agnostic NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description="NAS-Bench-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="NAS-Bench-201",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--mode", type=str, required=True, help="The script mode.")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument("--max_node", type=int, help="The maximum node in a cell.")
# use for train the model
parser.add_argument("--workers", type=int, default=8, help="number of data loading workers (default: 2)")
parser.add_argument("--srange", type=int, nargs="+", help="The range of models to be evaluated")
parser.add_argument(
"--arch_index", type=int, default=-1, help="The architecture index to be evaluated (cover mode)."
"--workers",
type=int,
default=8,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--srange", type=int, nargs="+", help="The range of models to be evaluated"
)
parser.add_argument(
"--arch_index",
type=int,
default=-1,
help="The architecture index to be evaluated (cover mode).",
)
parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.")
parser.add_argument("--xpaths", type=str, nargs="+", help="The root path for this dataset.")
parser.add_argument("--splits", type=int, nargs="+", help="The root path for this dataset.")
parser.add_argument("--use_less", type=int, default=0, choices=[0, 1], help="Using the less-training-epoch config.")
parser.add_argument("--seeds", type=int, nargs="+", help="The range of models to be evaluated")
parser.add_argument(
"--xpaths", type=str, nargs="+", help="The root path for this dataset."
)
parser.add_argument(
"--splits", type=int, nargs="+", help="The root path for this dataset."
)
parser.add_argument(
"--use_less",
type=int,
default=0,
choices=[0, 1],
help="Using the less-training-epoch config.",
)
parser.add_argument(
"--seeds", type=int, nargs="+", help="The range of models to be evaluated"
)
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
args = parser.parse_args()
assert args.mode in ["meta", "new", "cover"] or args.mode.startswith("specific-"), "invalid mode : {:}".format(
args.mode
)
assert args.mode in ["meta", "new", "cover"] or args.mode.startswith(
"specific-"
), "invalid mode : {:}".format(args.mode)
if args.mode == "meta":
generate_meta_info(args.save_dir, args.max_node)
@@ -470,11 +642,15 @@ if __name__ == "__main__":
assert meta_path.exists(), "{:} does not exist.".format(meta_path)
meta_info = torch.load(meta_path)
# check whether args is ok
assert len(args.srange) == 2 and args.srange[0] <= args.srange[1], "invalid length of srange args: {:}".format(
args.srange
assert (
len(args.srange) == 2 and args.srange[0] <= args.srange[1]
), "invalid length of srange args: {:}".format(args.srange)
assert len(args.seeds) > 0, "invalid length of seeds args: {:}".format(
args.seeds
)
assert len(args.seeds) > 0, "invalid length of seeds args: {:}".format(args.seeds)
assert len(args.datasets) == len(args.xpaths) == len(args.splits), "invalid infos : {:} vs {:} vs {:}".format(
assert (
len(args.datasets) == len(args.xpaths) == len(args.splits)
), "invalid infos : {:} vs {:} vs {:}".format(
len(args.datasets), len(args.xpaths), len(args.splits)
)
assert args.workers > 0, "invalid number of workers : {:}".format(args.workers)

View File

@@ -13,7 +13,12 @@ from nas_201_api import NASBench201API as API
if __name__ == "__main__":
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.")
parser.add_argument(
"--api_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 benchmark file.",
)
args = parser.parse_args()
meta_file = Path(args.api_path)

View File

@@ -19,7 +19,9 @@ from models import CellStructure, get_cell_based_tiny_net
from nas_201_api import NASBench201API, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
api = NASBench201API("{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"]))
api = NASBench201API(
"{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"])
)
def create_result_count(
@@ -55,33 +57,56 @@ def create_result_count(
network.load_state_dict(xresult.get_net_param())
if "train_times" in results: # new version
xresult.update_train_info(
results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"]
results["train_acc1es"],
results["train_acc5es"],
results["train_losses"],
results["train_times"],
)
xresult.update_eval(
results["valid_acc1es"], results["valid_losses"], results["valid_times"]
)
xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"])
else:
if dataset == "cifar10-valid":
xresult.update_OLD_eval("x-valid", results["valid_acc1es"], results["valid_losses"])
xresult.update_OLD_eval(
"x-valid", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda()
)
xresult.update_OLD_eval("ori-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss})
xresult.update_OLD_eval(
"ori-test",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
xresult.update_latency(latencies)
elif dataset == "cifar10":
xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"])
xresult.update_OLD_eval(
"ori-test", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda()
)
xresult.update_latency(latencies)
elif dataset == "cifar100" or dataset == "ImageNet16-120":
xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"])
xresult.update_OLD_eval(
"ori-test", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda()
)
xresult.update_OLD_eval("x-valid", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss})
xresult.update_OLD_eval(
"x-valid",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda()
)
xresult.update_OLD_eval("x-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss})
xresult.update_OLD_eval(
"x-test",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
xresult.update_latency(latencies)
else:
raise ValueError("invalid dataset name : {:}".format(dataset))
@@ -89,7 +114,11 @@ def create_result_count(
def account_one_arch(
arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text], dataloader_dict: Dict[Text, Any]
arch_index: int,
arch_str: Text,
checkpoints: List[Text],
datasets: List[Text],
dataloader_dict: Dict[Text, Any],
) -> ArchResults:
information = ArchResults(arch_index, arch_str)
@@ -99,12 +128,18 @@ def account_one_arch(
ok_dataset = 0
for dataset in datasets:
if dataset not in checkpoint:
print("Can not find {:} in arch-{:} from {:}".format(dataset, arch_index, checkpoint_path))
print(
"Can not find {:} in arch-{:} from {:}".format(
dataset, arch_index, checkpoint_path
)
)
continue
else:
ok_dataset += 1
results = checkpoint[dataset]
assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
assert results[
"finish-train"
], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
arch_index, used_seed, dataset, checkpoint_path
)
arch_config = {
@@ -114,17 +149,22 @@ def account_one_arch(
"class_num": results["config"]["class_num"],
}
xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict)
xresult = create_result_count(
used_seed, dataset, arch_config, results, dataloader_dict
)
information.update(dataset, int(used_seed), xresult)
if ok_dataset == 0:
raise ValueError("{:} does not find any data".format(checkpoint_path))
return information
def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch_info_less: ArchResults):
def correct_time_related_info(
arch_index: int, arch_info_full: ArchResults, arch_info_less: ArchResults
):
# calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth
cifar010_latency = (
api.get_latency(arch_index, "cifar10-valid", hp="200") + api.get_latency(arch_index, "cifar10", hp="200")
api.get_latency(arch_index, "cifar10-valid", hp="200")
+ api.get_latency(arch_index, "cifar10", hp="200")
) / 2
arch_info_full.reset_latency("cifar10-valid", None, cifar010_latency)
arch_info_full.reset_latency("cifar10", None, cifar010_latency)
@@ -139,7 +179,9 @@ def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch
arch_info_full.reset_latency("ImageNet16-120", None, image_latency)
arch_info_less.reset_latency("ImageNet16-120", None, image_latency)
train_per_epoch_time = list(arch_info_less.query("cifar10-valid", 777).train_times.values())
train_per_epoch_time = list(
arch_info_less.query("cifar10-valid", 777).train_times.values()
)
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time, eval_x_valid_time = [], []
for key, value in arch_info_less.query("cifar10-valid", 777).eval_times.items():
@@ -149,7 +191,9 @@ def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch
eval_x_valid_time.append(value)
else:
raise ValueError("-- {:} --".format(key))
eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time))
eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(
np.mean(eval_x_valid_time)
)
nums = {
"ImageNet16-120-train": 151700,
"ImageNet16-120-valid": 3000,
@@ -162,36 +206,72 @@ def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch
"cifar100-test": 10000,
"cifar100-valid": 5000,
}
eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (nums["cifar10-valid-valid"] + nums["cifar10-test"])
eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (
nums["cifar10-valid-valid"] + nums["cifar10-test"]
)
for arch_info in [arch_info_less, arch_info_full]:
arch_info.reset_pseudo_train_times(
"cifar10-valid", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-valid-train"]
"cifar10-valid",
None,
train_per_epoch_time
/ nums["cifar10-valid-train"]
* nums["cifar10-valid-train"],
)
arch_info.reset_pseudo_train_times(
"cifar10", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"]
"cifar10",
None,
train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"],
)
arch_info.reset_pseudo_train_times(
"cifar100", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"]
"cifar100",
None,
train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"],
)
arch_info.reset_pseudo_train_times(
"ImageNet16-120", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["ImageNet16-120-train"]
"ImageNet16-120",
None,
train_per_epoch_time
/ nums["cifar10-valid-train"]
* nums["ImageNet16-120-train"],
)
arch_info.reset_pseudo_eval_times(
"cifar10-valid", None, "x-valid", eval_per_sample * nums["cifar10-valid-valid"]
)
arch_info.reset_pseudo_eval_times("cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"])
arch_info.reset_pseudo_eval_times("cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"])
arch_info.reset_pseudo_eval_times("cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"])
arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"])
arch_info.reset_pseudo_eval_times("cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"])
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "x-valid", eval_per_sample * nums["ImageNet16-120-valid"]
"cifar10-valid",
None,
"x-valid",
eval_per_sample * nums["cifar10-valid-valid"],
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "x-test", eval_per_sample * nums["ImageNet16-120-valid"]
"cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"]
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120", None, "ori-test", eval_per_sample * nums["ImageNet16-120-test"]
"cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"]
)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"]
)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"]
)
arch_info.reset_pseudo_eval_times(
"cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"]
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120",
None,
"x-valid",
eval_per_sample * nums["ImageNet16-120-valid"],
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120",
None,
"x-test",
eval_per_sample * nums["ImageNet16-120-valid"],
)
arch_info.reset_pseudo_eval_times(
"ImageNet16-120",
None,
"ori-test",
eval_per_sample * nums["ImageNet16-120-test"],
)
# arch_info_full.debug_test()
# arch_info_less.debug_test()
@@ -202,12 +282,16 @@ def simplify(save_dir, meta_file, basestr, target_dir):
meta_infos = torch.load(meta_file, map_location="cpu")
meta_archs = meta_infos["archs"] # a list of architecture strings
meta_num_archs = meta_infos["total"]
assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format(
meta_num_archs, len(meta_archs)
)
assert meta_num_archs == len(
meta_archs
), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr))))
print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs)))
print(
"{:} find {:} directories used to save checkpoints".format(
time_string(), len(sub_model_dirs)
)
)
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
num_seeds = defaultdict(lambda: 0)
@@ -217,14 +301,18 @@ def simplify(save_dir, meta_file, basestr, target_dir):
for checkpoint in xcheckpoints:
temp_names = checkpoint.name.split("-")
assert (
len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed"
len(temp_names) == 4
and temp_names[0] == "arch"
and temp_names[2] == "seed"
), "invalid checkpoint name : {:}".format(checkpoint.name)
arch_indexes.add(temp_names[1])
subdir2archs[sub_dir] = sorted(list(arch_indexes))
num_evaluated_arch += len(arch_indexes)
# count number of seeds for each architecture
for arch_index in arch_indexes:
num_seeds[len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))] += 1
num_seeds[
len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))
] += 1
print(
"{:} There are {:5d} architectures that have been evaluated ({:} in total).".format(
time_string(), num_evaluated_arch, meta_num_archs
@@ -232,7 +320,9 @@ def simplify(save_dir, meta_file, basestr, target_dir):
)
for key in sorted(list(num_seeds.keys())):
print(
"{:} There are {:5d} architectures that are evaluated {:} times.".format(time_string(), num_seeds[key], key)
"{:} There are {:5d} architectures that are evaluated {:} times.".format(
time_string(), num_seeds[key], key
)
)
dataloader_dict = get_nas_bench_loaders(6)
@@ -243,8 +333,15 @@ def simplify(save_dir, meta_file, basestr, target_dir):
if not to_save_allarc.exists():
to_save_allarc.mkdir(parents=True, exist_ok=True)
assert (save_dir / target_dir) in subdir2archs, "can not find {:}".format(target_dir)
arch2infos, datasets = {}, ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120")
assert (save_dir / target_dir) in subdir2archs, "can not find {:}".format(
target_dir
)
arch2infos, datasets = {}, (
"cifar10-valid",
"cifar10",
"cifar100",
"ImageNet16-120",
)
evaluated_indexes = set()
target_full_dir = save_dir / target_dir
target_less_dir = save_dir / "{:}-LESS".format(target_dir)
@@ -253,30 +350,46 @@ def simplify(save_dir, meta_file, basestr, target_dir):
end_time = time.time()
arch_time = AverageMeter()
for idx, arch_index in enumerate(arch_indexes):
checkpoints = list(target_full_dir.glob("arch-{:}-seed-*.pth".format(arch_index)))
checkpoints = list(
target_full_dir.glob("arch-{:}-seed-*.pth".format(arch_index))
)
ckps_less = list(target_less_dir.glob("arch-{:}-seed-*.pth".format(arch_index)))
# create the arch info for each architecture
try:
arch_info_full = account_one_arch(
arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict
arch_index,
meta_archs[int(arch_index)],
checkpoints,
datasets,
dataloader_dict,
)
arch_info_less = account_one_arch(
arch_index, meta_archs[int(arch_index)], ckps_less, datasets, dataloader_dict
arch_index,
meta_archs[int(arch_index)],
ckps_less,
datasets,
dataloader_dict,
)
num_seeds[len(checkpoints)] += 1
except:
print("Loading {:} failed, : {:}".format(arch_index, checkpoints))
continue
assert int(arch_index) not in evaluated_indexes, "conflict arch-index : {:}".format(arch_index)
assert 0 <= int(arch_index) < len(meta_archs), "invalid arch-index {:} (not found in meta_archs)".format(
arch_index
)
assert (
int(arch_index) not in evaluated_indexes
), "conflict arch-index : {:}".format(arch_index)
assert (
0 <= int(arch_index) < len(meta_archs)
), "invalid arch-index {:} (not found in meta_archs)".format(arch_index)
arch_info = {"full": arch_info_full, "less": arch_info_less}
evaluated_indexes.add(int(arch_index))
arch2infos[int(arch_index)] = arch_info
# to correct the latency and training_time info.
arch_info_full, arch_info_less = correct_time_related_info(int(arch_index), arch_info_full, arch_info_less)
to_save_data = OrderedDict(full=arch_info_full.state_dict(), less=arch_info_less.state_dict())
arch_info_full, arch_info_less = correct_time_related_info(
int(arch_index), arch_info_full, arch_info_less
)
to_save_data = OrderedDict(
full=arch_info_full.state_dict(), less=arch_info_less.state_dict()
)
torch.save(to_save_data, to_save_allarc / "{:}-FULL.pth".format(arch_index))
arch_info["full"].clear_params()
arch_info["less"].clear_params()
@@ -284,14 +397,19 @@ def simplify(save_dir, meta_file, basestr, target_dir):
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = "{:}".format(convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1), True))
need_time = "{:}".format(
convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1), True)
)
print(
"{:} {:} [{:03d}/{:03d}] : {:} still need {:}".format(
time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time
)
)
# measure time
xstrs = ["{:}:{:03d}".format(key, num_seeds[key]) for key in sorted(list(num_seeds.keys()))]
xstrs = [
"{:}:{:03d}".format(key, num_seeds[key])
for key in sorted(list(num_seeds.keys()))
]
print("{:} {:} done : {:}".format(time_string(), target_dir, xstrs))
final_infos = {
"meta_archs": meta_archs,
@@ -303,7 +421,9 @@ def simplify(save_dir, meta_file, basestr, target_dir):
save_file_name = to_save_simply / "{:}.pth".format(target_dir)
torch.save(final_infos, save_file_name)
print(
"Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name)
"Save {:} / {:} architecture results into {:}.".format(
len(evaluated_indexes), meta_num_archs, save_file_name
)
)
@@ -311,12 +431,16 @@ def merge_all(save_dir, meta_file, basestr):
meta_infos = torch.load(meta_file, map_location="cpu")
meta_archs = meta_infos["archs"]
meta_num_archs = meta_infos["total"]
assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format(
meta_num_archs, len(meta_archs)
)
assert meta_num_archs == len(
meta_archs
), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr))))
print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs)))
print(
"{:} find {:} directories used to save checkpoints".format(
time_string(), len(sub_model_dirs)
)
)
for index, sub_dir in enumerate(sub_model_dirs):
arch_info_files = sorted(list(sub_dir.glob("arch-*-seed-*.pth")))
print(
@@ -330,11 +454,16 @@ def merge_all(save_dir, meta_file, basestr):
ckp_path = sub_dir.parent / "simplifies" / "{:}.pth".format(sub_dir.name)
if ckp_path.exists():
sub_ckps = torch.load(ckp_path, map_location="cpu")
assert sub_ckps["total_archs"] == meta_num_archs and sub_ckps["basestr"] == basestr
assert (
sub_ckps["total_archs"] == meta_num_archs
and sub_ckps["basestr"] == basestr
)
xarch2infos = sub_ckps["arch2infos"]
xevalindexs = sub_ckps["evaluated_indexes"]
for eval_index in xevalindexs:
assert eval_index not in evaluated_indexes and eval_index not in arch2infos
assert (
eval_index not in evaluated_indexes and eval_index not in arch2infos
)
# arch2infos[eval_index] = xarch2infos[eval_index].state_dict()
arch2infos[eval_index] = {
"full": xarch2infos[eval_index]["full"].state_dict(),
@@ -351,7 +480,11 @@ def merge_all(save_dir, meta_file, basestr):
# print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path))
evaluated_indexes = sorted(list(evaluated_indexes))
print("Finally, there are {:} architectures that have been trained and evaluated.".format(len(evaluated_indexes)))
print(
"Finally, there are {:} architectures that have been trained and evaluated.".format(
len(evaluated_indexes)
)
)
to_save_simply = save_dir / "simplifies"
if not to_save_simply.exists():
@@ -365,16 +498,24 @@ def merge_all(save_dir, meta_file, basestr):
save_file_name = to_save_simply / "{:}-final-infos.pth".format(basestr)
torch.save(final_infos, save_file_name)
print(
"Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name)
"Save {:} / {:} architecture results into {:}.".format(
len(evaluated_indexes), meta_num_archs, save_file_name
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NAS-BENCH-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="NAS-BENCH-201",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--mode",
type=str,
choices=["cal", "merge"],
help="The running mode for this script.",
)
parser.add_argument("--mode", type=str, choices=["cal", "merge"], help="The running mode for this script.")
parser.add_argument(
"--base_save_dir",
type=str,
@@ -382,16 +523,26 @@ if __name__ == "__main__":
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument("--target_dir", type=str, help="The target directory.")
parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.")
parser.add_argument("--channel", type=int, default=16, help="The number of channels.")
parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.")
parser.add_argument(
"--max_node", type=int, default=4, help="The maximum node in a cell."
)
parser.add_argument(
"--channel", type=int, default=16, help="The number of channels."
)
parser.add_argument(
"--num_cells", type=int, default=5, help="The number of cells in one stage."
)
args = parser.parse_args()
save_dir = Path(args.base_save_dir)
meta_path = save_dir / "meta-node-{:}.pth".format(args.max_node)
assert save_dir.exists(), "invalid save dir path : {:}".format(save_dir)
assert meta_path.exists(), "invalid saved meta path : {:}".format(meta_path)
print("start the statistics of our nas-benchmark from {:} using {:}.".format(save_dir, args.target_dir))
print(
"start the statistics of our nas-benchmark from {:} using {:}.".format(
save_dir, args.target_dir
)
)
basestr = "C{:}-N{:}".format(args.channel, args.num_cells)
if args.mode == "cal":

View File

@@ -48,33 +48,56 @@ def create_result_count(used_seed, dataset, arch_config, results, dataloader_dic
network.load_state_dict(xresult.get_net_param())
if "train_times" in results: # new version
xresult.update_train_info(
results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"]
results["train_acc1es"],
results["train_acc5es"],
results["train_losses"],
results["train_times"],
)
xresult.update_eval(
results["valid_acc1es"], results["valid_losses"], results["valid_times"]
)
xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"])
else:
if dataset == "cifar10-valid":
xresult.update_OLD_eval("x-valid", results["valid_acc1es"], results["valid_losses"])
xresult.update_OLD_eval(
"x-valid", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda()
)
xresult.update_OLD_eval("ori-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss})
xresult.update_OLD_eval(
"ori-test",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
xresult.update_latency(latencies)
elif dataset == "cifar10":
xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"])
xresult.update_OLD_eval(
"ori-test", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda()
)
xresult.update_latency(latencies)
elif dataset == "cifar100" or dataset == "ImageNet16-120":
xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"])
xresult.update_OLD_eval(
"ori-test", results["valid_acc1es"], results["valid_losses"]
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda()
)
xresult.update_OLD_eval("x-valid", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss})
xresult.update_OLD_eval(
"x-valid",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
loss, top1, top5, latencies = pure_evaluate(
dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda()
)
xresult.update_OLD_eval("x-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss})
xresult.update_OLD_eval(
"x-test",
{results["total_epoch"] - 1: top1},
{results["total_epoch"] - 1: loss},
)
xresult.update_latency(latencies)
else:
raise ValueError("invalid dataset name : {:}".format(dataset))
@@ -88,11 +111,15 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic
checkpoint = torch.load(checkpoint_path, map_location="cpu")
used_seed = checkpoint_path.name.split("-")[-1].split(".")[0]
for dataset in datasets:
assert dataset in checkpoint, "Can not find {:} in arch-{:} from {:}".format(
assert (
dataset in checkpoint
), "Can not find {:} in arch-{:} from {:}".format(
dataset, arch_index, checkpoint_path
)
results = checkpoint[dataset]
assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
assert results[
"finish-train"
], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
arch_index, used_seed, dataset, checkpoint_path
)
arch_config = {
@@ -102,7 +129,9 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic
"class_num": results["config"]["class_num"],
}
xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict)
xresult = create_result_count(
used_seed, dataset, arch_config, results, dataloader_dict
)
information.update(dataset, int(used_seed), xresult)
return information
@@ -118,14 +147,29 @@ def GET_DataLoaders(workers):
cifar_config = load_config(cifar_config_path, None, None)
print("{:} Create data-loader for all datasets".format(time_string()))
print("-" * 200)
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets("cifar10", str(torch_dir / "cifar.python"), -1)
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets(
"cifar10", str(torch_dir / "cifar.python"), -1
)
print(
"original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num
)
)
cifar10_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None)
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [
cifar10_splits = load_config(
root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None
)
assert cifar10_splits.train[:10] == [
0,
5,
7,
11,
13,
15,
16,
17,
20,
24,
] and cifar10_splits.valid[:10] == [
1,
2,
3,
@@ -141,7 +185,11 @@ def GET_DataLoaders(workers):
temp_dataset.transform = VALID_CIFAR10.transform
# data loader
trainval_cifar10_loader = torch.utils.data.DataLoader(
TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
TRAIN_CIFAR10,
batch_size=cifar_config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
train_cifar10_loader = torch.utils.data.DataLoader(
TRAIN_CIFAR10,
@@ -158,7 +206,11 @@ def GET_DataLoaders(workers):
pin_memory=True,
)
test__cifar10_loader = torch.utils.data.DataLoader(
VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True
VALID_CIFAR10,
batch_size=cifar_config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
print(
"CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format(
@@ -182,14 +234,29 @@ def GET_DataLoaders(workers):
)
print("-" * 200)
# CIFAR-100
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets("cifar100", str(torch_dir / "cifar.python"), -1)
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets(
"cifar100", str(torch_dir / "cifar.python"), -1
)
print(
"original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num
)
)
cifar100_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None)
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [
cifar100_splits = load_config(
root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None
)
assert cifar100_splits.xvalid[:10] == [
1,
3,
4,
5,
8,
10,
13,
14,
15,
16,
] and cifar100_splits.xtest[:10] == [
0,
2,
6,
@@ -202,7 +269,11 @@ def GET_DataLoaders(workers):
24,
]
train_cifar100_loader = torch.utils.data.DataLoader(
TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
TRAIN_CIFAR100,
batch_size=cifar_config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
valid_cifar100_loader = torch.utils.data.DataLoader(
VALID_CIFAR100,
@@ -218,9 +289,15 @@ def GET_DataLoaders(workers):
num_workers=workers,
pin_memory=True,
)
print("CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader)))
print("CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader)))
print("CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader)))
print(
"CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader))
)
print(
"CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))
)
print(
"CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader))
)
print("-" * 200)
imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config"
@@ -233,8 +310,23 @@ def GET_DataLoaders(workers):
len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num
)
)
imagenet_splits = load_config(root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", None, None)
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [
imagenet_splits = load_config(
root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt",
None,
None,
)
assert imagenet_splits.xvalid[:10] == [
1,
2,
3,
6,
7,
8,
9,
12,
16,
18,
] and imagenet_splits.xtest[:10] == [
0,
4,
5,
@@ -304,12 +396,16 @@ def simplify(save_dir, meta_file, basestr, target_dir):
meta_archs = meta_infos["archs"] # a list of architecture strings
meta_num_archs = meta_infos["total"]
meta_max_node = meta_infos["max_node"]
assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format(
meta_num_archs, len(meta_archs)
)
assert meta_num_archs == len(
meta_archs
), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr))))
print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs)))
print(
"{:} find {:} directories used to save checkpoints".format(
time_string(), len(sub_model_dirs)
)
)
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
num_seeds = defaultdict(lambda: 0)
@@ -319,14 +415,18 @@ def simplify(save_dir, meta_file, basestr, target_dir):
for checkpoint in xcheckpoints:
temp_names = checkpoint.name.split("-")
assert (
len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed"
len(temp_names) == 4
and temp_names[0] == "arch"
and temp_names[2] == "seed"
), "invalid checkpoint name : {:}".format(checkpoint.name)
arch_indexes.add(temp_names[1])
subdir2archs[sub_dir] = sorted(list(arch_indexes))
num_evaluated_arch += len(arch_indexes)
# count number of seeds for each architecture
for arch_index in arch_indexes:
num_seeds[len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))] += 1
num_seeds[
len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))
] += 1
print(
"{:} There are {:5d} architectures that have been evaluated ({:} in total).".format(
time_string(), num_evaluated_arch, meta_num_archs
@@ -334,7 +434,9 @@ def simplify(save_dir, meta_file, basestr, target_dir):
)
for key in sorted(list(num_seeds.keys())):
print(
"{:} There are {:5d} architectures that are evaluated {:} times.".format(time_string(), num_seeds[key], key)
"{:} There are {:5d} architectures that are evaluated {:} times.".format(
time_string(), num_seeds[key], key
)
)
dataloader_dict = GET_DataLoaders(6)
@@ -346,8 +448,15 @@ def simplify(save_dir, meta_file, basestr, target_dir):
if not to_save_allarc.exists():
to_save_allarc.mkdir(parents=True, exist_ok=True)
assert (save_dir / target_dir) in subdir2archs, "can not find {:}".format(target_dir)
arch2infos, datasets = {}, ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120")
assert (save_dir / target_dir) in subdir2archs, "can not find {:}".format(
target_dir
)
arch2infos, datasets = {}, (
"cifar10-valid",
"cifar10",
"cifar100",
"ImageNet16-120",
)
evaluated_indexes = set()
target_directory = save_dir / target_dir
target_less_dir = save_dir / "{:}-LESS".format(target_dir)
@@ -356,24 +465,36 @@ def simplify(save_dir, meta_file, basestr, target_dir):
end_time = time.time()
arch_time = AverageMeter()
for idx, arch_index in enumerate(arch_indexes):
checkpoints = list(target_directory.glob("arch-{:}-seed-*.pth".format(arch_index)))
checkpoints = list(
target_directory.glob("arch-{:}-seed-*.pth".format(arch_index))
)
ckps_less = list(target_less_dir.glob("arch-{:}-seed-*.pth".format(arch_index)))
# create the arch info for each architecture
try:
arch_info_full = account_one_arch(
arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict
arch_index,
meta_archs[int(arch_index)],
checkpoints,
datasets,
dataloader_dict,
)
arch_info_less = account_one_arch(
arch_index, meta_archs[int(arch_index)], ckps_less, ["cifar10-valid"], dataloader_dict
arch_index,
meta_archs[int(arch_index)],
ckps_less,
["cifar10-valid"],
dataloader_dict,
)
num_seeds[len(checkpoints)] += 1
except:
print("Loading {:} failed, : {:}".format(arch_index, checkpoints))
continue
assert int(arch_index) not in evaluated_indexes, "conflict arch-index : {:}".format(arch_index)
assert 0 <= int(arch_index) < len(meta_archs), "invalid arch-index {:} (not found in meta_archs)".format(
arch_index
)
assert (
int(arch_index) not in evaluated_indexes
), "conflict arch-index : {:}".format(arch_index)
assert (
0 <= int(arch_index) < len(meta_archs)
), "invalid arch-index {:} (not found in meta_archs)".format(arch_index)
arch_info = {"full": arch_info_full, "less": arch_info_less}
evaluated_indexes.add(int(arch_index))
arch2infos[int(arch_index)] = arch_info
@@ -390,14 +511,19 @@ def simplify(save_dir, meta_file, basestr, target_dir):
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = "{:}".format(convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1), True))
need_time = "{:}".format(
convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1), True)
)
print(
"{:} {:} [{:03d}/{:03d}] : {:} still need {:}".format(
time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time
)
)
# measure time
xstrs = ["{:}:{:03d}".format(key, num_seeds[key]) for key in sorted(list(num_seeds.keys()))]
xstrs = [
"{:}:{:03d}".format(key, num_seeds[key])
for key in sorted(list(num_seeds.keys()))
]
print("{:} {:} done : {:}".format(time_string(), target_dir, xstrs))
final_infos = {
"meta_archs": meta_archs,
@@ -409,7 +535,9 @@ def simplify(save_dir, meta_file, basestr, target_dir):
save_file_name = to_save_simply / "{:}.pth".format(target_dir)
torch.save(final_infos, save_file_name)
print(
"Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name)
"Save {:} / {:} architecture results into {:}.".format(
len(evaluated_indexes), meta_num_archs, save_file_name
)
)
@@ -418,12 +546,16 @@ def merge_all(save_dir, meta_file, basestr):
meta_archs = meta_infos["archs"]
meta_num_archs = meta_infos["total"]
meta_max_node = meta_infos["max_node"]
assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format(
meta_num_archs, len(meta_archs)
)
assert meta_num_archs == len(
meta_archs
), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr))))
print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs)))
print(
"{:} find {:} directories used to save checkpoints".format(
time_string(), len(sub_model_dirs)
)
)
for index, sub_dir in enumerate(sub_model_dirs):
arch_info_files = sorted(list(sub_dir.glob("arch-*-seed-*.pth")))
print(
@@ -437,11 +569,16 @@ def merge_all(save_dir, meta_file, basestr):
ckp_path = sub_dir.parent / "simplifies" / "{:}.pth".format(sub_dir.name)
if ckp_path.exists():
sub_ckps = torch.load(ckp_path, map_location="cpu")
assert sub_ckps["total_archs"] == meta_num_archs and sub_ckps["basestr"] == basestr
assert (
sub_ckps["total_archs"] == meta_num_archs
and sub_ckps["basestr"] == basestr
)
xarch2infos = sub_ckps["arch2infos"]
xevalindexs = sub_ckps["evaluated_indexes"]
for eval_index in xevalindexs:
assert eval_index not in evaluated_indexes and eval_index not in arch2infos
assert (
eval_index not in evaluated_indexes and eval_index not in arch2infos
)
# arch2infos[eval_index] = xarch2infos[eval_index].state_dict()
arch2infos[eval_index] = {
"full": xarch2infos[eval_index]["full"].state_dict(),
@@ -458,7 +595,11 @@ def merge_all(save_dir, meta_file, basestr):
# print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path))
evaluated_indexes = sorted(list(evaluated_indexes))
print("Finally, there are {:} architectures that have been trained and evaluated.".format(len(evaluated_indexes)))
print(
"Finally, there are {:} architectures that have been trained and evaluated.".format(
len(evaluated_indexes)
)
)
to_save_simply = save_dir / "simplifies"
if not to_save_simply.exists():
@@ -472,16 +613,24 @@ def merge_all(save_dir, meta_file, basestr):
save_file_name = to_save_simply / "{:}-final-infos.pth".format(basestr)
torch.save(final_infos, save_file_name)
print(
"Save {:} / {:} architecture results into {:}.".format(len(evaluated_indexes), meta_num_archs, save_file_name)
"Save {:} / {:} architecture results into {:}.".format(
len(evaluated_indexes), meta_num_archs, save_file_name
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NAS-BENCH-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="NAS-BENCH-201",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--mode",
type=str,
choices=["cal", "merge"],
help="The running mode for this script.",
)
parser.add_argument("--mode", type=str, choices=["cal", "merge"], help="The running mode for this script.")
parser.add_argument(
"--base_save_dir",
type=str,
@@ -489,16 +638,26 @@ if __name__ == "__main__":
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument("--target_dir", type=str, help="The target directory.")
parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.")
parser.add_argument("--channel", type=int, default=16, help="The number of channels.")
parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.")
parser.add_argument(
"--max_node", type=int, default=4, help="The maximum node in a cell."
)
parser.add_argument(
"--channel", type=int, default=16, help="The number of channels."
)
parser.add_argument(
"--num_cells", type=int, default=5, help="The number of cells in one stage."
)
args = parser.parse_args()
save_dir = Path(args.base_save_dir)
meta_path = save_dir / "meta-node-{:}.pth".format(args.max_node)
assert save_dir.exists(), "invalid save dir path : {:}".format(save_dir)
assert meta_path.exists(), "invalid saved meta path : {:}".format(meta_path)
print("start the statistics of our nas-benchmark from {:} using {:}.".format(save_dir, args.target_dir))
print(
"start the statistics of our nas-benchmark from {:} using {:}.".format(
save_dir, args.target_dir
)
)
basestr = "C{:}-N{:}".format(args.channel, args.num_cells)
if args.mode == "cal":

View File

@@ -25,7 +25,11 @@ def check_unique_arch(meta_file):
def get_unique_matrix(archs, consider_zero):
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
print("{:} create unique-string ({:}/{:}) done".format(time_string(), len(set(UniquStrs)), len(UniquStrs)))
print(
"{:} create unique-string ({:}/{:}) done".format(
time_string(), len(set(UniquStrs)), len(UniquStrs)
)
)
Unique2Index = dict()
for index, xstr in enumerate(UniquStrs):
if xstr not in Unique2Index:
@@ -47,16 +51,32 @@ def check_unique_arch(meta_file):
unique_num += 1
return sm_matrix, unique_ids, unique_num
print("There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs)))
print(
"There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs))
)
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None)
print("{:} There are {:} unique architectures (considering nothing).".format(time_string(), unique_num))
print(
"{:} There are {:} unique architectures (considering nothing).".format(
time_string(), unique_num
)
)
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False)
print("{:} There are {:} unique architectures (not considering zero).".format(time_string(), unique_num))
print(
"{:} There are {:} unique architectures (not considering zero).".format(
time_string(), unique_num
)
)
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True)
print("{:} There are {:} unique architectures (considering zero).".format(time_string(), unique_num))
print(
"{:} There are {:} unique architectures (considering zero).".format(
time_string(), unique_num
)
)
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False):
def check_cor_for_bandit(
meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False
):
if isinstance(meta_file, API):
api = meta_file
else:
@@ -69,7 +89,9 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
imagenet_test = []
imagenet_valid = []
for idx, arch in enumerate(api):
results = api.get_more_info(idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand)
results = api.get_more_info(
idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand
)
cifar10_currs.append(results["valid-accuracy"])
# --->>>>>
results = api.get_more_info(idx, "cifar10-valid", None, False, is_rand)
@@ -89,13 +111,23 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, n
cors = []
for basestr, xlist in zip(
["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"],
[cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test],
[
cifar10_valid,
cifar10_test,
cifar100_valid,
cifar100_test,
imagenet_valid,
imagenet_test,
],
):
correlation = get_cor(cifar10_currs, xlist)
if need_print:
print(
"With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}".format(
test_epoch, "012" if use_less_or_not else "200", basestr, correlation
test_epoch,
"012" if use_less_or_not else "200",
basestr,
correlation,
)
)
cors.append(correlation)
@@ -113,7 +145,11 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
# xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
xstrs = ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"]
correlations = np.array(corrs)
print("------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(test_epoch, "012" if use_less_or_not else "200"))
print(
"------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(
test_epoch, "012" if use_less_or_not else "200"
)
)
for idx, xstr in enumerate(xstrs):
print(
"{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}".format(
@@ -135,7 +171,12 @@ if __name__ == "__main__":
default="./output/search-cell-nas-bench-201/visuals",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.")
parser.add_argument(
"--api_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 benchmark file.",
)
args = parser.parse_args()
vis_save_dir = Path(args.save_dir)

View File

@@ -47,15 +47,21 @@ def visualize_relative_ranking(vis_save_dir):
print("{:} start to visualize relative ranking".format(time_string()))
# maximum accuracy with ResNet-level params 11472
x_010_accs = [
cifar010_info["test_accs"][i] if cifar010_info["params"][i] <= cifar010_info["params"][11472] else -1
cifar010_info["test_accs"][i]
if cifar010_info["params"][i] <= cifar010_info["params"][11472]
else -1
for i in indexes
]
x_100_accs = [
cifar100_info["test_accs"][i] if cifar100_info["params"][i] <= cifar100_info["params"][11472] else -1
cifar100_info["test_accs"][i]
if cifar100_info["params"][i] <= cifar100_info["params"][11472]
else -1
for i in indexes
]
x_img_accs = [
imagenet_info["test_accs"][i] if imagenet_info["params"][i] <= imagenet_info["params"][11472] else -1
imagenet_info["test_accs"][i]
if imagenet_info["params"][i] <= imagenet_info["params"][11472]
else -1
for i in indexes
]
@@ -79,8 +85,15 @@ def visualize_relative_ranking(vis_save_dir):
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(0)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize, rotation="vertical")
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize)
plt.yticks(
np.arange(min(indexes), max(indexes), max(indexes) // 6),
fontsize=LegendFontsize,
rotation="vertical",
)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 6),
fontsize=LegendFontsize,
)
# ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100')
# ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8, label='ImageNet-16-120')
# ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10')
@@ -113,7 +126,9 @@ def visualize_relative_ranking(vis_save_dir):
)
fig = plt.figure(figsize=figsize)
plt.axis("off")
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={"size": sns_size}, fmt=".3f", linewidths=0.5)
h = sns.heatmap(
CoRelMatrix, annot=True, annot_kws={"size": sns_size}, fmt=".3f", linewidths=0.5
)
save_path = (vis_save_dir / "co-relation-all.pdf").resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
print("{:} save into {:}".format(time_string(), save_path))
@@ -142,8 +157,16 @@ def visualize_relative_ranking(vis_save_dir):
)
fig = plt.figure(figsize=figsize)
plt.axis("off")
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={"size": sns_size}, fmt=".3f", linewidths=0.5)
save_path = (vis_save_dir / "co-relation-top-{:}.pdf".format(len(selected_indexes))).resolve()
h = sns.heatmap(
CoRelMatrix,
annot=True,
annot_kws={"size": sns_size},
fmt=".3f",
linewidths=0.5,
)
save_path = (
vis_save_dir / "co-relation-top-{:}.pdf".format(len(selected_indexes))
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
@@ -155,7 +178,14 @@ def visualize_info(meta_file, dataset, vis_save_dir):
if not cache_file_path.exists():
print("Do not find cache file : {:}".format(cache_file_path))
nas_bench = API(str(meta_file))
params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], []
params, flops, train_accs, valid_accs, test_accs, otest_accs = (
[],
[],
[],
[],
[],
[],
)
for index in range(len(nas_bench)):
info = nas_bench.query_by_index(index, use_12epochs_result=False)
resx = info.get_comput_costs(dataset)
@@ -239,7 +269,13 @@ def visualize_info(meta_file, dataset, vis_save_dir):
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
ax.scatter(params, valid_accs, marker="o", s=0.5, c="tab:blue")
ax.scatter(
[resnet["params"]], [resnet["valid_acc"]], marker="*", s=resnet_scale, c="tab:orange", label="resnet", alpha=0.4
[resnet["params"]],
[resnet["valid_acc"]],
marker="*",
s=resnet_scale,
c="tab:orange",
label="resnet",
alpha=0.4,
)
plt.grid(zorder=0)
ax.set_axisbelow(True)
@@ -321,7 +357,10 @@ def visualize_info(meta_file, dataset, vis_save_dir):
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(0, max(indexes))
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 5),
fontsize=LegendFontsize,
)
if dataset == "cifar10":
plt.ylim(50, 100)
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
@@ -357,7 +396,11 @@ def visualize_info(meta_file, dataset, vis_save_dir):
def visualize_rank_over_time(meta_file, vis_save_dir):
print("\n" + "-" * 150)
vis_save_dir.mkdir(parents=True, exist_ok=True)
print("{:} start to visualize rank-over-time into {:}".format(time_string(), vis_save_dir))
print(
"{:} start to visualize rank-over-time into {:}".format(
time_string(), vis_save_dir
)
)
cache_file_path = vis_save_dir / "rank-over-time-cache-info.pth"
if not cache_file_path.exists():
print("Do not find cache file : {:}".format(cache_file_path))
@@ -434,17 +477,26 @@ def visualize_rank_over_time(meta_file, vis_save_dir):
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
plt.yticks(
np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize, rotation="vertical"
np.arange(min(indexes), max(indexes), max(indexes) // 6),
fontsize=LegendFontsize,
rotation="vertical",
)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 6),
fontsize=LegendFontsize,
)
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize)
ax.scatter(indexes, valid_ord_lbls, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-10 validation")
ax.scatter(
[-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-10 validation"
)
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10 test")
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc="upper left", fontsize=LegendFontsize)
ax.set_xlabel("architecture ranking in the final test accuracy", fontsize=LabelSize)
ax.set_xlabel(
"architecture ranking in the final test accuracy", fontsize=LabelSize
)
ax.set_ylabel("architecture ranking in the validation set", fontsize=LabelSize)
save_path = (vis_save_dir / "time-{:03d}.pdf".format(sepoch)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
@@ -464,7 +516,9 @@ def write_video(save_dir):
# shape = (ximage.shape[1], ximage.shape[0])
shape = (1000, 1000)
# writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape)
writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape)
writer = cv2.VideoWriter(
str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape
)
for idx, image in enumerate(images):
ximage = cv2.imread(str(image))
_image = cv2.resize(ximage, shape)
@@ -490,9 +544,13 @@ def plot_results_nas_v2(api, dataset_xset_a, dataset_xset_b, root, file_name, y_
accuracies = []
for x in all_indexes:
info = api.arch2infos_full[x]
metrics = info.get_metrics(dataset_xset_a[0], dataset_xset_a[1], None, False)
metrics = info.get_metrics(
dataset_xset_a[0], dataset_xset_a[1], None, False
)
accuracies_A.append(metrics["accuracy"])
metrics = info.get_metrics(dataset_xset_b[0], dataset_xset_b[1], None, False)
metrics = info.get_metrics(
dataset_xset_b[0], dataset_xset_b[1], None, False
)
accuracies_B.append(metrics["accuracy"])
accuracies.append((accuracies_A[-1], accuracies_B[-1]))
if indexes is None:
@@ -580,7 +638,14 @@ def plot_results_nas(api, dataset, xset, root, file_name, y_lims):
plt.ylabel("The accuracy (%)", fontsize=LabelSize)
for idx, legend in enumerate(legends):
plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle="-", label="{:}".format(legend), lw=2)
plt.plot(
indexes,
All_Accs[legend],
color=color_set[idx],
linestyle="-",
label="{:}".format(legend),
lw=2,
)
print(
"{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}".format(
legend,
@@ -646,13 +711,19 @@ def just_show(api):
return xresults
for xkey in xpaths.keys():
all_paths = ["{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]]
all_paths = [
"{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]
]
all_datas = [torch.load(xpath) for xpath in all_paths]
accyss = [get_accs(xdatas) for xdatas in all_datas]
accyss = np.array(accyss)
print("\nxkey = {:}".format(xkey))
for i in range(accyss.shape[1]):
print("---->>>> {:.2f}$\\pm${:.2f}".format(accyss[:, i].mean(), accyss[:, i].std()))
print(
"---->>>> {:.2f}$\\pm${:.2f}".format(
accyss[:, i].mean(), accyss[:, i].std()
)
)
print("\n{:}".format(get_accs(None, 11472))) # resnet
pairs = [
@@ -665,10 +736,16 @@ def just_show(api):
]
for dataset, metric_on_set in pairs:
arch_index, highest_acc = api.find_best(dataset, metric_on_set)
print("[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}".format(dataset, metric_on_set, arch_index, highest_acc))
print(
"[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}".format(
dataset, metric_on_set, arch_index, highest_acc
)
)
def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_lims, x_maxs):
def show_nas_sharing_w(
api, dataset, subset, vis_save_dir, sufix, file_name, y_lims, x_maxs
):
color_set = ["r", "b", "g", "c", "m", "y", "k"]
dpi, width, height = 300, 3400, 2600
LabelSize, LegendFontsize = 28, 28
@@ -685,12 +762,24 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_l
plt.ylabel("The accuracy (%)", fontsize=LabelSize)
xpaths = {
"RSPS": "output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/".format(sufix),
"DARTS-V1": "output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/".format(sufix),
"DARTS-V2": "output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/".format(sufix),
"GDAS": "output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/".format(sufix),
"SETN": "output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/".format(sufix),
"ENAS": "output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/".format(sufix),
"RSPS": "output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/".format(
sufix
),
"DARTS-V1": "output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/".format(
sufix
),
"DARTS-V2": "output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/".format(
sufix
),
"GDAS": "output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/".format(
sufix
),
"SETN": "output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/".format(
sufix
),
"ENAS": "output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/".format(
sufix
),
}
"""
xseeds = {'RSPS' : [5349, 59613, 5983],
@@ -713,16 +802,20 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_l
def get_accs(xdata):
epochs, xresults = xdata["epoch"], []
if -1 in xdata["genotypes"]:
metrics = api.arch2infos_full[api.query_index_by_arch(xdata["genotypes"][-1])].get_metrics(
metrics = api.arch2infos_full[
api.query_index_by_arch(xdata["genotypes"][-1])
].get_metrics(dataset, subset, None, False)
else:
metrics = api.arch2infos_full[api.random()].get_metrics(
dataset, subset, None, False
)
else:
metrics = api.arch2infos_full[api.random()].get_metrics(dataset, subset, None, False)
xresults.append(metrics["accuracy"])
for iepoch in range(epochs):
genotype = xdata["genotypes"][iepoch]
index = api.query_index_by_arch(genotype)
metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False)
metrics = api.arch2infos_full[index].get_metrics(
dataset, subset, None, False
)
xresults.append(metrics["accuracy"])
return xresults
@@ -735,7 +828,9 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_l
for idx, method in enumerate(xxxstrs):
xkey = method
all_paths = ["{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]]
all_paths = [
"{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]
]
all_datas = [torch.load(xpath, map_location="cpu") for xpath in all_paths]
accyss = [get_accs(xdatas) for xdatas in all_datas]
accyss = np.array(accyss)
@@ -762,7 +857,9 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, sufix, file_name, y_l
fig.savefig(str(save_path), dpi=dpi, bbox_inches="tight", format="pdf")
def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file_name, y_lims, x_maxs):
def show_nas_sharing_w_v2(
api, data_sub_a, data_sub_b, vis_save_dir, sufix, file_name, y_lims, x_maxs
):
color_set = ["r", "b", "g", "c", "m", "y", "k"]
dpi, width, height = 300, 3400, 2600
LabelSize, LegendFontsize = 28, 28
@@ -779,12 +876,24 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file
plt.ylabel("The accuracy (%)", fontsize=LabelSize)
xpaths = {
"RSPS": "output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/".format(sufix),
"DARTS-V1": "output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/".format(sufix),
"DARTS-V2": "output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/".format(sufix),
"GDAS": "output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/".format(sufix),
"SETN": "output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/".format(sufix),
"ENAS": "output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/".format(sufix),
"RSPS": "output/search-cell-nas-bench-201/RANDOM-NAS-cifar10-{:}/checkpoint/".format(
sufix
),
"DARTS-V1": "output/search-cell-nas-bench-201/DARTS-V1-cifar10-{:}/checkpoint/".format(
sufix
),
"DARTS-V2": "output/search-cell-nas-bench-201/DARTS-V2-cifar10-{:}/checkpoint/".format(
sufix
),
"GDAS": "output/search-cell-nas-bench-201/GDAS-cifar10-{:}/checkpoint/".format(
sufix
),
"SETN": "output/search-cell-nas-bench-201/SETN-cifar10-{:}/checkpoint/".format(
sufix
),
"ENAS": "output/search-cell-nas-bench-201/ENAS-cifar10-{:}/checkpoint/".format(
sufix
),
}
"""
xseeds = {'RSPS' : [5349, 59613, 5983],
@@ -807,16 +916,20 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file
def get_accs(xdata, dataset, subset):
epochs, xresults = xdata["epoch"], []
if -1 in xdata["genotypes"]:
metrics = api.arch2infos_full[api.query_index_by_arch(xdata["genotypes"][-1])].get_metrics(
metrics = api.arch2infos_full[
api.query_index_by_arch(xdata["genotypes"][-1])
].get_metrics(dataset, subset, None, False)
else:
metrics = api.arch2infos_full[api.random()].get_metrics(
dataset, subset, None, False
)
else:
metrics = api.arch2infos_full[api.random()].get_metrics(dataset, subset, None, False)
xresults.append(metrics["accuracy"])
for iepoch in range(epochs):
genotype = xdata["genotypes"][iepoch]
index = api.query_index_by_arch(genotype)
metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False)
metrics = api.arch2infos_full[index].get_metrics(
dataset, subset, None, False
)
xresults.append(metrics["accuracy"])
return xresults
@@ -829,10 +942,16 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file
for idx, method in enumerate(xxxstrs):
xkey = method
all_paths = ["{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]]
all_paths = [
"{:}/seed-{:}-basic.pth".format(xpaths[xkey], seed) for seed in xseeds[xkey]
]
all_datas = [torch.load(xpath, map_location="cpu") for xpath in all_paths]
accyss_A = np.array([get_accs(xdatas, data_sub_a[0], data_sub_a[1]) for xdatas in all_datas])
accyss_B = np.array([get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas])
accyss_A = np.array(
[get_accs(xdatas, data_sub_a[0], data_sub_a[1]) for xdatas in all_datas]
)
accyss_B = np.array(
[get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas]
)
epochs = list(range(accyss_A.shape[1]))
for j, accyss in enumerate([accyss_A, accyss_B]):
if x_maxs == 50:
@@ -859,7 +978,9 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file
)
setname = data_sub_a if j == 0 else data_sub_b
print(
"{:} -- {:} ---- {:.2f}$\\pm${:.2f}".format(method, setname, accyss[:, -1].mean(), accyss[:, -1].std())
"{:} -- {:} ---- {:.2f}$\\pm${:.2f}".format(
method, setname, accyss[:, -1].mean(), accyss[:, -1].std()
)
)
# plt.legend(loc=4, fontsize=LegendFontsize)
plt.legend(loc=0, fontsize=LegendFontsize)
@@ -871,7 +992,10 @@ def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, sufix, file
def show_reinforce(api, root, dataset, xset, file_name, y_lims):
print("root-path={:}, dataset={:}, xset={:}".format(root, dataset, xset))
LRs = ["0.01", "0.02", "0.1", "0.2", "0.5"]
checkpoints = ["./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth".format(x) for x in LRs]
checkpoints = [
"./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth".format(x)
for x in LRs
]
acc_lr_dict, indexes = {}, None
for lr, checkpoint in zip(LRs, checkpoints):
all_indexes, accuracies = torch.load(checkpoint, map_location="cpu"), []
@@ -882,7 +1006,11 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims):
if indexes is None:
indexes = list(range(len(accuracies)))
acc_lr_dict[lr] = np.array(sorted(accuracies))
print("LR={:.3f}, mean={:}, std={:}".format(float(lr), acc_lr_dict[lr].mean(), acc_lr_dict[lr].std()))
print(
"LR={:.3f}, mean={:}, std={:}".format(
float(lr), acc_lr_dict[lr].mean(), acc_lr_dict[lr].std()
)
)
color_set = ["r", "b", "g", "c", "m", "y", "k"]
dpi, width, height = 300, 3400, 2600
@@ -903,7 +1031,15 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims):
legend = "LR={:.2f}".format(float(LR))
# color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
color, linestyle = color_set[idx], "-"
plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8)
plt.plot(
indexes,
acc_lr_dict[LR],
color=color,
linestyle=linestyle,
label=legend,
lw=2,
alpha=0.8,
)
print(
"{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}".format(
legend,
@@ -922,7 +1058,10 @@ def show_reinforce(api, root, dataset, xset, file_name, y_lims):
def show_rea(api, root, dataset, xset, file_name, y_lims):
print("root-path={:}, dataset={:}, xset={:}".format(root, dataset, xset))
SSs = [3, 5, 10]
checkpoints = ["./output/search-cell-nas-bench-201/R-EA-cifar10-SS{:}/results.pth".format(x) for x in SSs]
checkpoints = [
"./output/search-cell-nas-bench-201/R-EA-cifar10-SS{:}/results.pth".format(x)
for x in SSs
]
acc_ss_dict, indexes = {}, None
for ss, checkpoint in zip(SSs, checkpoints):
all_indexes, accuracies = torch.load(checkpoint, map_location="cpu"), []
@@ -933,7 +1072,11 @@ def show_rea(api, root, dataset, xset, file_name, y_lims):
if indexes is None:
indexes = list(range(len(accuracies)))
acc_ss_dict[ss] = np.array(sorted(accuracies))
print("Sample-Size={:2d}, mean={:}, std={:}".format(ss, acc_ss_dict[ss].mean(), acc_ss_dict[ss].std()))
print(
"Sample-Size={:2d}, mean={:}, std={:}".format(
ss, acc_ss_dict[ss].mean(), acc_ss_dict[ss].std()
)
)
color_set = ["r", "b", "g", "c", "m", "y", "k"]
dpi, width, height = 300, 3400, 2600
@@ -954,7 +1097,15 @@ def show_rea(api, root, dataset, xset, file_name, y_lims):
legend = "sample-size={:2d}".format(ss)
# color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
color, linestyle = color_set[idx], "-"
plt.plot(indexes, acc_ss_dict[ss], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8)
plt.plot(
indexes,
acc_ss_dict[ss],
color=color,
linestyle=linestyle,
label=legend,
lw=2,
alpha=0.8,
)
print(
"{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}".format(
legend,
@@ -973,7 +1124,8 @@ def show_rea(api, root, dataset, xset, file_name, y_lims):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NAS-Bench-201", formatter_class=argparse.ArgumentDefaultsHelpFormatter
description="NAS-Bench-201",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir",
@@ -981,7 +1133,12 @@ if __name__ == "__main__":
default="./output/search-cell-nas-bench-201/visuals",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.")
parser.add_argument(
"--api_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 benchmark file.",
)
args = parser.parse_args()
vis_save_dir = Path(args.save_dir)
@@ -1066,9 +1223,25 @@ if __name__ == "__main__":
)
show_nas_sharing_w(
api, "cifar10-valid", "x-valid", vis_save_dir, "BN0", "BN0-XX-CIFAR010-VALID.pdf", (0, 100, 10), 250
api,
"cifar10-valid",
"x-valid",
vis_save_dir,
"BN0",
"BN0-XX-CIFAR010-VALID.pdf",
(0, 100, 10),
250,
)
show_nas_sharing_w(
api,
"cifar10",
"ori-test",
vis_save_dir,
"BN0",
"BN0-XX-CIFAR010-TEST.pdf",
(0, 100, 10),
250,
)
show_nas_sharing_w(api, "cifar10", "ori-test", vis_save_dir, "BN0", "BN0-XX-CIFAR010-TEST.pdf", (0, 100, 10), 250)
"""
for x_maxs in [50, 250]:
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)