102->201 / NAS->autoDL / more configs of TAS / reorganize docs / fix bugs in NAS baselines

This commit is contained in:
D-X-Y
2020-01-15 00:52:06 +11:00
parent 33384a78af
commit bb2f405961
62 changed files with 789 additions and 412 deletions

View File

@@ -0,0 +1,84 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# python exps/NAS-Bench-201/check.py --base_save_dir
##################################################
import os, sys, time, argparse, collections
from shutil import copyfile
import torch
import torch.nn as nn
from pathlib import Path
from collections import defaultdict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from log_utils import AverageMeter, time_string, convert_secs2time
def check_files(save_dir, meta_file, basestr):
meta_infos = torch.load(meta_file, map_location='cpu')
meta_archs = meta_infos['archs']
meta_num_archs = meta_infos['total']
meta_max_node = meta_infos['max_node']
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
num_seeds = defaultdict(lambda: 0)
for index, sub_dir in enumerate(sub_model_dirs):
xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth'))
#xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth'))
arch_indexes = set()
for checkpoint in xcheckpoints:
temp_names = checkpoint.name.split('-')
assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name)
arch_indexes.add( temp_names[1] )
subdir2archs[sub_dir] = sorted(list(arch_indexes))
num_evaluated_arch += len(arch_indexes)
# count number of seeds for each architecture
for arch_index in arch_indexes:
num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1
print('There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).'.format(num_evaluated_arch, meta_num_archs, sum(k*v for k, v in num_seeds.items())))
for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key))
dir2ckps, dir2ckp_exists = dict(), dict()
start_time, epoch_time = time.time(), AverageMeter()
for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()):
seeds = [777, 888, 999]
numrs = defaultdict(lambda: 0)
all_checkpoints, all_ckp_exists = [], []
for arch_index in arch_indexes:
checkpoints = ['arch-{:}-seed-{:04d}.pth'.format(arch_index, seed) for seed in seeds]
ckp_exists = [(sub_dir/x).exists() for x in checkpoints]
arch_index = int(arch_index)
assert 0 <= arch_index < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index)
all_checkpoints += checkpoints
all_ckp_exists += ckp_exists
numrs[sum(ckp_exists)] += 1
dir2ckps[ str(sub_dir) ] = all_checkpoints
dir2ckp_exists[ str(sub_dir) ] = all_ckp_exists
# measure time
epoch_time.update(time.time() - start_time)
start_time = time.time()
numrstr = ', '.join( ['{:}: {:03d}'.format(x, numrs[x]) for x in sorted(numrs.keys())] )
print('{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}'.format(time_string(), IDX+1, len(subdir2archs), len(arch_indexes), len(all_checkpoints), sum(all_ckp_exists), sub_dir, convert_secs2time(epoch_time.avg * (len(subdir2archs)-IDX-1), True), numrstr))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--max_node', type=int, default=4, help='The maximum node in a cell.')
parser.add_argument('--channel', type=int, default=16, help='The number of channels.')
parser.add_argument('--num_cells', type=int, default=5, help='The number of cells in one stage.')
args = parser.parse_args()
save_dir = Path( args.base_save_dir )
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
print ('check NAS-Bench-201 in {:}'.format(save_dir))
basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells)
check_files(save_dir, meta_path, basestr)

View File

@@ -0,0 +1,28 @@
import os
from setuptools import setup
def read(fname='README.md'):
with open(os.path.join(os.path.dirname(__file__), fname), encoding='utf-8') as cfile:
return cfile.read()
setup(
name = "nas_bench_201",
version = "1.0",
author = "Xuanyi Dong",
author_email = "dongxuanyi888@gmail.com",
description = "API for NAS-Bench-201 (a benchmark for neural architecture search).",
license = "MIT",
keywords = "NAS Dataset API DeepLearning",
url = "https://github.com/D-X-Y/NAS-Bench-201",
packages=['nas_201_api'],
long_description=read('README.md'),
long_description_content_type='text/markdown',
classifiers=[
"Programming Language :: Python",
"Topic :: Database",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
],
)

View File

@@ -0,0 +1,136 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
from procedures import prepare_seed, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from config_utils import dict2config
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net
__all__ = ['evaluate_for_seed', 'pure_evaluate']
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
latencies = []
network.eval()
with torch.no_grad():
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
targets = targets.cuda(non_blocking=True)
inputs = inputs.cuda(non_blocking=True)
data_time.update(time.time() - end)
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
batch_time.update(time.time() - end)
if batch is None or batch == inputs.size(0):
batch = inputs.size(0)
latencies.append( batch_time.val - data_time.val )
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
end = time.time()
if len(latencies) > 2: latencies = latencies[1:]
return losses.avg, top1.avg, top5.avg, latencies
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
if mode == 'train' : network.train()
elif mode == 'valid': network.eval()
else: raise ValueError("The mode is not right : {:}".format(mode))
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
targets = targets.cuda(non_blocking=True)
if mode == 'train': optimizer.zero_grad()
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
# backward
if mode == 'train':
loss.backward()
optimizer.step()
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
# count time
batch_time.update(time.time() - end)
end = time.time()
return losses.avg, top1.avg, top5.avg, batch_time.sum
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger):
prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny',
'C': arch_config['channel'], 'N': arch_config['num_cells'],
'genotype': arch, 'num_classes': config.class_num}
, None)
)
#net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
flop, param = get_model_infos(net, config.xshape)
logger.log('Network : {:}'.format(net.get_message()), False)
logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
# train and valid
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config)
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
# start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
train_times , valid_times = {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
train_losses[epoch] = train_loss
train_acc1es[epoch] = train_acc1
train_acc5es[epoch] = train_acc5
train_times [epoch] = train_tm
with torch.no_grad():
for key, xloder in valid_loaders.items():
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder , network, criterion, None, None, 'valid')
valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss
valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1
valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5
valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) )
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5))
info_seed = {'flop' : flop,
'param': param,
'channel' : arch_config['channel'],
'num_cells' : arch_config['num_cells'],
'config' : config._asdict(),
'total_epoch' : total_epoch ,
'train_losses': train_losses,
'train_acc1es': train_acc1es,
'train_acc5es': train_acc5es,
'train_times' : train_times,
'valid_losses': valid_losses,
'valid_acc1es': valid_acc1es,
'valid_acc5es': valid_acc5es,
'valid_times' : valid_times,
'net_state_dict': net.state_dict(),
'net_string' : '{:}'.format(net),
'finish-train': True
}
return info_seed

316
exps/NAS-Bench-201/main.py Normal file
View File

