first commit

This commit is contained in:
CownowAn
2024-03-15 14:38:51 +00:00
commit bc2ed1304f
321 changed files with 44802 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
from pathlib import Path
import sys
dir_path = (Path(__file__).parent).resolve()
if str(dir_path) not in sys.path: sys.path.insert(0, str(dir_path))
from .architecture import train_single_model

View File

@@ -0,0 +1,173 @@
###############################################################
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
###############################################################
from functions import evaluate_for_seed
from nas_bench_201_models import CellStructure, CellArchitectures, get_search_spaces
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
from nas_bench_201_datasets import get_datasets
from procedures import get_machine_info
from procedures import save_checkpoint, copy_checkpoint
from config_utils import load_config
from pathlib import Path
from copy import deepcopy
import os
import sys
import time
import torch
import random
import argparse
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
NASBENCH201_CONFIG_PATH = os.path.join(
os.getcwd(), 'main_exp', 'transfer_nag')
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed,
arch_config, workers, logger):
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
all_infos = {'info': machine_info}
all_dataset_keys = []
# look all the datasets
for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data
task = None
train_data, valid_data, xshape, class_num = get_datasets(
dataset, xpath, -1, task)
# load the configuration
if dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
if use_less:
config_path = os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/LESS.config')
else:
config_path = os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}.config'.format(dataset))
p = os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset))
if not os.path.exists(p):
import json
label_list = list(range(len(train_data)))
random.shuffle(label_list)
strlist = [str(label_list[i]) for i in range(len(label_list))]
splited = {'train': ["int", strlist[:len(train_data) // 2]],
'valid': ["int", strlist[len(train_data) // 2:]]}
with open(p, 'w') as f:
f.write(json.dumps(splited))
split_info = load_config(os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset)), None, None)
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
config = load_config(
config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size,
shuffle=True, num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
shuffle=False, num_workers=workers, pin_memory=True)
splits = load_config(os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}-test-split.txt'.format(dataset)), None, None)
ValLoaders = {'ori-test': valid_loader,
'x-valid': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
splits.xvalid),
num_workers=workers, pin_memory=True),
'x-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
splits.xtest),
num_workers=workers, pin_memory=True)
}
dataset_key = '{:}'.format(dataset)
if bool(split):
dataset_key = dataset_key + '-valid'
logger.log(
'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.
format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(
dataset_key, config))
for key, value in ValLoaders.items():
logger.log(
'Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
results = evaluate_for_seed(
arch_config, config, arch, train_loader, ValLoaders, seed, logger)
all_infos[dataset_key] = results
all_dataset_keys.append(dataset_key)
all_infos['all_dataset_keys'] = all_dataset_keys
return all_infos
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
seeds, model_str, arch_config):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.set_num_threads(workers)
save_dir = Path(save_dir)
logger = Logger(str(save_dir), 0, False)
if model_str in CellArchitectures:
arch = CellArchitectures[model_str]
logger.log(
'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
else:
try:
arch = CellStructure.str2structure(model_str)
except:
raise ValueError(
'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
assert arch.check_valid_op(get_search_spaces(
'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
# assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
logger.log('Start train-evaluate {:}'.format(arch.tostr()))
logger.log('arch_config : {:}'.format(arch_config))
start_time, seed_time = time.time(), AverageMeter()
for _is, seed in enumerate(seeds):
logger.log(
'\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
seed))
to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
if to_save_name.exists():
logger.log(
'Find the existing file {:}, directly load!'.format(to_save_name))
checkpoint = torch.load(to_save_name)
else:
logger.log(
'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
seed, arch_config, workers, logger)
torch.save(checkpoint, to_save_name)
# log information
logger.log('{:}'.format(checkpoint['info']))
all_dataset_keys = checkpoint['all_dataset_keys']
for dataset_key in all_dataset_keys:
logger.log('\n{:} dataset : {:} {:}'.format(
'-' * 15, dataset_key, '-' * 15))
dataset_info = checkpoint[dataset_key]
# logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
logger.log('Flops = {:} MB, Params = {:} MB'.format(
dataset_info['flop'], dataset_info['param']))
logger.log('config : {:}'.format(dataset_info['config']))
logger.log('Training State (finish) = {:}'.format(
dataset_info['finish-train']))
last_epoch = dataset_info['total_epoch'] - 1
train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
# measure elapsed time
seed_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format(convert_secs2time(
seed_time.avg * (len(seeds) - _is - 1), True))
logger.log(
'\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
need_time))
logger.close()

View File

@@ -0,0 +1,13 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .configure_utils import load_config, dict2config#, configure2str
#from .basic_args import obtain_basic_args
#from .attention_args import obtain_attention_args
#from .random_baseline import obtain_RandomSearch_args
#from .cls_kd_args import obtain_cls_kd_args
#from .cls_init_args import obtain_cls_init_args
#from .search_single_args import obtain_search_single_args
#from .search_args import obtain_search_args
# for network pruning
#from .pruning_args import obtain_pruning_args

View File

@@ -0,0 +1,106 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os, json
from os import path as osp
from pathlib import Path
from collections import namedtuple
support_types = ('str', 'int', 'bool', 'float', 'none')
def convert_param(original_lists):
assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
ctype, value = original_lists[0], original_lists[1]
assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
is_list = isinstance(value, list)
if not is_list: value = [value]
outs = []
for x in value:
if ctype == 'int':
x = int(x)
elif ctype == 'str':
x = str(x)
elif ctype == 'bool':
x = bool(int(x))
elif ctype == 'float':
x = float(x)
elif ctype == 'none':
if x.lower() != 'none':
raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
x = None
else:
raise TypeError('Does not know this type : {:}'.format(ctype))
outs.append(x)
if not is_list: outs = outs[0]
return outs
def load_config(path, extra, logger):
path = str(path)
if hasattr(logger, 'log'): logger.log(path)
assert os.path.exists(path), 'Can not find {:}'.format(path)
# Reading data back
with open(path, 'r') as f:
data = json.load(f)
content = { k: convert_param(v) for k,v in data.items()}
assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra)
if isinstance(extra, dict): content = {**content, **extra}
Arguments = namedtuple('Configure', ' '.join(content.keys()))
content = Arguments(**content)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content
def configure2str(config, xpath=None):
if not isinstance(config, dict):
config = config._asdict()
def cstring(x):
return "\"{:}\"".format(x)
def gtype(x):
if isinstance(x, list): x = x[0]
if isinstance(x, str) : return 'str'
elif isinstance(x, bool) : return 'bool'
elif isinstance(x, int): return 'int'
elif isinstance(x, float): return 'float'
elif x is None : return 'none'
else: raise ValueError('invalid : {:}'.format(x))
def cvalue(x, xtype):
if isinstance(x, list): is_list = True
else:
is_list, x = False, [x]
temps = []
for temp in x:
if xtype == 'bool' : temp = cstring(int(temp))
elif xtype == 'none': temp = cstring('None')
else : temp = cstring(temp)
temps.append( temp )
if is_list:
return "[{:}]".format( ', '.join( temps ) )
else:
return temps[0]
xstrings = []
for key, value in config.items():
xtype = gtype(value)
string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype))
xstrings.append(string)
Fstring = '{\n' + ',\n'.join(xstrings) + '\n}'
if xpath is not None:
parent = Path(xpath).resolve().parent
parent.mkdir(parents=True, exist_ok=True)
if osp.isfile(xpath): os.remove(xpath)
with open(xpath, "w") as text_file:
text_file.write('{:}'.format(Fstring))
return Fstring
def dict2config(xdict, logger):
assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) )
Arguments = namedtuple('Configure', ' '.join(xdict.keys()))
content = Arguments(**xdict)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "200"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "50"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "200"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "200"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

