102->201 / NAS->autoDL / more configs of TAS / reorganize docs / fix bugs in NAS baselines
This commit is contained in:
84
exps/NAS-Bench-201/check.py
Normal file
84
exps/NAS-Bench-201/check.py
Normal file
@@ -0,0 +1,84 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# python exps/NAS-Bench-201/check.py --base_save_dir
|
||||
##################################################
|
||||
import os, sys, time, argparse, collections
|
||||
from shutil import copyfile
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
|
||||
|
||||
def check_files(save_dir, meta_file, basestr):
|
||||
meta_infos = torch.load(meta_file, map_location='cpu')
|
||||
meta_archs = meta_infos['archs']
|
||||
meta_num_archs = meta_infos['total']
|
||||
meta_max_node = meta_infos['max_node']
|
||||
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
|
||||
|
||||
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
|
||||
print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
|
||||
|
||||
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
|
||||
num_seeds = defaultdict(lambda: 0)
|
||||
for index, sub_dir in enumerate(sub_model_dirs):
|
||||
xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth'))
|
||||
#xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth'))
|
||||
arch_indexes = set()
|
||||
for checkpoint in xcheckpoints:
|
||||
temp_names = checkpoint.name.split('-')
|
||||
assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name)
|
||||
arch_indexes.add( temp_names[1] )
|
||||
subdir2archs[sub_dir] = sorted(list(arch_indexes))
|
||||
num_evaluated_arch += len(arch_indexes)
|
||||
# count number of seeds for each architecture
|
||||
for arch_index in arch_indexes:
|
||||
num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1
|
||||
print('There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).'.format(num_evaluated_arch, meta_num_archs, sum(k*v for k, v in num_seeds.items())))
|
||||
for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key))
|
||||
|
||||
dir2ckps, dir2ckp_exists = dict(), dict()
|
||||
start_time, epoch_time = time.time(), AverageMeter()
|
||||
for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()):
|
||||
seeds = [777, 888, 999]
|
||||
numrs = defaultdict(lambda: 0)
|
||||
all_checkpoints, all_ckp_exists = [], []
|
||||
for arch_index in arch_indexes:
|
||||
checkpoints = ['arch-{:}-seed-{:04d}.pth'.format(arch_index, seed) for seed in seeds]
|
||||
ckp_exists = [(sub_dir/x).exists() for x in checkpoints]
|
||||
arch_index = int(arch_index)
|
||||
assert 0 <= arch_index < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index)
|
||||
all_checkpoints += checkpoints
|
||||
all_ckp_exists += ckp_exists
|
||||
numrs[sum(ckp_exists)] += 1
|
||||
dir2ckps[ str(sub_dir) ] = all_checkpoints
|
||||
dir2ckp_exists[ str(sub_dir) ] = all_ckp_exists
|
||||
# measure time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
numrstr = ', '.join( ['{:}: {:03d}'.format(x, numrs[x]) for x in sorted(numrs.keys())] )
|
||||
print('{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}'.format(time_string(), IDX+1, len(subdir2archs), len(arch_indexes), len(all_checkpoints), sum(all_ckp_exists), sub_dir, convert_secs2time(epoch_time.avg * (len(subdir2archs)-IDX-1), True), numrstr))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.')
|
||||
parser.add_argument('--max_node', type=int, default=4, help='The maximum node in a cell.')
|
||||
parser.add_argument('--channel', type=int, default=16, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, default=5, help='The number of cells in one stage.')
|
||||
args = parser.parse_args()
|
||||
|
||||
save_dir = Path( args.base_save_dir )
|
||||
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
|
||||
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
|
||||
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
|
||||
print ('check NAS-Bench-201 in {:}'.format(save_dir))
|
||||
|
||||
basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells)
|
||||
check_files(save_dir, meta_path, basestr)
|
28
exps/NAS-Bench-201/dist-setup.py
Normal file
28
exps/NAS-Bench-201/dist-setup.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
from setuptools import setup
|
||||
|
||||
|
||||
def read(fname='README.md'):
|
||||
with open(os.path.join(os.path.dirname(__file__), fname), encoding='utf-8') as cfile:
|
||||
return cfile.read()
|
||||
|
||||
|
||||
setup(
|
||||
name = "nas_bench_201",
|
||||
version = "1.0",
|
||||
author = "Xuanyi Dong",
|
||||
author_email = "dongxuanyi888@gmail.com",
|
||||
description = "API for NAS-Bench-201 (a benchmark for neural architecture search).",
|
||||
license = "MIT",
|
||||
keywords = "NAS Dataset API DeepLearning",
|
||||
url = "https://github.com/D-X-Y/NAS-Bench-201",
|
||||
packages=['nas_201_api'],
|
||||
long_description=read('README.md'),
|
||||
long_description_content_type='text/markdown',
|
||||
classifiers=[
|
||||
"Programming Language :: Python",
|
||||
"Topic :: Database",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
],
|
||||
)
|
136
exps/NAS-Bench-201/functions.py
Normal file
136
exps/NAS-Bench-201/functions.py
Normal file
@@ -0,0 +1,136 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from config_utils import dict2config
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
|
||||
__all__ = ['evaluate_for_seed', 'pure_evaluate']
|
||||
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
latencies = []
|
||||
network.eval()
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
batch_time.update(time.time() - end)
|
||||
if batch is None or batch == inputs.size(0):
|
||||
batch = inputs.size(0)
|
||||
latencies.append( batch_time.val - data_time.val )
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
end = time.time()
|
||||
if len(latencies) > 2: latencies = latencies[1:]
|
||||
return losses.avg, top1.avg, top5.avg, latencies
|
||||
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == 'train' : network.train()
|
||||
elif mode == 'valid': network.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# backward
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
# count time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
|
||||
|
||||
|
||||
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger):
|
||||
|
||||
prepare_seed(seed) # random seed
|
||||
net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny',
|
||||
'C': arch_config['channel'], 'N': arch_config['num_cells'],
|
||||
'genotype': arch, 'num_classes': config.class_num}
|
||||
, None)
|
||||
)
|
||||
#net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||
flop, param = get_model_infos(net, config.xshape)
|
||||
logger.log('Network : {:}'.format(net.get_message()), False)
|
||||
logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
|
||||
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
|
||||
# train and valid
|
||||
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config)
|
||||
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
|
||||
train_times , valid_times = {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
|
||||
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
|
||||
train_losses[epoch] = train_loss
|
||||
train_acc1es[epoch] = train_acc1
|
||||
train_acc5es[epoch] = train_acc5
|
||||
train_times [epoch] = train_tm
|
||||
with torch.no_grad():
|
||||
for key, xloder in valid_loaders.items():
|
||||
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder , network, criterion, None, None, 'valid')
|
||||
valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss
|
||||
valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1
|
||||
valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5
|
||||
valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) )
|
||||
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5))
|
||||
info_seed = {'flop' : flop,
|
||||
'param': param,
|
||||
'channel' : arch_config['channel'],
|
||||
'num_cells' : arch_config['num_cells'],
|
||||
'config' : config._asdict(),
|
||||
'total_epoch' : total_epoch ,
|
||||
'train_losses': train_losses,
|
||||
'train_acc1es': train_acc1es,
|
||||
'train_acc5es': train_acc5es,
|
||||
'train_times' : train_times,
|
||||
'valid_losses': valid_losses,
|
||||
'valid_acc1es': valid_acc1es,
|
||||
'valid_acc5es': valid_acc5es,
|
||||
'valid_times' : valid_times,
|
||||
'net_state_dict': net.state_dict(),
|
||||
'net_string' : '{:}'.format(net),
|
||||
'finish-train': True
|
||||
}
|
||||
return info_seed
|
316
exps/NAS-Bench-201/main.py
Normal file
316
exps/NAS-Bench-201/main.py
Normal file
@@ -0,0 +1,316 @@
|
||||
###############################################################
|
||||
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019-2020 #
|
||||
###############################################################
|
||||
import os, sys, time, torch, random, argparse
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config
|
||||
from procedures import save_checkpoint, copy_checkpoint
|
||||
from procedures import get_machine_info
|
||||
from datasets import get_datasets
|
||||
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||
from models import CellStructure, CellArchitectures, get_search_spaces
|
||||
from functions import evaluate_for_seed
|
||||
|
||||
|
||||
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger):
|
||||
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
||||
all_infos = {'info': machine_info}
|
||||
all_dataset_keys = []
|
||||
# look all the datasets
|
||||
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||
# train valid data
|
||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||||
# load the configurature
|
||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
|
||||
else : config_path = 'configs/nas-benchmark/CIFAR.config'
|
||||
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
||||
elif dataset.startswith('ImageNet16'):
|
||||
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
|
||||
else : config_path = 'configs/nas-benchmark/ImageNet-16.config'
|
||||
split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None)
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||
config = load_config(config_path, \
|
||||
{'class_num': class_num,
|
||||
'xshape' : xshape}, \
|
||||
logger)
|
||||
# check whether use splited validation set
|
||||
if bool(split):
|
||||
assert dataset == 'cifar10'
|
||||
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
|
||||
assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid))
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True)
|
||||
ValLoaders['x-valid'] = valid_loader
|
||||
else:
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
|
||||
if dataset == 'cifar10':
|
||||
ValLoaders = {'ori-test': valid_loader}
|
||||
elif dataset == 'cifar100':
|
||||
cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
|
||||
ValLoaders = {'ori-test': valid_loader,
|
||||
'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True),
|
||||
'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True)
|
||||
}
|
||||
elif dataset == 'ImageNet16-120':
|
||||
imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
|
||||
ValLoaders = {'ori-test': valid_loader,
|
||||
'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True),
|
||||
'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True)
|
||||
}
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||
|
||||
dataset_key = '{:}'.format(dataset)
|
||||
if bool(split): dataset_key = dataset_key + '-valid'
|
||||
logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config))
|
||||
for key, value in ValLoaders.items():
|
||||
logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
|
||||
results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger)
|
||||
all_infos[dataset_key] = results
|
||||
all_dataset_keys.append( dataset_key )
|
||||
all_infos['all_dataset_keys'] = all_dataset_keys
|
||||
return all_infos
|
||||
|
||||
|
||||
def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads( workers )
|
||||
|
||||
assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange)
|
||||
|
||||
if use_less:
|
||||
sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
|
||||
else:
|
||||
sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
|
||||
logger = Logger(str(sub_dir), 0, False)
|
||||
|
||||
all_archs = meta_info['archs']
|
||||
assert srange[1] < meta_info['total'], 'invalid range : {:}-{:} vs. {:}'.format(srange[0], srange[1], meta_info['total'])
|
||||
assert arch_index == -1 or srange[0] <= arch_index <= srange[1], 'invalid range : {:} vs. {:} vs. {:}'.format(srange[0], arch_index, srange[1])
|
||||
if arch_index == -1:
|
||||
to_evaluate_indexes = list(range(srange[0], srange[1]+1))
|
||||
else:
|
||||
to_evaluate_indexes = [arch_index]
|
||||
logger.log('xargs : seeds = {:}'.format(seeds))
|
||||
logger.log('xargs : arch_index = {:}'.format(arch_index))
|
||||
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
|
||||
logger.log('-'*100)
|
||||
|
||||
logger.log('Start evaluating range =: {:06d} vs. {:06d} vs. {:06d} / {:06d} with cover-mode={:}'.format(srange[0], arch_index, srange[1], meta_info['total'], cover_mode))
|
||||
for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
|
||||
logger.log('--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
|
||||
logger.log('--->>> architecture config : {:}'.format(arch_config))
|
||||
|
||||
|
||||
start_time, epoch_time = time.time(), AverageMeter()
|
||||
for i, index in enumerate(to_evaluate_indexes):
|
||||
arch = all_archs[index]
|
||||
logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seeds, '-'*15))
|
||||
#logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15))
|
||||
logger.log('{:} {:} {:}'.format('-'*15, arch, '-'*15))
|
||||
|
||||
# test this arch on different datasets with different seeds
|
||||
has_continue = False
|
||||
for seed in seeds:
|
||||
to_save_name = sub_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed)
|
||||
if to_save_name.exists():
|
||||
if cover_mode:
|
||||
logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name))
|
||||
os.remove(str(to_save_name))
|
||||
else :
|
||||
logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name))
|
||||
has_continue = True
|
||||
continue
|
||||
results = evaluate_all_datasets(CellStructure.str2structure(arch), \
|
||||
datasets, xpaths, splits, use_less, seed, \
|
||||
arch_config, workers, logger)
|
||||
torch.save(results, to_save_name)
|
||||
logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name))
|
||||
# measure elapsed time
|
||||
if not has_continue: epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) )
|
||||
logger.log('This arch costs : {:}'.format( convert_secs2time(epoch_time.val, True) ))
|
||||
logger.log('{:}'.format('*'*100))
|
||||
logger.log('{:} {:74s} {:}'.format('*'*10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(to_evaluate_indexes), index, meta_info['total'], need_time), '*'*10))
|
||||
logger.log('{:}'.format('*'*100))
|
||||
|
||||
logger.close()
|
||||
|
||||
|
||||
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
torch.set_num_threads( workers )
|
||||
|
||||
save_dir = Path(save_dir) / 'specifics' / '{:}-{:}-{:}-{:}'.format('LESS' if use_less else 'FULL', model_str, arch_config['channel'], arch_config['num_cells'])
|
||||
logger = Logger(str(save_dir), 0, False)
|
||||
if model_str in CellArchitectures:
|
||||
arch = CellArchitectures[model_str]
|
||||
logger.log('The model string is found in pre-defined architecture dict : {:}'.format(model_str))
|
||||
else:
|
||||
try:
|
||||
arch = CellStructure.str2structure(model_str)
|
||||
except:
|
||||
raise ValueError('Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
|
||||
assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
|
||||
logger.log('Start train-evaluate {:}'.format(arch.tostr()))
|
||||
logger.log('arch_config : {:}'.format(arch_config))
|
||||
|
||||
start_time, seed_time = time.time(), AverageMeter()
|
||||
for _is, seed in enumerate(seeds):
|
||||
logger.log('\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds), seed))
|
||||
to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
|
||||
if to_save_name.exists():
|
||||
logger.log('Find the existing file {:}, directly load!'.format(to_save_name))
|
||||
checkpoint = torch.load(to_save_name)
|
||||
else:
|
||||
logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
|
||||
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger)
|
||||
torch.save(checkpoint, to_save_name)
|
||||
# log information
|
||||
logger.log('{:}'.format(checkpoint['info']))
|
||||
all_dataset_keys = checkpoint['all_dataset_keys']
|
||||
for dataset_key in all_dataset_keys:
|
||||
logger.log('\n{:} dataset : {:} {:}'.format('-'*15, dataset_key, '-'*15))
|
||||
dataset_info = checkpoint[dataset_key]
|
||||
#logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
|
||||
logger.log('Flops = {:} MB, Params = {:} MB'.format(dataset_info['flop'], dataset_info['param']))
|
||||
logger.log('config : {:}'.format(dataset_info['config']))
|
||||
logger.log('Training State (finish) = {:}'.format(dataset_info['finish-train']))
|
||||
last_epoch = dataset_info['total_epoch'] - 1
|
||||
train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
|
||||
valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
|
||||
logger.log('Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%'.format(train_acc1es[last_epoch], train_acc5es[last_epoch], 100-train_acc1es[last_epoch], valid_acc1es[last_epoch], valid_acc5es[last_epoch], 100-valid_acc1es[last_epoch]))
|
||||
# measure elapsed time
|
||||
seed_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(seed_time.avg * (len(seeds)-_is-1), True) )
|
||||
logger.log('\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed, need_time))
|
||||
logger.close()
|
||||
|
||||
|
||||
def generate_meta_info(save_dir, max_node, divide=40):
|
||||
aa_nas_bench_ss = get_search_spaces('cell', 'nas-bench-201')
|
||||
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
|
||||
print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2)))
|
||||
|
||||
random.seed( 88 ) # please do not change this line for reproducibility
|
||||
random.shuffle( archs )
|
||||
# to test fixed-random shuffle
|
||||
#print ('arch [0] : {:}\n---->>>> {:}'.format( archs[0], archs[0].tostr() ))
|
||||
#print ('arch [9] : {:}\n---->>>> {:}'.format( archs[9], archs[9].tostr() ))
|
||||
assert archs[0 ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0])
|
||||
assert archs[9 ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9])
|
||||
assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123])
|
||||
total_arch = len(archs)
|
||||
|
||||
num = 50000
|
||||
indexes_5W = list(range(num))
|
||||
random.seed( 1021 )
|
||||
random.shuffle( indexes_5W )
|
||||
train_split = sorted( list(set(indexes_5W[:num//2])) )
|
||||
valid_split = sorted( list(set(indexes_5W[num//2:])) )
|
||||
assert len(train_split) + len(valid_split) == num
|
||||
assert train_split[0] == 0 and train_split[10] == 26 and train_split[111] == 203 and valid_split[0] == 1 and valid_split[10] == 18 and valid_split[111] == 242, '{:} {:} {:} - {:} {:} {:}'.format(train_split[0], train_split[10], train_split[111], valid_split[0], valid_split[10], valid_split[111])
|
||||
splits = {num: {'train': train_split, 'valid': valid_split} }
|
||||
|
||||
info = {'archs' : [x.tostr() for x in archs],
|
||||
'total' : total_arch,
|
||||
'max_node' : max_node,
|
||||
'splits': splits}
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_name = save_dir / 'meta-node-{:}.pth'.format(max_node)
|
||||
assert not save_name.exists(), '{:} already exist'.format(save_name)
|
||||
torch.save(info, save_name)
|
||||
print ('save the meta file into {:}'.format(save_name))
|
||||
|
||||
script_name_full = save_dir / 'BENCH-201-N{:}.opt-full.script'.format(max_node)
|
||||
script_name_less = save_dir / 'BENCH-201-N{:}.opt-less.script'.format(max_node)
|
||||
full_file = open(str(script_name_full), 'w')
|
||||
less_file = open(str(script_name_less), 'w')
|
||||
gaps = total_arch // divide
|
||||
for start in range(0, total_arch, gaps):
|
||||
xend = min(start+gaps, total_arch)
|
||||
full_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 0 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1))
|
||||
less_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 1 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1))
|
||||
print ('save the training script into {:} and {:}'.format(script_name_full, script_name_less))
|
||||
full_file.close()
|
||||
less_file.close()
|
||||
|
||||
script_name = save_dir / 'meta-node-{:}.cal-script.txt'.format(max_node)
|
||||
macro = 'OMP_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0'
|
||||
with open(str(script_name), 'w') as cfile:
|
||||
for start in range(0, total_arch, gaps):
|
||||
xend = min(start+gaps, total_arch)
|
||||
cfile.write('{:} python exps/NAS-Bench-201/statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n'.format(macro, start, xend-1))
|
||||
print ('save the post-processing script into {:}'.format(script_name))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()]
|
||||
#parser = argparse.ArgumentParser(description='Algorithm-Agnostic NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--mode' , type=str, required=True, help='The script mode.')
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--max_node', type=int, help='The maximum node in a cell.')
|
||||
# use for train the model
|
||||
parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--srange' , type=int, nargs='+', help='The range of models to be evaluated')
|
||||
parser.add_argument('--arch_index', type=int, default=-1, help='The architecture index to be evaluated (cover mode).')
|
||||
parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.')
|
||||
parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.')
|
||||
parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.')
|
||||
parser.add_argument('--use_less', type=int, default=0, choices=[0,1], help='Using the less-training-epoch config.')
|
||||
parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.mode in ['meta', 'new', 'cover'] or args.mode.startswith('specific-'), 'invalid mode : {:}'.format(args.mode)
|
||||
|
||||
if args.mode == 'meta':
|
||||
generate_meta_info(args.save_dir, args.max_node)
|
||||
elif args.mode.startswith('specific'):
|
||||
assert len(args.mode.split('-')) == 2, 'invalid mode : {:}'.format(args.mode)
|
||||
model_str = args.mode.split('-')[1]
|
||||
train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \
|
||||
tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells})
|
||||
else:
|
||||
meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node)
|
||||
assert meta_path.exists(), '{:} does not exist.'.format(meta_path)
|
||||
meta_info = torch.load( meta_path )
|
||||
# check whether args is ok
|
||||
assert len(args.srange) == 2 and args.srange[0] <= args.srange[1], 'invalid length of srange args: {:}'.format(args.srange)
|
||||
assert len(args.seeds) > 0, 'invalid length of seeds args: {:}'.format(args.seeds)
|
||||
assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))
|
||||
assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers)
|
||||
|
||||
main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \
|
||||
tuple(args.srange), args.arch_index, tuple(args.seeds), \
|
||||
args.mode == 'cover', meta_info, \
|
||||
{'channel': args.channel, 'num_cells': args.num_cells})
|
295
exps/NAS-Bench-201/statistics.py
Normal file
295
exps/NAS-Bench-201/statistics.py
Normal file
@@ -0,0 +1,295 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, argparse, collections
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from config_utils import load_config, dict2config
|
||||
from datasets import get_datasets
|
||||
# NAS-Bench-201 related module or function
|
||||
from models import CellStructure, get_cell_based_tiny_net
|
||||
from nas_201_api import ArchResults, ResultsCount
|
||||
from functions import pure_evaluate
|
||||
|
||||
|
||||
|
||||
def create_result_count(used_seed, dataset, arch_config, results, dataloader_dict):
|
||||
xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'], \
|
||||
results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None)
|
||||
|
||||
net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes':arch_config['class_num']}, None)
|
||||
network = get_cell_based_tiny_net(net_config)
|
||||
network.load_state_dict(xresult.get_net_param())
|
||||
if 'train_times' in results: # new version
|
||||
xresult.update_train_info(results['train_acc1es'], results['train_acc5es'], results['train_losses'], results['train_times'])
|
||||
xresult.update_eval(results['valid_acc1es'], results['valid_losses'], results['valid_times'])
|
||||
else:
|
||||
if dataset == 'cifar10-valid':
|
||||
xresult.update_OLD_eval('x-valid' , results['valid_acc1es'], results['valid_losses'])
|
||||
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format('cifar10', 'test')], network.cuda())
|
||||
xresult.update_OLD_eval('ori-test', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
|
||||
xresult.update_latency(latencies)
|
||||
elif dataset == 'cifar10':
|
||||
xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses'])
|
||||
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda())
|
||||
xresult.update_latency(latencies)
|
||||
elif dataset == 'cifar100' or dataset == 'ImageNet16-120':
|
||||
xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses'])
|
||||
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'valid')], network.cuda())
|
||||
xresult.update_OLD_eval('x-valid', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
|
||||
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda())
|
||||
xresult.update_OLD_eval('x-test' , {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
|
||||
xresult.update_latency(latencies)
|
||||
else:
|
||||
raise ValueError('invalid dataset name : {:}'.format(dataset))
|
||||
return xresult
|
||||
|
||||
|
||||
|
||||
def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dict):
|
||||
information = ArchResults(arch_index, arch_str)
|
||||
|
||||
for checkpoint_path in checkpoints:
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
|
||||
for dataset in datasets:
|
||||
assert dataset in checkpoint, 'Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path)
|
||||
results = checkpoint[dataset]
|
||||
assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path)
|
||||
arch_config = {'channel': results['channel'], 'num_cells': results['num_cells'], 'arch_str': arch_str, 'class_num': results['config']['class_num']}
|
||||
|
||||
xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict)
|
||||
information.update(dataset, int(used_seed), xresult)
|
||||
return information
|
||||
|
||||
|
||||
|
||||
def GET_DataLoaders(workers):
|
||||
|
||||
torch.set_num_threads(workers)
|
||||
|
||||
root_dir = (Path(__file__).parent / '..' / '..').resolve()
|
||||
torch_dir = Path(os.environ['TORCH_HOME'])
|
||||
# cifar
|
||||
cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config'
|
||||
cifar_config = load_config(cifar_config_path, None, None)
|
||||
print ('{:} Create data-loader for all datasets'.format(time_string()))
|
||||
print ('-'*200)
|
||||
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1)
|
||||
print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num))
|
||||
cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None)
|
||||
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14]
|
||||
temp_dataset = deepcopy(TRAIN_CIFAR10)
|
||||
temp_dataset.transform = VALID_CIFAR10.transform
|
||||
# data loader
|
||||
trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
train_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True)
|
||||
valid_cifar10_loader = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True)
|
||||
test__cifar10_loader = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
|
||||
print ('CIFAR-10 : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size))
|
||||
print ('-'*200)
|
||||
# CIFAR-100
|
||||
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1)
|
||||
print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num))
|
||||
cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None)
|
||||
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24]
|
||||
train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True)
|
||||
test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True)
|
||||
print ('CIFAR-100 : train-loader has {:3d} batch'.format(len(train_cifar100_loader)))
|
||||
print ('CIFAR-100 : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader)))
|
||||
print ('CIFAR-100 : test--loader has {:3d} batch'.format(len(test__cifar100_loader)))
|
||||
print ('-'*200)
|
||||
|
||||
imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config'
|
||||
imagenet16_config = load_config(imagenet16_config_path, None, None)
|
||||
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1)
|
||||
print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num))
|
||||
imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None)
|
||||
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20]
|
||||
train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True)
|
||||
test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True)
|
||||
print ('ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size))
|
||||
print ('ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size))
|
||||
print ('ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size))
|
||||
|
||||
# 'cifar10', 'cifar100', 'ImageNet16-120'
|
||||
loaders = {'cifar10@trainval': trainval_cifar10_loader,
|
||||
'cifar10@train' : train_cifar10_loader,
|
||||
'cifar10@valid' : valid_cifar10_loader,
|
||||
'cifar10@test' : test__cifar10_loader,
|
||||
'cifar100@train' : train_cifar100_loader,
|
||||
'cifar100@valid' : valid_cifar100_loader,
|
||||
'cifar100@test' : test__cifar100_loader,
|
||||
'ImageNet16-120@train': train_imagenet_loader,
|
||||
'ImageNet16-120@valid': valid_imagenet_loader,
|
||||
'ImageNet16-120@test' : test__imagenet_loader}
|
||||
return loaders
|
||||
|
||||
|
||||
|
||||
def simplify(save_dir, meta_file, basestr, target_dir):
|
||||
meta_infos = torch.load(meta_file, map_location='cpu')
|
||||
meta_archs = meta_infos['archs'] # a list of architecture strings
|
||||
meta_num_archs = meta_infos['total']
|
||||
meta_max_node = meta_infos['max_node']
|
||||
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
|
||||
|
||||
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
|
||||
print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
|
||||
|
||||
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
|
||||
num_seeds = defaultdict(lambda: 0)
|
||||
for index, sub_dir in enumerate(sub_model_dirs):
|
||||
xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth'))
|
||||
arch_indexes = set()
|
||||
for checkpoint in xcheckpoints:
|
||||
temp_names = checkpoint.name.split('-')
|
||||
assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name)
|
||||
arch_indexes.add( temp_names[1] )
|
||||
subdir2archs[sub_dir] = sorted(list(arch_indexes))
|
||||
num_evaluated_arch += len(arch_indexes)
|
||||
# count number of seeds for each architecture
|
||||
for arch_index in arch_indexes:
|
||||
num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1
|
||||
print('{:} There are {:5d} architectures that have been evaluated ({:} in total).'.format(time_string(), num_evaluated_arch, meta_num_archs))
|
||||
for key in sorted( list( num_seeds.keys() ) ): print ('{:} There are {:5d} architectures that are evaluated {:} times.'.format(time_string(), num_seeds[key], key))
|
||||
|
||||
dataloader_dict = GET_DataLoaders( 6 )
|
||||
|
||||
to_save_simply = save_dir / 'simplifies'
|
||||
to_save_allarc = save_dir / 'simplifies' / 'architectures'
|
||||
if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True)
|
||||
if not to_save_allarc.exists(): to_save_allarc.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
assert (save_dir / target_dir) in subdir2archs, 'can not find {:}'.format(target_dir)
|
||||
arch2infos, datasets = {}, ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120')
|
||||
evaluated_indexes = set()
|
||||
target_directory = save_dir / target_dir
|
||||
target_less_dir = save_dir / '{:}-LESS'.format(target_dir)
|
||||
arch_indexes = subdir2archs[ target_directory ]
|
||||
num_seeds = defaultdict(lambda: 0)
|
||||
end_time = time.time()
|
||||
arch_time = AverageMeter()
|
||||
for idx, arch_index in enumerate(arch_indexes):
|
||||
checkpoints = list(target_directory.glob('arch-{:}-seed-*.pth'.format(arch_index)))
|
||||
ckps_less = list(target_less_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))
|
||||
# create the arch info for each architecture
|
||||
try:
|
||||
arch_info_full = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict)
|
||||
arch_info_less = account_one_arch(arch_index, meta_archs[int(arch_index)], ckps_less, ['cifar10-valid'], dataloader_dict)
|
||||
num_seeds[ len(checkpoints) ] += 1
|
||||
except:
|
||||
print('Loading {:} failed, : {:}'.format(arch_index, checkpoints))
|
||||
continue
|
||||
assert int(arch_index) not in evaluated_indexes, 'conflict arch-index : {:}'.format(arch_index)
|
||||
assert 0 <= int(arch_index) < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index)
|
||||
arch_info = {'full': arch_info_full, 'less': arch_info_less}
|
||||
evaluated_indexes.add( int(arch_index) )
|
||||
arch2infos[int(arch_index)] = arch_info
|
||||
torch.save({'full': arch_info_full.state_dict(),
|
||||
'less': arch_info_less.state_dict()}, to_save_allarc / '{:}-FULL.pth'.format(arch_index))
|
||||
arch_info['full'].clear_params()
|
||||
arch_info['less'].clear_params()
|
||||
torch.save({'full': arch_info_full.state_dict(),
|
||||
'less': arch_info_less.state_dict()}, to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index))
|
||||
# measure elapsed time
|
||||
arch_time.update(time.time() - end_time)
|
||||
end_time = time.time()
|
||||
need_time = '{:}'.format( convert_secs2time(arch_time.avg * (len(arch_indexes)-idx-1), True) )
|
||||
print('{:} {:} [{:03d}/{:03d}] : {:} still need {:}'.format(time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time))
|
||||
# measure time
|
||||
xstrs = ['{:}:{:03d}'.format(key, num_seeds[key]) for key in sorted( list( num_seeds.keys() ) ) ]
|
||||
print('{:} {:} done : {:}'.format(time_string(), target_dir, xstrs))
|
||||
final_infos = {'meta_archs' : meta_archs,
|
||||
'total_archs': meta_num_archs,
|
||||
'basestr' : basestr,
|
||||
'arch2infos' : arch2infos,
|
||||
'evaluated_indexes': evaluated_indexes}
|
||||
save_file_name = to_save_simply / '{:}.pth'.format(target_dir)
|
||||
torch.save(final_infos, save_file_name)
|
||||
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
|
||||
|
||||
|
||||
|
||||
def merge_all(save_dir, meta_file, basestr):
|
||||
meta_infos = torch.load(meta_file, map_location='cpu')
|
||||
meta_archs = meta_infos['archs']
|
||||
meta_num_archs = meta_infos['total']
|
||||
meta_max_node = meta_infos['max_node']
|
||||
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
|
||||
|
||||
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
|
||||
print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
|
||||
for index, sub_dir in enumerate(sub_model_dirs):
|
||||
arch_info_files = sorted( list(sub_dir.glob('arch-*-seed-*.pth') ) )
|
||||
print ('The {:02d}/{:02d}-th directory : {:} : {:} runs.'.format(index, len(sub_model_dirs), sub_dir, len(arch_info_files)))
|
||||
|
||||
arch2infos, evaluated_indexes = dict(), set()
|
||||
for IDX, sub_dir in enumerate(sub_model_dirs):
|
||||
ckp_path = sub_dir.parent / 'simplifies' / '{:}.pth'.format(sub_dir.name)
|
||||
if ckp_path.exists():
|
||||
sub_ckps = torch.load(ckp_path, map_location='cpu')
|
||||
assert sub_ckps['total_archs'] == meta_num_archs and sub_ckps['basestr'] == basestr
|
||||
xarch2infos = sub_ckps['arch2infos']
|
||||
xevalindexs = sub_ckps['evaluated_indexes']
|
||||
for eval_index in xevalindexs:
|
||||
assert eval_index not in evaluated_indexes and eval_index not in arch2infos
|
||||
#arch2infos[eval_index] = xarch2infos[eval_index].state_dict()
|
||||
arch2infos[eval_index] = {'full': xarch2infos[eval_index]['full'].state_dict(),
|
||||
'less': xarch2infos[eval_index]['less'].state_dict()}
|
||||
evaluated_indexes.add( eval_index )
|
||||
print ('{:} [{:03d}/{:03d}] merge data from {:} with {:} models.'.format(time_string(), IDX, len(sub_model_dirs), ckp_path, len(xevalindexs)))
|
||||
else:
|
||||
raise ValueError('Can not find {:}'.format(ckp_path))
|
||||
#print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path))
|
||||
|
||||
evaluated_indexes = sorted( list( evaluated_indexes ) )
|
||||
print ('Finally, there are {:} architectures that have been trained and evaluated.'.format(len(evaluated_indexes)))
|
||||
|
||||
to_save_simply = save_dir / 'simplifies'
|
||||
if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True)
|
||||
final_infos = {'meta_archs' : meta_archs,
|
||||
'total_archs': meta_num_archs,
|
||||
'arch2infos' : arch2infos,
|
||||
'evaluated_indexes': evaluated_indexes}
|
||||
save_file_name = to_save_simply / '{:}-final-infos.pth'.format(basestr)
|
||||
torch.save(final_infos, save_file_name)
|
||||
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--mode' , type=str, choices=['cal', 'merge'], help='The running mode for this script.')
|
||||
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.')
|
||||
parser.add_argument('--target_dir' , type=str, help='The target directory.')
|
||||
parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.')
|
||||
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')
|
||||
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
|
||||
args = parser.parse_args()
|
||||
|
||||
save_dir = Path( args.base_save_dir )
|
||||
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
|
||||
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
|
||||
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
|
||||
print ('start the statistics of our nas-benchmark from {:} using {:}.'.format(save_dir, args.target_dir))
|
||||
basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells)
|
||||
|
||||
if args.mode == 'cal':
|
||||
simplify(save_dir, meta_path, basestr, args.target_dir)
|
||||
elif args.mode == 'merge':
|
||||
merge_all(save_dir, meta_path, basestr)
|
||||
else:
|
||||
raise ValueError('invalid mode : {:}'.format(args.mode))
|
223
exps/NAS-Bench-201/test-correlation.py
Normal file
223
exps/NAS-Bench-201/test-correlation.py
Normal file
@@ -0,0 +1,223 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
########################################################
|
||||
# python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
||||
########################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net, get_search_spaces, CellStructure
|
||||
from nas_201_api import NASBench201API as API
|
||||
|
||||
|
||||
def valid_func(xloader, network, criterion):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
network.eval()
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# prediction
|
||||
_, logits = network(arch_inputs)
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
# record
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def main(xargs):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads( xargs.workers )
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
imagenet16_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space}, None)
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
logger.log('search-model :\n{:}'.format(search_model))
|
||||
|
||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
||||
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
||||
logger.log('w-optimizer : {:}'.format(w_optimizer))
|
||||
logger.log('a-optimizer : {:}'.format(a_optimizer))
|
||||
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
||||
logger.log('criterion : {:}'.format(criterion))
|
||||
flop, param = get_model_infos(search_model, xshape)
|
||||
#logger.log('{:}'.format(search_model))
|
||||
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||
if xargs.arch_nas_dataset is None:
|
||||
api = None
|
||||
else:
|
||||
api = API(xargs.arch_nas_dataset)
|
||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||
|
||||
logger.close()
|
||||
|
||||
|
||||
def check_unique_arch(meta_file):
|
||||
api = API(str(meta_file))
|
||||
arch_strs = deepcopy(api.meta_archs)
|
||||
xarchs = [CellStructure.str2structure(x) for x in arch_strs]
|
||||
def get_unique_matrix(archs, consider_zero):
|
||||
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
|
||||
print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs)))
|
||||
Unique2Index = dict()
|
||||
for index, xstr in enumerate(UniquStrs):
|
||||
if xstr not in Unique2Index: Unique2Index[xstr] = list()
|
||||
Unique2Index[xstr].append( index )
|
||||
sm_matrix = torch.eye(len(archs)).bool()
|
||||
for _, xlist in Unique2Index.items():
|
||||
for i in xlist:
|
||||
for j in xlist:
|
||||
sm_matrix[i,j] = True
|
||||
unique_ids, unique_num = [-1 for _ in archs], 0
|
||||
for i in range(len(unique_ids)):
|
||||
if unique_ids[i] > -1: continue
|
||||
neighbours = sm_matrix[i].nonzero().view(-1).tolist()
|
||||
for nghb in neighbours:
|
||||
assert unique_ids[nghb] == -1, 'impossible'
|
||||
unique_ids[nghb] = unique_num
|
||||
unique_num += 1
|
||||
return sm_matrix, unique_ids, unique_num
|
||||
|
||||
print ('There are {:} valid-archs'.format( sum(arch.check_valid() for arch in xarchs) ))
|
||||
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None)
|
||||
print ('{:} There are {:} unique architectures (considering nothing).'.format(time_string(), unique_num))
|
||||
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False)
|
||||
print ('{:} There are {:} unique architectures (not considering zero).'.format(time_string(), unique_num))
|
||||
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True)
|
||||
print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num))
|
||||
|
||||
|
||||
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False):
|
||||
if isinstance(meta_file, API):
|
||||
api = meta_file
|
||||
else:
|
||||
api = API(str(meta_file))
|
||||
cifar10_currs = []
|
||||
cifar10_valid = []
|
||||
cifar10_test = []
|
||||
cifar100_valid = []
|
||||
cifar100_test = []
|
||||
imagenet_test = []
|
||||
imagenet_valid = []
|
||||
for idx, arch in enumerate(api):
|
||||
results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand)
|
||||
cifar10_currs.append( results['valid-accuracy'] )
|
||||
# --->>>>>
|
||||
results = api.get_more_info(idx, 'cifar10-valid' , None, False, is_rand)
|
||||
cifar10_valid.append( results['valid-accuracy'] )
|
||||
results = api.get_more_info(idx, 'cifar10' , None, False, is_rand)
|
||||
cifar10_test.append( results['test-accuracy'] )
|
||||
results = api.get_more_info(idx, 'cifar100' , None, False, is_rand)
|
||||
cifar100_test.append( results['test-accuracy'] )
|
||||
cifar100_valid.append( results['valid-accuracy'] )
|
||||
results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand)
|
||||
imagenet_test.append( results['test-accuracy'] )
|
||||
imagenet_valid.append( results['valid-accuracy'] )
|
||||
def get_cor(A, B):
|
||||
return float(np.corrcoef(A, B)[0,1])
|
||||
cors = []
|
||||
for basestr, xlist in zip(['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]):
|
||||
correlation = get_cor(cifar10_currs, xlist)
|
||||
if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation))
|
||||
cors.append( correlation )
|
||||
#print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
|
||||
#print('-'*200)
|
||||
#print('*'*230)
|
||||
return cors
|
||||
|
||||
|
||||
def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
|
||||
corrs = []
|
||||
for i in tqdm(range(100)):
|
||||
x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
|
||||
corrs.append( x )
|
||||
#xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
|
||||
xstrs = ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
|
||||
correlations = np.array(corrs)
|
||||
print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200'))
|
||||
for idx, xstr in enumerate(xstrs):
|
||||
print ('{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}'.format(xstr, correlations[:,idx].mean(), correlations[:,idx].std(), correlations[:,idx].mean(), correlations[:,idx].std()))
|
||||
print('')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
|
||||
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
|
||||
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.')
|
||||
args = parser.parse_args()
|
||||
|
||||
vis_save_dir = Path(args.save_dir)
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
meta_file = Path(args.api_path)
|
||||
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
|
||||
|
||||
#check_unique_arch(meta_file)
|
||||
api = API(str(meta_file))
|
||||
#for iepoch in [11, 25, 50, 100, 150, 175, 200]:
|
||||
# check_cor_for_bandit(api, 6, iepoch)
|
||||
# check_cor_for_bandit(api, 12, iepoch)
|
||||
check_cor_for_bandit_v2(api, 6, True, True)
|
||||
check_cor_for_bandit_v2(api, 12, True, True)
|
||||
check_cor_for_bandit_v2(api, 12, False, True)
|
||||
check_cor_for_bandit_v2(api, 24, False, True)
|
||||
check_cor_for_bandit_v2(api, 100, False, True)
|
||||
check_cor_for_bandit_v2(api, 150, False, True)
|
||||
check_cor_for_bandit_v2(api, 175, False, True)
|
||||
check_cor_for_bandit_v2(api, 200, False, True)
|
||||
print('----')
|
740
exps/NAS-Bench-201/visualize.py
Normal file
740
exps/NAS-Bench-201/visualize.py
Normal file
@@ -0,0 +1,740 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
||||
##################################################
|
||||
import os, sys, time, argparse, collections
|
||||
from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import matplotlib
|
||||
import seaborn as sns
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from log_utils import time_string
|
||||
from nas_201_api import NASBench201API as API
|
||||
|
||||
|
||||
|
||||
def calculate_correlation(*vectors):
|
||||
matrix = []
|
||||
for i, vectori in enumerate(vectors):
|
||||
x = []
|
||||
for j, vectorj in enumerate(vectors):
|
||||
x.append( np.corrcoef(vectori, vectorj)[0,1] )
|
||||
matrix.append( x )
|
||||
return np.array(matrix)
|
||||
|
||||
|
||||
|
||||
def visualize_relative_ranking(vis_save_dir):
|
||||
print ('\n' + '-'*100)
|
||||
cifar010_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar10')
|
||||
cifar100_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar100')
|
||||
imagenet_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('ImageNet16-120')
|
||||
cifar010_info = torch.load(cifar010_cache_path)
|
||||
cifar100_info = torch.load(cifar100_cache_path)
|
||||
imagenet_info = torch.load(imagenet_cache_path)
|
||||
indexes = list(range(len(cifar010_info['params'])))
|
||||
|
||||
print ('{:} start to visualize relative ranking'.format(time_string()))
|
||||
# maximum accuracy with ResNet-level params 11472
|
||||
x_010_accs = [ cifar010_info['test_accs'][i] if cifar010_info['params'][i] <= cifar010_info['params'][11472] else -1 for i in indexes]
|
||||
x_100_accs = [ cifar100_info['test_accs'][i] if cifar100_info['params'][i] <= cifar100_info['params'][11472] else -1 for i in indexes]
|
||||
x_img_accs = [ imagenet_info['test_accs'][i] if imagenet_info['params'][i] <= imagenet_info['params'][11472] else -1 for i in indexes]
|
||||
|
||||
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
|
||||
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
|
||||
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
|
||||
|
||||
cifar100_labels, imagenet_labels = [], []
|
||||
for idx in cifar010_ord_indexes:
|
||||
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
|
||||
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
|
||||
print ('{:} prepare data done.'.format(time_string()))
|
||||
|
||||
dpi, width, height = 300, 2600, 2600
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
LabelSize, LegendFontsize = 18, 18
|
||||
resnet_scale, resnet_alpha = 120, 0.5
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xlim(min(indexes), max(indexes))
|
||||
plt.ylim(min(indexes), max(indexes))
|
||||
#plt.ylabel('y').set_rotation(0)
|
||||
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical')
|
||||
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize)
|
||||
#ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100')
|
||||
#ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8, label='ImageNet-16-120')
|
||||
#ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10')
|
||||
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
|
||||
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
|
||||
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
|
||||
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
|
||||
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
|
||||
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
|
||||
plt.grid(zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
|
||||
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / 'relative-rank.pdf').resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / 'relative-rank.png').resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
# calculate correlation
|
||||
sns_size = 15
|
||||
CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs'])
|
||||
fig = plt.figure(figsize=figsize)
|
||||
plt.axis('off')
|
||||
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5)
|
||||
save_path = (vis_save_dir / 'co-relation-all.pdf').resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
# calculate correlation
|
||||
acc_bars = [92, 93]
|
||||
for acc_bar in acc_bars:
|
||||
selected_indexes = []
|
||||
for i, acc in enumerate(cifar010_info['test_accs']):
|
||||
if acc > acc_bar: selected_indexes.append( i )
|
||||
print ('select {:} architectures'.format(len(selected_indexes)))
|
||||
cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ]
|
||||
cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ]
|
||||
cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ]
|
||||
cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ]
|
||||
imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ]
|
||||
imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ]
|
||||
CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs)
|
||||
fig = plt.figure(figsize=figsize)
|
||||
plt.axis('off')
|
||||
h = sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt='.3f', linewidths=0.5)
|
||||
save_path = (vis_save_dir / 'co-relation-top-{:}.pdf'.format(len(selected_indexes))).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
plt.close('all')
|
||||
|
||||
|
||||
|
||||
def visualize_info(meta_file, dataset, vis_save_dir):
|
||||
print ('{:} start to visualize {:} information'.format(time_string(), dataset))
|
||||
cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset)
|
||||
if not cache_file_path.exists():
|
||||
print ('Do not find cache file : {:}'.format(cache_file_path))
|
||||
nas_bench = API(str(meta_file))
|
||||
params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], []
|
||||
for index in range( len(nas_bench) ):
|
||||
info = nas_bench.query_by_index(index, use_12epochs_result=False)
|
||||
resx = info.get_comput_costs(dataset) ; flop, param = resx['flops'], resx['params']
|
||||
if dataset == 'cifar10':
|
||||
res = info.get_metrics('cifar10', 'train') ; train_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10-valid', 'x-valid') ; valid_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10', 'ori-test') ; test_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10', 'ori-test') ; otest_acc = res['accuracy']
|
||||
else:
|
||||
res = info.get_metrics(dataset, 'train') ; train_acc = res['accuracy']
|
||||
res = info.get_metrics(dataset, 'x-valid') ; valid_acc = res['accuracy']
|
||||
res = info.get_metrics(dataset, 'x-test') ; test_acc = res['accuracy']
|
||||
res = info.get_metrics(dataset, 'ori-test') ; otest_acc = res['accuracy']
|
||||
if index == 11472: # resnet
|
||||
resnet = {'params':param, 'flops': flop, 'index': 11472, 'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': test_acc, 'otest_acc': otest_acc}
|
||||
flops.append( flop )
|
||||
params.append( param )
|
||||
train_accs.append( train_acc )
|
||||
valid_accs.append( valid_acc )
|
||||
test_accs.append( test_acc )
|
||||
otest_accs.append( otest_acc )
|
||||
#resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97}
|
||||
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs}
|
||||
info['resnet'] = resnet
|
||||
torch.save(info, cache_file_path)
|
||||
else:
|
||||
print ('Find cache file : {:}'.format(cache_file_path))
|
||||
info = torch.load(cache_file_path)
|
||||
params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs']
|
||||
resnet = info['resnet']
|
||||
print ('{:} collect data done.'.format(time_string()))
|
||||
|
||||
indexes = list(range(len(params)))
|
||||
dpi, width, height = 300, 2600, 2600
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
LabelSize, LegendFontsize = 22, 22
|
||||
resnet_scale, resnet_alpha = 120, 0.5
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
|
||||
if dataset == 'cifar10':
|
||||
plt.ylim(50, 100)
|
||||
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||
elif dataset == 'cifar100':
|
||||
plt.ylim(25, 75)
|
||||
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||
else:
|
||||
plt.ylim(0, 50)
|
||||
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
|
||||
ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax.scatter([resnet['params']], [resnet['valid_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=0.4)
|
||||
plt.grid(zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||
ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / '{:}-param-vs-valid.pdf'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / '{:}-param-vs-valid.png'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
|
||||
if dataset == 'cifar10':
|
||||
plt.ylim(50, 100)
|
||||
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||
elif dataset == 'cifar100':
|
||||
plt.ylim(25, 75)
|
||||
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||
else:
|
||||
plt.ylim(0, 50)
|
||||
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
|
||||
ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax.scatter([resnet['params']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
|
||||
plt.grid()
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||
ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / '{:}-param-vs-test.pdf'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / '{:}-param-vs-test.png'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
|
||||
if dataset == 'cifar10':
|
||||
plt.ylim(50, 100)
|
||||
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||
elif dataset == 'cifar100':
|
||||
plt.ylim(20, 100)
|
||||
plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize)
|
||||
else:
|
||||
plt.ylim(25, 76)
|
||||
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||
ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax.scatter([resnet['params']], [resnet['train_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
|
||||
plt.grid()
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
|
||||
ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / '{:}-param-vs-train.pdf'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / '{:}-param-vs-train.png'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xlim(0, max(indexes))
|
||||
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
|
||||
if dataset == 'cifar10':
|
||||
plt.ylim(50, 100)
|
||||
plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
|
||||
elif dataset == 'cifar100':
|
||||
plt.ylim(25, 75)
|
||||
plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
|
||||
else:
|
||||
plt.ylim(0, 50)
|
||||
plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
|
||||
ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
|
||||
ax.scatter([resnet['index']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha)
|
||||
plt.grid()
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
ax.set_xlabel('architecture ID', fontsize=LabelSize)
|
||||
ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / '{:}-test-over-ID.pdf'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / '{:}-test-over-ID.png'.format(dataset)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
plt.close('all')
|
||||
|
||||
|
||||
|
||||
def visualize_rank_over_time(meta_file, vis_save_dir):
|
||||
print ('\n' + '-'*150)
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
print ('{:} start to visualize rank-over-time into {:}'.format(time_string(), vis_save_dir))
|
||||
cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth'
|
||||
if not cache_file_path.exists():
|
||||
print ('Do not find cache file : {:}'.format(cache_file_path))
|
||||
nas_bench = API(str(meta_file))
|
||||
print ('{:} load nas_bench done'.format(time_string()))
|
||||
params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
|
||||
#for iepoch in range(200): for index in range( len(nas_bench) ):
|
||||
for index in tqdm(range(len(nas_bench))):
|
||||
info = nas_bench.query_by_index(index, use_12epochs_result=False)
|
||||
for iepoch in range(200):
|
||||
res = info.get_metrics('cifar10' , 'train' , iepoch) ; train_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10-valid', 'x-valid' , iepoch) ; valid_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; test_acc = res['accuracy']
|
||||
res = info.get_metrics('cifar10' , 'ori-test', iepoch) ; otest_acc = res['accuracy']
|
||||
train_accs[iepoch].append( train_acc )
|
||||
valid_accs[iepoch].append( valid_acc )
|
||||
test_accs [iepoch].append( test_acc )
|
||||
otest_accs[iepoch].append( otest_acc )
|
||||
if iepoch == 0:
|
||||
res = info.get_comput_costs('cifar10') ; flop, param = res['flops'], res['params']
|
||||
flops.append( flop )
|
||||
params.append( param )
|
||||
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs}
|
||||
torch.save(info, cache_file_path)
|
||||
else:
|
||||
print ('Find cache file : {:}'.format(cache_file_path))
|
||||
info = torch.load(cache_file_path)
|
||||
params, flops, train_accs, valid_accs, test_accs, otest_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs'], info['otest_accs']
|
||||
print ('{:} collect data done.'.format(time_string()))
|
||||
#selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
|
||||
selected_epochs = list( range(200) )
|
||||
x_xtests = test_accs[199]
|
||||
indexes = list(range(len(x_xtests)))
|
||||
ord_idxs = sorted(indexes, key=lambda i: x_xtests[i])
|
||||
for sepoch in selected_epochs:
|
||||
x_valids = valid_accs[sepoch]
|
||||
valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i])
|
||||
valid_ord_lbls = []
|
||||
for idx in ord_idxs:
|
||||
valid_ord_lbls.append( valid_ord_idxs.index(idx) )
|
||||
# labeled data
|
||||
dpi, width, height = 300, 2600, 2600
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
LabelSize, LegendFontsize = 18, 18
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111)
|
||||
plt.xlim(min(indexes), max(indexes))
|
||||
plt.ylim(min(indexes), max(indexes))
|
||||
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize, rotation='vertical')
|
||||
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//6), fontsize=LegendFontsize)
|
||||
ax.scatter(indexes, valid_ord_lbls, marker='^', s=0.5, c='tab:green', alpha=0.8)
|
||||
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
|
||||
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-10 validation')
|
||||
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10 test')
|
||||
plt.grid(zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
plt.legend(loc='upper left', fontsize=LegendFontsize)
|
||||
ax.set_xlabel('architecture ranking in the final test accuracy', fontsize=LabelSize)
|
||||
ax.set_ylabel('architecture ranking in the validation set', fontsize=LabelSize)
|
||||
save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve()
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||||
print ('{:} save into {:}'.format(time_string(), save_path))
|
||||
plt.close('all')
|
||||
|
||||
|
||||
|
||||
def write_video(save_dir):
|
||||
import cv2
|
||||
video_save_path = save_dir / 'time.avi'
|
||||
print ('{:} start create video for {:}'.format(time_string(), video_save_path))
|
||||
images = sorted( list( save_dir.glob('time-*.png') ) )
|
||||
ximage = cv2.imread(str(images[0]))
|
||||
#shape = (ximage.shape[1], ximage.shape[0])
|
||||
shape = (1000, 1000)
|
||||
#writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 25, shape)
|
||||
writer = cv2.VideoWriter(str(video_save_path), cv2.VideoWriter_fourcc(*"MJPG"), 5, shape)
|
||||
for idx, image in enumerate(images):
|
||||
ximage = cv2.imread(str(image))
|
||||
_image = cv2.resize(ximage, shape)
|
||||
writer.write(_image)
|
||||
writer.release()
|
||||
print ('write video [{:} frames] into {:}'.format(len(images), video_save_path))
|
||||
|
||||
|
||||
|
||||
def plot_results_nas_v2(api, dataset_xset_a, dataset_xset_b, root, file_name, y_lims):
|
||||
#print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
|
||||
print ('root-path : {:} and {:}'.format(dataset_xset_a, dataset_xset_b))
|
||||
checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth',
|
||||
'./output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth',
|
||||
'./output/search-cell-nas-bench-201/RAND-cifar10/results.pth',
|
||||
'./output/search-cell-nas-bench-201/BOHB-cifar10/results.pth'
|
||||
]
|
||||
legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
|
||||
All_Accs_A, All_Accs_B = OrderedDict(), OrderedDict()
|
||||
for legend, checkpoint in zip(legends, checkpoints):
|
||||
all_indexes = torch.load(checkpoint, map_location='cpu')
|
||||
accuracies_A, accuracies_B = [], []
|
||||
accuracies = []
|
||||
for x in all_indexes:
|
||||
info = api.arch2infos_full[ x ]
|
||||
metrics = info.get_metrics(dataset_xset_a[0], dataset_xset_a[1], None, False)
|
||||
accuracies_A.append( metrics['accuracy'] )
|
||||
metrics = info.get_metrics(dataset_xset_b[0], dataset_xset_b[1], None, False)
|
||||
accuracies_B.append( metrics['accuracy'] )
|
||||
accuracies.append( (accuracies_A[-1], accuracies_B[-1]) )
|
||||
if indexes is None: indexes = list(range(len(all_indexes)))
|
||||
accuracies = sorted(accuracies)
|
||||
All_Accs_A[legend] = [x[0] for x in accuracies]
|
||||
All_Accs_B[legend] = [x[1] for x in accuracies]
|
||||
|
||||
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||
dpi, width, height = 300, 3400, 2600
|
||||
LabelSize, LegendFontsize = 28, 28
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
fig = plt.figure(figsize=figsize)
|
||||
x_axis = np.arange(0, 600)
|
||||
plt.xlim(0, max(indexes))
|
||||
plt.ylim(y_lims[0], y_lims[1])
|
||||
interval_x, interval_y = 100, y_lims[2]
|
||||
plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
|
||||
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
||||
plt.grid()
|
||||
plt.xlabel('The index of runs', fontsize=LabelSize)
|
||||
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||
|
||||
for idx, legend in enumerate(legends):
|
||||
plt.plot(indexes, All_Accs_B[legend], color=color_set[idx], linestyle='--', label='{:}'.format(legend), lw=1, alpha=0.5)
|
||||
plt.plot(indexes, All_Accs_A[legend], color=color_set[idx], linestyle='-', lw=1)
|
||||
for All_Accs in [All_Accs_A, All_Accs_B]:
|
||||
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend])))
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
save_path = root / '{:}'.format(file_name)
|
||||
print('save figure into {:}\n'.format(save_path))
|
||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
|
||||
|
||||
|
||||
|
||||
def plot_results_nas(api, dataset, xset, root, file_name, y_lims):
|
||||
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
|
||||
checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth',
|
||||
'./output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth',
|
||||
'./output/search-cell-nas-bench-201/RAND-cifar10/results.pth',
|
||||
'./output/search-cell-nas-bench-201/BOHB-cifar10/results.pth'
|
||||
]
|
||||
legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
|
||||
All_Accs = OrderedDict()
|
||||
for legend, checkpoint in zip(legends, checkpoints):
|
||||
all_indexes = torch.load(checkpoint, map_location='cpu')
|
||||
accuracies = []
|
||||
for x in all_indexes:
|
||||
info = api.arch2infos_full[ x ]
|
||||
metrics = info.get_metrics(dataset, xset, None, False)
|
||||
accuracies.append( metrics['accuracy'] )
|
||||
if indexes is None: indexes = list(range(len(all_indexes)))
|
||||
All_Accs[legend] = sorted(accuracies)
|
||||
|
||||
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||
dpi, width, height = 300, 3400, 2600
|
||||
LabelSize, LegendFontsize = 28, 28
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
fig = plt.figure(figsize=figsize)
|
||||
x_axis = np.arange(0, 600)
|
||||
plt.xlim(0, max(indexes))
|
||||
plt.ylim(y_lims[0], y_lims[1])
|
||||
interval_x, interval_y = 100, y_lims[2]
|
||||
plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
|
||||
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
||||
plt.grid()
|
||||
plt.xlabel('The index of runs', fontsize=LabelSize)
|
||||
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||
|
||||
for idx, legend in enumerate(legends):
|
||||
plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2)
|
||||
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend])))
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name)
|
||||
print('save figure into {:}\n'.format(save_path))
|
||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
|
||||
|
||||
def just_show(api):
|
||||
xtimes = {'RSPS' : [8082.5, 7794.2, 8144.7],
|
||||
'DARTS-V1': [11582.1, 11347.0, 11948.2],
|
||||
'DARTS-V2': [35694.7, 36132.7, 35518.0],
|
||||
'GDAS' : [31334.1, 31478.6, 32016.7],
|
||||
'SETN' : [33528.8, 33831.5, 35058.3],
|
||||
'ENAS' : [14340.2, 13817.3, 14018.9]}
|
||||
for xkey, xlist in xtimes.items():
|
||||
xlist = np.array(xlist)
|
||||
print ('{:4s} : mean-time={:.2f} s'.format(xkey, xlist.mean()))
|
||||
|
||||
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/',
|
||||
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/',
|
||||
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/',
|
||||
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/',
|
||||
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/',
|
||||
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/',
|
||||
}
|
||||
xseeds = {'RSPS' : [5349, 59613, 5983],
|
||||
'DARTS-V1': [11416, 72873, 81184],
|
||||
'DARTS-V2': [43330, 79405, 79423],
|
||||
'GDAS' : [19677, 884, 95950],
|
||||
'SETN' : [20518, 61817, 89144],
|
||||
'ENAS' : [3231, 34238, 96929],
|
||||
}
|
||||
|
||||
def get_accs(xdata, index=-1):
|
||||
if index == -1:
|
||||
epochs = xdata['epoch']
|
||||
genotype = xdata['genotypes'][epochs-1]
|
||||
index = api.query_index_by_arch(genotype)
|
||||
pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')]
|
||||
xresults = []
|
||||
for dataset, xset in pairs:
|
||||
metrics = api.arch2infos_full[index].get_metrics(dataset, xset, None, False)
|
||||
xresults.append( metrics['accuracy'] )
|
||||
return xresults
|
||||
|
||||
for xkey in xpaths.keys():
|
||||
all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ]
|
||||
all_datas = [torch.load(xpath) for xpath in all_paths]
|
||||
accyss = [get_accs(xdatas) for xdatas in all_datas]
|
||||
accyss = np.array( accyss )
|
||||
print('\nxkey = {:}'.format(xkey))
|
||||
for i in range(accyss.shape[1]): print('---->>>> {:.2f}$\\pm${:.2f}'.format(accyss[:,i].mean(), accyss[:,i].std()))
|
||||
|
||||
print('\n{:}'.format(get_accs(None, 11472))) # resnet
|
||||
pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')]
|
||||
for dataset, metric_on_set in pairs:
|
||||
arch_index, highest_acc = api.find_best(dataset, metric_on_set)
|
||||
print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc))
|
||||
|
||||
|
||||
def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_maxs):
|
||||
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||
dpi, width, height = 300, 3400, 2600
|
||||
LabelSize, LegendFontsize = 28, 28
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
fig = plt.figure(figsize=figsize)
|
||||
#x_maxs = 250
|
||||
plt.xlim(0, x_maxs+1)
|
||||
plt.ylim(y_lims[0], y_lims[1])
|
||||
interval_x, interval_y = x_maxs // 5, y_lims[2]
|
||||
plt.xticks(np.arange(0, x_maxs+1, interval_x), fontsize=LegendFontsize)
|
||||
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
||||
plt.grid()
|
||||
plt.xlabel('The searching epoch', fontsize=LabelSize)
|
||||
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||
|
||||
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/',
|
||||
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/',
|
||||
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/',
|
||||
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/',
|
||||
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/',
|
||||
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/',
|
||||
}
|
||||
xseeds = {'RSPS' : [5349, 59613, 5983],
|
||||
'DARTS-V1': [11416, 72873, 81184, 28640],
|
||||
'DARTS-V2': [43330, 79405, 79423],
|
||||
'GDAS' : [19677, 884, 95950],
|
||||
'SETN' : [20518, 61817, 89144],
|
||||
'ENAS' : [3231, 34238, 96929],
|
||||
}
|
||||
|
||||
def get_accs(xdata):
|
||||
epochs, xresults = xdata['epoch'], []
|
||||
if -1 in xdata['genotypes']:
|
||||
metrics = api.arch2infos_full[ api.query_index_by_arch(xdata['genotypes'][-1]) ].get_metrics(dataset, subset, None, False)
|
||||
else:
|
||||
metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
|
||||
xresults.append( metrics['accuracy'] )
|
||||
for iepoch in range(epochs):
|
||||
genotype = xdata['genotypes'][iepoch]
|
||||
index = api.query_index_by_arch(genotype)
|
||||
metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False)
|
||||
xresults.append( metrics['accuracy'] )
|
||||
return xresults
|
||||
|
||||
if x_maxs == 50:
|
||||
xox, xxxstrs = 'v2', ['DARTS-V1', 'DARTS-V2']
|
||||
elif x_maxs == 250:
|
||||
xox, xxxstrs = 'v1', ['RSPS', 'GDAS', 'SETN', 'ENAS']
|
||||
else: raise ValueError('invalid x_maxs={:}'.format(x_maxs))
|
||||
|
||||
for idx, method in enumerate(xxxstrs):
|
||||
xkey = method
|
||||
all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ]
|
||||
all_datas = [torch.load(xpath, map_location='cpu') for xpath in all_paths]
|
||||
accyss = [get_accs(xdatas) for xdatas in all_datas]
|
||||
accyss = np.array( accyss )
|
||||
epochs = list(range(accyss.shape[1]))
|
||||
plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx], linestyle='-', label='{:}'.format(method), lw=2)
|
||||
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx])
|
||||
#plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||
save_path = vis_save_dir / '{:}-{:}-{:}-{:}'.format(xox, dataset, subset, file_name)
|
||||
print('save figure into {:}\n'.format(save_path))
|
||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
|
||||
|
||||
def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name, y_lims, x_maxs):
|
||||
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||
dpi, width, height = 300, 3400, 2600
|
||||
LabelSize, LegendFontsize = 28, 28
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
fig = plt.figure(figsize=figsize)
|
||||
#x_maxs = 250
|
||||
plt.xlim(0, x_maxs+1)
|
||||
plt.ylim(y_lims[0], y_lims[1])
|
||||
interval_x, interval_y = x_maxs // 5, y_lims[2]
|
||||
plt.xticks(np.arange(0, x_maxs+1, interval_x), fontsize=LegendFontsize)
|
||||
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
||||
plt.grid()
|
||||
plt.xlabel('The searching epoch', fontsize=LabelSize)
|
||||
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||
|
||||
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/',
|
||||
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/',
|
||||
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/',
|
||||
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/',
|
||||
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/',
|
||||
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/',
|
||||
}
|
||||
xseeds = {'RSPS' : [5349, 59613, 5983],
|
||||
'DARTS-V1': [11416, 72873, 81184, 28640],
|
||||
'DARTS-V2': [43330, 79405, 79423],
|
||||
'GDAS' : [19677, 884, 95950],
|
||||
'SETN' : [20518, 61817, 89144],
|
||||
'ENAS' : [3231, 34238, 96929],
|
||||
}
|
||||
|
||||
def get_accs(xdata, dataset, subset):
|
||||
epochs, xresults = xdata['epoch'], []
|
||||
if -1 in xdata['genotypes']:
|
||||
metrics = api.arch2infos_full[ api.query_index_by_arch(xdata['genotypes'][-1]) ].get_metrics(dataset, subset, None, False)
|
||||
else:
|
||||
metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
|
||||
xresults.append( metrics['accuracy'] )
|
||||
for iepoch in range(epochs):
|
||||
genotype = xdata['genotypes'][iepoch]
|
||||
index = api.query_index_by_arch(genotype)
|
||||
metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False)
|
||||
xresults.append( metrics['accuracy'] )
|
||||
return xresults
|
||||
|
||||
if x_maxs == 50:
|
||||
xox, xxxstrs = 'v2', ['DARTS-V1', 'DARTS-V2']
|
||||
elif x_maxs == 250:
|
||||
xox, xxxstrs = 'v1', ['RSPS', 'GDAS', 'SETN', 'ENAS']
|
||||
else: raise ValueError('invalid x_maxs={:}'.format(x_maxs))
|
||||
|
||||
for idx, method in enumerate(xxxstrs):
|
||||
xkey = method
|
||||
all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ]
|
||||
all_datas = [torch.load(xpath, map_location='cpu') for xpath in all_paths]
|
||||
accyss_A = np.array( [get_accs(xdatas, data_sub_a[0], data_sub_a[1]) for xdatas in all_datas] )
|
||||
accyss_B = np.array( [get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas] )
|
||||
epochs = list(range(accyss_A.shape[1]))
|
||||
for j, accyss in enumerate([accyss_A, accyss_B]):
|
||||
plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx*2+j], linestyle='-' if j==0 else '--', label='{:} ({:})'.format(method, 'VALID' if j == 0 else 'TEST'), lw=2, alpha=0.9)
|
||||
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx*2+j])
|
||||
#plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||
save_path = vis_save_dir / '{:}-{:}'.format(xox, file_name)
|
||||
print('save figure into {:}\n'.format(save_path))
|
||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
|
||||
|
||||
def show_reinforce(api, root, dataset, xset, file_name, y_lims):
|
||||
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
|
||||
LRs = ['0.01', '0.02', '0.1', '0.2', '0.5', '1.0', '1.5', '2.0', '2.5', '3.0']
|
||||
checkpoints = ['./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth'.format(x) for x in LRs]
|
||||
acc_lr_dict, indexes = {}, None
|
||||
for lr, checkpoint in zip(LRs, checkpoints):
|
||||
all_indexes, accuracies = torch.load(checkpoint, map_location='cpu'), []
|
||||
for x in all_indexes:
|
||||
info = api.arch2infos_full[ x ]
|
||||
metrics = info.get_metrics(dataset, xset, None, False)
|
||||
accuracies.append( metrics['accuracy'] )
|
||||
if indexes is None: indexes = list(range(len(accuracies)))
|
||||
acc_lr_dict[lr] = np.array( sorted(accuracies) )
|
||||
print ('LR={:.3f}, mean={:}, std={:}'.format(float(lr), acc_lr_dict[lr].mean(), acc_lr_dict[lr].std()))
|
||||
|
||||
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||
dpi, width, height = 300, 3400, 2600
|
||||
LabelSize, LegendFontsize = 28, 22
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
fig = plt.figure(figsize=figsize)
|
||||
x_axis = np.arange(0, 600)
|
||||
plt.xlim(0, max(indexes))
|
||||
plt.ylim(y_lims[0], y_lims[1])
|
||||
interval_x, interval_y = 100, y_lims[2]
|
||||
plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
|
||||
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
||||
plt.grid()
|
||||
plt.xlabel('The index of runs', fontsize=LabelSize)
|
||||
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||
|
||||
for idx, LR in enumerate(LRs):
|
||||
legend = 'LR={:.2f}'.format(float(LR))
|
||||
color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
|
||||
plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8)
|
||||
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]), np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR])))
|
||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||
save_path = root / '{:}-{:}-{:}.pdf'.format(dataset, xset, file_name)
|
||||
print('save figure into {:}\n'.format(save_path))
|
||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
|
||||
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.')
|
||||
args = parser.parse_args()
|
||||
|
||||
vis_save_dir = Path(args.save_dir)
|
||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
meta_file = Path(args.api_path)
|
||||
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
|
||||
#visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time')
|
||||
#write_video(vis_save_dir / 'over-time')
|
||||
#visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
|
||||
#visualize_info(str(meta_file), 'cifar100', vis_save_dir)
|
||||
#visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
|
||||
#visualize_relative_ranking(vis_save_dir)
|
||||
|
||||
api = API(args.api_path)
|
||||
show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (75, 95, 5))
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
for x_maxs in [50, 250]:
|
||||
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
|
||||
show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'DARTS-CIFAR010.pdf', (0, 100,10), 50)
|
||||
show_nas_sharing_w_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ) , vis_save_dir, 'DARTS-CIFAR100.pdf', (0, 100,10), 50)
|
||||
show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ) , vis_save_dir, 'DARTS-ImageNet.pdf', (0, 100,10), 50)
|
||||
#just_show(api)
|
||||
"""
|
||||
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||
plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1))
|
||||
plot_results_nas_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3))
|
||||
plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2))
|
||||
"""
|
Reference in New Issue
Block a user