@@ -0,0 +1,316 @@
###############################################################
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019-2020 #
###############################################################
import os, sys, time, torch, random, argparse
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config
from procedures import save_checkpoint, copy_checkpoint
from procedures import get_machine_info
from datasets import get_datasets
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
from models import CellStructure, CellArchitectures, get_search_spaces
from functions import evaluate_for_seed
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger):
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
all_infos = {'info': machine_info}
all_dataset_keys = []
# look all the datasets
for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configurature
if dataset == 'cifar10' or dataset == 'cifar100':
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
else : config_path = 'configs/nas-benchmark/CIFAR.config'
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
elif dataset.startswith('ImageNet16'):
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
else : config_path = 'configs/nas-benchmark/ImageNet-16.config'
split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None)
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
config = load_config(config_path, \
{'class_num': class_num,
'xshape' : xshape}, \
logger)
# check whether use splited validation set
if bool(split):
assert dataset == 'cifar10'
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid))
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True)
ValLoaders['x-valid'] = valid_loader
else:
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
if dataset == 'cifar10':
ValLoaders = {'ori-test': valid_loader}
elif dataset == 'cifar100':
cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
ValLoaders = {'ori-test': valid_loader,
'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True),
'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True)
}
elif dataset == 'ImageNet16-120':
imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
ValLoaders = {'ori-test': valid_loader,
'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True),
'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True)
}
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
dataset_key = '{:}'.format(dataset)
if bool(split): dataset_key = dataset_key + '-valid'
logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config))
for key, value in ValLoaders.items():
logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger)
all_infos[dataset_key] = results
all_dataset_keys.append( dataset_key )
all_infos['all_dataset_keys'] = all_dataset_keys
return all_infos
def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
#torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
torch.set_num_threads( workers )
assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange)
if use_less:
sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
else:
sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
logger = Logger(str(sub_dir), 0, False)
all_archs = meta_info['archs']
assert srange[1] < meta_info['total'], 'invalid range : {:}-{:} vs. {:}'.format(srange[0], srange[1], meta_info['total'])
assert arch_index == -1 or srange[0] <= arch_index <= srange[1], 'invalid range : {:} vs. {:} vs. {:}'.format(srange[0], arch_index, srange[1])
if arch_index == -1:
to_evaluate_indexes = list(range(srange[0], srange[1]+1))
else:
to_evaluate_indexes = [arch_index]
logger.log('xargs : seeds = {:}'.format(seeds))
logger.log('xargs : arch_index = {:}'.format(arch_index))
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
logger.log('-'*100)
logger.log('Start evaluating range =: {:06d} vs. {:06d} vs. {:06d} / {:06d} with cover-mode={:}'.format(srange[0], arch_index, srange[1], meta_info['total'], cover_mode))
for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
logger.log('--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
logger.log('--->>> architecture config : {:}'.format(arch_config))
start_time, epoch_time = time.time(), AverageMeter()
for i, index in enumerate(to_evaluate_indexes):
arch = all_archs[index]
logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seeds, '-'*15))
#logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15))
logger.log('{:} {:} {:}'.format('-'*15, arch, '-'*15))
# test this arch on different datasets with different seeds
has_continue = False
for seed in seeds:
to_save_name = sub_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed)
if to_save_name.exists():
if cover_mode:
logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name))
os.remove(str(to_save_name))
else :
logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name))
has_continue = True
continue
results = evaluate_all_datasets(CellStructure.str2structure(arch), \
datasets, xpaths, splits, use_less, seed, \
arch_config, workers, logger)
torch.save(results, to_save_name)
logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name))
# measure elapsed time
if not has_continue: epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) )
logger.log('This arch costs : {:}'.format( convert_secs2time(epoch_time.val, True) ))
logger.log('{:}'.format('*'*100))
logger.log('{:} {:74s} {:}'.format('*'*10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(to_evaluate_indexes), index, meta_info['total'], need_time), '*'*10))
logger.log('{:}'.format('*'*100))
logger.close()
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = True
torch.set_num_threads( workers )
save_dir = Path(save_dir) / 'specifics' / '{:}-{:}-{:}-{:}'.format('LESS' if use_less else 'FULL', model_str, arch_config['channel'], arch_config['num_cells'])
logger = Logger(str(save_dir), 0, False)
if model_str in CellArchitectures:
arch = CellArchitectures[model_str]
logger.log('The model string is found in pre-defined architecture dict : {:}'.format(model_str))
else:
try:
arch = CellStructure.str2structure(model_str)
except:
raise ValueError('Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
logger.log('Start train-evaluate {:}'.format(arch.tostr()))
logger.log('arch_config : {:}'.format(arch_config))
start_time, seed_time = time.time(), AverageMeter()
for _is, seed in enumerate(seeds):
logger.log('\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds), seed))
to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
if to_save_name.exists():
logger.log('Find the existing file {:}, directly load!'.format(to_save_name))
checkpoint = torch.load(to_save_name)
else:
logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger)
torch.save(checkpoint, to_save_name)
# log information
logger.log('{:}'.format(checkpoint['info']))
all_dataset_keys = checkpoint['all_dataset_keys']
for dataset_key in all_dataset_keys:
logger.log('\n{:} dataset : {:} {:}'.format('-'*15, dataset_key, '-'*15))
dataset_info = checkpoint[dataset_key]
#logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
logger.log('Flops = {:} MB, Params = {:} MB'.format(dataset_info['flop'], dataset_info['param']))
logger.log('config : {:}'.format(dataset_info['config']))
logger.log('Training State (finish) = {:}'.format(dataset_info['finish-train']))
last_epoch = dataset_info['total_epoch'] - 1
train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
logger.log('Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%'.format(train_acc1es[last_epoch], train_acc5es[last_epoch], 100-train_acc1es[last_epoch], valid_acc1es[last_epoch], valid_acc5es[last_epoch], 100-valid_acc1es[last_epoch]))
# measure elapsed time
seed_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format( convert_secs2time(seed_time.avg * (len(seeds)-_is-1), True) )
logger.log('\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed, need_time))
logger.close()
def generate_meta_info(save_dir, max_node, divide=40):
aa_nas_bench_ss = get_search_spaces('cell', 'nas-bench-201')
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2)))
random.seed( 88 ) # please do not change this line for reproducibility
random.shuffle( archs )
# to test fixed-random shuffle
#print ('arch [0] : {:}\n---->>>> {:}'.format( archs[0], archs[0].tostr() ))
#print ('arch [9] : {:}\n---->>>> {:}'.format( archs[9], archs[9].tostr() ))
assert archs[0 ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0])
assert archs[9 ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9])
assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123])
total_arch = len(archs)
num = 50000
indexes_5W = list(range(num))
random.seed( 1021 )
random.shuffle( indexes_5W )
train_split = sorted( list(set(indexes_5W[:num//2])) )
valid_split = sorted( list(set(indexes_5W[num//2:])) )
assert len(train_split) + len(valid_split) == num
assert train_split[0] == 0 and train_split[10] == 26 and train_split[111] == 203 and valid_split[0] == 1 and valid_split[10] == 18 and valid_split[111] == 242, '{:} {:} {:} - {:} {:} {:}'.format(train_split[0], train_split[10], train_split[111], valid_split[0], valid_split[10], valid_split[111])
splits = {num: {'train': train_split, 'valid': valid_split} }
info = {'archs' : [x.tostr() for x in archs],
'total' : total_arch,
'max_node' : max_node,
'splits': splits}
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
save_name = save_dir / 'meta-node-{:}.pth'.format(max_node)
assert not save_name.exists(), '{:} already exist'.format(save_name)
torch.save(info, save_name)
print ('save the meta file into {:}'.format(save_name))
script_name_full = save_dir / 'BENCH-201-N{:}.opt-full.script'.format(max_node)
script_name_less = save_dir / 'BENCH-201-N{:}.opt-less.script'.format(max_node)
full_file = open(str(script_name_full), 'w')
less_file = open(str(script_name_less), 'w')
gaps = total_arch // divide
for start in range(0, total_arch, gaps):
xend = min(start+gaps, total_arch)
full_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 0 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1))
less_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 1 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1))
print ('save the training script into {:} and {:}'.format(script_name_full, script_name_less))
full_file.close()
less_file.close()
script_name = save_dir / 'meta-node-{:}.cal-script.txt'.format(max_node)
macro = 'OMP_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0'
with open(str(script_name), 'w') as cfile:
for start in range(0, total_arch, gaps):
xend = min(start+gaps, total_arch)
cfile.write('{:} python exps/NAS-Bench-201/statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n'.format(macro, start, xend-1))
print ('save the post-processing script into {:}'.format(script_name))
if __name__ == '__main__':
#mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()]
#parser = argparse.ArgumentParser(description='Algorithm-Agnostic NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--mode' , type=str, required=True, help='The script mode.')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--max_node', type=int, help='The maximum node in a cell.')
# use for train the model
parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 2)')
parser.add_argument('--srange' , type=int, nargs='+', help='The range of models to be evaluated')
parser.add_argument('--arch_index', type=int, default=-1, help='The architecture index to be evaluated (cover mode).')
parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.')
parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.')
parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.')
parser.add_argument('--use_less', type=int, default=0, choices=[0,1], help='Using the less-training-epoch config.')
parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
args = parser.parse_args()
assert args.mode in ['meta', 'new', 'cover'] or args.mode.startswith('specific-'), 'invalid mode : {:}'.format(args.mode)
if args.mode == 'meta':
generate_meta_info(args.save_dir, args.max_node)
elif args.mode.startswith('specific'):
assert len(args.mode.split('-')) == 2, 'invalid mode : {:}'.format(args.mode)
model_str = args.mode.split('-')[1]
train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \
tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells})
else:
meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node)
assert meta_path.exists(), '{:} does not exist.'.format(meta_path)
meta_info = torch.load( meta_path )
# check whether args is ok
assert len(args.srange) == 2 and args.srange[0] <= args.srange[1], 'invalid length of srange args: {:}'.format(args.srange)
assert len(args.seeds) > 0, 'invalid length of seeds args: {:}'.format(args.seeds)
assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))
assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers)
main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \
tuple(args.srange), args.arch_index, tuple(args.seeds), \
args.mode == 'cover', meta_info, \
{'channel': args.channel, 'num_cells': args.num_cells})