View File

@@ -0,0 +1,153 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
import time
import torch
from procedures import prepare_seed, get_optim_scheduler
from nasbench_utils import get_model_infos, obtain_accuracy
from config_utils import dict2config
from log_utils import AverageMeter, time_string, convert_secs2time
from nas_bench_201_models import get_cell_based_tiny_net
__all__ = ['evaluate_for_seed', 'pure_evaluate']
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
latencies = []
network.eval()
with torch.no_grad():
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
targets = targets.cuda(non_blocking=True)
inputs = inputs.cuda(non_blocking=True)
data_time.update(time.time() - end)
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
batch_time.update(time.time() - end)
if batch is None or batch == inputs.size(0):
batch = inputs.size(0)
latencies.append(batch_time.val - data_time.val)
# record loss and accuracy
prec1, prec5 = obtain_accuracy(
logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
end = time.time()
if len(latencies) > 2:
latencies = latencies[1:]
return losses.avg, top1.avg, top5.avg, latencies
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
if mode == 'train':
network.train()
elif mode == 'valid':
network.eval()
else:
raise ValueError("The mode is not right : {:}".format(mode))
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train':
scheduler.update(None, 1.0 * i / len(xloader))
targets = targets.cuda(non_blocking=True)
if mode == 'train':
optimizer.zero_grad()
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
# backward
if mode == 'train':
loss.backward()
optimizer.step()
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# count time
batch_time.update(time.time() - end)
end = time.time()
return losses.avg, top1.avg, top5.avg, batch_time.sum
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger):
prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny',
'C': arch_config['channel'], 'N': arch_config['num_cells'],
'genotype': arch, 'num_classes': config.class_num}, None)
)
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
if 'ckpt_path' in arch_config.keys():
ckpt = torch.load(arch_config['ckpt_path'])
ckpt['classifier.weight'] = net.state_dict()['classifier.weight']
ckpt['classifier.bias'] = net.state_dict()['classifier.bias']
net.load_state_dict(ckpt)
flop, param = get_model_infos(net, config.xshape)
logger.log('Network : {:}'.format(net.get_message()), False)
logger.log(
'{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
# train and valid
optimizer, scheduler, criterion = get_optim_scheduler(
net.parameters(), config)
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
# network, criterion = torch.nn.DataParallel(net).to(torch.device(f"cuda:{device}")), criterion.to(torch.device(f"cuda:{device}"))
# start training
start_time, epoch_time, total_epoch = time.time(
), AverageMeter(), config.epochs + config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {
}, {}, {}, {}, {}, {}
train_times, valid_times = {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
train_loss, train_acc1, train_acc5, train_tm = procedure(
train_loader, network, criterion, scheduler, optimizer, 'train')
train_losses[epoch] = train_loss
train_acc1es[epoch] = train_acc1
train_acc5es[epoch] = train_acc5
train_times[epoch] = train_tm
with torch.no_grad():
for key, xloder in valid_loaders.items():
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
xloder, network, criterion, None, None, 'valid')
valid_losses['{:}@{:}'.format(key, epoch)] = valid_loss
valid_acc1es['{:}@{:}'.format(key, epoch)] = valid_acc1
valid_acc5es['{:}@{:}'.format(key, epoch)] = valid_acc5
valid_times['{:}@{:}'.format(key, epoch)] = valid_tm
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format(convert_secs2time(
epoch_time.avg * (total_epoch-epoch-1), True))
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(
time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5))
info_seed = {'flop': flop,
'param': param,
'channel': arch_config['channel'],
'num_cells': arch_config['num_cells'],
'config': config._asdict(),
'total_epoch': total_epoch,
'train_losses': train_losses,
'train_acc1es': train_acc1es,
'train_acc5es': train_acc5es,
'train_times': train_times,
'valid_losses': valid_losses,
'valid_acc1es': valid_acc1es,
'valid_acc5es': valid_acc5es,
'valid_times': valid_times,
'net_state_dict': net.state_dict(),
'net_string': '{:}'.format(net),
'finish-train': True
}
return info_seed

View File

@@ -0,0 +1,9 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# every package does not rely on pytorch or tensorflow
# I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib
from .logger import Logger#, PrintLogger
from .meter import AverageMeter
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_secs2time
from .time_utils import time_string, convert_secs2time

View File

@@ -0,0 +1,150 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from pathlib import Path
import importlib, warnings
import os, sys, time, numpy as np
if sys.version_info.major == 2: # Python 2.x
from StringIO import StringIO as BIO
else: # Python 3.x
from io import BytesIO as BIO
if importlib.util.find_spec('tensorflow'):
import tensorflow as tf
class PrintLogger(object):
def __init__(self):
"""Create a summary writer logging to log_dir."""
self.name = 'PrintLogger'
def log(self, string):
print (string)
def close(self):
print ('-'*30 + ' close printer ' + '-'*30)
class Logger(object):
def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
"""Create a summary writer logging to log_dir."""
self.seed = int(seed)
self.log_dir = Path(log_dir)
self.model_dir = Path(log_dir) / 'checkpoint'
self.log_dir.mkdir (parents=True, exist_ok=True)
if create_model_dir:
self.model_dir.mkdir(parents=True, exist_ok=True)
#self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
self.use_tf = bool(use_tf)
self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h', time.gmtime(time.time()) )))
#self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) )))
self.logger_path = self.log_dir / 'seed-{:}-T-{:}.log'.format(self.seed, time.strftime('%d-%h-at-%H-%M-%S', time.gmtime(time.time())))
self.logger_file = open(self.logger_path, 'w')
if self.use_tf:
self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
self.writer = tf.summary.FileWriter(str(self.tensorboard_dir))
else:
self.writer = None
def __repr__(self):
return ('{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__))
def path(self, mode):
valids = ('model', 'best', 'info', 'log')
if mode == 'model': return self.model_dir / 'seed-{:}-basic.pth'.format(self.seed)
elif mode == 'best' : return self.model_dir / 'seed-{:}-best.pth'.format(self.seed)
elif mode == 'info' : return self.log_dir / 'seed-{:}-last-info.pth'.format(self.seed)
elif mode == 'log' : return self.log_dir
else: raise TypeError('Unknow mode = {:}, valid modes = {:}'.format(mode, valids))
def extract_log(self):
return self.logger_file
def close(self):
self.logger_file.close()
if self.writer is not None:
self.writer.close()
def log(self, string, save=True, stdout=False):
if stdout:
sys.stdout.write(string); sys.stdout.flush()
else:
print (string)
if save:
self.logger_file.write('{:}\n'.format(string))
self.logger_file.flush()
def scalar_summary(self, tags, values, step):
"""Log a scalar variable."""
if not self.use_tf:
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
else:
assert isinstance(tags, list) == isinstance(values, list), 'Type : {:} vs {:}'.format(type(tags), type(values))
if not isinstance(tags, list):
tags, values = [tags], [values]
for tag, value in zip(tags, values):
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
self.writer.flush()
def image_summary(self, tag, images, step):
"""Log a list of images."""
import scipy
if not self.use_tf:
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
return
img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='{}/{}'.format(tag, i), image=img_sum))
# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
self.writer.flush()
def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""
if not self.use_tf: raise ValueError('Do not have tensorflow')
import tensorflow as tf
# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)
# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2))
# Drop the start of the first bin
bin_edges = bin_edges[1:]
# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()