View File

@@ -0,0 +1,295 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, argparse, collections
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
from collections import defaultdict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from log_utils import AverageMeter, time_string, convert_secs2time
from config_utils import load_config, dict2config
from datasets import get_datasets
# NAS-Bench-201 related module or function
from models import CellStructure, get_cell_based_tiny_net
from nas_201_api import ArchResults, ResultsCount
from functions import pure_evaluate
def create_result_count(used_seed, dataset, arch_config, results, dataloader_dict):
xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'], \
results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None)
net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes':arch_config['class_num']}, None)
network = get_cell_based_tiny_net(net_config)
network.load_state_dict(xresult.get_net_param())
if 'train_times' in results: # new version
xresult.update_train_info(results['train_acc1es'], results['train_acc5es'], results['train_losses'], results['train_times'])
xresult.update_eval(results['valid_acc1es'], results['valid_losses'], results['valid_times'])
else:
if dataset == 'cifar10-valid':
xresult.update_OLD_eval('x-valid' , results['valid_acc1es'], results['valid_losses'])
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format('cifar10', 'test')], network.cuda())
xresult.update_OLD_eval('ori-test', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
xresult.update_latency(latencies)
elif dataset == 'cifar10':
xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses'])
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda())
xresult.update_latency(latencies)
elif dataset == 'cifar100' or dataset == 'ImageNet16-120':
xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses'])
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'valid')], network.cuda())
xresult.update_OLD_eval('x-valid', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda())
xresult.update_OLD_eval('x-test' , {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
xresult.update_latency(latencies)
else:
raise ValueError('invalid dataset name : {:}'.format(dataset))
return xresult
def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dict):
information = ArchResults(arch_index, arch_str)
for checkpoint_path in checkpoints:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
for dataset in datasets:
assert dataset in checkpoint, 'Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path)
results = checkpoint[dataset]
assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path)
arch_config = {'channel': results['channel'], 'num_cells': results['num_cells'], 'arch_str': arch_str, 'class_num': results['config']['class_num']}
xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict)
information.update(dataset, int(used_seed), xresult)
return information
def GET_DataLoaders(workers):
torch.set_num_threads(workers)
root_dir = (Path(__file__).parent / '..' / '..').resolve()
torch_dir = Path(os.environ['TORCH_HOME'])
# cifar
cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config'
cifar_config = load_config(cifar_config_path, None, None)
print ('{:} Create data-loader for all datasets'.format(time_string()))
print ('-'*200)
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1)
print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num))
cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None)
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14]
temp_dataset = deepcopy(TRAIN_CIFAR10)
temp_dataset.transform = VALID_CIFAR10.transform
# data loader
trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
train_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True)
valid_cifar10_loader = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True)
test__cifar10_loader = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
print ('CIFAR-10 : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size))
print ('CIFAR-10 : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size))
print ('CIFAR-10 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size))
print ('CIFAR-10 : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size))
print ('-'*200)
# CIFAR-100
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1)
print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num))
cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None)
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24]
train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True)
test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True)
print ('CIFAR-100 : train-loader has {:3d} batch'.format(len(train_cifar100_loader)))
print ('CIFAR-100 : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader)))
print ('CIFAR-100 : test--loader has {:3d} batch'.format(len(test__cifar100_loader)))
print ('-'*200)
imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config'
imagenet16_config = load_config(imagenet16_config_path, None, None)
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1)
print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num))
imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None)
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20]
train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True)
test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True)
print ('ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size))
print ('ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size))
print ('ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size))
# 'cifar10', 'cifar100', 'ImageNet16-120'
loaders = {'cifar10@trainval': trainval_cifar10_loader,
'cifar10@train' : train_cifar10_loader,
'cifar10@valid' : valid_cifar10_loader,
'cifar10@test' : test__cifar10_loader,
'cifar100@train' : train_cifar100_loader,
'cifar100@valid' : valid_cifar100_loader,
'cifar100@test' : test__cifar100_loader,
'ImageNet16-120@train': train_imagenet_loader,
'ImageNet16-120@valid': valid_imagenet_loader,
'ImageNet16-120@test' : test__imagenet_loader}
return loaders
def simplify(save_dir, meta_file, basestr, target_dir):
meta_infos = torch.load(meta_file, map_location='cpu')
meta_archs = meta_infos['archs'] # a list of architecture strings
meta_num_archs = meta_infos['total']
meta_max_node = meta_infos['max_node']
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
num_seeds = defaultdict(lambda: 0)
for index, sub_dir in enumerate(sub_model_dirs):
xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth'))
arch_indexes = set()
for checkpoint in xcheckpoints:
temp_names = checkpoint.name.split('-')
assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name)
arch_indexes.add( temp_names[1] )
subdir2archs[sub_dir] = sorted(list(arch_indexes))
num_evaluated_arch += len(arch_indexes)
# count number of seeds for each architecture
for arch_index in arch_indexes:
num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1
print('{:} There are {:5d} architectures that have been evaluated ({:} in total).'.format(time_string(), num_evaluated_arch, meta_num_archs))
for key in sorted( list( num_seeds.keys() ) ): print ('{:} There are {:5d} architectures that are evaluated {:} times.'.format(time_string(), num_seeds[key], key))
dataloader_dict = GET_DataLoaders( 6 )
to_save_simply = save_dir / 'simplifies'
to_save_allarc = save_dir / 'simplifies' / 'architectures'
if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True)
if not to_save_allarc.exists(): to_save_allarc.mkdir(parents=True, exist_ok=True)
assert (save_dir / target_dir) in subdir2archs, 'can not find {:}'.format(target_dir)
arch2infos, datasets = {}, ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120')
evaluated_indexes = set()
target_directory = save_dir / target_dir
target_less_dir = save_dir / '{:}-LESS'.format(target_dir)
arch_indexes = subdir2archs[ target_directory ]
num_seeds = defaultdict(lambda: 0)
end_time = time.time()
arch_time = AverageMeter()
for idx, arch_index in enumerate(arch_indexes):
checkpoints = list(target_directory.glob('arch-{:}-seed-*.pth'.format(arch_index)))
ckps_less = list(target_less_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))
# create the arch info for each architecture
try:
arch_info_full = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict)
arch_info_less = account_one_arch(arch_index, meta_archs[int(arch_index)], ckps_less, ['cifar10-valid'], dataloader_dict)
num_seeds[ len(checkpoints) ] += 1
except:
print('Loading {:} failed, : {:}'.format(arch_index, checkpoints))
continue
assert int(arch_index) not in evaluated_indexes, 'conflict arch-index : {:}'.format(arch_index)
assert 0 <= int(arch_index) < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index)
arch_info = {'full': arch_info_full, 'less': arch_info_less}
evaluated_indexes.add( int(arch_index) )
arch2infos[int(arch_index)] = arch_info
torch.save({'full': arch_info_full.state_dict(),
'less': arch_info_less.state_dict()}, to_save_allarc / '{:}-FULL.pth'.format(arch_index))
arch_info['full'].clear_params()
arch_info['less'].clear_params()
torch.save({'full': arch_info_full.state_dict(),
'less': arch_info_less.state_dict()}, to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index))
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = '{:}'.format( convert_secs2time(arch_time.avg * (len(arch_indexes)-idx-1), True) )
print('{:} {:} [{:03d}/{:03d}] : {:} still need {:}'.format(time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time))
# measure time
xstrs = ['{:}:{:03d}'.format(key, num_seeds[key]) for key in sorted( list( num_seeds.keys() ) ) ]
print('{:} {:} done : {:}'.format(time_string(), target_dir, xstrs))
final_infos = {'meta_archs' : meta_archs,
'total_archs': meta_num_archs,
'basestr' : basestr,
'arch2infos' : arch2infos,
'evaluated_indexes': evaluated_indexes}
save_file_name = to_save_simply / '{:}.pth'.format(target_dir)
torch.save(final_infos, save_file_name)
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
def merge_all(save_dir, meta_file, basestr):
meta_infos = torch.load(meta_file, map_location='cpu')
meta_archs = meta_infos['archs']
meta_num_archs = meta_infos['total']
meta_max_node = meta_infos['max_node']
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
for index, sub_dir in enumerate(sub_model_dirs):
arch_info_files = sorted( list(sub_dir.glob('arch-*-seed-*.pth') ) )
print ('The {:02d}/{:02d}-th directory : {:} : {:} runs.'.format(index, len(sub_model_dirs), sub_dir, len(arch_info_files)))
arch2infos, evaluated_indexes = dict(), set()
for IDX, sub_dir in enumerate(sub_model_dirs):
ckp_path = sub_dir.parent / 'simplifies' / '{:}.pth'.format(sub_dir.name)
if ckp_path.exists():
sub_ckps = torch.load(ckp_path, map_location='cpu')
assert sub_ckps['total_archs'] == meta_num_archs and sub_ckps['basestr'] == basestr
xarch2infos = sub_ckps['arch2infos']
xevalindexs = sub_ckps['evaluated_indexes']
for eval_index in xevalindexs:
assert eval_index not in evaluated_indexes and eval_index not in arch2infos
#arch2infos[eval_index] = xarch2infos[eval_index].state_dict()
arch2infos[eval_index] = {'full': xarch2infos[eval_index]['full'].state_dict(),
'less': xarch2infos[eval_index]['less'].state_dict()}
evaluated_indexes.add( eval_index )
print ('{:} [{:03d}/{:03d}] merge data from {:} with {:} models.'.format(time_string(), IDX, len(sub_model_dirs), ckp_path, len(xevalindexs)))
else:
raise ValueError('Can not find {:}'.format(ckp_path))
#print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path))
evaluated_indexes = sorted( list( evaluated_indexes ) )
print ('Finally, there are {:} architectures that have been trained and evaluated.'.format(len(evaluated_indexes)))
to_save_simply = save_dir / 'simplifies'
if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True)
final_infos = {'meta_archs' : meta_archs,
'total_archs': meta_num_archs,
'arch2infos' : arch2infos,
'evaluated_indexes': evaluated_indexes}
save_file_name = to_save_simply / '{:}-final-infos.pth'.format(basestr)
torch.save(final_infos, save_file_name)
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--mode' , type=str, choices=['cal', 'merge'], help='The running mode for this script.')
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--target_dir' , type=str, help='The target directory.')
parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.')
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
args = parser.parse_args()
save_dir = Path( args.base_save_dir )
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
print ('start the statistics of our nas-benchmark from {:} using {:}.'.format(save_dir, args.target_dir))
basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells)
if args.mode == 'cal':
simplify(save_dir, meta_path, basestr, args.target_dir)
elif args.mode == 'merge':
merge_all(save_dir, meta_path, basestr)
else:
raise ValueError('invalid mode : {:}'.format(args.mode))

View File

@@ -0,0 +1,223 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
########################################################
# python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
########################################################
import os, sys, time, glob, random, argparse
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces, CellStructure
from nas_201_api import NASBench201API as API
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
elif xargs.dataset.startswith('ImageNet16'):
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
imagenet16_split = load_config(split_Fpath, None, None)
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
else:
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
config_path = 'configs/nas-benchmark/algos/DARTS.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space}, None)
search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('a-optimizer : {:}'.format(a_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
flop, param = get_model_infos(search_model, xshape)
#logger.log('{:}'.format(search_model))
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
logger.close()
def check_unique_arch(meta_file):
api = API(str(meta_file))
arch_strs = deepcopy(api.meta_archs)
xarchs = [CellStructure.str2structure(x) for x in arch_strs]
def get_unique_matrix(archs, consider_zero):
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs)))
Unique2Index = dict()
for index, xstr in enumerate(UniquStrs):
if xstr not in Unique2Index: Unique2Index[xstr] = list()
Unique2Index[xstr].append( index )
sm_matrix = torch.eye(len(archs)).bool()
for _, xlist in Unique2Index.items():
for i in xlist:
for j in xlist:
sm_matrix[i,j] = True
unique_ids, unique_num = [-1 for _ in archs], 0
for i in range(len(unique_ids)):
if unique_ids[i] > -1: continue
neighbours = sm_matrix[i].nonzero().view(-1).tolist()
for nghb in neighbours:
assert unique_ids[nghb] == -1, 'impossible'
unique_ids[nghb] = unique_num
unique_num += 1
return sm_matrix, unique_ids, unique_num
print ('There are {:} valid-archs'.format( sum(arch.check_valid() for arch in xarchs) ))
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None)
print ('{:} There are {:} unique architectures (considering nothing).'.format(time_string(), unique_num))
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False)
print ('{:} There are {:} unique architectures (not considering zero).'.format(time_string(), unique_num))
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True)
print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num))
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False):
if isinstance(meta_file, API):
api = meta_file
else:
api = API(str(meta_file))
cifar10_currs = []
cifar10_valid = []
cifar10_test = []
cifar100_valid = []
cifar100_test = []
imagenet_test = []
imagenet_valid = []
for idx, arch in enumerate(api):
results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand)
cifar10_currs.append( results['valid-accuracy'] )
# --->>>>>
results = api.get_more_info(idx, 'cifar10-valid' , None, False, is_rand)
cifar10_valid.append( results['valid-accuracy'] )
results = api.get_more_info(idx, 'cifar10' , None, False, is_rand)
cifar10_test.append( results['test-accuracy'] )
results = api.get_more_info(idx, 'cifar100' , None, False, is_rand)
cifar100_test.append( results['test-accuracy'] )
cifar100_valid.append( results['valid-accuracy'] )
results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand)
imagenet_test.append( results['test-accuracy'] )
imagenet_valid.append( results['valid-accuracy'] )
def get_cor(A, B):
return float(np.corrcoef(A, B)[0,1])
cors = []
for basestr, xlist in zip(['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]):
correlation = get_cor(cifar10_currs, xlist)
if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation))
cors.append( correlation )
#print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
#print('-'*200)
#print('*'*230)
return cors
def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
corrs = []
for i in tqdm(range(100)):
x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
corrs.append( x )
#xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
xstrs = ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
correlations = np.array(corrs)
print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200'))
for idx, xstr in enumerate(xstrs):
print ('{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}'.format(xstr, correlations[:,idx].mean(), correlations[:,idx].std(), correlations[:,idx].mean(), correlations[:,idx].std()))
print('')
if __name__ == '__main__':
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.')
args = parser.parse_args()
vis_save_dir = Path(args.save_dir)
vis_save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.api_path)
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
#check_unique_arch(meta_file)
api = API(str(meta_file))
#for iepoch in [11, 25, 50, 100, 150, 175, 200]:
# check_cor_for_bandit(api, 6, iepoch)
# check_cor_for_bandit(api, 12, iepoch)
check_cor_for_bandit_v2(api, 6, True, True)
check_cor_for_bandit_v2(api, 12, True, True)
check_cor_for_bandit_v2(api, 12, False, True)
check_cor_for_bandit_v2(api, 24, False, True)
check_cor_for_bandit_v2(api, 100, False, True)
check_cor_for_bandit_v2(api, 150, False, True)
check_cor_for_bandit_v2(api, 175, False, True)
check_cor_for_bandit_v2(api, 200, False, True)
print('----')

View File

@@ -0,0 +1,740 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
##################################################
import os, sys, time, argparse, collections
from tqdm import tqdm
from collections import OrderedDict
import numpy as np
import torch
from pathlib import Path
from collections import defaultdict
import matplotlib
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
matplotlib.use('agg')
import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from log_utils import time_string
from nas_201_api import NASBench201API as API
def calculate_correlation(*vectors):
matrix = []
for i, vectori in enumerate(vectors):
x = []
for j, vectorj in enumerate(vectors):
x.append( np.corrcoef(vectori, vectorj)[0,1] )
matrix.append( x )
return np.array(matrix)
def visualize_relative_ranking(vis_save_dir):
print ('\n' + '-'*100)
cifar010_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar10')
cifar100_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar100')
imagenet_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('ImageNet16-120')
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info['params'])))
print ('{:} start to visualize relative ranking'.format(time_string()))
# maximum accuracy with ResNet-level params 11472
x_010_accs = [ cifar010_info['test_accs'][i] if cifar010_info['params'][i] <= cifar010_info['params'][11472] else -1 for i in indexes]
x_100_accs = [ cifar100_info['test_accs'][i] if cifar100_info['params'][i] <= cifar100_info['params'][11472] else -1 for i in indexes]
x_img_accs = [ imagenet_info['test_accs'][i] if imagenet_info['params'][i] <= imagenet_info['params'][11472] else -1 for i in indexes]
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
cifar100_labels, imagenet_labels = [], []
for idx in cifar010_ord_indexes:
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
print ('{:} prepare data done.'.format(time_string()))
dpi, width, height = 300, 2600, 2600
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 18
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
#plt.ylabel('y').set_rotation(0)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical')
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize)
#ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100')
#ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8, label='ImageNet-16-120')
#ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10')
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
save_path = (vis_save_dir / 'relative-rank.pdf').resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / 'relative-rank.png').resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
# calculate correlation
sns_size = 15
CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs'])
fig = plt.figure(figsize=figsize)
plt.axis('off')
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5)
save_path = (vis_save_dir / 'co-relation-all.pdf').resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
print ('{:} save into {:}'.format(time_string(), save_path))
# calculate correlation
acc_bars = [92, 93]
for acc_bar in acc_bars:
selected_indexes = []
for i, acc in enumerate(cifar010_info['test_accs']):
if acc > acc_bar: selected_indexes.append( i )
print ('select {:} architectures'.format(len(selected_indexes)))
cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ]
cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ]
cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ]
cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ]
imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ]
imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ]
CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs)
fig = plt.figure(figsize=figsize)
plt.axis('off')
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5)
save_path = (vis_save_dir / 'co-relation-top-{:}.pdf'.format(len(selected_indexes))).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
def visualize_info(meta_file, dataset, vis_save_dir):
print ('{:} start to visualize {:} information'.format(time_string(), dataset))
cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset)
if not cache_file_path.exists():
print ('Do not find cache file : {:}'.format(cache_file_path))
nas_bench = API(str(meta_file))
params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], []
for index in range( len(nas_bench) ):
info = nas_bench.query_by_index(index, use_12epochs_result=False)
resx = info.get_comput_costs(dataset) ; flop, param = resx['flops'], resx['params']
if dataset == 'cifar10':
res = info.get_metrics('cifar10', 'train') ; train_acc = res['accuracy']
res = info.get_metrics('cifar10-valid', 'x-valid') ; valid_acc = res['accuracy']
res = info.get_metrics('cifar10', 'ori-test') ; test_acc = res['accuracy']
res = info.get_metrics('cifar10', 'ori-test') ; otest_acc = res['accuracy']
else:
res = info.get_metrics(dataset, 'train') ; train_acc = res['accuracy']
res = info.get_metrics(dataset, 'x-valid') ; valid_acc = res['accuracy']
res = info.get_metrics(dataset, 'x-test') ; test_acc = res['accuracy']
res = info.get_metrics(dataset, 'ori-test') ; otest_acc = res['accuracy']
if index == 11472: # resnet
resnet = {'params':param, 'flops': flop, 'index': 11472, 'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': test_acc, 'otest_acc': otest_acc}
flops.append( flop )
params.append( param )
train_accs.append( train_acc )
valid_accs.append( valid_acc )
test_accs.append( test_acc )
otest_accs.append( otest_acc )
#resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97}
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs}
info['resnet'] = resnet
torch.save(info, cache_file_path)
else:
print ('Find cache file : {:}'.format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs']
resnet = info['resnet']
print ('{:} collect data done.'.format(time_string()))
indexes = list(range(len(params)))
dpi, width, height = 300, 2600, 2600
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 22, 22
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
if dataset == 'cifar10':
plt.ylim(50, 100)
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
elif dataset == 'cifar100':
plt.ylim(25, 75)
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
else:
plt.ylim(0, 50)
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue')
ax.scatter([resnet['params']], [resnet['valid_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=0.4)
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-param-vs-valid.pdf'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-param-vs-valid.png'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
if dataset == 'cifar10':
plt.ylim(50, 100)
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
elif dataset == 'cifar100':
plt.ylim(25, 75)
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
else:
plt.ylim(0, 50)
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
ax.scatter([resnet['params']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
plt.grid()
ax.set_axisbelow(True)
plt.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-param-vs-test.pdf'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-param-vs-test.png'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
if dataset == 'cifar10':
plt.ylim(50, 100)
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
elif dataset == 'cifar100':
plt.ylim(20, 100)
plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize)
else:
plt.ylim(25, 76)
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
ax.scatter([resnet['params']], [resnet['train_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
plt.grid()
ax.set_axisbelow(True)
plt.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-param-vs-train.pdf'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-param-vs-train.png'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(0, max(indexes))
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
if dataset == 'cifar10':
plt.ylim(50, 100)
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
elif dataset == 'cifar100':
plt.ylim(25, 75)
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
else:
plt.ylim(0, 50)
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
ax.scatter([resnet['index']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
plt.grid()
ax.set_axisbelow(True)
plt.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel('architecture ID', fontsize=LabelSize)
ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-test-over-ID.pdf'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-test-over-ID.png'.format(dataset)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
def visualize_rank_over_time(meta_file, vis_save_dir):
print ('\n' + '-'*150)
vis_save_dir.mkdir(parents=True, exist_ok=True)
print ('{:} start to visualize rank-over-time into {:}'.format(time_string(), vis_save_dir))
cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth'
if not cache_file_path.exists():
print ('Do not find cache file : {:}'.format(cache_file_path))
nas_bench = API(str(meta_file))
print ('{:} load nas_bench done'.format(time_string()))
params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
#for iepoch in range(200): for index in range( len(nas_bench) ):
for index in tqdm(range(len(nas_bench))):
info = nas_bench.query_by_index(index, use_12epochs_result=False)
for iepoch in range(200):
res = info.get_metrics('cifar10' , 'train' , iepoch) ; train_acc = res['accuracy']
res = info.get_metrics('cifar10-valid', 'x-valid' , iepoch) ; valid_acc = res['accuracy']
res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; test_acc = res['accuracy']
res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; otest_acc = res['accuracy']
train_accs[iepoch].append( train_acc )
valid_accs[iepoch].append( valid_acc )
test_accs [iepoch].append( test_acc )
otest_accs[iepoch].append( otest_acc )
if iepoch == 0:
res = info.get_comput_costs('cifar10') ; flop, param = res['flops'], res['params']
flops.append( flop )
params.append( param )
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs}
torch.save(info, cache_file_path)
else:
print ('Find cache file : {:}'.format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs']
print ('{:} collect data done.'.format(time_string()))
#selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
selected_epochs = list( range(200) )
x_xtests = test_accs[199]
indexes = list(range(len(x_xtests)))
ord_idxs = sorted(indexes, key=lambda i: x_xtests[i])
for sepoch in selected_epochs:
x_valids = valid_accs[sepoch]
valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i])
valid_ord_lbls = []
for idx in ord_idxs:
valid_ord_lbls.append( valid_ord_idxs.index(idx) )
# labeled data
dpi, width, height = 300, 2600, 2600
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 18
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical')
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize)
ax.scatter(indexes, valid_ord_lbls, marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-10 validation')
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10 test')
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc='upper left', fontsize=LegendFontsize)
ax.set_xlabel('architecture ranking in the final test accuracy', fontsize=LabelSize)
ax.set_ylabel('architecture ranking in the validation set', fontsize=LabelSize)
save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
def write_video(save_dir):
import cv2
video_save_path = save_dir / 'time.avi'
print ('{:} start create video for {:}'.format(time_string(), video_save_path))
images = sorted( list( save_dir.glob('time-*.png') ) )
ximage = cv2.imread(str(images[0]))
#shape = (ximage.shape[1], ximage.shape[0])
shape = (1000, 1000)
#writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape)
writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape)
for idx, image in enumerate(images):
ximage = cv2.imread(str(image))
_image = cv2.resize(ximage, shape)
writer.write(_image)
writer.release()
print ('write video [{:} frames] into {:}'.format(len(images), video_save_path))
def plot_results_nas_v2(api, dataset_xset_a, dataset_xset_b, root, file_name, y_lims):
#print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
print ('root-path : {:} and {:}'.format(dataset_xset_a, dataset_xset_b))
checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth',
'./output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth',
'./output/search-cell-nas-bench-201/RAND-cifar10/results.pth',
'./output/search-cell-nas-bench-201/BOHB-cifar10/results.pth'
]
legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
All_Accs_A, All_Accs_B = OrderedDict(), OrderedDict()
for legend, checkpoint in zip(legends, checkpoints):
all_indexes = torch.load(checkpoint, map_location='cpu')
accuracies_A, accuracies_B = [], []
accuracies = []
for x in all_indexes:
info = api.arch2infos_full[ x ]
metrics = info.get_metrics(dataset_xset_a[0], dataset_xset_a[1], None, False)
accuracies_A.append( metrics['accuracy'] )
metrics = info.get_metrics(dataset_xset_b[0], dataset_xset_b[1], None, False)
accuracies_B.append( metrics['accuracy'] )
accuracies.append( (accuracies_A[-1], accuracies_B[-1]) )
if indexes is None: indexes = list(range(len(all_indexes)))
accuracies = sorted(accuracies)
All_Accs_A[legend] = [x[0] for x in accuracies]
All_Accs_B[legend] = [x[1] for x in accuracies]
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
dpi, width, height = 300, 3400, 2600
LabelSize, LegendFontsize = 28, 28
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
x_axis = np.arange(0, 600)
plt.xlim(0, max(indexes))
plt.ylim(y_lims[0], y_lims[1])
interval_x, interval_y = 100, y_lims[2]
plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
plt.grid()
plt.xlabel('The index of runs', fontsize=LabelSize)
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
for idx, legend in enumerate(legends):
plt.plot(indexes, All_Accs_B[legend], color=color_set[idx], linestyle='--', label='{:}'.format(legend), lw=1, alpha=0.5)
plt.plot(indexes, All_Accs_A[legend], color=color_set[idx], linestyle='-', lw=1)
for All_Accs in [All_Accs_A, All_Accs_B]:
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend])))
plt.legend(loc=4, fontsize=LegendFontsize)
save_path = root / '{:}'.format(file_name)
print('save figure into {:}\n'.format(save_path))
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
def plot_results_nas(api, dataset, xset, root, file_name, y_lims):
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth',
'./output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth',
'./output/search-cell-nas-bench-201/RAND-cifar10/results.pth',
'./output/search-cell-nas-bench-201/BOHB-cifar10/results.pth'
]
legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
All_Accs = OrderedDict()
for legend, checkpoint in zip(legends, checkpoints):
all_indexes = torch.load(checkpoint, map_location='cpu')
accuracies = []
for x in all_indexes:
info = api.arch2infos_full[ x ]
metrics = info.get_metrics(dataset, xset, None, False)
accuracies.append( metrics['accuracy'] )
if indexes is None: indexes = list(range(len(all_indexes)))
All_Accs[legend] = sorted(accuracies)
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
dpi, width, height = 300, 3400, 2600
LabelSize, LegendFontsize = 28, 28
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
x_axis = np.arange(0, 600)
plt.xlim(0, max(indexes))
plt.ylim(y_lims[0], y_lims[1])
interval_x, interval_y = 100, y_lims[2]
plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
plt.grid()
plt.xlabel('The index of runs', fontsize=LabelSize)
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
for idx, legend in enumerate(legends):
plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2)
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend])))
plt.legend(loc=4, fontsize=LegendFontsize)
save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name)
print('save figure into {:}\n'.format(save_path))
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
def just_show(api):
xtimes = {'RSPS' : [8082.5, 7794.2, 8144.7],
'DARTS-V1': [11582.1, 11347.0, 11948.2],
'DARTS-V2': [35694.7, 36132.7, 35518.0],
'GDAS' : [31334.1, 31478.6, 32016.7],
'SETN' : [33528.8, 33831.5, 35058.3],
'ENAS' : [14340.2, 13817.3, 14018.9]}
for xkey, xlist in xtimes.items():
xlist = np.array(xlist)
print ('{:4s} : mean-time={:.2f} s'.format(xkey, xlist.mean()))
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/',
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/',
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/',
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/',
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/',
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/',
}
xseeds = {'RSPS' : [5349, 59613, 5983],
'DARTS-V1': [11416, 72873, 81184],
'DARTS-V2': [43330, 79405, 79423],
'GDAS' : [19677, 884, 95950],
'SETN' : [20518, 61817, 89144],
'ENAS' : [3231, 34238, 96929],
}
def get_accs(xdata, index=-1):
if index == -1:
epochs = xdata['epoch']
genotype = xdata['genotypes'][epochs-1]
index = api.query_index_by_arch(genotype)
pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')]
xresults = []
for dataset, xset in pairs:
metrics = api.arch2infos_full[index].get_metrics(dataset, xset, None, False)
xresults.append( metrics['accuracy'] )
return xresults
for xkey in xpaths.keys():
all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ]
all_datas = [torch.load(xpath) for xpath in all_paths]
accyss = [get_accs(xdatas) for xdatas in all_datas]
accyss = np.array( accyss )
print('\nxkey = {:}'.format(xkey))
for i in range(accyss.shape[1]): print('---->>>> {:.2f}$\\pm${:.2f}'.format(accyss[:,i].mean(), accyss[:,i].std()))
print('\n{:}'.format(get_accs(None, 11472))) # resnet
pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')]
for dataset, metric_on_set in pairs:
arch_index, highest_acc = api.find_best(dataset, metric_on_set)
print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc))
def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_maxs):
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
dpi, width, height = 300, 3400, 2600
LabelSize, LegendFontsize = 28, 28
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
#x_maxs = 250
plt.xlim(0, x_maxs+1)
plt.ylim(y_lims[0], y_lims[1])
interval_x, interval_y = x_maxs // 5, y_lims[2]
plt.xticks(np.arange(0, x_maxs+1, interval_x), fontsize=LegendFontsize)
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
plt.grid()
plt.xlabel('The searching epoch', fontsize=LabelSize)
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/',
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/',
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/',
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/',
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/',
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/',
}
xseeds = {'RSPS' : [5349, 59613, 5983],
'DARTS-V1': [11416, 72873, 81184, 28640],
'DARTS-V2': [43330, 79405, 79423],
'GDAS' : [19677, 884, 95950],
'SETN' : [20518, 61817, 89144],
'ENAS' : [3231, 34238, 96929],
}
def get_accs(xdata):
epochs, xresults = xdata['epoch'], []
if -1 in xdata['genotypes']:
metrics = api.arch2infos_full[ api.query_index_by_arch(xdata['genotypes'][-1]) ].get_metrics(dataset, subset, None, False)
else:
metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
xresults.append( metrics['accuracy'] )
for iepoch in range(epochs):
genotype = xdata['genotypes'][iepoch]
index = api.query_index_by_arch(genotype)
metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False)
xresults.append( metrics['accuracy'] )
return xresults
if x_maxs == 50:
xox, xxxstrs = 'v2', ['DARTS-V1', 'DARTS-V2']
elif x_maxs == 250:
xox, xxxstrs = 'v1', ['RSPS', 'GDAS', 'SETN', 'ENAS']
else: raise ValueError('invalid x_maxs={:}'.format(x_maxs))
for idx, method in enumerate(xxxstrs):
xkey = method
all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ]
all_datas = [torch.load(xpath, map_location='cpu') for xpath in all_paths]
accyss = [get_accs(xdatas) for xdatas in all_datas]
accyss = np.array( accyss )
epochs = list(range(accyss.shape[1]))
plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx], linestyle='-', label='{:}'.format(method), lw=2)
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx])
#plt.legend(loc=4, fontsize=LegendFontsize)
plt.legend(loc=0, fontsize=LegendFontsize)
save_path = vis_save_dir / '{:}-{:}-{:}-{:}'.format(xox, dataset, subset, file_name)
print('save figure into {:}\n'.format(save_path))
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name, y_lims, x_maxs):
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
dpi, width, height = 300, 3400, 2600
LabelSize, LegendFontsize = 28, 28
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
#x_maxs = 250
plt.xlim(0, x_maxs+1)
plt.ylim(y_lims[0], y_lims[1])
interval_x, interval_y = x_maxs // 5, y_lims[2]
plt.xticks(np.arange(0, x_maxs+1, interval_x), fontsize=LegendFontsize)
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
plt.grid()
plt.xlabel('The searching epoch', fontsize=LabelSize)
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/',
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/',
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/',
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/',
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/',
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/',
}
xseeds = {'RSPS' : [5349, 59613, 5983],
'DARTS-V1': [11416, 72873, 81184, 28640],
'DARTS-V2': [43330, 79405, 79423],
'GDAS' : [19677, 884, 95950],
'SETN' : [20518, 61817, 89144],
'ENAS' : [3231, 34238, 96929],
}
def get_accs(xdata, dataset, subset):
epochs, xresults = xdata['epoch'], []
if -1 in xdata['genotypes']:
metrics = api.arch2infos_full[ api.query_index_by_arch(xdata['genotypes'][-1]) ].get_metrics(dataset, subset, None, False)
else:
metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
xresults.append( metrics['accuracy'] )
for iepoch in range(epochs):
genotype = xdata['genotypes'][iepoch]
index = api.query_index_by_arch(genotype)
metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False)
xresults.append( metrics['accuracy'] )
return xresults
if x_maxs == 50:
xox, xxxstrs = 'v2', ['DARTS-V1', 'DARTS-V2']
elif x_maxs == 250:
xox, xxxstrs = 'v1', ['RSPS', 'GDAS', 'SETN', 'ENAS']
else: raise ValueError('invalid x_maxs={:}'.format(x_maxs))
for idx, method in enumerate(xxxstrs):
xkey = method
all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ]
all_datas = [torch.load(xpath, map_location='cpu') for xpath in all_paths]
accyss_A = np.array( [get_accs(xdatas, data_sub_a[0], data_sub_a[1]) for xdatas in all_datas] )
accyss_B = np.array( [get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas] )
epochs = list(range(accyss_A.shape[1]))
for j, accyss in enumerate([accyss_A, accyss_B]):
plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx*2+j], linestyle='-' if j==0 else '--', label='{:} ({:})'.format(method, 'VALID' if j == 0 else 'TEST'), lw=2, alpha=0.9)
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx*2+j])
#plt.legend(loc=4, fontsize=LegendFontsize)
plt.legend(loc=0, fontsize=LegendFontsize)
save_path = vis_save_dir / '{:}-{:}'.format(xox, file_name)
print('save figure into {:}\n'.format(save_path))
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
def show_reinforce(api, root, dataset, xset, file_name, y_lims):
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
LRs = ['0.01', '0.02', '0.1', '0.2', '0.5', '1.0', '1.5', '2.0', '2.5', '3.0']
checkpoints = ['./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth'.format(x) for x in LRs]
acc_lr_dict, indexes = {}, None
for lr, checkpoint in zip(LRs, checkpoints):
all_indexes, accuracies = torch.load(checkpoint, map_location='cpu'), []
for x in all_indexes:
info = api.arch2infos_full[ x ]
metrics = info.get_metrics(dataset, xset, None, False)
accuracies.append( metrics['accuracy'] )
if indexes is None: indexes = list(range(len(accuracies)))
acc_lr_dict[lr] = np.array( sorted(accuracies) )
print ('LR={:.3f}, mean={:}, std={:}'.format(float(lr), acc_lr_dict[lr].mean(), acc_lr_dict[lr].std()))
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
dpi, width, height = 300, 3400, 2600
LabelSize, LegendFontsize = 28, 22
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
x_axis = np.arange(0, 600)
plt.xlim(0, max(indexes))
plt.ylim(y_lims[0], y_lims[1])
interval_x, interval_y = 100, y_lims[2]
plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
plt.grid()
plt.xlabel('The index of runs', fontsize=LabelSize)
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
for idx, LR in enumerate(LRs):
legend = 'LR={:.2f}'.format(float(LR))
color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8)
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]), np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR])))
plt.legend(loc=4, fontsize=LegendFontsize)
save_path = root / '{:}-{:}-{:}.pdf'.format(dataset, xset, file_name)
print('save figure into {:}\n'.format(save_path))
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.')
args = parser.parse_args()
vis_save_dir = Path(args.save_dir)
vis_save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.api_path)
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
#visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time')
#write_video(vis_save_dir / 'over-time')
#visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
#visualize_info(str(meta_file), 'cifar100', vis_save_dir)
#visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
#visualize_relative_ranking(vis_save_dir)
api = API(args.api_path)
show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (75, 95, 5))
import pdb; pdb.set_trace()
for x_maxs in [50, 250]:
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'DARTS-CIFAR010.pdf', (0, 100,10), 50)
show_nas_sharing_w_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ) , vis_save_dir, 'DARTS-CIFAR100.pdf', (0, 100,10), 50)
show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ) , vis_save_dir, 'DARTS-ImageNet.pdf', (0, 100,10), 50)
#just_show(api)
"""
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1))
plot_results_nas_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3))
plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2))
"""