View File

@@ -0,0 +1,98 @@
import numpy as np
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __repr__(self):
return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))
class RecorderMeter(object):
"""Computes and stores the minimum loss value and its epoch index"""
def __init__(self, total_epoch):
self.reset(total_epoch)
def reset(self, total_epoch):
assert total_epoch > 0, 'total_epoch should be greater than 0 vs {:}'.format(total_epoch)
self.total_epoch = total_epoch
self.current_epoch = 0
self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
self.epoch_losses = self.epoch_losses - 1
self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
self.epoch_accuracy= self.epoch_accuracy
def update(self, idx, train_loss, train_acc, val_loss, val_acc):
assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx)
self.epoch_losses [idx, 0] = train_loss
self.epoch_losses [idx, 1] = val_loss
self.epoch_accuracy[idx, 0] = train_acc
self.epoch_accuracy[idx, 1] = val_acc
self.current_epoch = idx + 1
return self.max_accuracy(False) == self.epoch_accuracy[idx, 1]
def max_accuracy(self, istrain):
if self.current_epoch <= 0: return 0
if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max()
else: return self.epoch_accuracy[:self.current_epoch, 1].max()
def plot_curve(self, save_path):
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
title = 'the accuracy/loss curve of train/val'
dpi = 100
width, height = 1600, 1000
legend_fontsize = 10
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
y_axis = np.zeros(self.total_epoch)
plt.xlim(0, self.total_epoch)
plt.ylim(0, 100)
interval_y = 5
interval_x = 5
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
plt.yticks(np.arange(0, 100 + interval_y, interval_y))
plt.grid()
plt.title(title, fontsize=20)
plt.xlabel('the training epoch', fontsize=16)
plt.ylabel('accuracy', fontsize=16)
y_axis[:] = self.epoch_accuracy[:, 0]
plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_accuracy[:, 1]
plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_losses[:, 0]
plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_losses[:, 1]
plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
if save_path is not None:
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
print ('---- save figure {} into {}'.format(title, save_path))
plt.close(fig)

View File

@@ -0,0 +1,42 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import time, sys
import numpy as np
def time_for_file():
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
return '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
def time_string():
ISOTIMEFORMAT='%Y-%m-%d %X'
string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def time_string_short():
ISOTIMEFORMAT='%Y%m%d'
string = '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def time_print(string, is_print=True):
if (is_print):
print('{} : {}'.format(time_string(), string))
def convert_secs2time(epoch_time, return_str=False):
need_hour = int(epoch_time / 3600)
need_mins = int((epoch_time - 3600*need_hour) / 60)
need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
if return_str:
str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
return str
else:
return need_hour, need_mins, need_secs
def print_log(print_string, log):
#if isinstance(log, Logger): log.log('{:}'.format(print_string))
if hasattr(log, 'log'): log.log('{:}'.format(print_string))
else:
print("{:}".format(print_string))
if log is not None:
log.write('{:}\n'.format(print_string))
log.flush()

View File

@@ -0,0 +1,4 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .get_dataset_with_transform import get_datasets

View File

@@ -0,0 +1,179 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import torch.utils.data as data
from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader
from PIL import Image
import os
import numpy as np
def make_dataset(dir, image_ids, targets):
assert (len(image_ids) == len(targets))
images = []
dir = os.path.expanduser(dir)
for i in range(len(image_ids)):
item = (os.path.join(dir, 'data', 'images',
'%s.jpg' % image_ids[i]), targets[i])
images.append(item)
return images
def find_classes(classes_file):
# read classes file, separating out image IDs and class names
image_ids = []
targets = []
f = open(classes_file, 'r')
for line in f:
split_line = line.split(' ')
image_ids.append(split_line[0])
targets.append(' '.join(split_line[1:]))
f.close()
# index class names
classes = np.unique(targets)
class_to_idx = {classes[i]: i for i in range(len(classes))}
targets = [class_to_idx[c] for c in targets]
return (image_ids, targets, classes, class_to_idx)
class FGVCAircraft(data.Dataset):
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
Args:
root (string): Root directory path to dataset.
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g. ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in the root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
class_types = ('variant', 'family', 'manufacturer')
splits = ('train', 'val', 'trainval', 'test')
def __init__(self, root, class_type='variant', split='train', transform=None,
target_transform=None, loader=default_loader, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
if class_type not in self.class_types:
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
class_type, ', '.join(self.class_types),
))
self.root = os.path.expanduser(root)
self.root = os.path.join(self.root, 'fgvc-aircraft-2013b')
self.class_type = class_type
self.split = split
self.classes_file = os.path.join(self.root, 'data',
'images_%s_%s.txt' % (self.class_type, self.split))
if download:
self.download()
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
samples = make_dataset(self.root, image_ids, targets)
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.samples = samples
self.classes = classes
self.class_to_idx = class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def _check_exists(self):
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
os.path.exists(self.classes_file)
def download(self):
"""Download the FGVC-Aircraft data if it doesn't exist already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
print('Downloading %s ... (may take a few minutes)' % self.url)
parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
tar_name = self.url.rpartition('/')[-1]
tar_path = os.path.join(parent_dir, tar_name)
data = urllib.request.urlopen(self.url)
# download .tar.gz file
with open(tar_path, 'wb') as f:
f.write(data.read())
# extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
data_folder = tar_path.strip('.tar.gz')
print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
tar = tarfile.open(tar_path)
tar.extractall(parent_dir)
# if necessary, rename data folder to self.root
if not os.path.samefile(data_folder, self.root):
print('Renaming %s to %s ...' % (data_folder, self.root))
os.rename(data_folder, self.root)
# delete .tar.gz file
print('Deleting %s ...' % tar_path)
os.remove(tar_path)
print('Done!')
if __name__ == '__main__':
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='train', transform=None,
target_transform=None, loader=default_loader, download=False)
print(len(air))
print(len(air))
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='val', transform=None,
target_transform=None, loader=default_loader, download=False)
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='trainval', transform=None,
target_transform=None, loader=default_loader, download=False)
print(len(air))
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b/', class_type='manufacturer', split='test', transform=None,
target_transform=None, loader=default_loader, download=False)
print(len(air))
import pdb;
pdb.set_trace()
print(len(air))

View File

@@ -0,0 +1,304 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
# Modified by Hayeon Lee, Eunyoung Hyung 2021. 03.
##################################################
import os
import sys
import torch
import os.path as osp
import numpy as np
import torchvision.datasets as dset
import torchvision.transforms as transforms
from copy import deepcopy
# from PIL import Image
import random
import pdb
from .aircraft import FGVCAircraft
from .pets import PetDataset
from config_utils import load_config
Dataset2Class = {'cifar10': 10,
'cifar100': 100,
'mnist': 10,
'svhn': 10,
'aircraft': 30,
'pets': 37}
class CUTOUT(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
imagenet_pca = {
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
'eigvec': np.asarray([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
class Lighting(object):
def __init__(self, alphastd,
eigval=imagenet_pca['eigval'],
eigvec=imagenet_pca['eigvec']):
self.alphastd = alphastd
assert eigval.shape == (3,)
assert eigvec.shape == (3, 3)
self.eigval = eigval
self.eigvec = eigvec
def __call__(self, img):
if self.alphastd == 0.:
return img
rnd = np.random.randn(3) * self.alphastd
rnd = rnd.astype('float32')
v = rnd
old_dtype = np.asarray(img).dtype
v = v * self.eigval
v = v.reshape((3, 1))
inc = np.dot(self.eigvec, v).reshape((3,))
img = np.add(img, inc)
if old_dtype == np.uint8:
img = np.clip(img, 0, 255)
img = Image.fromarray(img.astype(old_dtype), 'RGB')
return img
def __repr__(self):
return self.__class__.__name__ + '()'
def get_datasets(name, root, cutout, use_num_cls=None):
if name == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name.startswith('mnist'):
mean, std = [0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081]
elif name.startswith('svhn'):
mean, std = [0.4376821, 0.4437697, 0.47280442], [
0.19803012, 0.20101562, 0.19703614]
elif name.startswith('aircraft'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name.startswith('pets'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
else:
raise TypeError("Unknow dataset : {:}".format(name))
# Data Argumentation
if name == 'cifar10' or name == 'cifar100':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
transforms.Normalize(mean, std)]
if cutout > 0:
lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name.startswith('cub200'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
xshape = (1, 3, 32, 32)
elif name.startswith('mnist'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean, std),
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean, std)
])
xshape = (1, 3, 32, 32)
elif name.startswith('svhn'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
xshape = (1, 3, 32, 32)
elif name.startswith('aircraft'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
xshape = (1, 3, 32, 32)
elif name.startswith('pets'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
xshape = (1, 3, 32, 32)
else:
raise TypeError("Unknow dataset : {:}".format(name))
if name == 'cifar10':
train_data = dset.CIFAR10(
root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR10(
root, train=False, transform=test_transform, download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'cifar100':
train_data = dset.CIFAR100(
root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR100(
root, train=False, transform=test_transform, download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'mnist':
train_data = dset.MNIST(
root, train=True, transform=train_transform, download=True)
test_data = dset.MNIST(
root, train=False, transform=test_transform, download=True)
assert len(train_data) == 60000 and len(test_data) == 10000
elif name == 'svhn':
train_data = dset.SVHN(root, split='train',
transform=train_transform, download=True)
test_data = dset.SVHN(root, split='test',
transform=test_transform, download=True)
assert len(train_data) == 73257 and len(test_data) == 26032
elif name == 'aircraft':
train_data = FGVCAircraft(root, class_type='manufacturer', split='trainval',
transform=train_transform, download=False)
test_data = FGVCAircraft(root, class_type='manufacturer', split='test',
transform=test_transform, download=False)
assert len(train_data) == 6667 and len(test_data) == 3333
elif name == 'pets':
train_data = PetDataset(root, train=True, num_cl=37,
val_split=0.15, transforms=train_transform)
test_data = PetDataset(root, train=False, num_cl=37,
val_split=0.15, transforms=test_transform)
else:
raise TypeError("Unknow dataset : {:}".format(name))
class_num = Dataset2Class[name] if use_num_cls is None else len(
use_num_cls)
return train_data, test_data, xshape, class_num
def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, num_cls=None):
if isinstance(batch_size, (list, tuple)):
batch, test_batch = batch_size
else:
batch, test_batch = batch_size, batch_size
if dataset == 'cifar10':
# split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(
'{:}/cifar-split.txt'.format(config_root), None, None)
# search over the proposed training and validation set
train_split, valid_split = cifar_split.train, cifar_split.valid
# logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
# To split data
xvalid_data = deepcopy(train_data)
if hasattr(xvalid_data, 'transforms'): # to avoid a print issue
xvalid_data.transforms = valid_data.transform
xvalid_data.transform = deepcopy(valid_data.transform)
search_data = SearchDataset(
dataset, train_data, train_split, valid_split)
# data loader
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
train_split),
num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
valid_split),
num_workers=workers, pin_memory=True)
elif dataset == 'cifar100':
cifar100_test_split = load_config(
'{:}/cifar100-test-split.txt'.format(config_root), None, None)
search_train_data = train_data
search_valid_data = deepcopy(valid_data)
search_valid_data.transform = train_data.transform
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
list(range(len(search_train_data))),
cifar100_test_split.xvalid)
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, num_workers=workers,
pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_test_split.xvalid), num_workers=workers, pin_memory=True)
elif dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
if not os.path.exists('{:}/{}-test-split.txt'.format(config_root, dataset)):
import json
label_list = list(range(len(valid_data)))
random.shuffle(label_list)
strlist = [str(label_list[i]) for i in range(len(label_list))]
split = {'xvalid': ["int", strlist[:len(valid_data) // 2]],
'xtest': ["int", strlist[len(valid_data) // 2:]]}
with open('{:}/{}-test-split.txt'.format(config_root, dataset), 'w') as f:
f.write(json.dumps(split))
test_split = load_config(
'{:}/{}-test-split.txt'.format(config_root, dataset), None, None)
search_train_data = train_data
search_valid_data = deepcopy(valid_data)
search_valid_data.transform = train_data.transform
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
list(range(len(search_train_data))), test_split.xvalid)
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True,
num_workers=workers, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True,
num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
test_split.xvalid), num_workers=workers, pin_memory=True)
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
return search_loader, train_loader, valid_loader

View File

@@ -0,0 +1,45 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import torch
from glob import glob
from torch.utils.data.dataset import Dataset
import os
from PIL import Image
def load_image(filename):
img = Image.open(filename)
img = img.convert('RGB')
return img
class PetDataset(Dataset):
def __init__(self, root, train=True, num_cl=37, val_split=0.2, transforms=None):
self.data = torch.load(os.path.join(root,'{}{}.pth'.format('train' if train else 'test',
int(100*(1-val_split)) if train else int(100*val_split))))
self.len = len(self.data)
self.transform = transforms
def __getitem__(self, index):
img, label = self.data[index]
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return self.len
if __name__ == '__main__':
# Added
import torchvision.transforms as transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose(
[transforms.Resize(256), transforms.RandomRotation(45), transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
test_transform = transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
root = '/w14/dataset/MetaGen/pets'
train_data, test_data = get_pets(root, num_cl=37, val_split=0.2,
tr_transform=train_transform,
te_transform=test_transform)
import pdb;
pdb.set_trace()

View File

@@ -0,0 +1,34 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
def additive_func(A, B):
assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size())
C = min(A.size(1), B.size(1))
if A.size(1) == B.size(1):
return A + B
elif A.size(1) < B.size(1):
out = B.clone()
out[:,:C] += A
return out
else:
out = A.clone()
out[:,:C] += B
return out
def change_key(key, value):
def func(m):
if hasattr(m, key):
setattr(m, key, value)
return func
def parse_channel_info(xstring):
blocks = xstring.split(' ')
blocks = [x.split('-') for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

View File

@@ -0,0 +1,45 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from os import path as osp
from typing import List, Text
import torch
__all__ = ['get_cell_based_tiny_net', 'get_search_spaces', \
'CellStructure', 'CellArchitectures'
]
# useful modules
from config_utils import dict2config
from .SharedUtils import change_key
from .cell_searchs import CellStructure, CellArchitectures
# Cell-based NAS Models
def get_cell_based_tiny_net(config):
if config.name == 'infer.tiny':
from .cell_infers import TinyNetwork
if hasattr(config, 'genotype'):
genotype = config.genotype
elif hasattr(config, 'arch_str'):
genotype = CellStructure.str2structure(config.arch_str)
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
else:
raise ValueError('invalid network name : {:}'.format(config.name))
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
def get_search_spaces(xtype, name) -> List[Text]:
if xtype == 'cell' or xtype == 'tss': # The topology search space.
from .cell_operations import SearchSpaceNames
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
return SearchSpaceNames[name]
elif xtype == 'sss': # The size search space.
if name == 'nas-bench-301':
return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64],
'numbers': 5}
else:
raise ValueError('Invalid name : {:}'.format(name))
else:
raise ValueError('invalid search-space type is {:}'.format(xtype))

View File

@@ -0,0 +1,4 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .tiny_network import TinyNetwork

View File

@@ -0,0 +1,122 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import OPS
# Cell for NAS-Bench-201
class InferCell(nn.Module):
def __init__(self, genotype, C_in, C_out, stride):
super(InferCell, self).__init__()
self.layers = nn.ModuleList()
self.node_IN = []
self.node_IX = []
self.genotype = deepcopy(genotype)
for i in range(1, len(genotype)):
node_info = genotype[i-1]
cur_index = []
cur_innod = []
for (op_name, op_in) in node_info:
if op_in == 0:
layer = OPS[op_name](C_in , C_out, stride, True, True)
else:
layer = OPS[op_name](C_out, C_out, 1, True, True)
# import pdb; pdb.set_trace()
cur_index.append( len(self.layers) )
cur_innod.append( op_in )
self.layers.append( layer )
self.node_IX.append( cur_index )
self.node_IN.append( cur_innod )
self.nodes = len(genotype)
self.in_dim = C_in
self.out_dim = C_out
def extra_repr(self):
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
laystr = []
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
x = '{:}<-({:})'.format(i+1, ','.join(y))
laystr.append( x )
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
def forward(self, inputs):
nodes = [inputs]
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
nodes.append( node_feature )
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetInferCell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
super(NASNetInferCell, self).__init__()
self.reduction = reduction
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
if not reduction:
nodes, concats = genotype['normal'], genotype['normal_concat']
else:
nodes, concats = genotype['reduce'], genotype['reduce_concat']
self._multiplier = len(concats)
self._concats = concats
self._steps = len(nodes)
self._nodes = nodes
self.edges = nn.ModuleDict()
for i, node in enumerate(nodes):
for in_node in node:
name, j = in_node[0], in_node[1]
stride = 2 if reduction and j < 2 else 1
node_str = '{:}<-{:}'.format(i+2, j)
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
# [TODO] to support drop_prob in this function..
def forward(self, s0, s1, unused_drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i, node in enumerate(self._nodes):
clist = []
for in_node in node:
name, j = in_node[0], in_node[1]
node_str = '{:}<-{:}'.format(i+2, j)
op = self.edges[ node_str ]
clist.append( op(states[j]) )
states.append( sum(clist) )
return torch.cat([states[x] for x in self._concats], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x

View File

@@ -0,0 +1,66 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
from ..cell_operations import ResNetBasicblock
from .cells import InferCell
# The macro structure for architectures in NAS-Bench-201
class TinyNetwork(nn.Module):
def __init__(self, C, N, genotype, num_classes):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
else:
cell = InferCell(genotype, C_prev, C_curr, 1)
self.cells.append( cell )
C_prev = cell.out_dim
self._Layer= len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
'''
out2 = self.lastact(feature)
out = self.global_pooling( out2 )
out = out.view(out.size(0), -1)
out2 = out2.view(out2.size(0), -1)
logits = self.classifier(out)
return out2, logits
'''
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,308 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = {
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats),
'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats),
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
}
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'nas-bench-201': NAS_BENCH_201,
'nas-bench-301': NAS_BENCH_201,
'darts' : DARTS_SPACE}
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
)
def forward(self, x):
return self.op(x)
class DualSepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(DualSepConv, self).__init__()
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
def forward(self, x):
x = self.op_a(x)
x = self.op_b(x)
return x
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, affine=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def extra_repr(self):
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
return string
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
return residual + basicblock
class POOLING(nn.Module):
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1, affine, track_running_stats)
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
def forward(self, inputs):
if self.preprocess: x = self.preprocess(inputs)
else : x = inputs
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x):
if self.C_in == self.C_out:
if self.stride == 1: return x.mul(0.)
else : return x[:,:,::self.stride,::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 1:
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
def forward(self, x):
if self.stride == 2:
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
else:
out = self.conv(x)
out = self.bn(out)
return out
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
class PartAwareOp(nn.Module):
def __init__(self, C_in, C_out, stride, part=4):
super().__init__()
self.part = 4
self.hidden = C_in // 3
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.local_conv_list = nn.ModuleList()
for i in range(self.part):
self.local_conv_list.append(
nn.Sequential(nn.ReLU(), nn.Conv2d(C_in, self.hidden, 1), nn.BatchNorm2d(self.hidden, affine=True))
)
self.W_K = nn.Linear(self.hidden, self.hidden)
self.W_Q = nn.Linear(self.hidden, self.hidden)
if stride == 2 : self.last = FactorizedReduce(C_in + self.hidden, C_out, 2)
elif stride == 1: self.last = FactorizedReduce(C_in + self.hidden, C_out, 1)
else: raise ValueError('Invalid Stride : {:}'.format(stride))
def forward(self, x):
batch, C, H, W = x.size()
assert H >= self.part, 'input size too small : {:} vs {:}'.format(x.shape, self.part)
IHs = [0]
for i in range(self.part): IHs.append( min(H, int((i+1)*(float(H)/self.part))) )
local_feat_list = []
for i in range(self.part):
feature = x[:, :, IHs[i]:IHs[i+1], :]
xfeax = self.avg_pool(feature)
xfea = self.local_conv_list[i]( xfeax )
local_feat_list.append( xfea )
part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part)
part_feature = part_feature.transpose(1,2).contiguous()
part_K = self.W_K(part_feature)
part_Q = self.W_Q(part_feature).transpose(1,2).contiguous()
weight_att = torch.bmm(part_K, part_Q)
attention = torch.softmax(weight_att, dim=2)
aggreateF = torch.bmm(attention, part_feature).transpose(1,2).contiguous()
features = []
for i in range(self.part):
feature = aggreateF[:, :, i:i+1].expand(batch, self.hidden, IHs[i+1]-IHs[i])
feature = feature.view(batch, self.hidden, IHs[i+1]-IHs[i], 1)
features.append( feature )
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W)
final_fea = torch.cat((x,features), dim=1)
outputs = self.last( final_fea )
return outputs
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob)
x = torch.div(x, keep_prob)
x.mul_(mask)
return x
# Searching for A Robust Neural Architecture in Four GPU Hours
class GDAS_Reduction_Cell(nn.Module):
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier, affine, track_running_stats):
super(GDAS_Reduction_Cell, self).__init__()
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine, track_running_stats)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, affine, track_running_stats)
self.multiplier = multiplier
self.reduction = True
self.ops1 = nn.ModuleList(
[nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True)),
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True))])
self.ops2 = nn.ModuleList(
[nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.BatchNorm2d(C, affine=True)),
nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
nn.BatchNorm2d(C, affine=True))])
def forward(self, s0, s1, drop_prob = -1):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
X0 = self.ops1[0] (s0)
X1 = self.ops1[1] (s1)
if self.training and drop_prob > 0.:
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
#X2 = self.ops2[0] (X0+X1)
X2 = self.ops2[0] (s0)
X3 = self.ops2[1] (s1)
if self.training and drop_prob > 0.:
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
return torch.cat([X0, X1, X2, X3], dim=1)

View File

@@ -0,0 +1,26 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# The macro structure is defined in NAS-Bench-201
# from .search_model_darts import TinyNetworkDarts
# from .search_model_gdas import TinyNetworkGDAS
# from .search_model_setn import TinyNetworkSETN
# from .search_model_enas import TinyNetworkENAS
# from .search_model_random import TinyNetworkRANDOM
# from .generic_model import GenericNAS201Model
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
# NASNet-based macro structure
# from .search_model_gdas_nasnet import NASNetworkGDAS
# from .search_model_darts_nasnet import NASNetworkDARTS
# nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
# "DARTS-V2": TinyNetworkDarts,
# "GDAS": TinyNetworkGDAS,
# "SETN": TinyNetworkSETN,
# "ENAS": TinyNetworkENAS,
# "RANDOM": TinyNetworkRANDOM,
# "generic": GenericNAS201Model}
# nasnet_super_nets = {"GDAS": NASNetworkGDAS,
# "DARTS": NASNetworkDARTS}

View File

@@ -0,0 +1,198 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from copy import deepcopy
def get_combination(space, num):
combs = []
for i in range(num):
if i == 0:
for func in space:
combs.append( [(func, i)] )
else:
new_combs = []
for string in combs:
for func in space:
xstring = string + [(func, i)]
new_combs.append( xstring )
combs = new_combs
return combs
class Structure:
def __init__(self, genotype):
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
self.node_num = len(genotype) + 1
self.nodes = []
self.node_N = []
for idx, node_info in enumerate(genotype):
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
for node_in in node_info:
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
self.node_N.append( len(node_info) )
self.nodes.append( tuple(deepcopy(node_info)) )
def tolist(self, remove_str):
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
# note that we re-order the input node in this function
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
genotypes = []
for node_info in self.nodes:
node_info = list( node_info )
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
if len(node_info) == 0: return None, False
genotypes.append( node_info )
return genotypes, True
def node(self, index):
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
return self.nodes[index]
def tostr(self):
strings = []
for node_info in self.nodes:
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
string = '|{:}|'.format(string)
strings.append( string )
return '+'.join(strings)
def check_valid(self):
nodes = {0: True}
for i, node_info in enumerate(self.nodes):
sums = []
for op, xin in node_info:
if op == 'none' or nodes[xin] is False: x = False
else: x = True
sums.append( x )
nodes[i+1] = sum(sums) > 0
return nodes[len(self.nodes)]
def to_unique_str(self, consider_zero=False):
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
# two operations are special, i.e., none and skip_connect
nodes = {0: '0'}
for i_node, node_info in enumerate(self.nodes):
cur_node = []
for op, xin in node_info:
if consider_zero is None:
x = '('+nodes[xin]+')' + '@{:}'.format(op)
elif consider_zero:
if op == 'none' or nodes[xin] == '#': x = '#' # zero
elif op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
else:
if op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
cur_node.append(x)
nodes[i_node+1] = '+'.join( sorted(cur_node) )
return nodes[ len(self.nodes) ]
def check_valid_op(self, op_names):
for node_info in self.nodes:
for inode_edge in node_info:
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
if inode_edge[0] not in op_names: return False
return True
def __repr__(self):
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
def __len__(self):
return len(self.nodes) + 1
def __getitem__(self, index):
return self.nodes[index]
@staticmethod
def str2structure(xstr):
if isinstance(xstr, Structure): return xstr
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
genotypes.append( input_infos )
return Structure( genotypes )
@staticmethod
def str2fullstructure(xstr, default_name='none'):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
all_in_nodes= list(x[1] for x in input_infos)
for j in range(i):
if j not in all_in_nodes: input_infos.append((default_name, j))
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
genotypes.append( tuple(node_info) )
return Structure( genotypes )
@staticmethod
def gen_all(search_space, num, return_ori):
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
all_archs = get_combination(search_space, 1)
for i, arch in enumerate(all_archs):
all_archs[i] = [ tuple(arch) ]
for inode in range(2, num):
cur_nodes = get_combination(search_space, inode)
new_all_archs = []
for previous_arch in all_archs:
for cur_node in cur_nodes:
new_all_archs.append( previous_arch + [tuple(cur_node)] )
all_archs = new_all_archs
if return_ori:
return all_archs
else:
return [Structure(x) for x in all_archs]
ResNet_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 1), ), # node-2
(('skip_connect', 0), ('skip_connect', 2))] # node-3
)
AllConv3x3_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
)
AllFull_CODE = Structure(
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
)
AllConv1x1_CODE = Structure(
[(('nor_conv_1x1', 0), ), # node-1
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
)
AllIdentity_CODE = Structure(
[(('skip_connect', 0), ), # node-1
(('skip_connect', 0), ('skip_connect', 1)), # node-2
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
)
architectures = {'resnet' : ResNet_CODE,
'all_c3x3': AllConv3x3_CODE,
'all_c1x1': AllConv1x1_CODE,
'all_idnt': AllIdentity_CODE,
'all_full': AllFull_CODE}

View File

@@ -0,0 +1,167 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual):
super(InferCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,150 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
else:
self.downsample = None
self.out_dim = planes*self.expansion
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferDepthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
super(InferDepthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.channels = [16]
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, planes, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.channels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,160 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferWidthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
super(InferWidthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,170 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=True, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferImagenetResNet(nn.Module):
def __init__(self, block_name, layers, xblocks, xchannels, deep_stem, num_classes, zero_init_residual):
super(InferImagenetResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'BasicBlock':
block = ResNetBasicblock
elif block_name == 'Bottleneck':
block = ResNetBottleneck
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == len(layers), 'invalid layers : {:} vs xblocks : {:}'.format(layers, xblocks)
self.message = 'InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}'.format(sum(layers)*block.num_conv, sum(xblocks)*block.num_conv, xblocks)
self.num_classes = num_classes
self.xchannels = xchannels
if not deep_stem:
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
else:
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True)
,ConvBNReLU(xchannels[1], xchannels[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 2
self.layers.append( nn.MaxPool2d(kernel_size=3, stride=2, padding=1) )
for stage, layer_blocks in enumerate(layers):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
assert last_channel_idx + 1 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,122 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
from torch import nn
from ..initialization import initialize_resnet
from ..SharedUtils import parse_channel_info
class ConvBNReLU(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, groups, has_bn=True, has_relu=True):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False)
if has_bn: self.bn = nn.BatchNorm2d(out_planes)
else : self.bn = None
if has_relu: self.relu = nn.ReLU6(inplace=True)
else : self.relu = None
def forward(self, x):
out = self.conv( x )
if self.bn: out = self.bn ( out )
if self.relu: out = self.relu( out )
return out
class InvertedResidual(nn.Module):
def __init__(self, channels, stride, expand_ratio, additive):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2], 'invalid stride : {:}'.format(stride)
assert len(channels) in [2, 3], 'invalid channels : {:}'.format(channels)
if len(channels) == 2:
layers = []
else:
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
layers.extend([
# dw
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
# pw-linear
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
])
self.conv = nn.Sequential(*layers)
self.additive = additive
if self.additive and channels[0] != channels[-1]:
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
else:
self.shortcut = None
self.out_dim = channels[-1]
def forward(self, x):
out = self.conv(x)
# if self.additive: return additive_func(out, x)
if self.shortcut: return out + self.shortcut(x)
else : return out
class InferMobileNetV2(nn.Module):
def __init__(self, num_classes, xchannels, xblocks, dropout):
super(InferMobileNetV2, self).__init__()
block = InvertedResidual
inverted_residual_setting = [
# t, c, n, s
[1, 16 , 1, 1],
[6, 24 , 2, 2],
[6, 32 , 3, 2],
[6, 64 , 4, 2],
[6, 96 , 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
assert len(inverted_residual_setting) == len(xblocks), 'invalid number of layers : {:} vs {:}'.format(len(inverted_residual_setting), len(xblocks))
for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
assert block_num <= ir_setting[2], '{:} vs {:}'.format(block_num, ir_setting)
xchannels = parse_channel_info(xchannels)
#for i, chs in enumerate(xchannels):
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
self.xchannels = xchannels
self.message = 'InferMobileNetV2 : xblocks={:}'.format(xblocks)
# building first layer
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
last_channel_idx = 1
# building inverted residual blocks
for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
for i in range(n):
stride = s if i == 0 else 1
additv = True if i > 0 else False
module = block(self.xchannels[last_channel_idx], stride, t, additv)
features.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(stage, i, n, len(features), self.xchannels[last_channel_idx], stride, t, c)
last_channel_idx += 1
if i + 1 == xblocks[stage]:
out_channel = module.out_dim
for iiL in range(i+1, n):
last_channel_idx += 1
self.xchannels[last_channel_idx][0] = module.out_dim
break
# building last several layers
features.append(ConvBNReLU(self.xchannels[last_channel_idx][0], self.xchannels[last_channel_idx][1], 1, 1, 1))
assert last_channel_idx + 2 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
)
# weight initialization
self.apply( initialize_resnet )
def get_message(self):
return self.message
def forward(self, inputs):
features = self.features(inputs)
vectors = features.mean([2, 3])
predicts = self.classifier(vectors)
return features, predicts

View File

@@ -0,0 +1,58 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from typing import List, Text, Any
import torch.nn as nn
from models.cell_operations import ResNetBasicblock
from models.cell_infers.cells import InferCell
class DynamicShapeTinyNet(nn.Module):
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
super(DynamicShapeTinyNet, self).__init__()
self._channels = channels
if len(channels) % 3 != 2:
raise ValueError('invalid number of layers : {:}'.format(len(channels)))
self._num_stage = N = len(channels) // 3
self.stem = nn.Sequential(
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(channels[0]))
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
c_prev = channels[0]
self.cells = nn.ModuleList()
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True)
else : cell = InferCell(genotype, c_prev, c_curr, 1)
self.cells.append( cell )
c_prev = cell.out_dim
self._num_layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(c_prev, num_classes)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .InferCifarResNet_width import InferWidthCifarResNet
from .InferImagenetResNet import InferImagenetResNet
from .InferCifarResNet_depth import InferDepthCifarResNet
from .InferCifarResNet import InferCifarResNet
from .InferMobileNetV2 import InferMobileNetV2
from .InferTinyCellNet import DynamicShapeTinyNet

View File

@@ -0,0 +1,5 @@
def parse_channel_info(xstring):
blocks = xstring.split(' ')
blocks = [x.split('-') for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

View File

@@ -0,0 +1,2 @@
from .evaluation_utils import obtain_accuracy
from .flop_benchmark import get_model_infos

View File

@@ -0,0 +1,17 @@
import torch
def obtain_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
# correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

View File

@@ -0,0 +1,181 @@
import torch
import torch.nn as nn
import numpy as np
def count_parameters_in_MB(model):
if isinstance(model, nn.Module):
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
else:
return np.sum(np.prod(v.size()) for v in model)/1e6
def get_model_infos(model, shape):
#model = copy.deepcopy( model )
model = add_flops_counting_methods(model)
#model = model.cuda()
model.eval()
#cache_inputs = torch.zeros(*shape).cuda()
#cache_inputs = torch.zeros(*shape)
cache_inputs = torch.rand(*shape)
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda()
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
with torch.no_grad():
_____ = model(cache_inputs)
FLOPs = compute_average_flops_cost( model ) / 1e6
Param = count_parameters_in_MB(model)
if hasattr(model, 'auxiliary_param'):
aux_params = count_parameters_in_MB(model.auxiliary_param())
print ('The auxiliary params of this model is : {:}'.format(aux_params))
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param))
Param = Param - aux_params
#print_log('FLOPs : {:} MB'.format(FLOPs), log)
torch.cuda.empty_cache()
model.apply( remove_hook_function )
return FLOPs, Param
# ---- Public functions
def add_flops_counting_methods( model ):
model.__batch_counter__ = 0
add_batch_counter_hook_function( model )
model.apply( add_flops_counter_variable_or_reset )
model.apply( add_flops_counter_hook_function )
return model
def compute_average_flops_cost(model):
"""
A method that will be available after add_flops_counting_methods() is called on a desired net object.
Returns current mean flops consumption per image.
"""
batches_count = model.__batch_counter__
flops_sum = 0
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
for module in model.modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.Conv1d) \
or hasattr(module, 'calculate_flop_self'):
flops_sum += module.__flops__
return flops_sum / batches_count
# ---- Internal functions
def pool_flops_counter_hook(pool_module, inputs, output):
batch_size = inputs[0].size(0)
kernel_size = pool_module.kernel_size
out_C, output_height, output_width = output.shape[1:]
assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size())
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
pool_module.__flops__ += overall_flops
def self_calculate_flops_counter_hook(self_module, inputs, output):
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
self_module.__flops__ += overall_flops
def fc_flops_counter_hook(fc_module, inputs, output):
batch_size = inputs[0].size(0)
xin, xout = fc_module.in_features, fc_module.out_features
assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout)
overall_flops = batch_size * xin * xout
if fc_module.bias is not None:
overall_flops += batch_size * xout
fc_module.__flops__ += overall_flops
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
batch_size = inputs[0].size(0)
outL = outputs.shape[-1]
[kernel] = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel * in_channels * out_channels / groups
active_elements_count = batch_size * outL
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def conv2d_flops_counter_hook(conv_module, inputs, output):
batch_size = inputs[0].size(0)
output_height, output_width = output.shape[2:]
kernel_height, kernel_width = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
active_elements_count = batch_size * output_height * output_width
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def batch_counter_hook(module, inputs, output):
# Can have multiple inputs, getting the first one
inputs = inputs[0]
batch_size = inputs.shape[0]
module.__batch_counter__ += batch_size
def add_batch_counter_hook_function(module):
if not hasattr(module, '__batch_counter_handle__'):
handle = module.register_forward_hook(batch_counter_hook)
module.__batch_counter_handle__ = handle
def add_flops_counter_variable_or_reset(module):
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.Conv1d) \
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
or hasattr(module, 'calculate_flop_self'):
module.__flops__ = 0
def add_flops_counter_hook_function(module):
if isinstance(module, torch.nn.Conv2d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv2d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Conv1d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv1d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Linear):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(fc_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(pool_flops_counter_hook)
module.__flops_handle__ = handle
elif hasattr(module, 'calculate_flop_self'): # self-defined module
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
module.__flops_handle__ = handle
def remove_hook_function(module):
hookers = ['__batch_counter_handle__', '__flops_handle__']
for hooker in hookers:
if hasattr(module, hooker):
handle = getattr(module, hooker)
handle.remove()
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
for ckey in keys:
if hasattr(module, ckey): delattr(module, ckey)

View File

@@ -0,0 +1,28 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .starts import get_machine_info, save_checkpoint, copy_checkpoint
from .optimizers import get_optim_scheduler
from .starts import prepare_seed #, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
'''
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
from .funcs_nasbench import get_nas_bench_loaders
def get_procedures(procedure):
from .basic_main import basic_train, basic_valid
from .search_main import search_train, search_valid
from .search_main_v2 import search_train_v2
from .simple_KD_main import simple_KD_train, simple_KD_valid
train_funcs = {'basic' : basic_train, \
'search': search_train,'Simple-KD': simple_KD_train, \
'search-v2': search_train_v2}
valid_funcs = {'basic' : basic_valid, \
'search': search_valid,'Simple-KD': simple_KD_valid, \
'search-v2': search_valid}
train_func = train_funcs[procedure]
valid_func = valid_funcs[procedure]
return train_func, valid_func
'''

View File

@@ -0,0 +1,204 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import math, torch
import torch.nn as nn
from bisect import bisect_right
from torch.optim import Optimizer
class _LRScheduler(object):
def __init__(self, optimizer, warmup_epochs, epochs):
if not isinstance(optimizer, Optimizer):
raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__))
self.optimizer = optimizer
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.max_epochs = epochs
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
self.current_iter = 0
def extra_repr(self):
return ''
def __repr__(self):
return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__)
+ ', {:})'.format(self.extra_repr()))
def state_dict(self):
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
raise NotImplementedError
def get_min_info(self):
lrs = self.get_lr()
return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter)
def get_min_lr(self):
return min( self.get_lr() )
def update(self, cur_epoch, cur_iter):
if cur_epoch is not None:
assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch)
self.current_epoch = cur_epoch
if cur_iter is not None:
assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter)
self.current_iter = cur_iter
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class CosineAnnealingLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
#if last_epoch < self.T_max:
#if last_epoch < self.max_epochs:
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
#else:
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
elif self.current_epoch >= self.max_epochs:
lr = self.eta_min
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class MultiStepLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas))
self.milestones = milestones
self.gammas = gammas
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
idx = bisect_right(self.milestones, last_epoch)
lr = base_lr
for x in self.gammas[:idx]: lr *= x
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class ExponentialLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
self.gamma = gamma
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
lr = base_lr * (self.gamma ** last_epoch)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class LinearLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
self.max_LR = max_LR
self.min_LR = min_LR
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
lr = base_lr * (1-ratio)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def get_optim_scheduler(parameters, config):
assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config)
if config.optim == 'SGD':
optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov)
elif config.optim == 'RMSprop':
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
else:
raise ValueError('invalid optim : {:}'.format(config.optim))
if config.scheduler == 'cos':
T_max = getattr(config, 'T_max', config.epochs)
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
elif config.scheduler == 'multistep':
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
elif config.scheduler == 'exponential':
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
elif config.scheduler == 'linear':
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
else:
raise ValueError('invalid scheduler : {:}'.format(config.scheduler))
if config.criterion == 'Softmax':
criterion = torch.nn.CrossEntropyLoss()
elif config.criterion == 'SmoothSoftmax':
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
else:
raise ValueError('invalid criterion : {:}'.format(config.criterion))
return optim, scheduler, criterion

View File

@@ -0,0 +1,64 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, torch, random, PIL, copy, numpy as np
from os import path as osp
from shutil import copyfile
def prepare_seed(rand_seed):
random.seed(rand_seed)
np.random.seed(rand_seed)
torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed)
torch.cuda.manual_seed_all(rand_seed)
def prepare_logger(xargs):
args = copy.deepcopy( xargs )
from log_utils import Logger
logger = Logger(args.save_dir, args.rand_seed)
logger.log('Main Function with logger : {:}'.format(logger))
logger.log('Arguments : -------------------------------')
for name, value in args._get_kwargs():
logger.log('{:16} : {:}'.format(name, value))
logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
logger.log("Pillow Version : {:}".format(PIL.__version__))
logger.log("PyTorch Version : {:}".format(torch.__version__))
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
return logger
def get_machine_info():
info = "Python Version : {:}".format(sys.version.replace('\n', ' '))
info+= "\nPillow Version : {:}".format(PIL.__version__)
info+= "\nPyTorch Version : {:}".format(torch.__version__)
info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
info+= "\nCUDA available : {:}".format(torch.cuda.is_available())
info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
if 'CUDA_VISIBLE_DEVICES' in os.environ:
info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES'])
else:
info+= "\nDoes not set CUDA_VISIBLE_DEVICES"
return info
def save_checkpoint(state, filename, logger):
if osp.isfile(filename):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename))
os.remove(filename)
torch.save(state, filename)
assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename)
if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename))
return filename
def copy_checkpoint(src, dst, logger):
if osp.isfile(dst):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst))
os.remove(dst)
copyfile(src, dst)
if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst))