first commit

This commit is contained in:
CownowAn
2024-03-15 14:38:51 +00:00
commit bc2ed1304f
321 changed files with 44802 additions and 0 deletions

View File

@@ -0,0 +1,168 @@
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets
This code is for MobileNetV3 Search Space experiments
## Prerequisites
- Python 3.6 (Anaconda)
- PyTorch 1.6.0
- CUDA 10.2
- python-igraph==0.8.2
- tqdm==4.50.2
- torchvision==0.7.0
- python-igraph==0.8.2
- scipy==1.5.2
- ofa==0.0.4-2007200808
## MobileNetV3 Search Space
Go to the folder for MobileNetV3 experiments (i.e. ```MetaD2A_mobilenetV3```)
The overall flow is summarized as follows:
- Building database for Predictor
- Meta-Training Predictor
- Building database for Generator with trained Predictor
- Meta-Training Generator
- Meta-Testing (Searching)
- Evaluating the Searched architecture
## Data Preparation
To download preprocessed data files, run ```get_files/get_preprocessed_data.py```:
```shell script
$ python get_files/get_preprocessed_data.py
```
It will take some time to download and preprocess each dataset.
## Meta Test and Evaluation
### Meta-Test
You can download trained checkpoint files for generator and predictor
```shell script
$ python get_files/get_generator_checkpoint.py
$ python get_files/get_predictor_checkpoint.py
```
If you want to meta-test with your own dataset, please first make your own preprocessed data,
by modifying ```process_dataset.py``` .
```shell script
$ process_dataset.py
```
This code automatically generates neural architecturess and then
selects high-performing architectures among the candidates.
By setting ```--data-name``` as the name of dataset (i.e. ```cifar10```, ```cifar100```, ```aircraft100```, ```pets```),
you can evaluate the specific dataset.
```shell script
# Meta-testing
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 120 --num-gen-arch 200 --data-name {DATASET_NAME}
```
### Arhictecture Evaluation (MetaD2A vs NSGANetV2)
##### Dataset Preparation
You need to download Oxford-IIIT Pet dataset to evaluate on ```--data-name pets```
```shell script
$ python get_files/get_pets.py
```
Every others ```cifar10```, ```cifar100```, ```aircraft100``` will be downloaded automatically.
##### evaluation
You can run the searched architecture by running ```evaluation/main```. Codes are based on NSGANetV2.
Go to the evaluation folder (i.e. ```evaluation```)
```shell script
$ cd evaluation
```
This automatically run the top 1 predicted architecture derived by MetaD2A.
```shell script
python main.py --data-name cifar10 --num-gen-arch 200
```
You can also give flop constraint by using ```bound``` option.
```shell script
python main.py --data-name cifar10 --num-gen-arch 200 --bound 300
```
You can compare MetaD2A with NSGANetV2
but you need to download some files provided
by [NSGANetV2](https://github.com/human-analysis/nsganetv2)
```shell script
python main.py --data-name cifar10 --num-gen-arch 200 --model-config flops@232
```
## Meta-Training MetaD2A Model
To build database for Meta-training, you need to set ```IMGNET_PATH```, which is a directory of ILSVRC2021.
### Database Building for Predictor
We recommend you to run the multiple ```create_database.sh``` simultaneously to build fast.
You need to set ```IMGNET_PATH``` in the shell script.
```shell script
# Examples
bash create_database.sh 0,1,2,3 0 49 predictor
bash create_database.sh all 50 99 predictor
...
```
After enough dataset is gathered, run ```build_database.py``` to collect them as one file.
```shell script
python build_database.py --model_name predictor --collect
```
We also provide the database we use. To download database, run ```get_files/get_predictor_database.py```:
```shell script
$ python get_files/get_predictor_database.py
```
### Meta-Train Predictor
You can train the predictor as follows
```shell script
# Meta-training for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56
```
### Database Building for Generator
We recommend you to run the multiple ```create_database.sh``` simultaneously to build fast.
```shell script
# Examples
bash create_database.sh 4,5,6,7 0 49 generator
bash create_database.sh all 50 99 generator
...
```
After enough dataset is gathered, run ```build_database.py``` to collect them as one.
```shell script
python build_database.py --model_name generator --collect
```
We also provide the database we use. To download database, run ```get_files/get_generator_database.py```
```shell script
$ python get_files/get_generator_database.py
```
### Meta-Train Generator
You can train the generator as follows
```shell script
# Meta-training for generator
$ python main.py --gpu 0 --model generator --hs 56 --nz 56
```
## Citation
If you found the provided code useful, please cite our work.
```
@inproceedings{
lee2021rapid,
title={Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets},
author={Hayeon Lee and Eunyoung Hyung and Sung Ju Hwang},
booktitle={ICLR},
year={2021}
}
```
## Reference
- [Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks (ICML2019)](https://github.com/juho-lee/set_transformer)
- [D-VAE: A Variational Autoencoder for Directed Acyclic Graphs, Advances in Neural Information Processing Systems (NeurIPS2019)](https://github.com/muhanzhang/D-VAE)
- [Once for All: Train One Network and Specialize it for Efficient Deployment (ICLR2020)](https://github.com/mit-han-lab/once-for-all)
- [NSGANetV2: Evolutionary Multi-Objective Surrogate-Assisted Neural Architecture Search (ECCV2020)](https://github.com/human-analysis/nsganetv2)

View File

@@ -0,0 +1,49 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
import random
import numpy as np
import torch
from parser import get_parser
from predictor import PredictorModel
from database import DatabaseOFA
from utils import load_graph_config
def main():
args = get_parser()
if args.gpu == 'all':
device_list = range(torch.cuda.device_count())
args.gpu = ','.join(str(_) for _ in device_list)
else:
device_list = [int(_) for _ in args.gpu.split(',')]
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.device = torch.device("cuda:0")
args.batch_size = args.batch_size * max(len(device_list), 1)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
if args.model_name == 'generator':
graph_config = load_graph_config(
args.graph_data_name, args.nvt, args.data_path)
model = PredictorModel(args, graph_config)
d = DatabaseOFA(args, model)
else:
d = DatabaseOFA(args)
if args.collect:
d.collect_db()
else:
assert args.index is not None
assert args.imgnet is not None
d.make_db()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,15 @@
#bash create_database.sh all predictor 0 49
IMGNET_PATH='/w14/dataset/ILSVRC2012' # PUT YOUR ILSVRC2012 DIR
for ((ind=$2;ind<=$3;ind++))
do
python build_database.py --gpu $1 \
--model_name $4 \
--index $ind \
--imgnet $IMGNET_PATH \
--hs 512 \
--nz 56
done

View File

@@ -0,0 +1,5 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from .db_ofa import DatabaseOFA

View File

@@ -0,0 +1,57 @@
######################################################################################
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import numpy as np
import torch
__all__ = ['DataProvider']
class DataProvider:
SUB_SEED = 937162211 # random seed for sampling subset
VALID_SEED = 2147483647 # random seed for the validation set
@staticmethod
def name():
""" Return name of the dataset """
raise NotImplementedError
@property
def data_shape(self):
""" Return shape as python list of one data entry """
raise NotImplementedError
@property
def n_classes(self):
""" Return `int` of num classes """
raise NotImplementedError
@property
def save_path(self):
""" local path to save the data """
raise NotImplementedError
@property
def data_url(self):
""" link to download the data """
raise NotImplementedError
@staticmethod
def random_sample_valid_set(train_size, valid_size):
assert train_size > valid_size
g = torch.Generator()
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
rand_indexes = torch.randperm(train_size, generator=g).tolist()
valid_indexes = rand_indexes[:valid_size]
train_indexes = rand_indexes[valid_size:]
return train_indexes, valid_indexes
@staticmethod
def labels_to_one_hot(n_classes, labels):
new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32)
new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape)
return new_labels

View File

@@ -0,0 +1,107 @@
import os
import torch
import time
import copy
import glob
from .imagenet import ImagenetDataProvider
from .imagenet_loader import ImagenetRunConfig
from .run_manager import RunManager
from ofa.model_zoo import ofa_net
class DatabaseOFA:
def __init__(self, args, predictor=None):
self.path = f'{args.data_path}/{args.model_name}'
self.model_name = args.model_name
self.index = args.index
self.args = args
self.predictor = predictor
ImagenetDataProvider.DEFAULT_PATH = args.imgnet
if not os.path.exists(self.path):
os.makedirs(self.path)
def make_db(self):
self.ofa_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.0', pretrained=True)
self.run_config = ImagenetRunConfig(test_batch_size=self.args.batch_size,
n_worker=20)
database = []
st_time = time.time()
f = open(f'{self.path}/txt_{self.index}.txt', 'w')
for dn in range(10000):
best_pp = -1
best_info = None
dls = None
with torch.no_grad():
if self.model_name == 'generator':
for i in range(10):
net_setting = self.ofa_network.sample_active_subnet()
subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
if i == 0:
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet,
self.run_config, init=False, pp=self.predictor)
self.run_config.data_provider.assign_active_img_size(224)
dls = {j: copy.deepcopy(run_manager.data_loader) for j in range(1, 10)}
else:
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet,
self.run_config,
init=False, data_loader=dls[i], pp=self.predictor)
run_manager.reset_running_statistics(net=subnet)
loss, (top1, top5), pred_acc \
= run_manager.validate(net=subnet, net_setting=net_setting)
if best_pp < pred_acc:
best_pp = pred_acc
print('[%d] class=%d,\t loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (
dn, len(run_manager.cls_lst), loss, top1, top5))
info_dict = {'loss': loss,
'top1': top1,
'top5': top5,
'net': net_setting,
'class': run_manager.cls_lst,
'params': run_manager.net_info['params'],
'flops': run_manager.net_info['flops'],
'test_transform': run_manager.test_transform
}
best_info = info_dict
elif self.model_name == 'predictor':
net_setting = self.ofa_network.sample_active_subnet()
subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet, self.run_config, init=False)
self.run_config.data_provider.assign_active_img_size(224)
run_manager.reset_running_statistics(net=subnet)
loss, (top1, top5), _ = run_manager.validate(net=subnet)
print('[%d] class=%d,\t loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (
dn, len(run_manager.cls_lst), loss, top1, top5))
best_info = {'loss': loss,
'top1': top1,
'top5': top5,
'net': net_setting,
'class': run_manager.cls_lst,
'params': run_manager.net_info['params'],
'flops': run_manager.net_info['flops'],
'test_transform': run_manager.test_transform
}
database.append(best_info)
if (len(database)) % 10 == 0:
msg = f'{(time.time() - st_time) / 60.0:0.2f}(min) save {len(database)} database, {self.index} id'
print(msg)
f.write(msg + '\n')
f.flush()
torch.save(database, f'{self.path}/database_{self.index}.pt')
def collect_db(self):
if not os.path.exists(self.path + f'/processed'):
os.makedirs(self.path + f'/processed')
database = []
dlst = glob.glob(self.path + '/*.pt')
for filepath in dlst:
database += torch.load(filepath)
assert len(database) != 0
print(f'The number of database: {len(database)}')
torch.save(database, self.path + f'/processed/collected_database.pt')

View File

@@ -0,0 +1,240 @@
######################################################################################
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import warnings
import os
import torch
import math
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa_local.imagenet_classification.data_providers.base_provider import DataProvider
from ofa_local.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
from .metaloader import MetaImageNetDataset, EpisodeSampler, MetaDataLoader
__all__ = ['ImagenetDataProvider']
class ImagenetDataProvider(DataProvider):
DEFAULT_PATH = '/dataset/imagenet'
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = 'None' if distort_color is None else distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
from ofa.utils.my_dataloader import MyDataLoader
assert isinstance(self.image_size, list)
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size) # active resolution for test
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
########################## modification ########################
train_dataset = self.train_dataset(self.build_train_transform())
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, True, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, True, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
self.valid = None
# test_dataset = self.test_dataset(valid_transforms)
test_dataset = self.meta_test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
# self.test = torch.utils.data.DataLoader(
# test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
# )
sampler = EpisodeSampler(
max_way=1000, query=10, ylst=test_dataset.ylst)
self.test = MetaDataLoader(dataset=test_dataset,
sampler=sampler,
batch_size=test_batch_size,
shuffle=False,
num_workers=4)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'imagenet'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 1000
@property
def save_path(self):
if self._save_path is None:
self._save_path = self.DEFAULT_PATH
if not os.path.exists(self._save_path):
self._save_path = os.path.expanduser('~/dataset/imagenet')
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
return datasets.ImageFolder(self.train_path, _transforms)
def test_dataset(self, _transforms):
return datasets.ImageFolder(self.valid_path, _transforms)
def meta_test_dataset(self, _transforms):
return MetaImageNetDataset('val', max_way=1000, query=10,
dpath='/w14/dataset/ILSVRC2012', transform=_transforms)
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'val')
@property
def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
# random_resize_crop -> random_horizontal_flip
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
# color augmentation (optional)
color_transform = None
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting BN running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, True, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,40 @@
from .imagenet import ImagenetDataProvider
from ofa_local.imagenet_classification.run_manager import RunConfig
__all__ = ['ImagenetRunConfig']
class ImagenetRunConfig(RunConfig):
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='imagenet', train_batch_size=256, test_batch_size=500, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, **kwargs):
super(ImagenetRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == ImagenetDataProvider.name():
DataProviderClass = ImagenetDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']

View File

@@ -0,0 +1,210 @@
from torch.utils.data.sampler import Sampler
import os
import random
from PIL import Image
from collections import defaultdict
import torch
from torch.utils.data import Dataset, DataLoader
import glob
class RandCycleIter:
'''
Return data_list per class
Shuffle the returning order after one epoch
'''
def __init__ (self, data, shuffle=True):
self.data_list = list(data)
self.length = len(self.data_list)
self.i = self.length - 1
self.shuffle = shuffle
def __iter__ (self):
return self
def __next__ (self):
self.i += 1
if self.i == self.length:
self.i = 0
if self.shuffle:
random.shuffle(self.data_list)
return self.data_list[self.i]
class EpisodeSampler(Sampler):
def __init__(self, max_way, query, ylst):
self.max_way = max_way
self.query = query
self.ylst = ylst
# self.n_epi = n_epi
clswise_xidx = defaultdict(list)
for i, y in enumerate(ylst):
clswise_xidx[y].append(i)
self.cws_xidx_iter = [RandCycleIter(cxidx, shuffle=True)
for cxidx in clswise_xidx.values()]
self.n_cls = len(clswise_xidx)
self.create_episode()
def __iter__ (self):
return self.get_index()
def __len__ (self):
return self.get_len()
def create_episode(self):
self.way = torch.randperm(int(self.max_way/10.0)-1)[0] * 10 + 10
cls_lst = torch.sort(torch.randperm(self.max_way)[:self.way])[0]
self.cls_itr = iter(cls_lst)
self.cls_lst = cls_lst
def get_len(self):
return self.way * self.query
def get_index(self):
x_itr = self.cws_xidx_iter
i, j = 0, 0
while i < self.query * self.way:
if j >= self.query:
j = 0
if j == 0:
cls_idx = next(self.cls_itr).item()
bb = [x_itr[cls_idx]] * self.query
didx = next(zip(*bb))
yield didx[j]
# yield (didx[j], self.way)
i += 1; j += 1
class MetaImageNetDataset(Dataset):
def __init__(self, mode='val',
max_way=1000, query=10,
dpath='/w14/dataset/ILSVRC2012', transform=None):
self.dpath = dpath
self.transform = transform
self.mode = mode
self.max_way = max_way
self.query = query
classes, class_to_idx = self._find_classes(dpath+'/'+mode)
self.classes, self.class_to_idx = classes, class_to_idx
# self.class_folder_lst = \
# glob.glob(dpath+'/'+mode+'/*')
# ## sorting alphabetically
# self.class_folder_lst = sorted(self.class_folder_lst)
self.file_path_lst, self.ylst = [], []
for cls in classes:
xlst = glob.glob(dpath+'/'+mode+'/'+cls+'/*')
self.file_path_lst += xlst[:self.query]
y = class_to_idx[cls]
self.ylst += [y] * len(xlst[:self.query])
# for y, cls in enumerate(self.class_folder_lst):
# xlst = glob.glob(cls+'/*')
# self.file_path_lst += xlst[:self.query]
# self.ylst += [y] * len(xlst[:self.query])
# # self.file_path_lst += [xlst[_] for _ in
# # torch.randperm(len(xlst))[:self.query]]
# # self.ylst += [cls.split('/')[-1]] * len(xlst)
self.way_idx = 0
self.x_idx = 0
self.way = 2
self.cls_lst = None
def __len__(self):
return self.way * self.query
def __getitem__(self, index):
# if self.way != index[1]:
# self.way = index[1]
# index = index[0]
x = Image.open(
self.file_path_lst[index]).convert('RGB')
if self.transform is not None:
x = self.transform(x)
cls_name = self.ylst[index]
y = self.cls_lst.index(cls_name)
# y = self.way_idx
# self.x_idx += 1
# if self.x_idx == self.query:
# self.way_idx += 1
# self.x_idx = 0
# if self.way_idx == self.way:
# self.way_idx = 0
# self.x_idx = 0
return x, y #, cls_name # y # cls_name #y
def _find_classes(self, dir: str):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
class MetaDataLoader(DataLoader):
def __init__(self,
dataset, sampler, batch_size, shuffle, num_workers):
super(MetaDataLoader, self).__init__(
dataset=dataset,
sampler=sampler,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers)
def create_episode(self):
self.sampler.create_episode()
self.dataset.way = self.sampler.way
self.dataset.cls_lst = self.sampler.cls_lst.tolist()
def get_cls_idx(self):
return self.sampler.cls_lst
def get_loader(mode='val', way=10, query=10,
n_epi=100, dpath='/w14/dataset/ILSVRC2012',
transform=None):
trans = get_transforms(mode)
dataset = MetaImageNetDataset(mode, way, query, dpath, trans)
sampler = EpisodeSampler(
way, query, n_epi, dataset.ylst)
dataset.way = sampler.way
dataset.cls_lst = sampler.cls_lst
loader = MetaDataLoader(dataset=dataset,
sampler=sampler,
batch_size=10,
shuffle=False,
num_workers=4)
return loader
# trloader = get_loader()
# trloader.create_episode()
# print(len(trloader))
# print(trloader.dataset.way)
# print(trloader.sampler.way)
# for i, episode in enumerate(trloader, start=1):
# print(episode[2])

View File

@@ -0,0 +1,302 @@
######################################################################################
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import os
import json
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from tqdm import tqdm
from utils import decode_ofa_mbv3_to_igraph
from ofa_local.utils import get_net_info, cross_entropy_loss_with_soft_target, cross_entropy_with_label_smoothing
from ofa_local.utils import AverageMeter, accuracy, write_log, mix_images, mix_labels, init_models
__all__ = ['RunManager']
import torchvision.models as models
class RunManager:
def __init__(self, path, args, net, run_config, init=True, measure_latency=None,
no_gpu=False, data_loader=None, pp=None):
self.path = path
self.mode = args.model_name
self.net = net
self.run_config = run_config
self.best_acc = 0
self.start_epoch = 0
os.makedirs(self.path, exist_ok=True)
# dataloader
if data_loader is not None:
self.data_loader = data_loader
cls_lst = self.data_loader.get_cls_idx()
self.cls_lst = cls_lst
else:
self.data_loader = self.run_config.valid_loader
self.data_loader.create_episode()
cls_lst = self.data_loader.get_cls_idx()
self.cls_lst = cls_lst
state_dict = self.net.classifier.state_dict()
new_state_dict = {'weight': state_dict['linear.weight'][cls_lst],
'bias': state_dict['linear.bias'][cls_lst]}
self.net.classifier = nn.Linear(1280, len(cls_lst), bias=True)
self.net.classifier.load_state_dict(new_state_dict)
# move network to GPU if available
if torch.cuda.is_available() and (not no_gpu):
self.device = torch.device('cuda:0')
self.net = self.net.to(self.device)
cudnn.benchmark = True
else:
self.device = torch.device('cpu')
# net info
net_info = get_net_info(
self.net, self.run_config.data_provider.data_shape, measure_latency, False)
self.net_info = net_info
self.test_transform = self.run_config.data_provider.test.dataset.transform
# criterion
if isinstance(self.run_config.mixup_alpha, float):
self.train_criterion = cross_entropy_loss_with_soft_target
elif self.run_config.label_smoothing > 0:
self.train_criterion = \
lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
else:
self.train_criterion = nn.CrossEntropyLoss()
self.test_criterion = nn.CrossEntropyLoss()
# optimizer
if self.run_config.no_decay_keys:
keys = self.run_config.no_decay_keys.split('#')
net_params = [
self.network.get_parameters(keys, mode='exclude'), # parameters with weight decay
self.network.get_parameters(keys, mode='include'), # parameters without weight decay
]
else:
# noinspection PyBroadException
try:
net_params = self.network.weight_parameters()
except Exception:
net_params = []
for param in self.network.parameters():
if param.requires_grad:
net_params.append(param)
self.optimizer = self.run_config.build_optimizer(net_params)
self.net = torch.nn.DataParallel(self.net)
if self.mode == 'generator':
# PP
save_dir = f'{args.save_path}/predictor/model/ckpt_max_corr.pt'
self.acc_predictor = pp.to('cuda')
self.acc_predictor.load_state_dict(torch.load(save_dir))
self.acc_predictor = torch.nn.DataParallel(self.acc_predictor)
model = models.resnet18(pretrained=True).eval()
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1]).to(self.device)
self.feature_extractor = torch.nn.DataParallel(feature_extractor)
""" save path and log path """
@property
def save_path(self):
if self.__dict__.get('_save_path', None) is None:
save_path = os.path.join(self.path, 'checkpoint')
os.makedirs(save_path, exist_ok=True)
self.__dict__['_save_path'] = save_path
return self.__dict__['_save_path']
@property
def logs_path(self):
if self.__dict__.get('_logs_path', None) is None:
logs_path = os.path.join(self.path, 'logs')
os.makedirs(logs_path, exist_ok=True)
self.__dict__['_logs_path'] = logs_path
return self.__dict__['_logs_path']
@property
def network(self):
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
write_log(self.logs_path, log_str, prefix, should_print, mode)
""" save and load models """
def save_model(self, checkpoint=None, is_best=False, model_name=None):
if checkpoint is None:
checkpoint = {'state_dict': self.network.state_dict()}
if model_name is None:
model_name = 'checkpoint.pth.tar'
checkpoint['dataset'] = self.run_config.dataset # add `dataset` info to the checkpoint
latest_fname = os.path.join(self.save_path, 'latest.txt')
model_path = os.path.join(self.save_path, model_name)
with open(latest_fname, 'w') as fout:
fout.write(model_path + '\n')
torch.save(checkpoint, model_path)
if is_best:
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
def load_model(self, model_fname=None):
latest_fname = os.path.join(self.save_path, 'latest.txt')
if model_fname is None and os.path.exists(latest_fname):
with open(latest_fname, 'r') as fin:
model_fname = fin.readline()
if model_fname[-1] == '\n':
model_fname = model_fname[:-1]
# noinspection PyBroadException
try:
if model_fname is None or not os.path.exists(model_fname):
model_fname = '%s/checkpoint.pth.tar' % self.save_path
with open(latest_fname, 'w') as fout:
fout.write(model_fname + '\n')
print("=> loading checkpoint '{}'".format(model_fname))
checkpoint = torch.load(model_fname, map_location='cpu')
except Exception:
print('fail to load checkpoint from %s' % self.save_path)
return {}
self.network.load_state_dict(checkpoint['state_dict'])
if 'epoch' in checkpoint:
self.start_epoch = checkpoint['epoch'] + 1
if 'best_acc' in checkpoint:
self.best_acc = checkpoint['best_acc']
if 'optimizer' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}'".format(model_fname))
return checkpoint
def save_config(self, extra_run_config=None, extra_net_config=None):
""" dump run_config and net_config to the model_folder """
run_save_path = os.path.join(self.path, 'run.config')
if not os.path.isfile(run_save_path):
run_config = self.run_config.config
if extra_run_config is not None:
run_config.update(extra_run_config)
json.dump(run_config, open(run_save_path, 'w'), indent=4)
print('Run configs dump to %s' % run_save_path)
try:
net_save_path = os.path.join(self.path, 'net.config')
net_config = self.network.config
if extra_net_config is not None:
net_config.update(extra_net_config)
json.dump(net_config, open(net_save_path, 'w'), indent=4)
print('Network configs dump to %s' % net_save_path)
except Exception:
print('%s do not support net config' % type(self.network))
""" metric related """
def get_metric_dict(self):
return {
'top1': AverageMeter(),
'top5': AverageMeter(),
}
def update_metric(self, metric_dict, output, labels):
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
metric_dict['top1'].update(acc1[0].item(), output.size(0))
metric_dict['top5'].update(acc5[0].item(), output.size(0))
def get_metric_vals(self, metric_dict, return_dict=False):
if return_dict:
return {
key: metric_dict[key].avg for key in metric_dict
}
else:
return [metric_dict[key].avg for key in metric_dict]
def get_metric_names(self):
return 'top1', 'top5'
""" train and test """
def validate(self, epoch=0, is_test=False, run_str='', net=None,
data_loader=None, no_logs=False, train_mode=False, net_setting=None):
if net is None:
net = self.net
if not isinstance(net, nn.DataParallel):
net = nn.DataParallel(net)
if data_loader is not None:
self.data_loader = data_loader
if train_mode:
net.train()
else:
net.eval()
losses = AverageMeter()
metric_dict = self.get_metric_dict()
features_stack = []
with torch.no_grad():
with tqdm(total=len(self.data_loader),
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str), disable=no_logs) as t:
for i, (images, labels) in enumerate(self.data_loader):
images, labels = images.to(self.device), labels.to(self.device)
if self.mode == 'generator':
features = self.feature_extractor(images).squeeze()
features_stack.append(features)
# compute output
output = net(images)
loss = self.test_criterion(output, labels)
# measure accuracy and record loss
self.update_metric(metric_dict, output, labels)
losses.update(loss.item(), images.size(0))
t.set_postfix({
'loss': losses.avg,
**self.get_metric_vals(metric_dict, return_dict=True),
'img_size': images.size(2),
})
t.update(1)
if self.mode == 'generator':
features_stack = torch.cat(features_stack)
igraph_g = decode_ofa_mbv3_to_igraph(net_setting)[0]
D_mu = self.acc_predictor.module.set_encode(features_stack.unsqueeze(0).to('cuda'))
G_mu = self.acc_predictor.module.graph_encode(igraph_g)
pred_acc = self.acc_predictor.module.predict(D_mu.unsqueeze(0), G_mu).item()
return losses.avg, self.get_metric_vals(metric_dict), \
pred_acc if self.mode == 'generator' else None
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
if net is None:
net = self.network
if isinstance(self.run_config.data_provider.image_size, list):
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
for img_size in self.run_config.data_provider.image_size:
img_size_list.append(img_size)
self.run_config.data_provider.assign_active_img_size(img_size)
self.reset_running_statistics(net=net)
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
loss_list.append(loss)
top1_list.append(top1)
top5_list.append(top5)
return img_size_list, loss_list, top1_list, top5_list
else:
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
from ofa_local.imagenet_classification.elastic_nn.utils import set_running_statistics
if net is None:
net = self.network
if data_loader is None:
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
set_running_statistics(net, data_loader)

View File

@@ -0,0 +1,4 @@
######################################################################################
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################

View File

@@ -0,0 +1,401 @@
from __future__ import print_function
import os
import math
import warnings
import numpy as np
# from timm.data.transforms import _pil_interp
from timm.data.auto_augment import rand_augment_transform
import torch.utils.data
import torchvision.transforms as transforms
from torchvision.datasets.folder import default_loader
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
def make_dataset(dir, image_ids, targets):
assert(len(image_ids) == len(targets))
images = []
dir = os.path.expanduser(dir)
for i in range(len(image_ids)):
item = (os.path.join(dir, 'data', 'images',
'%s.jpg' % image_ids[i]), targets[i])
images.append(item)
return images
def find_classes(classes_file):
# read classes file, separating out image IDs and class names
image_ids = []
targets = []
f = open(classes_file, 'r')
for line in f:
split_line = line.split(' ')
image_ids.append(split_line[0])
targets.append(' '.join(split_line[1:]))
f.close()
# index class names
classes = np.unique(targets)
class_to_idx = {classes[i]: i for i in range(len(classes))}
targets = [class_to_idx[c] for c in targets]
return (image_ids, targets, classes, class_to_idx)
class FGVCAircraft(torch.utils.data.Dataset):
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
Args:
root (string): Root directory path to dataset.
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g. ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in the root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
class_types = ('variant', 'family', 'manufacturer')
splits = ('train', 'val', 'trainval', 'test')
def __init__(self, root, class_type='variant', split='train', transform=None,
target_transform=None, loader=default_loader, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
if class_type not in self.class_types:
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
class_type, ', '.join(self.class_types),
))
self.root = os.path.expanduser(root)
self.class_type = class_type
self.split = split
self.classes_file = os.path.join(self.root, 'data',
'images_%s_%s.txt' % (self.class_type, self.split))
if download:
self.download()
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
samples = make_dataset(self.root, image_ids, targets)
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.samples = samples
self.classes = classes
self.class_to_idx = class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def _check_exists(self):
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
os.path.exists(self.classes_file)
def download(self):
"""Download the FGVC-Aircraft data if it doesn't exist already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
print('Downloading %s ... (may take a few minutes)' % self.url)
parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
tar_name = self.url.rpartition('/')[-1]
tar_path = os.path.join(parent_dir, tar_name)
data = urllib.request.urlopen(self.url)
# download .tar.gz file
with open(tar_path, 'wb') as f:
f.write(data.read())
# extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
data_folder = tar_path.strip('.tar.gz')
print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
tar = tarfile.open(tar_path)
tar.extractall(parent_dir)
# if necessary, rename data folder to self.root
if not os.path.samefile(data_folder, self.root):
print('Renaming %s to %s ...' % (data_folder, self.root))
os.rename(data_folder, self.root)
# delete .tar.gz file
print('Deleting %s ...' % tar_path)
os.remove(tar_path)
print('Done!')
class FGVCAircraftDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'aircraft'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 100
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/Aircraft' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/Aircraft' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = FGVCAircraft(
root=self.train_path, split='trainval', download=True, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = FGVCAircraft(
root=self.valid_path, split='test', download=True, transform=_transforms)
return dataset
@property
def train_path(self):
return self.save_path
@property
def valid_path(self):
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.48933587508932375, 0.5183537408957618, 0.5387914411673883],
std=[0.22388883112804625, 0.21641635409388751, 0.24615605842636115])
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
if image_size is None:
image_size = self.image_size
# if print_log:
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
# (self.distort_color, self.resize_scale, image_size))
# if self.distort_color == 'torch':
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
# elif self.distort_color == 'tf':
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
# else:
# color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
img_size_min = min(image_size)
else:
resize_transform_class = transforms.RandomResizedCrop
img_size_min = image_size
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in [0.48933587508932375, 0.5183537408957618,
0.5387914411673883]]),
)
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
# if color_transform is not None:
# train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]
if __name__ == '__main__':
data = FGVCAircraft(root='/mnt/datastore/Aircraft',
split='trainval', download=True)
print(len(data.classes))
print(len(data.samples))

View File

@@ -0,0 +1,238 @@
"""
Taken from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
"""
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
class ImageNetPolicy(object):
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
Example:
>>> policy = ImageNetPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment ImageNet Policy"
class CIFAR10Policy(object):
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
Example:
>>> policy = CIFAR10Policy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> CIFAR10Policy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment CIFAR10 Policy"
class SVHNPolicy(object):
""" Randomly choose one of the best 25 Sub-policies on SVHN.
Example:
>>> policy = SVHNPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> SVHNPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment SVHN Policy"
class SubPolicy(object):
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10
}
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
func = {
"shearX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
return img

View File

@@ -0,0 +1,657 @@
import os
import math
import numpy as np
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class CIFAR10DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'cifar10'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 10
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/CIFAR' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/CIFAR' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = torchvision.datasets.CIFAR10(
root=self.valid_path, train=True, download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = torchvision.datasets.CIFAR10(
root=self.valid_path, train=False, download=False, transform=_transforms)
return dataset
@property
def train_path(self):
# return os.path.join(self.save_path, 'train')
return self.save_path
@property
def valid_path(self):
# return os.path.join(self.save_path, 'val')
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.data)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]
class CIFAR100DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'cifar100'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 100
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/CIFAR' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/CIFAR' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = torchvision.datasets.CIFAR100(
root=self.valid_path, train=True, download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = torchvision.datasets.CIFAR100(
root=self.valid_path, train=False, download=False, transform=_transforms)
return dataset
@property
def train_path(self):
# return os.path.join(self.save_path, 'train')
return self.save_path
@property
def valid_path(self):
# return os.path.join(self.save_path, 'val')
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.data)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]
class CINIC10DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'cinic10'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 10
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/CINIC10' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/CINIC10' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = torchvision.datasets.ImageFolder(self.train_path, transform=_transforms)
# dataset = torchvision.datasets.CIFAR10(
# root=self.valid_path, train=True, download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
dataset = torchvision.datasets.ImageFolder(self.valid_path, transform=_transforms)
# dataset = torchvision.datasets.CIFAR10(
# root=self.valid_path, train=False, download=False, transform=_transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train_and_valid')
# return self.save_path
@property
def valid_path(self):
return os.path.join(self.save_path, 'test')
# return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.47889522, 0.47227842, 0.43047404], std=[0.24205776, 0.23828046, 0.25874835])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,237 @@
import os
import warnings
import numpy as np
from timm.data.transforms import _pil_interp
from timm.data.auto_augment import rand_augment_transform
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class DTDDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'dtd'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 47
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/dtd' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/dtd' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.valid_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'valid')
@property
def normalize(self):
return transforms.Normalize(
mean=[0.5329876098715876, 0.474260843249454, 0.42627281899380676],
std=[0.26549755708788914, 0.25473554309855373, 0.2631728035662832])
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
if image_size is None:
image_size = self.image_size
# if print_log:
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
# (self.distort_color, self.resize_scale, image_size))
# if self.distort_color == 'torch':
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
# elif self.distort_color == 'tf':
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
# else:
# color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
img_size_min = min(image_size)
else:
resize_transform_class = transforms.RandomResizedCrop
img_size_min = image_size
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in [0.5329876098715876, 0.474260843249454,
0.42627281899380676]]),
)
aa_params['interpolation'] = _pil_interp('bicubic')
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
# if color_transform is not None:
# train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
# transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.Resize((image_size, image_size), interpolation=3),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,241 @@
import warnings
import os
import math
import numpy as np
import PIL
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class Flowers102DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=512, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
# warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
weights = self.make_weights_for_balanced_classes(
train_dataset.imgs, self.n_classes)
weights = torch.DoubleTensor(weights)
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
if valid_size is not None:
raise NotImplementedError("validation dataset not yet implemented")
# valid_dataset = self.valid_dataset(valid_transforms)
# self.train = train_loader_class(
# train_dataset, batch_size=train_batch_size, sampler=train_sampler,
# num_workers=n_worker, pin_memory=True)
# self.valid = torch.utils.data.DataLoader(
# valid_dataset, batch_size=test_batch_size,
# num_workers=n_worker, pin_memory=True)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'flowers102'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 102
@property
def save_path(self):
if self._save_path is None:
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
self._save_path = '/mnt/datastore/Flowers102' # home server
if not os.path.exists(self._save_path):
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
self._save_path = '/mnt/datastore/Flowers102' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
# def valid_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
# return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.test_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
# @property
# def valid_path(self):
# return os.path.join(self.save_path, 'train')
@property
def test_path(self):
return os.path.join(self.save_path, 'test')
@property
def normalize(self):
return transforms.Normalize(
mean=[0.5178361839861569, 0.4106749456881299, 0.32864167836880803],
std=[0.2972239085211309, 0.24976049135203868, 0.28533308036347665])
@staticmethod
def make_weights_for_balanced_classes(images, nclasses):
count = [0] * nclasses
# Counts per label
for item in images:
count[item[1]] += 1
weight_per_class = [0.] * nclasses
# Total number of images.
N = float(sum(count))
# super-sample the smaller classes.
for i in range(nclasses):
weight_per_class[i] = N / float(count[i])
weight = [0] * len(images)
# Calculate a weight per image.
for idx, val in enumerate(images):
weight[idx] = weight_per_class[val[1]]
return weight
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
transforms.RandomAffine(
45, translate=(0.4, 0.4), scale=(0.75, 1.5), shear=None, resample=PIL.Image.BILINEAR, fillcolor=0),
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
# transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,225 @@
import warnings
import os
import math
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class ImagenetDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'imagenet'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 1000
@property
def save_path(self):
if self._save_path is None:
# self._save_path = '/dataset/imagenet'
# self._save_path = '/usr/local/soft/temp-datastore/ILSVRC2012' # servers
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
if not os.path.exists(self._save_path):
# self._save_path = os.path.expanduser('~/dataset/imagenet')
# self._save_path = os.path.expanduser('/usr/local/soft/temp-datastore/ILSVRC2012')
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.valid_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'val')
@property
def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,237 @@
import os
import math
import warnings
import numpy as np
# from timm.data.transforms import _pil_interp
from timm.data.auto_augment import rand_augment_transform
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class OxfordIIITPetsDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'pets'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 37
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.valid_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'valid')
@property
def normalize(self):
return transforms.Normalize(
mean=[0.4828895122298728, 0.4448394893850807, 0.39566558230789783],
std=[0.25925664613996574, 0.2532760018681693, 0.25981017205097917])
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
if image_size is None:
image_size = self.image_size
# if print_log:
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
# (self.distort_color, self.resize_scale, image_size))
# if self.distort_color == 'torch':
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
# elif self.distort_color == 'tf':
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
# else:
# color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
img_size_min = min(image_size)
else:
resize_transform_class = transforms.RandomResizedCrop
img_size_min = image_size
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in [0.4828895122298728, 0.4448394893850807,
0.39566558230789783]]),
)
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
# if color_transform is not None:
# train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,69 @@
import torch
from glob import glob
from torch.utils.data.dataset import Dataset
import os
from PIL import Image
def load_image(filename):
img = Image.open(filename)
img = img.convert('RGB')
return img
class PetDataset(Dataset):
def __init__(self, root, train=True, num_cl=37, val_split=0.15, transforms=None):
pt_name = os.path.join(root, '{}{}.pth'.format('train' if train else 'test',
int(100 * (1 - val_split)) if train else int(
100 * val_split)))
if not os.path.exists(pt_name):
filenames = glob(os.path.join(root, 'images') + '/*.jpg')
classes = set()
data = []
labels = []
for image in filenames:
class_name = image.rsplit("/", 1)[1].rsplit('_', 1)[0]
classes.add(class_name)
img = load_image(image)
data.append(img)
labels.append(class_name)
# convert classnames to indices
class2idx = {cl: idx for idx, cl in enumerate(classes)}
labels = torch.Tensor(list(map(lambda x: class2idx[x], labels))).long()
data = list(zip(data, labels))
class_values = [[] for x in range(num_cl)]
# create arrays for each class type
for d in data:
class_values[d[1].item()].append(d)
train_data = []
val_data = []
for class_dp in class_values:
split_idx = int(len(class_dp) * (1 - val_split))
train_data += class_dp[:split_idx]
val_data += class_dp[split_idx:]
torch.save(train_data, os.path.join(root, 'train{}.pth'.format(int(100 * (1 - val_split)))))
torch.save(val_data, os.path.join(root, 'test{}.pth'.format(int(100 * val_split))))
self.data = torch.load(pt_name)
self.len = len(self.data)
self.transform = transforms
def __getitem__(self, index):
img, label = self.data[index]
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return self.len

View File

@@ -0,0 +1,226 @@
import os
import math
import numpy as np
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class STL10DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'stl10'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 10
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/STL10' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/STL10' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = torchvision.datasets.STL10(
root=self.valid_path, split='train', download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = torchvision.datasets.STL10(
root=self.valid_path, split='test', download=False, transform=_transforms)
return dataset
@property
def train_path(self):
# return os.path.join(self.save_path, 'train')
return self.save_path
@property
def valid_path(self):
# return os.path.join(self.save_path, 'val')
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.44671097, 0.4398105, 0.4066468],
std=[0.2603405, 0.25657743, 0.27126738])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.data)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,4 @@
from ofa.imagenet_codebase.networks.proxyless_nets import ProxylessNASNets, proxyless_base, MobileNetV2
from ofa.imagenet_codebase.networks.mobilenet_v3 import MobileNetV3, MobileNetV3Large
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks.nsganetv2 import NSGANetV2

View File

@@ -0,0 +1,126 @@
from timm.models.layers import drop_path
from ofa.imagenet_codebase.modules.layers import *
from ofa.imagenet_codebase.networks import MobileNetV3
class MobileInvertedResidualBlock(MyModule):
"""
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
imagenet_codebase/networks/proxyless_nets.py to include drop path in training
"""
def __init__(self, mobile_inverted_conv, shortcut, drop_connect_rate=0.0):
super(MobileInvertedResidualBlock, self).__init__()
self.mobile_inverted_conv = mobile_inverted_conv
self.shortcut = shortcut
self.drop_connect_rate = drop_connect_rate
def forward(self, x):
if self.mobile_inverted_conv is None or isinstance(self.mobile_inverted_conv, ZeroLayer):
res = x
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
res = self.mobile_inverted_conv(x)
else:
# res = self.mobile_inverted_conv(x) + self.shortcut(x)
res = self.mobile_inverted_conv(x)
if self.drop_connect_rate > 0.:
res = drop_path(res, drop_prob=self.drop_connect_rate, training=self.training)
res += self.shortcut(x)
return res
@property
def module_str(self):
return '(%s, %s)' % (
self.mobile_inverted_conv.module_str if self.mobile_inverted_conv is not None else None,
self.shortcut.module_str if self.shortcut is not None else None
)
@property
def config(self):
return {
'name': MobileInvertedResidualBlock.__name__,
'mobile_inverted_conv': self.mobile_inverted_conv.config if self.mobile_inverted_conv is not None else None,
'shortcut': self.shortcut.config if self.shortcut is not None else None,
}
@staticmethod
def build_from_config(config):
mobile_inverted_conv = set_layer_from_config(config['mobile_inverted_conv'])
shortcut = set_layer_from_config(config['shortcut'])
return MobileInvertedResidualBlock(
mobile_inverted_conv, shortcut, drop_connect_rate=config['drop_connect_rate'])
class NSGANetV2(MobileNetV3):
"""
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
imagenet_codebase/networks/mobilenet_v3.py to include drop path in training
and option to reset classification layer
"""
@staticmethod
def build_from_config(config, drop_connect_rate=0.0):
first_conv = set_layer_from_config(config['first_conv'])
final_expand_layer = set_layer_from_config(config['final_expand_layer'])
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
classifier = set_layer_from_config(config['classifier'])
blocks = []
for block_idx, block_config in enumerate(config['blocks']):
block_config['drop_connect_rate'] = drop_connect_rate * block_idx / len(config['blocks'])
blocks.append(MobileInvertedResidualBlock.build_from_config(block_config))
net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
if 'bn' in config:
net.set_bn_param(**config['bn'])
else:
net.set_bn_param(momentum=0.1, eps=1e-3)
return net
def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, MobileInvertedResidualBlock):
if isinstance(m.mobile_inverted_conv, MBInvertedConvLayer) and isinstance(m.shortcut, IdentityLayer):
m.mobile_inverted_conv.point_linear.bn.weight.data.zero_()
@staticmethod
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
# first conv layer
first_conv = ConvLayer(
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act'
)
# build mobile blocks
feature_dim = input_channel
blocks = []
for stage_id, block_config_list in cfg.items():
for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list:
mb_conv = MBInvertedConvLayer(
feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se
)
if stride == 1 and out_channel == feature_dim:
shortcut = IdentityLayer(out_channel, out_channel)
else:
shortcut = None
blocks.append(MobileInvertedResidualBlock(mb_conv, shortcut))
feature_dim = out_channel
# final expand layer
final_expand_layer = ConvLayer(
feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act',
)
feature_dim = feature_dim * 6
# feature mix layer
feature_mix_layer = ConvLayer(
feature_dim, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
)
# classifier
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
@staticmethod
def reset_classifier(model, last_channel, n_classes, dropout_rate=0.0):
model.classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)

View File

@@ -0,0 +1,309 @@
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.imagenet import *
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.cifar import *
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets import *
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import *
from ofa.imagenet_codebase.run_manager.run_manager import *
class ImagenetRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=1e-4, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='imagenet', train_batch_size=128, test_batch_size=512, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/ILSVRC2012',
**kwargs):
super(ImagenetRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.imagenet_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == ImagenetDataProvider.name():
DataProviderClass = ImagenetDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.imagenet_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class CIFARRunConfig(RunConfig):
def __init__(self, n_epochs=5, init_lr=0.01, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='cifar10', train_batch_size=96, test_batch_size=256, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224,
data_path='/mnt/datastore/CIFAR',
**kwargs):
super(CIFARRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.cifar_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == CIFAR10DataProvider.name():
DataProviderClass = CIFAR10DataProvider
elif self.dataset == CIFAR100DataProvider.name():
DataProviderClass = CIFAR100DataProvider
elif self.dataset == CINIC10DataProvider.name():
DataProviderClass = CINIC10DataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.cifar_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class Flowers102RunConfig(RunConfig):
def __init__(self, n_epochs=3, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='flowers102', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
data_path='/mnt/datastore/Flowers102',
**kwargs):
super(Flowers102RunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.flowers102_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == Flowers102DataProvider.name():
DataProviderClass = Flowers102DataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.flowers102_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class STL10RunConfig(RunConfig):
def __init__(self, n_epochs=5, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='stl10', train_batch_size=96, test_batch_size=256, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
data_path='/mnt/datastore/STL10',
**kwargs):
super(STL10RunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.stl10_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == STL10DataProvider.name():
DataProviderClass = STL10DataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.stl10_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class DTDRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='dtd', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/dtd',
**kwargs):
super(DTDRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == DTDDataProvider.name():
DataProviderClass = DTDDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class PetsRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='pets', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/Oxford-IIITPets',
**kwargs):
super(PetsRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.imagenet_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == OxfordIIITPetsDataProvider.name():
DataProviderClass = OxfordIIITPetsDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.imagenet_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class AircraftRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='aircraft', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/Aircraft',
**kwargs):
super(AircraftRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == FGVCAircraftDataProvider.name():
DataProviderClass = FGVCAircraftDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
def get_run_config(**kwargs):
if kwargs['dataset'] == 'imagenet':
run_config = ImagenetRunConfig(**kwargs)
elif kwargs['dataset'].startswith('cifar') or kwargs['dataset'].startswith('cinic'):
run_config = CIFARRunConfig(**kwargs)
elif kwargs['dataset'] == 'flowers102':
run_config = Flowers102RunConfig(**kwargs)
elif kwargs['dataset'] == 'stl10':
run_config = STL10RunConfig(**kwargs)
elif kwargs['dataset'] == 'dtd':
run_config = DTDRunConfig(**kwargs)
elif kwargs['dataset'] == 'pets':
run_config = PetsRunConfig(**kwargs)
elif kwargs['dataset'] == 'aircraft':
run_config = AircraftRunConfig(**kwargs)
elif kwargs['dataset'] == 'aircraft100':
run_config = AircraftRunConfig(**kwargs)
else:
raise NotImplementedError
return run_config

View File

@@ -0,0 +1,122 @@
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
import torchvision.utils
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import FGVCAircraft
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets2 import PetDataset
import torch.utils.data as Data
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.autoaugment import CIFAR10Policy
def get_dataset(data_name, batch_size, data_path, num_workers,
img_size, autoaugment, cutout, cutout_length):
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'mnist': 10,
'aircraft': 100,
'svhn': 10,
'pets': 37
}
# 'aircraft30': 30,
# 'aircraft100': 100,
train_transform, valid_transform = _data_transforms(
data_name, img_size, autoaugment, cutout, cutout_length)
if data_name == 'cifar100':
train_data = torchvision.datasets.CIFAR100(
root=data_path, train=True, download=True, transform=train_transform)
valid_data = torchvision.datasets.CIFAR100(
root=data_path, train=False, download=True, transform=valid_transform)
elif data_name == 'cifar10':
train_data = torchvision.datasets.CIFAR10(
root=data_path, train=True, download=True, transform=train_transform)
valid_data = torchvision.datasets.CIFAR10(
root=data_path, train=False, download=True, transform=valid_transform)
elif data_name.startswith('aircraft'):
print(data_path)
if 'aircraft100' in data_path:
data_path = data_path.replace('aircraft100', 'aircraft/fgvc-aircraft-2013b')
else:
data_path = data_path.replace('aircraft', 'aircraft/fgvc-aircraft-2013b')
train_data = FGVCAircraft(data_path, class_type='variant', split='trainval',
transform=train_transform, download=True)
valid_data = FGVCAircraft(data_path, class_type='variant', split='test',
transform=valid_transform, download=True)
elif data_name.startswith('pets'):
train_data = PetDataset(data_path, train=True, num_cl=37,
val_split=0.15, transforms=train_transform)
valid_data = PetDataset(data_path, train=False, num_cl=37,
val_split=0.15, transforms=valid_transform)
else:
raise KeyError
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=batch_size, shuffle=True, pin_memory=True,
num_workers=num_workers)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=200, shuffle=False, pin_memory=True,
num_workers=num_workers)
return train_queue, valid_queue, num_class_dict[data_name]
class Cutout(object):
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def _data_transforms(data_name, img_size, autoaugment, cutout, cutout_length):
if 'cifar' in data_name:
norm_mean = [0.49139968, 0.48215827, 0.44653124]
norm_std = [0.24703233, 0.24348505, 0.26158768]
elif 'aircraft' in data_name:
norm_mean = [0.48933587508932375, 0.5183537408957618, 0.5387914411673883]
norm_std = [0.22388883112804625, 0.21641635409388751, 0.24615605842636115]
elif 'pets' in data_name:
norm_mean = [0.4828895122298728, 0.4448394893850807, 0.39566558230789783]
norm_std = [0.25925664613996574, 0.2532760018681693, 0.25981017205097917]
else:
raise KeyError
train_transform = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
transforms.RandomHorizontalFlip(),
])
if autoaugment:
train_transform.transforms.append(CIFAR10Policy())
train_transform.transforms.append(transforms.ToTensor())
if cutout:
train_transform.transforms.append(Cutout(cutout_length))
train_transform.transforms.append(transforms.Normalize(norm_mean, norm_std))
valid_transform = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
return train_transform, valid_transform

View File

@@ -0,0 +1,233 @@
import os
import torch
import numpy as np
import random
import sys
import transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.run_manager import get_run_config
from ofa.elastic_nn.networks import OFAMobileNetV3
from ofa.imagenet_codebase.run_manager import RunManager
from ofa.elastic_nn.modules.dynamic_op import DynamicSeparableConv2d
from torchprofile import profile_macs
import copy
import json
import warnings
warnings.simplefilter("ignore")
DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = 1
class ArchManager:
def __init__(self):
self.num_blocks = 20
self.num_stages = 5
self.kernel_sizes = [3, 5, 7]
self.expand_ratios = [3, 4, 6]
self.depths = [2, 3, 4]
self.resolutions = [160, 176, 192, 208, 224]
def random_sample(self):
sample = {}
d = []
e = []
ks = []
for i in range(self.num_stages):
d.append(random.choice(self.depths))
for i in range(self.num_blocks):
e.append(random.choice(self.expand_ratios))
ks.append(random.choice(self.kernel_sizes))
sample = {
'wid': None,
'ks': ks,
'e': e,
'd': d,
'r': [random.choice(self.resolutions)]
}
return sample
def random_resample(self, sample, i):
assert i >= 0 and i < self.num_blocks
sample['ks'][i] = random.choice(self.kernel_sizes)
sample['e'][i] = random.choice(self.expand_ratios)
def random_resample_depth(self, sample, i):
assert i >= 0 and i < self.num_stages
sample['d'][i] = random.choice(self.depths)
def random_resample_resolution(self, sample):
sample['r'][0] = random.choice(self.resolutions)
def parse_string_list(string):
if isinstance(string, str):
# convert '[5 5 5 7 7 7 3 3 7 7 7 3 3]' to [5, 5, 5, 7, 7, 7, 3, 3, 7, 7, 7, 3, 3]
return list(map(int, string[1:-1].split()))
else:
return string
def pad_none(x, depth, max_depth):
new_x, counter = [], 0
for d in depth:
for _ in range(d):
new_x.append(x[counter])
counter += 1
if d < max_depth:
new_x += [None] * (max_depth - d)
return new_x
def get_net_info(net, data_shape, measure_latency=None, print_info=True, clean=False, lut=None):
net_info = eval_utils.get_net_info(
net, data_shape, measure_latency, print_info=print_info, clean=clean, lut=lut)
gpu_latency, cpu_latency = None, None
for k in net_info.keys():
if 'gpu' in k:
gpu_latency = np.round(net_info[k]['val'], 2)
if 'cpu' in k:
cpu_latency = np.round(net_info[k]['val'], 2)
return {
'params': np.round(net_info['params'] / 1e6, 2),
'flops': np.round(net_info['flops'] / 1e6, 2),
'gpu': gpu_latency, 'cpu': cpu_latency
}
def validate_config(config, max_depth=4):
kernel_size, exp_ratio, depth = config['ks'], config['e'], config['d']
if isinstance(kernel_size, str): kernel_size = parse_string_list(kernel_size)
if isinstance(exp_ratio, str): exp_ratio = parse_string_list(exp_ratio)
if isinstance(depth, str): depth = parse_string_list(depth)
assert (isinstance(kernel_size, list) or isinstance(kernel_size, int))
assert (isinstance(exp_ratio, list) or isinstance(exp_ratio, int))
assert isinstance(depth, list)
if len(kernel_size) < len(depth) * max_depth:
kernel_size = pad_none(kernel_size, depth, max_depth)
if len(exp_ratio) < len(depth) * max_depth:
exp_ratio = pad_none(exp_ratio, depth, max_depth)
# return {'ks': kernel_size, 'e': exp_ratio, 'd': depth, 'w': config['w']}
return {'ks': kernel_size, 'e': exp_ratio, 'd': depth}
def set_nas_test_dataset(path, test_data_name, max_img):
if not test_data_name in ['mnist', 'svhn', 'cifar10',
'cifar100', 'aircraft', 'pets']: raise ValueError(test_data_name)
dpath = path
num_cls = 10 # mnist, svhn, cifar10
if test_data_name in ['cifar100', 'aircraft']:
num_cls = 100
elif test_data_name == 'pets':
num_cls = 37
x = torch.load(dpath + f'/{test_data_name}bylabel')
img_per_cls = min(int(max_img / num_cls), 20)
return x, img_per_cls, num_cls
class OFAEvaluator:
""" based on OnceForAll supernet taken from https://github.com/mit-han-lab/once-for-all """
def __init__(self, num_gen_arch, img_size, drop_path,
n_classes=1000,
model_path=None,
kernel_size=None, exp_ratio=None, depth=None):
# default configurations
self.kernel_size = [3, 5, 7] if kernel_size is None else kernel_size # depth-wise conv kernel size
self.exp_ratio = [3, 4, 6] if exp_ratio is None else exp_ratio # expansion rate
self.depth = [2, 3, 4] if depth is None else depth # number of MB block repetition
if 'w1.0' in model_path:
self.width_mult = 1.0
elif 'w1.2' in model_path:
self.width_mult = 1.2
else:
raise ValueError
self.engine = OFAMobileNetV3(
n_classes=n_classes,
dropout_rate=0, width_mult_list=self.width_mult, ks_list=self.kernel_size,
expand_ratio_list=self.exp_ratio, depth_list=self.depth)
init = torch.load(model_path, map_location='cpu')['state_dict']
self.engine.load_weights_from_net(init)
print(f'load {model_path}...')
## metad2a
self.arch_manager = ArchManager()
self.num_gen_arch = num_gen_arch
def sample_random_architecture(self):
sampled_architecture = self.arch_manager.random_sample()
return sampled_architecture
def get_architecture(self, bound=None):
g_lst, pred_acc_lst, x_lst = [], [], []
searched_g, max_pred_acc = None, 0
with torch.no_grad():
for n in range(self.num_gen_arch):
file_acc = self.lines[n].split()[0]
g_dict = ' '.join(self.lines[n].split())
g = json.loads(g_dict.replace("'", "\""))
if bound is not None:
subnet, config = self.sample(config=g)
net = NSGANetV2.build_from_config(subnet.config,
drop_connect_rate=self.drop_path)
inputs = torch.randn(1, 3, self.img_size, self.img_size)
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
if flops <= bound:
searched_g = g
break
else:
searched_g = g
pred_acc_lst.append(file_acc)
break
if searched_g is None:
raise ValueError(searched_g)
return searched_g, pred_acc_lst
def sample(self, config=None):
""" randomly sample a sub-network """
if config is not None:
config = validate_config(config)
self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d'])
else:
config = self.engine.sample_active_subnet()
subnet = self.engine.get_active_subnet(preserve_weight=True)
return subnet, config
@staticmethod
def save_net_config(path, net, config_name='net.config'):
""" dump run_config and net_config to the model_folder """
net_save_path = os.path.join(path, config_name)
json.dump(net.config, open(net_save_path, 'w'), indent=4)
print('Network configs dump to %s' % net_save_path)
@staticmethod
def save_net(path, net, model_name):
""" dump net weight as checkpoint """
if isinstance(net, torch.nn.DataParallel):
checkpoint = {'state_dict': net.module.state_dict()}
else:
checkpoint = {'state_dict': net.state_dict()}
model_path = os.path.join(path, model_name)
torch.save(checkpoint, model_path)
print('Network model dump to %s' % model_path)

View File

@@ -0,0 +1,169 @@
import os
import sys
import json
import logging
import numpy as np
import copy
import torch
import torch.nn as nn
import random
import torch.optim as optim
from evaluator import OFAEvaluator
from torchprofile import profile_macs
from codebase.networks import NSGANetV2
from parser import get_parse
from eval_utils import get_dataset
args = get_parse()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device_list = [int(_) for _ in args.gpu.split(',')]
args.n_gpus = len(device_list)
args.device = torch.device("cuda:0")
if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
evaluator = OFAEvaluator(args,
model_path='../.torch/ofa_nets/ofa_mbv3_d234_e346_k357_w1.0')
args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
if args.model_config.startswith('flops@'):
args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
else:
args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
args.data_path = os.path.join(args.data_path, args.data_name)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
if not torch.cuda.is_available():
logging.info('no gpu self.args.device available')
sys.exit(1)
logging.info("args = %s", args)
def set_architecture(n_cls):
if args.model_config.startswith('flops@'):
names = {'cifar10': 'CIFAR-10', 'cifar100': 'CIFAR-100',
'aircraft100': 'Aircraft', 'pets': 'Pets'}
p = os.path.join('./searched-architectures/{}/net-{}/net.subnet'.
format(names[args.data_name], args.model_config))
g = json.load(open(p))
else:
g, acc = evaluator.get_architecture(args)
subnet, config = evaluator.sample(g)
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=args.drop_path)
net.load_state_dict(subnet.state_dict())
NSGANetV2.reset_classifier(
net, last_channel=net.classifier.in_features,
n_classes=n_cls, dropout_rate=args.drop)
# calculate #Paramaters and #FLOPS
inputs = torch.randn(1, 3, args.img_size, args.img_size)
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
net_name = "net_flops@{:.0f}".format(flops)
logging.info('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
OFAEvaluator.save_net_config(args.save_path, net, net_name + '.config')
if args.n_gpus > 1:
net = nn.DataParallel(net) # data parallel in case more than 1 gpu available
net = net.to(args.device)
return net, net_name
def train(train_queue, net, criterion, optimizer):
net.train()
train_loss, correct, total = 0, 0, 0
for step, (inputs, targets) in enumerate(train_queue):
# upsample by bicubic to match imagenet training size
inputs, targets = inputs.to(args.device), targets.to(args.device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip)
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % args.report_freq == 0:
logging.info('train %03d %e %f', step, train_loss / total, 100. * correct / total)
logging.info('train acc %f', 100. * correct / total)
return train_loss / total, 100. * correct / total
def infer(valid_queue, net, criterion, early_stop=False):
net.eval()
test_loss, correct, total = 0, 0, 0
with torch.no_grad():
for step, (inputs, targets) in enumerate(valid_queue):
inputs, targets = inputs.to(args.device), targets.to(args.device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % args.report_freq == 0:
logging.info('valid %03d %e %f', step, test_loss / total, 100. * correct / total)
if early_stop and step == 10:
break
acc = 100. * correct / total
logging.info('valid acc %f', 100. * correct / total)
return test_loss / total, acc
def main():
best_acc, top_checkpoints = 0, []
train_queue, valid_queue, n_cls = get_dataset(args)
net, net_name = set_architecture(n_cls)
parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
criterion = nn.CrossEntropyLoss().to(args.device)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
for epoch in range(args.epochs):
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
train(train_queue, net, criterion, optimizer)
_, valid_acc = infer(valid_queue, net, criterion)
# checkpoint saving
if len(top_checkpoints) < args.topk:
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
else:
idx = np.argmin([x[1] for x in top_checkpoints])
if valid_acc > top_checkpoints[idx][1]:
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
# remove the idx
os.remove(top_checkpoints[idx][0])
top_checkpoints.pop(idx)
print(top_checkpoints)
if valid_acc > best_acc:
OFAEvaluator.save_net(args.save_path, net, net_name + '.best')
best_acc = valid_acc
scheduler.step()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,43 @@
import argparse
def get_parse():
parser = argparse.ArgumentParser(description='MetaD2A vs NSGANETv2')
parser.add_argument('--save-path', type=str, default='../results', help='the path of save directory')
parser.add_argument('--data-path', type=str, default='../data', help='the path of save directory')
parser.add_argument('--data-name', type=str, default=None, help='meta-test dataset name')
parser.add_argument('--num-gen-arch', type=int, default=200,
help='the number of candidate architectures generated by the generator')
parser.add_argument('--bound', type=int, default=None)
# original setting
parser.add_argument('--seed', type=int, default=-1, help='random seed')
parser.add_argument('--batch-size', type=int, default=96, help='batch size')
parser.add_argument('--num_workers', type=int, default=2, help='number of workers for data loading')
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
parser.add_argument('--lr', type=float, default=0.01, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=4e-5, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--epochs', type=int, default=150, help='num of training epochs')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--autoaugment', action='store_true', default=True, help='use auto augmentation')
parser.add_argument('--topk', type=int, default=10, help='top k checkpoints to save')
parser.add_argument('--evaluate', action='store_true', default=False, help='evaluate a pretrained model')
# model related
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--model-config', type=str, default='search',
help='location of a json file of specific model declaration')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--drop', type=float, default=0.2,
help='dropout rate')
parser.add_argument('--drop-path', type=float, default=0.2, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--img-size', type=int, default=224,
help='input resolution (192 -> 256)')
args = parser.parse_args()
return args

View File

@@ -0,0 +1,261 @@
import os
import sys
import json
import logging
import numpy as np
import copy
import torch
import torch.nn as nn
import random
import torch.optim as optim
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.evaluator import OFAEvaluator
from torchprofile import profile_macs
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.parser import get_parse
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils import get_dataset
from transfer_nag_lib.MetaD2A_nas_bench_201.metad2a_utils import reset_seed
from transfer_nag_lib.ofa_net import OFASubNet
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
# device_list = [int(_) for _ in args.gpu.split(',')]
# args.n_gpus = len(device_list)
# args.device = torch.device("cuda:0")
# if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
# torch.cuda.manual_seed(args.seed)
# torch.manual_seed(args.seed)
# np.random.seed(args.seed)
# random.seed(args.seed)
# args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
# if args.model_config.startswith('flops@'):
# args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
# else:
# args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
# if not os.path.exists(args.save_path):
# os.makedirs(args.save_path)
# args.data_path = os.path.join(args.data_path, args.data_name)
# log_format = '%(asctime)s %(message)s'
# logging.basicConfig(stream=sys.stdout, level=print,
# format=log_format, datefmt='%m/%d %I:%M:%S %p')
# fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
# fh.setFormatter(logging.Formatter(log_format))
# logging.getLogger().addHandler(fh)
# if not torch.cuda.is_available():
# print('no gpu self.args.device available')
# sys.exit(1)
# print("args = %s", args)
def set_architecture(n_cls, evaluator, drop_path, drop, img_size, n_gpus, device, save_path, model_str):
# g, acc = evaluator.get_architecture(model_str)
g = OFASubNet(model_str).get_op_dict()
subnet, config = evaluator.sample(g)
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=drop_path)
net.load_state_dict(subnet.state_dict())
NSGANetV2.reset_classifier(
net, last_channel=net.classifier.in_features,
n_classes=n_cls, dropout_rate=drop)
# calculate #Paramaters and #FLOPS
inputs = torch.randn(1, 3, img_size, img_size)
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
net_name = "net_flops@{:.0f}".format(flops)
print('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
# OFAEvaluator.save_net_config(save_path, net, net_name + '.config')
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
net = nn.DataParallel(net)
net = net.to(device)
return net, net_name, params, flops
def train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq):
net.train()
train_loss, correct, total = 0, 0, 0
for step, (inputs, targets) in enumerate(train_queue):
# upsample by bicubic to match imagenet training size
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % report_freq == 0:
print(f'train step {step:03d} loss {train_loss / total:.4f} train acc {100. * correct / total:.4f}')
print(f'train acc {100. * correct / total:.4f}')
return train_loss / total, 100. * correct / total
def infer(valid_queue, net, criterion, device, report_freq, early_stop=False):
net.eval()
test_loss, correct, total = 0, 0, 0
with torch.no_grad():
for step, (inputs, targets) in enumerate(valid_queue):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % report_freq == 0:
print(f'valid {step:03d} {test_loss / total:.4f} {100. * correct / total:.4f}')
if early_stop and step == 10:
break
acc = 100. * correct / total
print('valid acc {:.4f}'.format(100. * correct / total))
return test_loss / total, acc
def train_single_model(save_path, workers, datasets, xpaths, splits, use_less,
seed, model_str, device,
lr=0.01,
momentum=0.9,
weight_decay=4e-5,
report_freq=50,
epochs=150,
grad_clip=5,
cutout=True,
cutout_length=16,
autoaugment=True,
drop=0.2,
drop_path=0.2,
img_size=224,
batch_size=96,
):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
reset_seed(seed)
# save_dir = Path(save_dir)
# logger = Logger(str(save_dir), 0, False)
os.makedirs(save_path, exist_ok=True)
to_save_name = save_path + '/seed-{:04d}.pth'.format(seed)
print(to_save_name)
# args = get_parse()
num_gen_arch = None
evaluator = OFAEvaluator(num_gen_arch, img_size, drop_path,
model_path='/home/data/GTAD/checkpoints/ofa/ofa_net/ofa_mbv3_d234_e346_k357_w1.0')
train_queue, valid_queue, n_cls = get_dataset(datasets, batch_size,
xpaths, workers, img_size, autoaugment, cutout, cutout_length)
net, net_name, params, flops = set_architecture(n_cls, evaluator,
drop_path, drop, img_size, n_gpus=1, device=device, save_path=save_path, model_str=model_str)
# net.to(device)
parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss().to(device)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
# assert epochs == 1
max_valid_acc = 0
max_epoch = 0
for epoch in range(epochs):
print('epoch {:d} lr {:.4f}'.format(epoch, scheduler.get_lr()[0]))
train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq)
_, valid_acc = infer(valid_queue, net, criterion, device, report_freq)
torch.save(valid_acc, to_save_name)
print(f'seed {seed:04d} last acc {valid_acc:.4f} max acc {max_valid_acc:.4f}')
if max_valid_acc < valid_acc:
max_valid_acc = valid_acc
max_epoch = epoch
# parent_path = os.path.abspath(os.path.join(save_path, os.pardir))
# with open(parent_path + '/accuracy.txt', 'a+') as f:
# f.write(f'{model_str} seed {seed:04d} {valid_acc:.4f}\n')
return valid_acc, max_valid_acc, params, flops
################ NAS BENCH 201 #####################
# def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
# seeds, model_str, arch_config):
# assert torch.cuda.is_available(), 'CUDA is not available.'
# torch.backends.cudnn.enabled = True
# torch.backends.cudnn.deterministic = True
# torch.set_num_threads(workers)
# save_dir = Path(save_dir)
# logger = Logger(str(save_dir), 0, False)
# if model_str in CellArchitectures:
# arch = CellArchitectures[model_str]
# logger.log(
# 'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
# else:
# try:
# arch = CellStructure.str2structure(model_str)
# except:
# raise ValueError(
# 'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
# assert arch.check_valid_op(get_search_spaces(
# 'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
# # assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
# logger.log('Start train-evaluate {:}'.format(arch.tostr()))
# logger.log('arch_config : {:}'.format(arch_config))
# start_time, seed_time = time.time(), AverageMeter()
# for _is, seed in enumerate(seeds):
# logger.log(
# '\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
# seed))
# to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
# if to_save_name.exists():
# logger.log(
# 'Find the existing file {:}, directly load!'.format(to_save_name))
# checkpoint = torch.load(to_save_name)
# else:
# logger.log(
# 'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
# checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
# seed, arch_config, workers, logger)
# torch.save(checkpoint, to_save_name)
# # log information
# logger.log('{:}'.format(checkpoint['info']))
# all_dataset_keys = checkpoint['all_dataset_keys']
# for dataset_key in all_dataset_keys:
# logger.log('\n{:} dataset : {:} {:}'.format(
# '-' * 15, dataset_key, '-' * 15))
# dataset_info = checkpoint[dataset_key]
# # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
# logger.log('Flops = {:} MB, Params = {:} MB'.format(
# dataset_info['flop'], dataset_info['param']))
# logger.log('config : {:}'.format(dataset_info['config']))
# logger.log('Training State (finish) = {:}'.format(
# dataset_info['finish-train']))
# last_epoch = dataset_info['total_epoch'] - 1
# train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
# valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
# # measure elapsed time
# seed_time.update(time.time() - start_time)
# start_time = time.time()
# need_time = 'Time Left: {:}'.format(convert_secs2time(
# seed_time.avg * (len(seeds) - _is - 1), True))
# logger.log(
# '\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
# need_time))
# logger.close()
# ###################
if __name__ == '__main__':
train_single_model()

View File

@@ -0,0 +1,5 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from .generator import Generator

View File

@@ -0,0 +1,204 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import random
from tqdm import tqdm
import numpy as np
import time
import torch
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import load_graph_config, decode_ofa_mbv3_to_igraph, decode_igraph_to_ofa_mbv3
from utils import Accumulator, Log
from utils import load_model, save_model
from loader import get_meta_train_loader, get_meta_test_loader
from .generator_model import GeneratorModel
class Generator:
def __init__(self, args):
self.args = args
self.batch_size = args.batch_size
self.data_path = args.data_path
self.num_sample = args.num_sample
self.max_epoch = args.max_epoch
self.save_epoch = args.save_epoch
self.model_path = args.model_path
self.save_path = args.save_path
self.model_name = args.model_name
self.test = args.test
self.device = args.device
graph_config = load_graph_config(
args.graph_data_name, args.nvt, args.data_path)
self.model = GeneratorModel(args, graph_config)
self.model.to(self.device)
if self.test:
self.data_name = args.data_name
self.num_class = args.num_class
self.load_epoch = args.load_epoch
self.num_gen_arch = args.num_gen_arch
load_model(self.model, self.model_path, self.load_epoch)
else:
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
self.scheduler = ReduceLROnPlateau(self.optimizer, 'min',
factor=0.1, patience=10, verbose=True)
self.mtrloader = get_meta_train_loader(
self.batch_size, self.data_path, self.num_sample)
self.mtrlog = Log(self.args, open(os.path.join(
self.save_path, self.model_name, 'meta_train_generator.log'), 'w'))
self.mtrlog.print_args()
self.mtrlogger = Accumulator('loss', 'recon_loss', 'kld')
self.mvallogger = Accumulator('loss', 'recon_loss', 'kld')
def meta_train(self):
sttime = time.time()
for epoch in range(1, self.max_epoch + 1):
self.mtrlog.ep_sttime = time.time()
loss = self.meta_train_epoch(epoch)
self.scheduler.step(loss)
self.mtrlog.print(self.mtrlogger, epoch, tag='train')
self.meta_validation()
self.mtrlog.print(self.mvallogger, epoch, tag='valid')
if epoch % self.save_epoch == 0:
save_model(epoch, self.model, self.model_path)
self.mtrlog.save_time_log()
def meta_train_epoch(self, epoch):
self.model.to(self.device)
self.model.train()
self.mtrloader.dataset.set_mode('train')
pbar = tqdm(self.mtrloader)
for batch in pbar:
for x, g, acc in batch:
self.optimizer.zero_grad()
g = decode_ofa_mbv3_to_igraph(g)[0]
x_ = x.unsqueeze(0).to(self.device)
mu, logvar = self.model.set_encode(x_)
loss, recon, kld = self.model.loss(mu.unsqueeze(0), logvar.unsqueeze(0), [g])
loss.backward()
self.optimizer.step()
cnt = len(x)
self.mtrlogger.accum([loss.item() / cnt,
recon.item() / cnt,
kld.item() / cnt])
return self.mtrlogger.get('loss')
def meta_validation(self):
self.model.to(self.device)
self.model.eval()
self.mtrloader.dataset.set_mode('valid')
pbar = tqdm(self.mtrloader)
for batch in pbar:
for x, g, acc in batch:
with torch.no_grad():
g = decode_ofa_mbv3_to_igraph(g)[0]
x_ = x.unsqueeze(0).to(self.device)
mu, logvar = self.model.set_encode(x_)
loss, recon, kld = self.model.loss(mu.unsqueeze(0), logvar.unsqueeze(0), [g])
cnt = len(x)
self.mvallogger.accum([loss.item() / cnt,
recon.item() / cnt,
kld.item() / cnt])
return self.mvallogger.get('loss')
def meta_test(self, predictor):
if self.data_name == 'all':
for data_name in ['cifar100', 'cifar10', 'mnist', 'svhn', 'aircraft30', 'aircraft100', 'pets']:
self.meta_test_per_dataset(data_name, predictor)
else:
self.meta_test_per_dataset(self.data_name, predictor)
def meta_test_per_dataset(self, data_name, predictor):
# meta_test_path = os.path.join(
# self.save_path, 'meta_test', data_name, 'generated_arch')
meta_test_path = os.path.join(
self.save_path, 'meta_test', data_name, f'{self.num_gen_arch}', 'generated_arch')
if not os.path.exists(meta_test_path):
os.makedirs(meta_test_path)
meta_test_loader = get_meta_test_loader(
self.data_path, data_name, self.num_sample, self.num_class)
print(f'==> generate architectures for {data_name}')
runs = 10 if data_name in ['cifar10', 'cifar100'] else 1
# num_gen_arch = 500 if data_name in ['cifar100'] else self.num_gen_arch
elasped_time = []
for run in range(1, runs + 1):
print(f'==> run {run}/{runs}')
elasped_time.append(self.generate_architectures(
meta_test_loader, data_name,
meta_test_path, run, self.num_gen_arch, predictor))
print(f'==> done\n')
# time_path = os.path.join(self.save_path, 'meta_test', data_name, 'time.txt')
time_path = os.path.join(self.save_path, 'meta_test', data_name, f'{self.num_gen_arch}', 'time.txt')
with open(time_path, 'w') as f_time:
msg = f'generator elasped time {np.mean(elasped_time):.2f}s'
print(f'==> save time in {time_path}')
f_time.write(msg + '\n');
print(msg)
def generate_architectures(self, meta_test_loader, data_name,
meta_test_path, run, num_gen_arch, predictor):
self.model.eval()
self.model.to(self.device)
architecture_string_lst, pred_acc_lst = [], []
total_cnt, valid_cnt = 0, 0
flag = False
start = time.time()
with torch.no_grad():
for x in meta_test_loader:
x_ = x.unsqueeze(0).to(self.device)
mu, logvar = self.model.set_encode(x_)
z = self.model.reparameterize(mu.unsqueeze(0), logvar.unsqueeze(0))
g_recon = self.model.graph_decode(z)
pred_acc = predictor.forward(x_, g_recon)
architecture_string = decode_igraph_to_ofa_mbv3(g_recon[0])
total_cnt += 1
if architecture_string is not None:
if not architecture_string in architecture_string_lst:
valid_cnt += 1
architecture_string_lst.append(architecture_string)
pred_acc_lst.append(pred_acc.item())
if valid_cnt == num_gen_arch:
flag = True
break
if flag:
break
elapsed = time.time() - start
pred_acc_lst, architecture_string_lst = zip(*sorted(zip(pred_acc_lst,
architecture_string_lst),
key=lambda x: x[0], reverse=True))
spath = os.path.join(meta_test_path, f"run_{run}.txt")
with open(spath, 'w') as f:
print(f'==> save generated architectures in {spath}')
msg = f'elapsed time: {elapsed:6.2f}s '
print(msg);
f.write(msg + '\n')
for i, architecture_string in enumerate(architecture_string_lst):
f.write(f"{architecture_string}\n")
return elapsed

View File

@@ -0,0 +1,396 @@
######################################################################################
# Copyright (c) muhanzhang, D-VAE, NeurIPS 2019 [GitHub D-VAE]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import igraph
from set_encoder.setenc_models import SetPool
class GeneratorModel(nn.Module):
def __init__(self, args, graph_config):
super(GeneratorModel, self).__init__()
self.max_n = graph_config['max_n'] # maximum number of vertices
self.nvt = graph_config['num_vertex_type'] # number of vertex types
self.START_TYPE = graph_config['START_TYPE']
self.END_TYPE = graph_config['END_TYPE']
self.hs = args.hs # hidden state size of each vertex
self.nz = args.nz # size of latent representation z
self.gs = args.hs # size of graph state
self.bidir = True # whether to use bidirectional encoding
self.vid = True
self.device = None
self.num_sample = args.num_sample
if self.vid:
self.vs = self.hs + self.max_n # vertex state size = hidden state + vid
else:
self.vs = self.hs
# 0. encoding-related
self.grue_forward = nn.GRUCell(self.nvt, self.hs) # encoder GRU
self.grue_backward = nn.GRUCell(self.nvt, self.hs) # backward encoder GRU
self.enc_g_mu = nn.Linear(self.gs, self.nz) # latent mean
self.enc_g_var = nn.Linear(self.gs, self.nz) # latent var
self.fc1 = nn.Linear(self.gs, self.nz) # latent mean
self.fc2 = nn.Linear(self.gs, self.nz) # latent logvar
# 1. decoding-related
self.grud = nn.GRUCell(self.nvt, self.hs) # decoder GRU
self.fc3 = nn.Linear(self.nz, self.hs) # from latent z to initial hidden state h0
self.add_vertex = nn.Sequential(
nn.Linear(self.hs, self.hs * 2),
nn.ReLU(),
nn.Linear(self.hs * 2, self.nvt)
) # which type of new vertex to add f(h0, hg)
self.add_edge = nn.Sequential(
nn.Linear(self.hs * 2, self.hs * 4),
nn.ReLU(),
nn.Linear(self.hs * 4, 1)
) # whether to add edge between v_i and v_new, f(hvi, hnew)
self.decoding_gate = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.decoding_mapper = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
) # disable bias to ensure padded zeros also mapped to zeros
# 2. gate-related
self.gate_forward = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.gate_backward = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.mapper_forward = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
) # disable bias to ensure padded zeros also mapped to zeros
self.mapper_backward = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
)
# 3. bidir-related, to unify sizes
if self.bidir:
self.hv_unify = nn.Sequential(
nn.Linear(self.hs * 2, self.hs),
)
self.hg_unify = nn.Sequential(
nn.Linear(self.gs * 2, self.gs),
)
# 4. other
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.logsoftmax1 = nn.LogSoftmax(1)
# 6. predictor
np = self.gs
self.intra_setpool = SetPool(dim_input=512,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.inter_setpool = SetPool(dim_input=self.nz,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.set_fc = nn.Sequential(
nn.Linear(512, self.nz),
nn.ReLU())
def get_device(self):
if self.device is None:
self.device = next(self.parameters()).device
return self.device
def _get_zeros(self, n, length):
return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state
def _get_zero_hidden(self, n=1):
return self._get_zeros(n, self.hs) # get a zero hidden state
def _one_hot(self, idx, length):
if type(idx) in [list, range]:
if idx == []:
return None
idx = torch.LongTensor(idx).unsqueeze(0).t()
x = torch.zeros((len(idx), length)
).scatter_(1, idx, 1).to(self.get_device())
else:
idx = torch.LongTensor([idx]).unsqueeze(0)
x = torch.zeros((1, length)
).scatter_(1, idx, 1).to(self.get_device())
return x
def _gated(self, h, gate, mapper):
return gate(h) * mapper(h)
def _collate_fn(self, G):
return [g.copy() for g in G]
def _propagate_to(self, G, v, propagator,
H=None, reverse=False, gate=None, mapper=None):
# propagate messages to vertex index v for all graphs in G
# return the new messages (states) at v
G = [g for g in G if g.vcount() > v]
if len(G) == 0:
return
if H is not None:
idx = [i for i, g in enumerate(G) if g.vcount() > v]
H = H[idx]
v_types = [g.vs[v]['type'] for g in G]
X = self._one_hot(v_types, self.nvt)
H_name = 'H_forward' # name of the hidden states attribute
H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
if self.vid:
vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
if reverse:
H_name = 'H_backward' # name of the hidden states attribute
H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G]
if self.vid:
vids = [self._one_hot(g.successors(v), self.max_n) for g in G]
gate, mapper = self.gate_backward, self.mapper_backward
else:
H_name = 'H_forward' # name of the hidden states attribute
H_pred = [
[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
if self.vid:
vids = [
self._one_hot(g.predecessors(v), self.max_n) for g in G]
if gate is None:
gate, mapper = self.gate_forward, self.mapper_forward
if self.vid:
H_pred = [[torch.cat(
[x[i], y[i:i + 1]], 1) for i in range(len(x))
] for x, y in zip(H_pred, vids)]
# if h is not provided, use gated sum of v's predecessors' states as the input hidden state
if H is None:
max_n_pred = max([len(x) for x in H_pred]) # maximum number of predecessors
if max_n_pred == 0:
H = self._get_zero_hidden(len(G))
else:
H_pred = [torch.cat(h_pred +
[self._get_zeros(max_n_pred - len(h_pred),
self.vs)], 0).unsqueeze(0)
for h_pred in H_pred] # pad all to same length
H_pred = torch.cat(H_pred, 0) # batch * max_n_pred * vs
H = self._gated(H_pred, gate, mapper).sum(1) # batch * hs
Hv = propagator(X, H)
for i, g in enumerate(G):
g.vs[v][H_name] = Hv[i:i + 1]
return Hv
def _propagate_from(self, G, v, propagator, H0=None, reverse=False):
# perform a series of propagation_to steps starting from v following a topo order
# assume the original vertex indices are in a topological order
if reverse:
prop_order = range(v, -1, -1)
else:
prop_order = range(v, self.max_n)
Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse) # the initial vertex
for v_ in prop_order[1:]:
self._propagate_to(G, v_, propagator, reverse=reverse)
return Hv
def _update_v(self, G, v, H0=None):
# perform a forward propagation step at v when decoding to update v's state
# self._propagate_to(G, v, self.grud, H0, reverse=False)
self._propagate_to(G, v, self.grud, H0,
reverse=False, gate=self.decoding_gate,
mapper=self.decoding_mapper)
return
def _get_vertex_state(self, G, v):
# get the vertex states at v
Hv = []
for g in G:
if v >= g.vcount():
hv = self._get_zero_hidden()
else:
hv = g.vs[v]['H_forward']
Hv.append(hv)
Hv = torch.cat(Hv, 0)
return Hv
def _get_graph_state(self, G, decode=False):
# get the graph states
# when decoding, use the last generated vertex's state as the graph state
# when encoding, use the ending vertex state or unify the starting and ending vertex states
Hg = []
for g in G:
hg = g.vs[g.vcount() - 1]['H_forward']
if self.bidir and not decode: # decoding never uses backward propagation
hg_b = g.vs[0]['H_backward']
hg = torch.cat([hg, hg_b], 1)
Hg.append(hg)
Hg = torch.cat(Hg, 0)
if self.bidir and not decode:
Hg = self.hg_unify(Hg)
return Hg
def graph_encode(self, G):
# encode graphs G into latent vectors
if type(G) != list:
G = [G]
self._propagate_from(G, 0, self.grue_forward,
H0=self._get_zero_hidden(len(G)), reverse=False)
if self.bidir:
self._propagate_from(G, self.max_n - 1, self.grue_backward,
H0=self._get_zero_hidden(len(G)), reverse=True)
Hg = self._get_graph_state(G)
mu, logvar = self.enc_g_mu(Hg), self.enc_g_var(Hg)
return mu, logvar
def set_encode(self, X):
proto_batch = []
for x in X: # X.shape: [32, 400, 512]
cls_protos = self.intra_setpool(
x.view(-1, self.num_sample, 512)).squeeze(1)
proto_batch.append(
self.inter_setpool(cls_protos.unsqueeze(0)))
v = torch.stack(proto_batch).squeeze()
mu, logvar = self.fc1(v), self.fc2(v)
return mu, logvar
def reparameterize(self, mu, logvar, eps_scale=0.01):
# return z ~ N(mu, std)
if self.training:
std = logvar.mul(0.5).exp_()
eps = torch.randn_like(std) * eps_scale
return eps.mul(std).add_(mu)
else:
return mu
def _get_edge_score(self, Hvi, H, H0):
# compute scores for edges from vi based on Hvi, H (current vertex) and H0
# in most cases, H0 need not be explicitly included since Hvi and H contain its information
return self.sigmoid(self.add_edge(torch.cat([Hvi, H], -1)))
def graph_decode(self, z, stochastic=True):
# decode latent vectors z back to graphs
# if stochastic=True, stochastically sample each action from the predicted distribution;
# otherwise, select argmax action deterministically.
H0 = self.tanh(self.fc3(z)) # or relu activation, similar performance
G = [igraph.Graph(directed=True) for _ in range(len(z))]
for g in G:
g.add_vertex(type=self.START_TYPE)
self._update_v(G, 0, H0)
finished = [False] * len(G)
for idx in range(1, self.max_n):
# decide the type of the next added vertex
if idx == self.max_n - 1: # force the last node to be end_type
new_types = [self.END_TYPE] * len(G)
else:
Hg = self._get_graph_state(G, decode=True)
type_scores = self.add_vertex(Hg)
if stochastic:
type_probs = F.softmax(type_scores, 1
).cpu().detach().numpy()
new_types = [np.random.choice(range(self.nvt),
p=type_probs[i]) for i in range(len(G))]
else:
new_types = torch.argmax(type_scores, 1)
new_types = new_types.flatten().tolist()
for i, g in enumerate(G):
if not finished[i]:
g.add_vertex(type=new_types[i])
self._update_v(G, idx)
# decide connections
edge_scores = []
for vi in range(idx - 1, -1, -1):
Hvi = self._get_vertex_state(G, vi)
H = self._get_vertex_state(G, idx)
ei_score = self._get_edge_score(Hvi, H, H0)
if stochastic:
random_score = torch.rand_like(ei_score)
decisions = random_score < ei_score
else:
decisions = ei_score > 0.5
for i, g in enumerate(G):
if finished[i]:
continue
if new_types[i] == self.END_TYPE:
# if new node is end_type, connect it to all loose-end vertices (out_degree==0)
end_vertices = set([
v.index for v in g.vs.select(_outdegree_eq=0)
if v.index != g.vcount() - 1])
for v in end_vertices:
g.add_edge(v, g.vcount() - 1)
finished[i] = True
continue
if decisions[i, 0]:
g.add_edge(vi, g.vcount() - 1)
self._update_v(G, idx)
for g in G:
del g.vs['H_forward'] # delete hidden states to save GPU memory
return G
def loss(self, mu, logvar, G_true, beta=0.005):
# compute the loss of decoding mu and logvar to true graphs using teacher forcing
# ensure when computing the loss of step i, steps 0 to i-1 are correct
z = self.reparameterize(mu, logvar)
H0 = self.tanh(self.fc3(z)) # or relu activation, similar performance
G = [igraph.Graph(directed=True) for _ in range(len(z))]
for g in G:
g.add_vertex(type=self.START_TYPE)
self._update_v(G, 0, H0)
res = 0 # log likelihood
for v_true in range(1, self.max_n):
# calculate the likelihood of adding true types of nodes
# use start type to denote padding vertices since start type only appears for vertex 0
# and will never be a true type for later vertices, thus it's free to use
true_types = [g_true.vs[v_true]['type']
if v_true < g_true.vcount()
else self.START_TYPE for g_true in G_true]
Hg = self._get_graph_state(G, decode=True)
type_scores = self.add_vertex(Hg)
# vertex log likelihood
vll = self.logsoftmax1(type_scores)[
np.arange(len(G)), true_types].sum()
res = res + vll
for i, g in enumerate(G):
if true_types[i] != self.START_TYPE:
g.add_vertex(type=true_types[i])
self._update_v(G, v_true)
# calculate the likelihood of adding true edges
true_edges = []
for i, g_true in enumerate(G_true):
true_edges.append(g_true.get_adjlist(igraph.IN)[v_true]
if v_true < g_true.vcount() else [])
edge_scores = []
for vi in range(v_true - 1, -1, -1):
Hvi = self._get_vertex_state(G, vi)
H = self._get_vertex_state(G, v_true)
ei_score = self._get_edge_score(Hvi, H, H0)
edge_scores.append(ei_score)
for i, g in enumerate(G):
if vi in true_edges[i]:
g.add_edge(vi, v_true)
self._update_v(G, v_true)
edge_scores = torch.cat(edge_scores[::-1], 1)
ground_truth = torch.zeros_like(edge_scores)
idx1 = [i for i, x in enumerate(true_edges)
for _ in range(len(x))]
idx2 = [xx for x in true_edges for xx in x]
ground_truth[idx1, idx2] = 1.0
# edges log-likelihood
ell = - F.binary_cross_entropy(
edge_scores, ground_truth, reduction='sum')
res = res + ell
res = -res # convert likelihood to loss
kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return res + beta * kld, res, kld

View File

@@ -0,0 +1,37 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
file_name = 'ckpt_120.pt'
dir_path = 'results/generator/model'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
file_name = os.path.join(dir_path, file_name)
if not os.path.exists(file_name):
print(f"Downloading {file_name}\n")
download_file('https://www.dropbox.com/s/zss9yt034hen45h/ckpt_120.pt?dl=1', file_name)
print("Downloading done.\n")
else:
print(f"{file_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,38 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
file_name = 'collected_database.pt'
dir_path = 'data/generator/processed'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
file_name = os.path.join(dir_path, file_name)
if not os.path.exists(file_name):
print(f"Downloading generator {file_name}\n")
download_file('https://www.dropbox.com/s/zgip4aq0w2pkj49/generator_collected_database.pt?dl=1', file_name)
print("Downloading done.\n")
else:
print(f"{file_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,43 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
dir_path = 'data/pets'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
full_name = os.path.join(dir_path, 'test15.pth')
if not os.path.exists(full_name):
print(f"Downloading {full_name}\n")
download_file('https://www.dropbox.com/s/kzmrwyyk5iaugv0/test15.pth?dl=1', full_name)
print("Downloading done.\n")
else:
print(f"{full_name} has already been downloaded. Did not download twice.\n")
full_name = os.path.join(dir_path, 'train85.pth')
if not os.path.exists(full_name):
print(f"Downloading {full_name}\n")
download_file('https://www.dropbox.com/s/w7mikpztkamnw9s/train85.pth?dl=1', full_name)
print("Downloading done.\n")
else:
print(f"{full_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,35 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
file_name = 'ckpt_max_corr.pt'
dir_path = 'results/predictor/model'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
file_name = os.path.join(dir_path, file_name)
if not os.path.exists(file_name):
print(f"Downloading {file_name}\n")
download_file('https://www.dropbox.com/s/ycm4jaojgswp0zm/ckpt_max_corr.pt?dl=1', file_name)
print("Downloading done.\n")
else:
print(f"{file_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,38 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
file_name = 'collected_database.pt'
dir_path = 'data/predictor/processed'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
file_name = os.path.join(dir_path, file_name)
if not os.path.exists(file_name):
print(f"Downloading predictor {file_name}\n")
download_file('https://www.dropbox.com/s/ycm4jaojgswp0zm/ckpt_max_corr.pt?dl=1', file_name)
print("Downloading done.\n")
else:
print(f"{file_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,47 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
dir_path = 'data'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def get_preprocessed_data(file_name, url):
print(f"Downloading {file_name} datasets\n")
full_name = os.path.join(dir_path, file_name)
download_file(url, full_name)
print("Downloading done.\n")
for file_name, url in [
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
('aircraft100bylabel.pt', 'https://www.dropbox.com/s/nn6mlrk1jijg108/aircraft100bylabel.pt?dl=1'),
('cifar100bylabel.pt', 'https://www.dropbox.com/s/y0xahxgzj29kffk/cifar100bylabel.pt?dl=1'),
('cifar10bylabel.pt', 'https://www.dropbox.com/s/wt1pcwi991xyhwr/cifar10bylabel.pt?dl=1'),
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
('petsbylabel.pt', 'https://www.dropbox.com/s/mxh6qz3grhy7wcn/petsbylabel.pt?dl=1'),
('mnistbylabel.pt', 'https://www.dropbox.com/s/86rbuic7a7y34e4/mnistbylabel.pt?dl=1'),
('svhnbylabel.pt', 'https://www.dropbox.com/s/yywaelhrsl6egvd/svhnbylabel.pt?dl=1')
]:
get_preprocessed_data(file_name, url)

View File

@@ -0,0 +1,149 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import torch
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=False):
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
print(f'==> The number of tasks for meta-training: {len(dataset)}')
loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1,
collate_fn=collate_fn)
return loader
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
dataset = MetaTestDataset(data_path, data_name, num_class)
print(f'==> Meta-Test dataset {data_name}')
loader = DataLoader(dataset=dataset,
batch_size=100,
shuffle=False,
num_workers=1)
return loader
class MetaTrainDatabase(Dataset):
def __init__(self, data_path, num_sample, is_pred=False):
self.mode = 'train'
self.acc_norm = True
self.num_sample = num_sample
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
self.dpath = '{}/{}/processed/'.format(data_path, 'predictor' if is_pred else 'generator')
self.dname = f'database_219152_14.0K'
if not os.path.exists(self.dpath + f'{self.dname}_train.pt'):
raise ValueError('')
database = torch.load(self.dpath + f'{self.dname}.pt')
rand_idx = torch.randperm(len(database))
test_len = int(len(database) * 0.15)
idxlst = {'test': rand_idx[:test_len],
'valid': rand_idx[test_len:2 * test_len],
'train': rand_idx[2 * test_len:]}
for m in ['train', 'valid', 'test']:
acc, graph, cls, net, flops = [], [], [], [], []
for idx in tqdm(idxlst[m].tolist(), desc=f'data-{m}'):
acc.append(database[idx]['top1'])
net.append(database[idx]['net'])
cls.append(database[idx]['class'])
flops.append(database[idx]['flops'])
if m == 'train':
mean = torch.mean(torch.tensor(acc)).item()
std = torch.std(torch.tensor(acc)).item()
torch.save({'acc': acc,
'class': cls,
'net': net,
'flops': flops,
'mean': mean,
'std': std},
self.dpath + f'{self.dname}_{m}.pt')
self.set_mode(self.mode)
def set_mode(self, mode):
self.mode = mode
data = torch.load(self.dpath + f'{self.dname}_{self.mode}.pt')
self.acc = data['acc']
self.cls = data['class']
self.net = data['net']
self.flops = data['flops']
self.mean = data['mean']
self.std = data['std']
def __len__(self):
return len(self.acc)
def __getitem__(self, index):
data = []
classes = self.cls[index]
acc = self.acc[index]
graph = self.net[index]
for i, cls in enumerate(classes):
cx = self.x[cls.item()][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
if self.acc_norm:
acc = ((acc - self.mean) / self.std) / 100.0
else:
acc = acc / 100.0
return x, graph, torch.tensor(acc).view(1, 1)
class MetaTestDataset(Dataset):
def __init__(self, data_path, data_name, num_sample, num_class=None):
self.num_sample = num_sample
self.data_name = data_name
if data_name == 'aircraft':
data_name = 'aircraft100'
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'mnist': 10,
'aircraft100': 30,
'svhn': 10,
'pets': 37
}
# 'aircraft30': 30,
# 'aircraft100': 100,
if num_class is not None:
self.num_class = num_class
else:
self.num_class = num_class_dict[data_name]
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
def __len__(self):
return 1000000
def __getitem__(self, index):
data = []
classes = list(range(self.num_class))
for cls in classes:
cx = self.x[cls][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
return x
def collate_fn(batch):
# x = torch.stack([item[0] for item in batch])
# graph = [item[1] for item in batch]
# acc = torch.stack([item[2] for item in batch])
return batch

View File

@@ -0,0 +1,48 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
import random
import numpy as np
import torch
from parser import get_parser
from generator import Generator
from predictor import Predictor
def main():
args = get_parser()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
args.device = torch.device("cuda:0")
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
if not os.path.exists(args.model_path):
os.makedirs(args.model_path)
if args.model_name == 'generator':
g = Generator(args)
if args.test:
args.model_path = os.path.join(args.save_path, 'predictor', 'model')
hs = args.hs
args.hs = 512
p = Predictor(args)
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
args.hs = hs
g.meta_test(p)
else:
g.meta_train()
elif args.model_name == 'predictor':
p = Predictor(args)
p.meta_train()
else:
raise ValueError('You should select generator|predictor|train_arch')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,344 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import time
import igraph
import random
import numpy as np
import scipy.stats
import argparse
import torch
def load_graph_config(graph_data_name, nvt, data_path):
max_n=20
graph_config = {}
graph_config['num_vertex_type'] = nvt + 2 # original types + start/end types
graph_config['max_n'] = max_n + 2 # maximum number of nodes
graph_config['START_TYPE'] = 0 # predefined start vertex type
graph_config['END_TYPE'] = 1 # predefined end vertex type
return graph_config
type_dict = {'2-3-3': 0, '2-3-4': 1, '2-3-6': 2,
'2-5-3': 3, '2-5-4': 4, '2-5-6': 5,
'2-7-3': 6, '2-7-4': 7, '2-7-6': 8,
'3-3-3': 9, '3-3-4': 10, '3-3-6': 11,
'3-5-3': 12, '3-5-4': 13, '3-5-6': 14,
'3-7-3': 15, '3-7-4': 16, '3-7-6': 17,
'4-3-3': 18, '4-3-4': 19, '4-3-6': 20,
'4-5-3': 21, '4-5-4': 22, '4-5-6': 23,
'4-7-3': 24, '4-7-4': 25, '4-7-6': 26}
edge_dict = {2: (2, 3, 3), 3: (2, 3, 4), 4: (2, 3, 6),
5: (2, 5, 3), 6: (2, 5, 4), 7: (2, 5, 6),
8: (2, 7, 3), 9: (2, 7, 4), 10: (2, 7, 6),
11: (3, 3, 3), 12: (3, 3, 4), 13: (3, 3, 6),
14: (3, 5, 3), 15: (3, 5, 4), 16: (3, 5, 6),
17: (3, 7, 3), 18: (3, 7, 4), 19: (3, 7, 6),
20: (4, 3, 3), 21: (4, 3, 4), 22: (4, 3, 6),
23: (4, 5, 3), 24: (4, 5, 4), 25: (4, 5, 6),
26: (4, 7, 3), 27: (4, 7, 4), 28: (4, 7, 6)}
def decode_ofa_mbv3_to_igraph(matrix):
# 5 stages, 4 layers for each stage
# d: 2, 3, 4
# e: 3, 4, 6
# k: 3, 5, 7
# stage_depth to one hot
num_stage = 5
num_layer = 4
node_types = torch.zeros(num_stage * num_layer)
d = []
for i in range(num_stage):
for j in range(num_layer):
d.append(matrix['d'][i])
for i, (ks, e, d) in enumerate(zip(
matrix['ks'], matrix['e'], d)):
node_types[i] = type_dict[f'{d}-{ks}-{e}']
n = num_stage * num_layer
g = igraph.Graph(directed=True)
g.add_vertices(n + 2) # + in/out nodes
g.vs[0]['type'] = 0
for i, v in enumerate(node_types):
g.vs[i + 1]['type'] = v + 2 # in node: 0, out node: 1
g.add_edge(i, i + 1)
g.vs[n + 1]['type'] = 1
g.add_edge(n, n + 1)
return g, n + 2
def decode_ofa_mbv3_str_to_igraph(gen_str):
# 5 stages, 4 layers for each stage
# d: 2, 3, 4
# e: 3, 4, 6
# k: 3, 5, 7
# stage_depth to one hot
num_stage = 5
num_layer = 4
node_types = torch.zeros(num_stage * num_layer)
d = []
split_str = gen_str.split('_')
for i, s in enumerate(split_str):
if s == '0-0-0':
node_types[i] = random.randint(0, 26)
else:
node_types[i] = type_dict[s]
n = num_stage * num_layer
g = igraph.Graph(directed=True)
g.add_vertices(n + 2) # + in/out nodes
g.vs[0]['type'] = 0
for i, v in enumerate(node_types):
g.vs[i + 1]['type'] = v + 2 # in node: 0, out node: 1
g.add_edge(i, i + 1)
g.vs[n + 1]['type'] = 1
g.add_edge(n, n + 1)
return g
def is_valid_ofa_mbv3(g, START_TYPE=0, END_TYPE=1):
# first need to be a valid DAG computation graph
msg = ''
res = is_valid_DAG(g, START_TYPE, END_TYPE)
# in addition, node i must connect to node i+1
res = res and len(g.vs['type']) == 22
if not res:
return res
msg += '{} ({}) '.format(g.vs['type'][1:-1], len(g.vs['type']))
for i in range(5):
if ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 0:
for j in range(1, 4):
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 0
elif ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 1:
for j in range(1, 4):
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 1
elif ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 2:
for j in range(1, 4):
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 2
else:
raise ValueError
return res
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
res = g.is_dag()
n_start, n_end = 0, 0
for v in g.vs:
if v['type'] == START_TYPE:
n_start += 1
elif v['type'] == END_TYPE:
n_end += 1
if v.indegree() == 0 and v['type'] != START_TYPE:
return False
if v.outdegree() == 0 and v['type'] != END_TYPE:
return False
return res and n_start == 1 and n_end == 1
def decode_igraph_to_ofa_mbv3(g):
if not is_valid_ofa_mbv3(g, START_TYPE=0, END_TYPE=1):
return None
graph = {'ks': [], 'e': [], 'd': [4, 4, 4, 4, 4]}
for i, edge_type in enumerate(g.vs['type'][1:-1]):
edge_type = int(edge_type)
d, ks, e = edge_dict[edge_type]
graph['ks'].append(ks)
graph['e'].append(e)
graph['d'][i // 4] = d
return graph
class Accumulator():
def __init__(self, *args):
self.args = args
self.argdict = {}
for i, arg in enumerate(args):
self.argdict[arg] = i
self.sums = [0] * len(args)
self.cnt = 0
def accum(self, val):
val = [val] if type(val) is not list else val
val = [v for v in val if v is not None]
assert (len(val) == len(self.args))
for i in range(len(val)):
if torch.is_tensor(val[i]):
val[i] = val[i].item()
self.sums[i] += val[i]
self.cnt += 1
def clear(self):
self.sums = [0] * len(self.args)
self.cnt = 0
def get(self, arg, avg=True):
i = self.argdict.get(arg, -1)
assert (i is not -1)
if avg:
return self.sums[i] / (self.cnt + 1e-8)
else:
return self.sums[i]
def print_(self, header=None, time=None,
logfile=None, do_not_print=[], as_int=[],
avg=True):
msg = '' if header is None else header + ': '
if time is not None:
msg += ('(%.3f secs), ' % time)
args = [arg for arg in self.args if arg not in do_not_print]
arg = []
for arg in args:
val = self.sums[self.argdict[arg]]
if avg:
val /= (self.cnt + 1e-8)
if arg in as_int:
msg += ('%s %d, ' % (arg, int(val)))
else:
msg += ('%s %.4f, ' % (arg, val))
print(msg)
if logfile is not None:
logfile.write(msg + '\n')
logfile.flush()
def add_scalars(self, summary, header=None, tag_scalar=None,
step=None, avg=True, args=None):
for arg in self.args:
val = self.sums[self.argdict[arg]]
if avg:
val /= (self.cnt + 1e-8)
else:
val = val
tag = f'{header}/{arg}' if header is not None else arg
if tag_scalar is not None:
summary.add_scalars(main_tag=tag,
tag_scalar_dict={tag_scalar: val},
global_step=step)
else:
summary.add_scalar(tag=tag,
scalar_value=val,
global_step=step)
class Log:
def __init__(self, args, logf, summary=None):
self.args = args
self.logf = logf
self.summary = summary
self.stime = time.time()
self.ep_sttime = None
def print(self, logger, epoch, tag=None, avg=True):
if tag == 'train':
ct = time.time() - self.ep_sttime
tt = time.time() - self.stime
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
print(msg)
self.logf.write(msg + '\n')
logger.print_(header=tag, logfile=self.logf, avg=avg)
if self.summary is not None:
logger.add_scalars(
self.summary, header=tag, step=epoch, avg=avg)
logger.clear()
def print_args(self):
argdict = vars(self.args)
print(argdict)
for k, v in argdict.items():
self.logf.write(k + ': ' + str(v) + '\n')
self.logf.write('\n')
def set_time(self):
self.stime = time.time()
def save_time_log(self):
ct = time.time() - self.stime
msg = f'({ct:6.2f}s) meta-training phase done'
print(msg)
self.logf.write(msg + '\n')
def print_pred_log(self, loss, corr, tag, epoch=None, max_corr_dict=None):
if tag == 'train':
ct = time.time() - self.ep_sttime
tt = time.time() - self.stime
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
self.logf.write(msg + '\n');
print(msg);
self.logf.flush()
# msg = f'ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} '
# msg += f'time {time.time() - sttime:6.2f} '
if max_corr_dict is not None:
max_corr = max_corr_dict['corr']
max_loss = max_corr_dict['loss']
msg = f'{tag}: loss {loss:.6f} ({max_loss:.6f}) '
msg += f'corr {corr:.4f} ({max_corr:.4f})'
else:
msg = f'{tag}: loss {loss:.6f} corr {corr:.4f}'
self.logf.write(msg + '\n');
print(msg);
self.logf.flush()
def max_corr_log(self, max_corr_dict):
corr = max_corr_dict['corr']
loss = max_corr_dict['loss']
epoch = max_corr_dict['epoch']
msg = f'[epoch {epoch}] max correlation: {corr:.4f}, loss: {loss:.6f}'
self.logf.write(msg + '\n');
print(msg);
self.logf.flush()
def get_log(epoch, loss, y_pred, y, acc_std, acc_mean, tag='train'):
msg = f'[{tag}] Ep {epoch} loss {loss.item() / len(y):0.4f} '
msg += f'pacc {y_pred[0]:0.4f}'
msg += f'({y_pred[0] * 100.0 * acc_std + acc_mean:0.4f}) '
msg += f'acc {y[0]:0.4f}({y[0] * 100 * acc_std + acc_mean:0.4f})'
return msg
def load_model(model, model_path, load_epoch=None, load_max_pt=None):
if load_max_pt is not None:
ckpt_path = os.path.join(model_path, load_max_pt)
else:
ckpt_path = os.path.join(model_path, f'ckpt_{load_epoch}.pt')
print(f"==> load checkpoint for MetaD2A predictor: {ckpt_path} ...")
model.cpu()
model.load_state_dict(torch.load(ckpt_path))
def save_model(epoch, model, model_path, max_corr=None):
print("==> save current model...")
if max_corr is not None:
torch.save(model.cpu().state_dict(),
os.path.join(model_path, 'ckpt_max_corr.pt'))
else:
torch.save(model.cpu().state_dict(),
os.path.join(model_path, f'ckpt_{epoch}.pt'))
def mean_confidence_interval(data, confidence=0.95):
a = 1.0 * np.array(data)
n = len(a)
m, se = np.mean(a), scipy.stats.sem(a)
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
return m, h

View File

@@ -0,0 +1,5 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .imagenet import *

View File

@@ -0,0 +1,56 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import numpy as np
import torch
__all__ = ['DataProvider']
class DataProvider:
SUB_SEED = 937162211 # random seed for sampling subset
VALID_SEED = 2147483647 # random seed for the validation set
@staticmethod
def name():
""" Return name of the dataset """
raise NotImplementedError
@property
def data_shape(self):
""" Return shape as python list of one data entry """
raise NotImplementedError
@property
def n_classes(self):
""" Return `int` of num classes """
raise NotImplementedError
@property
def save_path(self):
""" local path to save the data """
raise NotImplementedError
@property
def data_url(self):
""" link to download the data """
raise NotImplementedError
@staticmethod
def random_sample_valid_set(train_size, valid_size):
assert train_size > valid_size
g = torch.Generator()
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
rand_indexes = torch.randperm(train_size, generator=g).tolist()
valid_indexes = rand_indexes[:valid_size]
train_indexes = rand_indexes[valid_size:]
return train_indexes, valid_indexes
@staticmethod
def labels_to_one_hot(n_classes, labels):
new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32)
new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape)
return new_labels

View File

@@ -0,0 +1,225 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import warnings
import os
import math
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from .base_provider import DataProvider
from ofa_local.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
__all__ = ['ImagenetDataProvider']
class ImagenetDataProvider(DataProvider):
DEFAULT_PATH = '/dataset/imagenet'
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = 'None' if distort_color is None else distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
from ofa.utils.my_dataloader import MyDataLoader
assert isinstance(self.image_size, list)
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size) # active resolution for test
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_dataset = self.train_dataset(self.build_train_transform())
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, True, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, True, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'imagenet'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 1000
@property
def save_path(self):
if self._save_path is None:
self._save_path = self.DEFAULT_PATH
if not os.path.exists(self._save_path):
self._save_path = os.path.expanduser('~/dataset/imagenet')
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
return datasets.ImageFolder(self.train_path, _transforms)
def test_dataset(self, _transforms):
return datasets.ImageFolder(self.valid_path, _transforms)
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'val')
@property
def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
# random_resize_crop -> random_horizontal_flip
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
# color augmentation (optional)
color_transform = None
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting BN running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, True, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,6 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .dynamic_layers import *
from .dynamic_op import *

View File

@@ -0,0 +1,632 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import torch
import torch.nn as nn
from collections import OrderedDict
from ofa_local.utils.layers import MBConvLayer, ConvLayer, IdentityLayer, set_layer_from_config
from ofa_local.utils.layers import ResNetBottleneckBlock, LinearLayer
from ofa_local.utils import MyModule, val2list, get_net_device, build_activation, make_divisible, SEModule, MyNetwork
from .dynamic_op import DynamicSeparableConv2d, DynamicConv2d, DynamicBatchNorm2d, DynamicSE, DynamicGroupNorm
from .dynamic_op import DynamicLinear
__all__ = [
'adjust_bn_according_to_idx', 'copy_bn',
'DynamicMBConvLayer', 'DynamicConvLayer', 'DynamicLinearLayer', 'DynamicResNetBottleneckBlock'
]
def adjust_bn_according_to_idx(bn, idx):
bn.weight.data = torch.index_select(bn.weight.data, 0, idx)
bn.bias.data = torch.index_select(bn.bias.data, 0, idx)
if type(bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
bn.running_mean.data = torch.index_select(bn.running_mean.data, 0, idx)
bn.running_var.data = torch.index_select(bn.running_var.data, 0, idx)
def copy_bn(target_bn, src_bn):
feature_dim = target_bn.num_channels if isinstance(target_bn, nn.GroupNorm) else target_bn.num_features
target_bn.weight.data.copy_(src_bn.weight.data[:feature_dim])
target_bn.bias.data.copy_(src_bn.bias.data[:feature_dim])
if type(src_bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
target_bn.running_mean.data.copy_(src_bn.running_mean.data[:feature_dim])
target_bn.running_var.data.copy_(src_bn.running_var.data[:feature_dim])
class DynamicLinearLayer(MyModule):
def __init__(self, in_features_list, out_features, bias=True, dropout_rate=0):
super(DynamicLinearLayer, self).__init__()
self.in_features_list = in_features_list
self.out_features = out_features
self.bias = bias
self.dropout_rate = dropout_rate
if self.dropout_rate > 0:
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
else:
self.dropout = None
self.linear = DynamicLinear(
max_in_features=max(self.in_features_list), max_out_features=self.out_features, bias=self.bias
)
def forward(self, x):
if self.dropout is not None:
x = self.dropout(x)
return self.linear(x)
@property
def module_str(self):
return 'DyLinear(%d, %d)' % (max(self.in_features_list), self.out_features)
@property
def config(self):
return {
'name': DynamicLinear.__name__,
'in_features_list': self.in_features_list,
'out_features': self.out_features,
'bias': self.bias,
'dropout_rate': self.dropout_rate,
}
@staticmethod
def build_from_config(config):
return DynamicLinearLayer(**config)
def get_active_subnet(self, in_features, preserve_weight=True):
sub_layer = LinearLayer(in_features, self.out_features, self.bias, dropout_rate=self.dropout_rate)
sub_layer = sub_layer.to(get_net_device(self))
if not preserve_weight:
return sub_layer
sub_layer.linear.weight.data.copy_(
self.linear.get_active_weight(self.out_features, in_features).data
)
if self.bias:
sub_layer.linear.bias.data.copy_(
self.linear.get_active_bias(self.out_features).data
)
return sub_layer
def get_active_subnet_config(self, in_features):
return {
'name': LinearLayer.__name__,
'in_features': in_features,
'out_features': self.out_features,
'bias': self.bias,
'dropout_rate': self.dropout_rate,
}
class DynamicMBConvLayer(MyModule):
def __init__(self, in_channel_list, out_channel_list,
kernel_size_list=3, expand_ratio_list=6, stride=1, act_func='relu6', use_se=False):
super(DynamicMBConvLayer, self).__init__()
self.in_channel_list = in_channel_list
self.out_channel_list = out_channel_list
self.kernel_size_list = val2list(kernel_size_list)
self.expand_ratio_list = val2list(expand_ratio_list)
self.stride = stride
self.act_func = act_func
self.use_se = use_se
# build modules
max_middle_channel = make_divisible(
round(max(self.in_channel_list) * max(self.expand_ratio_list)), MyNetwork.CHANNEL_DIVISIBLE)
if max(self.expand_ratio_list) == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func)),
]))
self.depth_conv = nn.Sequential(OrderedDict([
('conv', DynamicSeparableConv2d(max_middle_channel, self.kernel_size_list, self.stride)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func))
]))
if self.use_se:
self.depth_conv.add_module('se', DynamicSE(max_middle_channel))
self.point_linear = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max_middle_channel, max(self.out_channel_list))),
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
]))
self.active_kernel_size = max(self.kernel_size_list)
self.active_expand_ratio = max(self.expand_ratio_list)
self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
in_channel = x.size(1)
if self.inverted_bottleneck is not None:
self.inverted_bottleneck.conv.active_out_channel = \
make_divisible(round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE)
self.depth_conv.conv.active_kernel_size = self.active_kernel_size
self.point_linear.conv.active_out_channel = self.active_out_channel
if self.inverted_bottleneck is not None:
x = self.inverted_bottleneck(x)
x = self.depth_conv(x)
x = self.point_linear(x)
return x
@property
def module_str(self):
if self.use_se:
return 'SE(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
else:
return '(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
@property
def config(self):
return {
'name': DynamicMBConvLayer.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'kernel_size_list': self.kernel_size_list,
'expand_ratio_list': self.expand_ratio_list,
'stride': self.stride,
'act_func': self.act_func,
'use_se': self.use_se,
}
@staticmethod
def build_from_config(config):
return DynamicMBConvLayer(**config)
############################################################################################
@property
def in_channels(self):
return max(self.in_channel_list)
@property
def out_channels(self):
return max(self.out_channel_list)
def active_middle_channel(self, in_channel):
return make_divisible(round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE)
############################################################################################
def get_active_subnet(self, in_channel, preserve_weight=True):
# build the new layer
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
sub_layer = sub_layer.to(get_net_device(self))
if not preserve_weight:
return sub_layer
middle_channel = self.active_middle_channel(in_channel)
# copy weight from current layer
if sub_layer.inverted_bottleneck is not None:
sub_layer.inverted_bottleneck.conv.weight.data.copy_(
self.inverted_bottleneck.conv.get_active_filter(middle_channel, in_channel).data,
)
copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn)
sub_layer.depth_conv.conv.weight.data.copy_(
self.depth_conv.conv.get_active_filter(middle_channel, self.active_kernel_size).data
)
copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn)
if self.use_se:
se_mid = make_divisible(middle_channel // SEModule.REDUCTION, divisor=MyNetwork.CHANNEL_DIVISIBLE)
sub_layer.depth_conv.se.fc.reduce.weight.data.copy_(
self.depth_conv.se.get_active_reduce_weight(se_mid, middle_channel).data
)
sub_layer.depth_conv.se.fc.reduce.bias.data.copy_(
self.depth_conv.se.get_active_reduce_bias(se_mid).data
)
sub_layer.depth_conv.se.fc.expand.weight.data.copy_(
self.depth_conv.se.get_active_expand_weight(se_mid, middle_channel).data
)
sub_layer.depth_conv.se.fc.expand.bias.data.copy_(
self.depth_conv.se.get_active_expand_bias(middle_channel).data
)
sub_layer.point_linear.conv.weight.data.copy_(
self.point_linear.conv.get_active_filter(self.active_out_channel, middle_channel).data
)
copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn)
return sub_layer
def get_active_subnet_config(self, in_channel):
return {
'name': MBConvLayer.__name__,
'in_channels': in_channel,
'out_channels': self.active_out_channel,
'kernel_size': self.active_kernel_size,
'stride': self.stride,
'expand_ratio': self.active_expand_ratio,
'mid_channels': self.active_middle_channel(in_channel),
'act_func': self.act_func,
'use_se': self.use_se,
}
def re_organize_middle_weights(self, expand_ratio_stage=0):
importance = torch.sum(torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3))
if isinstance(self.depth_conv.bn, DynamicGroupNorm):
channel_per_group = self.depth_conv.bn.channel_per_group
importance_chunks = torch.split(importance, channel_per_group)
for chunk in importance_chunks:
chunk.data.fill_(torch.mean(chunk))
importance = torch.cat(importance_chunks, dim=0)
if expand_ratio_stage > 0:
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
sorted_expand_list.sort(reverse=True)
target_width_list = [
make_divisible(round(max(self.in_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
for expand in sorted_expand_list
]
right = len(importance)
base = - len(target_width_list) * 1e5
for i in range(expand_ratio_stage + 1):
left = target_width_list[i]
importance[left:right] += base
base += 1e5
right = left
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
self.point_linear.conv.conv.weight.data = torch.index_select(
self.point_linear.conv.conv.weight.data, 1, sorted_idx
)
adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx)
self.depth_conv.conv.conv.weight.data = torch.index_select(
self.depth_conv.conv.conv.weight.data, 0, sorted_idx
)
if self.use_se:
# se expand: output dim 0 reorganize
se_expand = self.depth_conv.se.fc.expand
se_expand.weight.data = torch.index_select(se_expand.weight.data, 0, sorted_idx)
se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx)
# se reduce: input dim 1 reorganize
se_reduce = self.depth_conv.se.fc.reduce
se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 1, sorted_idx)
# middle weight reorganize
se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3))
se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True)
se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx)
se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx)
se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx)
if self.inverted_bottleneck is not None:
adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx)
self.inverted_bottleneck.conv.conv.weight.data = torch.index_select(
self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx
)
return None
else:
return sorted_idx
class DynamicConvLayer(MyModule):
def __init__(self, in_channel_list, out_channel_list, kernel_size=3, stride=1, dilation=1,
use_bn=True, act_func='relu6'):
super(DynamicConvLayer, self).__init__()
self.in_channel_list = in_channel_list
self.out_channel_list = out_channel_list
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.use_bn = use_bn
self.act_func = act_func
self.conv = DynamicConv2d(
max_in_channels=max(self.in_channel_list), max_out_channels=max(self.out_channel_list),
kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation,
)
if self.use_bn:
self.bn = DynamicBatchNorm2d(max(self.out_channel_list))
self.act = build_activation(self.act_func)
self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
self.conv.active_out_channel = self.active_out_channel
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
x = self.act(x)
return x
@property
def module_str(self):
return 'DyConv(O%d, K%d, S%d)' % (self.active_out_channel, self.kernel_size, self.stride)
@property
def config(self):
return {
'name': DynamicConvLayer.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'kernel_size': self.kernel_size,
'stride': self.stride,
'dilation': self.dilation,
'use_bn': self.use_bn,
'act_func': self.act_func,
}
@staticmethod
def build_from_config(config):
return DynamicConvLayer(**config)
############################################################################################
@property
def in_channels(self):
return max(self.in_channel_list)
@property
def out_channels(self):
return max(self.out_channel_list)
############################################################################################
def get_active_subnet(self, in_channel, preserve_weight=True):
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
sub_layer = sub_layer.to(get_net_device(self))
if not preserve_weight:
return sub_layer
sub_layer.conv.weight.data.copy_(self.conv.get_active_filter(self.active_out_channel, in_channel).data)
if self.use_bn:
copy_bn(sub_layer.bn, self.bn.bn)
return sub_layer
def get_active_subnet_config(self, in_channel):
return {
'name': ConvLayer.__name__,
'in_channels': in_channel,
'out_channels': self.active_out_channel,
'kernel_size': self.kernel_size,
'stride': self.stride,
'dilation': self.dilation,
'use_bn': self.use_bn,
'act_func': self.act_func,
}
class DynamicResNetBottleneckBlock(MyModule):
def __init__(self, in_channel_list, out_channel_list, expand_ratio_list=0.25,
kernel_size=3, stride=1, act_func='relu', downsample_mode='avgpool_conv'):
super(DynamicResNetBottleneckBlock, self).__init__()
self.in_channel_list = in_channel_list
self.out_channel_list = out_channel_list
self.expand_ratio_list = val2list(expand_ratio_list)
self.kernel_size = kernel_size
self.stride = stride
self.act_func = act_func
self.downsample_mode = downsample_mode
# build modules
max_middle_channel = make_divisible(
round(max(self.out_channel_list) * max(self.expand_ratio_list)), MyNetwork.CHANNEL_DIVISIBLE)
self.conv1 = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func, inplace=True)),
]))
self.conv2 = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max_middle_channel, max_middle_channel, kernel_size, stride)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func, inplace=True))
]))
self.conv3 = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max_middle_channel, max(self.out_channel_list))),
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
]))
if self.stride == 1 and self.in_channel_list == self.out_channel_list:
self.downsample = IdentityLayer(max(self.in_channel_list), max(self.out_channel_list))
elif self.downsample_mode == 'conv':
self.downsample = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max(self.in_channel_list), max(self.out_channel_list), stride=stride)),
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
]))
elif self.downsample_mode == 'avgpool_conv':
self.downsample = nn.Sequential(OrderedDict([
('avg_pool', nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0, ceil_mode=True)),
('conv', DynamicConv2d(max(self.in_channel_list), max(self.out_channel_list))),
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
]))
else:
raise NotImplementedError
self.final_act = build_activation(self.act_func, inplace=True)
self.active_expand_ratio = max(self.expand_ratio_list)
self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
feature_dim = self.active_middle_channels
self.conv1.conv.active_out_channel = feature_dim
self.conv2.conv.active_out_channel = feature_dim
self.conv3.conv.active_out_channel = self.active_out_channel
if not isinstance(self.downsample, IdentityLayer):
self.downsample.conv.active_out_channel = self.active_out_channel
residual = self.downsample(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x + residual
x = self.final_act(x)
return x
@property
def module_str(self):
return '(%s, %s)' % (
'%dx%d_BottleneckConv_in->%d->%d_S%d' % (
self.kernel_size, self.kernel_size, self.active_middle_channels, self.active_out_channel, self.stride
),
'Identity' if isinstance(self.downsample, IdentityLayer) else self.downsample_mode,
)
@property
def config(self):
return {
'name': DynamicResNetBottleneckBlock.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'expand_ratio_list': self.expand_ratio_list,
'kernel_size': self.kernel_size,
'stride': self.stride,
'act_func': self.act_func,
'downsample_mode': self.downsample_mode,
}
@staticmethod
def build_from_config(config):
return DynamicResNetBottleneckBlock(**config)
############################################################################################
@property
def in_channels(self):
return max(self.in_channel_list)
@property
def out_channels(self):
return max(self.out_channel_list)
@property
def active_middle_channels(self):
feature_dim = round(self.active_out_channel * self.active_expand_ratio)
feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE)
return feature_dim
############################################################################################
def get_active_subnet(self, in_channel, preserve_weight=True):
# build the new layer
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
sub_layer = sub_layer.to(get_net_device(self))
if not preserve_weight:
return sub_layer
# copy weight from current layer
sub_layer.conv1.conv.weight.data.copy_(
self.conv1.conv.get_active_filter(self.active_middle_channels, in_channel).data)
copy_bn(sub_layer.conv1.bn, self.conv1.bn.bn)
sub_layer.conv2.conv.weight.data.copy_(
self.conv2.conv.get_active_filter(self.active_middle_channels, self.active_middle_channels).data)
copy_bn(sub_layer.conv2.bn, self.conv2.bn.bn)
sub_layer.conv3.conv.weight.data.copy_(
self.conv3.conv.get_active_filter(self.active_out_channel, self.active_middle_channels).data)
copy_bn(sub_layer.conv3.bn, self.conv3.bn.bn)
if not isinstance(self.downsample, IdentityLayer):
sub_layer.downsample.conv.weight.data.copy_(
self.downsample.conv.get_active_filter(self.active_out_channel, in_channel).data)
copy_bn(sub_layer.downsample.bn, self.downsample.bn.bn)
return sub_layer
def get_active_subnet_config(self, in_channel):
return {
'name': ResNetBottleneckBlock.__name__,
'in_channels': in_channel,
'out_channels': self.active_out_channel,
'kernel_size': self.kernel_size,
'stride': self.stride,
'expand_ratio': self.active_expand_ratio,
'mid_channels': self.active_middle_channels,
'act_func': self.act_func,
'groups': 1,
'downsample_mode': self.downsample_mode,
}
def re_organize_middle_weights(self, expand_ratio_stage=0):
# conv3 -> conv2
importance = torch.sum(torch.abs(self.conv3.conv.conv.weight.data), dim=(0, 2, 3))
if isinstance(self.conv2.bn, DynamicGroupNorm):
channel_per_group = self.conv2.bn.channel_per_group
importance_chunks = torch.split(importance, channel_per_group)
for chunk in importance_chunks:
chunk.data.fill_(torch.mean(chunk))
importance = torch.cat(importance_chunks, dim=0)
if expand_ratio_stage > 0:
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
sorted_expand_list.sort(reverse=True)
target_width_list = [
make_divisible(round(max(self.out_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
for expand in sorted_expand_list
]
right = len(importance)
base = - len(target_width_list) * 1e5
for i in range(expand_ratio_stage + 1):
left = target_width_list[i]
importance[left:right] += base
base += 1e5
right = left
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
self.conv3.conv.conv.weight.data = torch.index_select(self.conv3.conv.conv.weight.data, 1, sorted_idx)
adjust_bn_according_to_idx(self.conv2.bn.bn, sorted_idx)
self.conv2.conv.conv.weight.data = torch.index_select(self.conv2.conv.conv.weight.data, 0, sorted_idx)
# conv2 -> conv1
importance = torch.sum(torch.abs(self.conv2.conv.conv.weight.data), dim=(0, 2, 3))
if isinstance(self.conv1.bn, DynamicGroupNorm):
channel_per_group = self.conv1.bn.channel_per_group
importance_chunks = torch.split(importance, channel_per_group)
for chunk in importance_chunks:
chunk.data.fill_(torch.mean(chunk))
importance = torch.cat(importance_chunks, dim=0)
if expand_ratio_stage > 0:
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
sorted_expand_list.sort(reverse=True)
target_width_list = [
make_divisible(round(max(self.out_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
for expand in sorted_expand_list
]
right = len(importance)
base = - len(target_width_list) * 1e5
for i in range(expand_ratio_stage + 1):
left = target_width_list[i]
importance[left:right] += base
base += 1e5
right = left
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
self.conv2.conv.conv.weight.data = torch.index_select(self.conv2.conv.conv.weight.data, 1, sorted_idx)
adjust_bn_according_to_idx(self.conv1.bn.bn, sorted_idx)
self.conv1.conv.conv.weight.data = torch.index_select(self.conv1.conv.conv.weight.data, 0, sorted_idx)
return None

View File

@@ -0,0 +1,314 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch.nn.functional as F
import torch.nn as nn
import torch
from torch.nn.parameter import Parameter
from ofa_local.utils import get_same_padding, sub_filter_start_end, make_divisible, SEModule, MyNetwork, MyConv2d
__all__ = ['DynamicSeparableConv2d', 'DynamicConv2d', 'DynamicGroupConv2d',
'DynamicBatchNorm2d', 'DynamicGroupNorm', 'DynamicSE', 'DynamicLinear']
class DynamicSeparableConv2d(nn.Module):
KERNEL_TRANSFORM_MODE = 1 # None or 1
def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1):
super(DynamicSeparableConv2d, self).__init__()
self.max_in_channels = max_in_channels
self.kernel_size_list = kernel_size_list
self.stride = stride
self.dilation = dilation
self.conv = nn.Conv2d(
self.max_in_channels, self.max_in_channels, max(self.kernel_size_list), self.stride,
groups=self.max_in_channels, bias=False,
)
self._ks_set = list(set(self.kernel_size_list))
self._ks_set.sort() # e.g., [3, 5, 7]
if self.KERNEL_TRANSFORM_MODE is not None:
# register scaling parameters
# 7to5_matrix, 5to3_matrix
scale_params = {}
for i in range(len(self._ks_set) - 1):
ks_small = self._ks_set[i]
ks_larger = self._ks_set[i + 1]
param_name = '%dto%d' % (ks_larger, ks_small)
# noinspection PyArgumentList
scale_params['%s_matrix' % param_name] = Parameter(torch.eye(ks_small ** 2))
for name, param in scale_params.items():
self.register_parameter(name, param)
self.active_kernel_size = max(self.kernel_size_list)
def get_active_filter(self, in_channel, kernel_size):
out_channel = in_channel
max_kernel_size = max(self.kernel_size_list)
start, end = sub_filter_start_end(max_kernel_size, kernel_size)
filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
start_filter = self.conv.weight[:out_channel, :in_channel, :, :] # start with max kernel
for i in range(len(self._ks_set) - 1, 0, -1):
src_ks = self._ks_set[i]
if src_ks <= kernel_size:
break
target_ks = self._ks_set[i - 1]
start, end = sub_filter_start_end(src_ks, target_ks)
_input_filter = start_filter[:, :, start:end, start:end]
_input_filter = _input_filter.contiguous()
_input_filter = _input_filter.view(_input_filter.size(0), _input_filter.size(1), -1)
_input_filter = _input_filter.view(-1, _input_filter.size(2))
_input_filter = F.linear(
_input_filter, self.__getattr__('%dto%d_matrix' % (src_ks, target_ks)),
)
_input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks ** 2)
_input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks, target_ks)
start_filter = _input_filter
filters = start_filter
return filters
def forward(self, x, kernel_size=None):
if kernel_size is None:
kernel_size = self.active_kernel_size
in_channel = x.size(1)
filters = self.get_active_filter(in_channel, kernel_size).contiguous()
padding = get_same_padding(kernel_size)
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
y = F.conv2d(
x, filters, None, self.stride, padding, self.dilation, in_channel
)
return y
class DynamicConv2d(nn.Module):
def __init__(self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1):
super(DynamicConv2d, self).__init__()
self.max_in_channels = max_in_channels
self.max_out_channels = max_out_channels
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.conv = nn.Conv2d(
self.max_in_channels, self.max_out_channels, self.kernel_size, stride=self.stride, bias=False,
)
self.active_out_channel = self.max_out_channels
def get_active_filter(self, out_channel, in_channel):
return self.conv.weight[:out_channel, :in_channel, :, :]
def forward(self, x, out_channel=None):
if out_channel is None:
out_channel = self.active_out_channel
in_channel = x.size(1)
filters = self.get_active_filter(out_channel, in_channel).contiguous()
padding = get_same_padding(self.kernel_size)
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, 1)
return y
class DynamicGroupConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size_list, groups_list, stride=1, dilation=1):
super(DynamicGroupConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size_list = kernel_size_list
self.groups_list = groups_list
self.stride = stride
self.dilation = dilation
self.conv = nn.Conv2d(
self.in_channels, self.out_channels, max(self.kernel_size_list), self.stride,
groups=min(self.groups_list), bias=False,
)
self.active_kernel_size = max(self.kernel_size_list)
self.active_groups = min(self.groups_list)
def get_active_filter(self, kernel_size, groups):
start, end = sub_filter_start_end(max(self.kernel_size_list), kernel_size)
filters = self.conv.weight[:, :, start:end, start:end]
sub_filters = torch.chunk(filters, groups, dim=0)
sub_in_channels = self.in_channels // groups
sub_ratio = filters.size(1) // sub_in_channels
filter_crops = []
for i, sub_filter in enumerate(sub_filters):
part_id = i % sub_ratio
start = part_id * sub_in_channels
filter_crops.append(sub_filter[:, start:start + sub_in_channels, :, :])
filters = torch.cat(filter_crops, dim=0)
return filters
def forward(self, x, kernel_size=None, groups=None):
if kernel_size is None:
kernel_size = self.active_kernel_size
if groups is None:
groups = self.active_groups
filters = self.get_active_filter(kernel_size, groups).contiguous()
padding = get_same_padding(kernel_size)
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
y = F.conv2d(
x, filters, None, self.stride, padding, self.dilation, groups,
)
return y
class DynamicBatchNorm2d(nn.Module):
SET_RUNNING_STATISTICS = False
def __init__(self, max_feature_dim):
super(DynamicBatchNorm2d, self).__init__()
self.max_feature_dim = max_feature_dim
self.bn = nn.BatchNorm2d(self.max_feature_dim)
@staticmethod
def bn_forward(x, bn: nn.BatchNorm2d, feature_dim):
if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS:
return bn(x)
else:
exponential_average_factor = 0.0
if bn.training and bn.track_running_stats:
if bn.num_batches_tracked is not None:
bn.num_batches_tracked += 1
if bn.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(bn.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = bn.momentum
return F.batch_norm(
x, bn.running_mean[:feature_dim], bn.running_var[:feature_dim], bn.weight[:feature_dim],
bn.bias[:feature_dim], bn.training or not bn.track_running_stats,
exponential_average_factor, bn.eps,
)
def forward(self, x):
feature_dim = x.size(1)
y = self.bn_forward(x, self.bn, feature_dim)
return y
class DynamicGroupNorm(nn.GroupNorm):
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, channel_per_group=None):
super(DynamicGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
self.channel_per_group = channel_per_group
def forward(self, x):
n_channels = x.size(1)
n_groups = n_channels // self.channel_per_group
return F.group_norm(x, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps)
@property
def bn(self):
return self
class DynamicSE(SEModule):
def __init__(self, max_channel):
super(DynamicSE, self).__init__(max_channel)
def get_active_reduce_weight(self, num_mid, in_channel, groups=None):
if groups is None or groups == 1:
return self.fc.reduce.weight[:num_mid, :in_channel, :, :]
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_filters = torch.chunk(self.fc.reduce.weight[:num_mid, :, :, :], groups, dim=1)
return torch.cat([
sub_filter[:, :sub_in_channels, :, :] for sub_filter in sub_filters
], dim=1)
def get_active_reduce_bias(self, num_mid):
return self.fc.reduce.bias[:num_mid] if self.fc.reduce.bias is not None else None
def get_active_expand_weight(self, num_mid, in_channel, groups=None):
if groups is None or groups == 1:
return self.fc.expand.weight[:in_channel, :num_mid, :, :]
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_filters = torch.chunk(self.fc.expand.weight[:, :num_mid, :, :], groups, dim=0)
return torch.cat([
sub_filter[:sub_in_channels, :, :, :] for sub_filter in sub_filters
], dim=0)
def get_active_expand_bias(self, in_channel, groups=None):
if groups is None or groups == 1:
return self.fc.expand.bias[:in_channel] if self.fc.expand.bias is not None else None
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_bias_list = torch.chunk(self.fc.expand.bias, groups, dim=0)
return torch.cat([
sub_bias[:sub_in_channels] for sub_bias in sub_bias_list
], dim=0)
def forward(self, x, groups=None):
in_channel = x.size(1)
num_mid = make_divisible(in_channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE)
y = x.mean(3, keepdim=True).mean(2, keepdim=True)
# reduce
reduce_filter = self.get_active_reduce_weight(num_mid, in_channel, groups=groups).contiguous()
reduce_bias = self.get_active_reduce_bias(num_mid)
y = F.conv2d(y, reduce_filter, reduce_bias, 1, 0, 1, 1)
# relu
y = self.fc.relu(y)
# expand
expand_filter = self.get_active_expand_weight(num_mid, in_channel, groups=groups).contiguous()
expand_bias = self.get_active_expand_bias(in_channel, groups=groups)
y = F.conv2d(y, expand_filter, expand_bias, 1, 0, 1, 1)
# hard sigmoid
y = self.fc.h_sigmoid(y)
return x * y
class DynamicLinear(nn.Module):
def __init__(self, max_in_features, max_out_features, bias=True):
super(DynamicLinear, self).__init__()
self.max_in_features = max_in_features
self.max_out_features = max_out_features
self.bias = bias
self.linear = nn.Linear(self.max_in_features, self.max_out_features, self.bias)
self.active_out_features = self.max_out_features
def get_active_weight(self, out_features, in_features):
return self.linear.weight[:out_features, :in_features]
def get_active_bias(self, out_features):
return self.linear.bias[:out_features] if self.bias else None
def forward(self, x, out_features=None):
if out_features is None:
out_features = self.active_out_features
in_features = x.size(1)
weight = self.get_active_weight(out_features, in_features).contiguous()
bias = self.get_active_bias(out_features)
y = F.linear(x, weight, bias)
return y

View File

@@ -0,0 +1,7 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .ofa_proxyless import OFAProxylessNASNets
from .ofa_mbv3 import OFAMobileNetV3
from .ofa_resnets import OFAResNets

View File

@@ -0,0 +1,336 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import random
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicMBConvLayer
from ofa_local.utils.layers import ConvLayer, IdentityLayer, LinearLayer, MBConvLayer, ResidualBlock
from ofa_local.imagenet_classification.networks import MobileNetV3
from ofa_local.utils import make_divisible, val2list, MyNetwork
from ofa_local.utils.layers import set_layer_from_config
import gin
__all__ = ['OFAMobileNetV3']
@gin.configurable
class OFAMobileNetV3(MobileNetV3):
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-5), dropout_rate=0.1, base_stage_width=None, width_mult=1.0,
ks_list=3, expand_ratio_list=6, depth_list=4, dropblock=False, block_size=0):
self.width_mult = width_mult
self.ks_list = val2list(ks_list, 1)
self.expand_ratio_list = val2list(expand_ratio_list, 1)
self.depth_list = val2list(depth_list, 1)
self.ks_list.sort()
self.expand_ratio_list.sort()
self.depth_list.sort()
base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280]
final_expand_width = make_divisible(base_stage_width[-2] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
last_channel = make_divisible(base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
stride_stages = [1, 2, 2, 2, 1, 2]
act_stages = ['relu', 'relu', 'relu', 'h_swish', 'h_swish', 'h_swish']
se_stages = [False, False, True, False, True, True]
n_block_list = [1] + [max(self.depth_list)] * 5
width_list = []
for base_width in base_stage_width[:-2]:
width = make_divisible(base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
width_list.append(width)
input_channel, first_block_dim = width_list[0], width_list[1]
# first conv layer
first_conv = ConvLayer(3, input_channel, kernel_size=3, stride=2, act_func='h_swish')
first_block_conv = MBConvLayer(
in_channels=input_channel, out_channels=first_block_dim, kernel_size=3, stride=stride_stages[0],
expand_ratio=1, act_func=act_stages[0], use_se=se_stages[0],
)
first_block = ResidualBlock(
first_block_conv,
IdentityLayer(first_block_dim, first_block_dim) if input_channel == first_block_dim else None,
dropout_rate, dropblock, block_size
)
# inverted residual blocks
self.block_group_info = []
blocks = [first_block]
_block_index = 1
feature_dim = first_block_dim
for width, n_block, s, act_func, use_se in zip(width_list[2:], n_block_list[1:],
stride_stages[1:], act_stages[1:], se_stages[1:]):
self.block_group_info.append([_block_index + i for i in range(n_block)])
_block_index += n_block
output_channel = width
for i in range(n_block):
if i == 0:
stride = s
else:
stride = 1
mobile_inverted_conv = DynamicMBConvLayer(
in_channel_list=val2list(feature_dim), out_channel_list=val2list(output_channel),
kernel_size_list=ks_list, expand_ratio_list=expand_ratio_list,
stride=stride, act_func=act_func, use_se=use_se,
)
if stride == 1 and feature_dim == output_channel:
shortcut = IdentityLayer(feature_dim, feature_dim)
else:
shortcut = None
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut,
dropout_rate, dropblock, block_size))
feature_dim = output_channel
# final expand layer, feature mix layer & classifier
final_expand_layer = ConvLayer(feature_dim, final_expand_width, kernel_size=1, act_func='h_swish')
feature_mix_layer = ConvLayer(
final_expand_width, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
)
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
super(OFAMobileNetV3, self).__init__(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
# set bn param
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
# runtime_depth
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
""" MyNetwork required methods """
@staticmethod
def name():
return 'OFAMobileNetV3'
def forward(self, x):
# first conv
x = self.first_conv(x)
# first block
x = self.blocks[0](x)
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
x = self.blocks[idx](x)
x = self.final_expand_layer(x)
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = self.feature_mix_layer(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
_str += self.blocks[0].module_str + '\n'
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
_str += self.blocks[idx].module_str + '\n'
_str += self.final_expand_layer.module_str + '\n'
_str += self.feature_mix_layer.module_str + '\n'
_str += self.classifier.module_str + '\n'
return _str
@property
def config(self):
return {
'name': OFAMobileNetV3.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'final_expand_layer': self.final_expand_layer.config,
'feature_mix_layer': self.feature_mix_layer.config,
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
raise ValueError('do not support this function')
@property
def grouped_block_index(self):
return self.block_group_info
def load_state_dict(self, state_dict, **kwargs):
model_dict = self.state_dict()
for key in state_dict:
if '.mobile_inverted_conv.' in key:
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
else:
new_key = key
if new_key in model_dict:
pass
elif '.bn.bn.' in new_key:
new_key = new_key.replace('.bn.bn.', '.bn.')
elif '.conv.conv.weight' in new_key:
new_key = new_key.replace('.conv.conv.weight', '.conv.weight')
elif '.linear.linear.' in new_key:
new_key = new_key.replace('.linear.linear.', '.linear.')
##############################################################################
elif '.linear.' in new_key:
new_key = new_key.replace('.linear.', '.linear.linear.')
elif 'bn.' in new_key:
new_key = new_key.replace('bn.', 'bn.bn.')
elif 'conv.weight' in new_key:
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
else:
raise ValueError(new_key)
assert new_key in model_dict, '%s' % new_key
model_dict[new_key] = state_dict[key]
super(OFAMobileNetV3, self).load_state_dict(model_dict)
""" set, sample and get active sub-networks """
def set_max_net(self):
self.set_active_subnet(ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list))
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
ks = val2list(ks, len(self.blocks) - 1)
expand_ratio = val2list(e, len(self.blocks) - 1)
depth = val2list(d, len(self.block_group_info))
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
if k is not None:
block.conv.active_kernel_size = k
if e is not None:
block.conv.active_expand_ratio = e
for i, d in enumerate(depth):
if d is not None:
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
def set_constraint(self, include_list, constraint_type='depth'):
if constraint_type == 'depth':
self.__dict__['_depth_include_list'] = include_list.copy()
elif constraint_type == 'expand_ratio':
self.__dict__['_expand_include_list'] = include_list.copy()
elif constraint_type == 'kernel_size':
self.__dict__['_ks_include_list'] = include_list.copy()
else:
raise NotImplementedError
def clear_constraint(self):
self.__dict__['_depth_include_list'] = None
self.__dict__['_expand_include_list'] = None
self.__dict__['_ks_include_list'] = None
def sample_active_subnet(self):
ks_candidates = self.ks_list if self.__dict__.get('_ks_include_list', None) is None \
else self.__dict__['_ks_include_list']
expand_candidates = self.expand_ratio_list if self.__dict__.get('_expand_include_list', None) is None \
else self.__dict__['_expand_include_list']
depth_candidates = self.depth_list if self.__dict__.get('_depth_include_list', None) is None else \
self.__dict__['_depth_include_list']
# sample kernel size
ks_setting = []
if not isinstance(ks_candidates[0], list):
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
for k_set in ks_candidates:
k = random.choice(k_set)
ks_setting.append(k)
# sample expand ratio
expand_setting = []
if not isinstance(expand_candidates[0], list):
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
for e_set in expand_candidates:
e = random.choice(e_set)
expand_setting.append(e)
# sample depth
depth_setting = []
if not isinstance(depth_candidates[0], list):
depth_candidates = [depth_candidates for _ in range(len(self.block_group_info))]
for d_set in depth_candidates:
d = random.choice(d_set)
depth_setting.append(d)
import pdb; pdb.set_trace()
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
return {
'ks': ks_setting,
'e': expand_setting,
'd': depth_setting,
}
def get_active_subnet(self, preserve_weight=True):
first_conv = copy.deepcopy(self.first_conv)
blocks = [copy.deepcopy(self.blocks[0])]
final_expand_layer = copy.deepcopy(self.final_expand_layer)
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
classifier = copy.deepcopy(self.classifier)
input_channel = blocks[0].conv.out_channels
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(ResidualBlock(
self.blocks[idx].conv.get_active_subnet(input_channel, preserve_weight),
copy.deepcopy(self.blocks[idx].shortcut),
copy.deepcopy(self.blocks[idx].dropout_rate),
copy.deepcopy(self.blocks[idx].dropblock),
copy.deepcopy(self.blocks[idx].block_size),
))
input_channel = stage_blocks[-1].conv.out_channels
blocks += stage_blocks
_subnet = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
_subnet.set_bn_param(**self.get_bn_param())
return _subnet
def get_active_net_config(self):
# first conv
first_conv_config = self.first_conv.config
first_block_config = self.blocks[0].config
final_expand_config = self.final_expand_layer.config
feature_mix_layer_config = self.feature_mix_layer.config
classifier_config = self.classifier.config
block_config_list = [first_block_config]
input_channel = first_block_config['conv']['out_channels']
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append({
'name': ResidualBlock.__name__,
'conv': self.blocks[idx].conv.get_active_subnet_config(input_channel),
'shortcut': self.blocks[idx].shortcut.config if self.blocks[idx].shortcut is not None else None,
})
input_channel = self.blocks[idx].conv.active_out_channel
block_config_list += stage_blocks
return {
'name': MobileNetV3.__name__,
'bn': self.get_bn_param(),
'first_conv': first_conv_config,
'blocks': block_config_list,
'final_expand_layer': final_expand_config,
'feature_mix_layer': feature_mix_layer_config,
'classifier': classifier_config,
}
""" Width Related Methods """
def re_organize_middle_weights(self, expand_ratio_stage=0):
for block in self.blocks[1:]:
block.conv.re_organize_middle_weights(expand_ratio_stage)

View File

@@ -0,0 +1,331 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import random
from ofa_local.utils import make_divisible, val2list, MyNetwork
from ofa_local.imagenet_classification.elastic_nn.modules import DynamicMBConvLayer
from ofa_local.utils.layers import ConvLayer, IdentityLayer, LinearLayer, MBConvLayer, ResidualBlock
from ofa_local.imagenet_classification.networks.proxyless_nets import ProxylessNASNets
__all__ = ['OFAProxylessNASNets']
class OFAProxylessNASNets(ProxylessNASNets):
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-3), dropout_rate=0.1, base_stage_width=None, width_mult=1.0,
ks_list=3, expand_ratio_list=6, depth_list=4):
self.width_mult = width_mult
self.ks_list = val2list(ks_list, 1)
self.expand_ratio_list = val2list(expand_ratio_list, 1)
self.depth_list = val2list(depth_list, 1)
self.ks_list.sort()
self.expand_ratio_list.sort()
self.depth_list.sort()
if base_stage_width == 'google':
# MobileNetV2 Stage Width
base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280]
else:
# ProxylessNAS Stage Width
base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280]
input_channel = make_divisible(base_stage_width[0] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
first_block_width = make_divisible(base_stage_width[1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
last_channel = make_divisible(base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
# first conv layer
first_conv = ConvLayer(
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act'
)
# first block
first_block_conv = MBConvLayer(
in_channels=input_channel, out_channels=first_block_width, kernel_size=3, stride=1,
expand_ratio=1, act_func='relu6',
)
first_block = ResidualBlock(first_block_conv, None)
input_channel = first_block_width
# inverted residual blocks
self.block_group_info = []
blocks = [first_block]
_block_index = 1
stride_stages = [2, 2, 2, 1, 2, 1]
n_block_list = [max(self.depth_list)] * 5 + [1]
width_list = []
for base_width in base_stage_width[2:-1]:
width = make_divisible(base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
width_list.append(width)
for width, n_block, s in zip(width_list, n_block_list, stride_stages):
self.block_group_info.append([_block_index + i for i in range(n_block)])
_block_index += n_block
output_channel = width
for i in range(n_block):
if i == 0:
stride = s
else:
stride = 1
mobile_inverted_conv = DynamicMBConvLayer(
in_channel_list=val2list(input_channel, 1), out_channel_list=val2list(output_channel, 1),
kernel_size_list=ks_list, expand_ratio_list=expand_ratio_list, stride=stride, act_func='relu6',
)
if stride == 1 and input_channel == output_channel:
shortcut = IdentityLayer(input_channel, input_channel)
else:
shortcut = None
mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut)
blocks.append(mb_inverted_block)
input_channel = output_channel
# 1x1_conv before global average pooling
feature_mix_layer = ConvLayer(
input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6',
)
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
super(OFAProxylessNASNets, self).__init__(first_conv, blocks, feature_mix_layer, classifier)
# set bn param
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
# runtime_depth
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
""" MyNetwork required methods """
@staticmethod
def name():
return 'OFAProxylessNASNets'
def forward(self, x):
# first conv
x = self.first_conv(x)
# first block
x = self.blocks[0](x)
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
x = self.blocks[idx](x)
# feature_mix_layer
x = self.feature_mix_layer(x)
x = x.mean(3).mean(2)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
_str += self.blocks[0].module_str + '\n'
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
_str += self.blocks[idx].module_str + '\n'
_str += self.feature_mix_layer.module_str + '\n'
_str += self.classifier.module_str + '\n'
return _str
@property
def config(self):
return {
'name': OFAProxylessNASNets.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'feature_mix_layer': None if self.feature_mix_layer is None else self.feature_mix_layer.config,
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
raise ValueError('do not support this function')
@property
def grouped_block_index(self):
return self.block_group_info
def load_state_dict(self, state_dict, **kwargs):
model_dict = self.state_dict()
for key in state_dict:
if '.mobile_inverted_conv.' in key:
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
else:
new_key = key
if new_key in model_dict:
pass
elif '.bn.bn.' in new_key:
new_key = new_key.replace('.bn.bn.', '.bn.')
elif '.conv.conv.weight' in new_key:
new_key = new_key.replace('.conv.conv.weight', '.conv.weight')
elif '.linear.linear.' in new_key:
new_key = new_key.replace('.linear.linear.', '.linear.')
##############################################################################
elif '.linear.' in new_key:
new_key = new_key.replace('.linear.', '.linear.linear.')
elif 'bn.' in new_key:
new_key = new_key.replace('bn.', 'bn.bn.')
elif 'conv.weight' in new_key:
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
else:
raise ValueError(new_key)
assert new_key in model_dict, '%s' % new_key
model_dict[new_key] = state_dict[key]
super(OFAProxylessNASNets, self).load_state_dict(model_dict)
""" set, sample and get active sub-networks """
def set_max_net(self):
self.set_active_subnet(ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list))
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
ks = val2list(ks, len(self.blocks) - 1)
expand_ratio = val2list(e, len(self.blocks) - 1)
depth = val2list(d, len(self.block_group_info))
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
if k is not None:
block.conv.active_kernel_size = k
if e is not None:
block.conv.active_expand_ratio = e
for i, d in enumerate(depth):
if d is not None:
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
def set_constraint(self, include_list, constraint_type='depth'):
if constraint_type == 'depth':
self.__dict__['_depth_include_list'] = include_list.copy()
elif constraint_type == 'expand_ratio':
self.__dict__['_expand_include_list'] = include_list.copy()
elif constraint_type == 'kernel_size':
self.__dict__['_ks_include_list'] = include_list.copy()
else:
raise NotImplementedError
def clear_constraint(self):
self.__dict__['_depth_include_list'] = None
self.__dict__['_expand_include_list'] = None
self.__dict__['_ks_include_list'] = None
def sample_active_subnet(self):
ks_candidates = self.ks_list if self.__dict__.get('_ks_include_list', None) is None \
else self.__dict__['_ks_include_list']
expand_candidates = self.expand_ratio_list if self.__dict__.get('_expand_include_list', None) is None \
else self.__dict__['_expand_include_list']
depth_candidates = self.depth_list if self.__dict__.get('_depth_include_list', None) is None else \
self.__dict__['_depth_include_list']
# sample kernel size
ks_setting = []
if not isinstance(ks_candidates[0], list):
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
for k_set in ks_candidates:
k = random.choice(k_set)
ks_setting.append(k)
# sample expand ratio
expand_setting = []
if not isinstance(expand_candidates[0], list):
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
for e_set in expand_candidates:
e = random.choice(e_set)
expand_setting.append(e)
# sample depth
depth_setting = []
if not isinstance(depth_candidates[0], list):
depth_candidates = [depth_candidates for _ in range(len(self.block_group_info))]
for d_set in depth_candidates:
d = random.choice(d_set)
depth_setting.append(d)
depth_setting[-1] = 1
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
return {
'ks': ks_setting,
'e': expand_setting,
'd': depth_setting,
}
def get_active_subnet(self, preserve_weight=True):
first_conv = copy.deepcopy(self.first_conv)
blocks = [copy.deepcopy(self.blocks[0])]
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
classifier = copy.deepcopy(self.classifier)
input_channel = blocks[0].conv.out_channels
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(ResidualBlock(
self.blocks[idx].conv.get_active_subnet(input_channel, preserve_weight),
copy.deepcopy(self.blocks[idx].shortcut)
))
input_channel = stage_blocks[-1].conv.out_channels
blocks += stage_blocks
_subnet = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
_subnet.set_bn_param(**self.get_bn_param())
return _subnet
def get_active_net_config(self):
first_conv_config = self.first_conv.config
first_block_config = self.blocks[0].config
feature_mix_layer_config = self.feature_mix_layer.config
classifier_config = self.classifier.config
block_config_list = [first_block_config]
input_channel = first_block_config['conv']['out_channels']
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append({
'name': ResidualBlock.__name__,
'conv': self.blocks[idx].conv.get_active_subnet_config(input_channel),
'shortcut': self.blocks[idx].shortcut.config if self.blocks[idx].shortcut is not None else None,
})
try:
input_channel = self.blocks[idx].conv.active_out_channel
except Exception:
input_channel = self.blocks[idx].conv.out_channels
block_config_list += stage_blocks
return {
'name': ProxylessNASNets.__name__,
'bn': self.get_bn_param(),
'first_conv': first_conv_config,
'blocks': block_config_list,
'feature_mix_layer': feature_mix_layer_config,
'classifier': classifier_config,
}
""" Width Related Methods """
def re_organize_middle_weights(self, expand_ratio_stage=0):
for block in self.blocks[1:]:
block.conv.re_organize_middle_weights(expand_ratio_stage)

View File

@@ -0,0 +1,267 @@
import random
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicConvLayer, DynamicLinearLayer
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicResNetBottleneckBlock
from ofa_local.utils.layers import IdentityLayer, ResidualBlock
from ofa_local.imagenet_classification.networks import ResNets
from ofa_local.utils import make_divisible, val2list, MyNetwork
__all__ = ['OFAResNets']
class OFAResNets(ResNets):
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-5), dropout_rate=0,
depth_list=2, expand_ratio_list=0.25, width_mult_list=1.0):
self.depth_list = val2list(depth_list)
self.expand_ratio_list = val2list(expand_ratio_list)
self.width_mult_list = val2list(width_mult_list)
# sort
self.depth_list.sort()
self.expand_ratio_list.sort()
self.width_mult_list.sort()
input_channel = [
make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) for width_mult in self.width_mult_list
]
mid_input_channel = [
make_divisible(channel // 2, MyNetwork.CHANNEL_DIVISIBLE) for channel in input_channel
]
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
for i, width in enumerate(stage_width_list):
stage_width_list[i] = [
make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE) for width_mult in self.width_mult_list
]
n_block_list = [base_depth + max(self.depth_list) for base_depth in ResNets.BASE_DEPTH_LIST]
stride_list = [1, 2, 2, 2]
# build input stem
input_stem = [
DynamicConvLayer(val2list(3), mid_input_channel, 3, stride=2, use_bn=True, act_func='relu'),
ResidualBlock(
DynamicConvLayer(mid_input_channel, mid_input_channel, 3, stride=1, use_bn=True, act_func='relu'),
IdentityLayer(mid_input_channel, mid_input_channel)
),
DynamicConvLayer(mid_input_channel, input_channel, 3, stride=1, use_bn=True, act_func='relu')
]
# blocks
blocks = []
for d, width, s in zip(n_block_list, stage_width_list, stride_list):
for i in range(d):
stride = s if i == 0 else 1
bottleneck_block = DynamicResNetBottleneckBlock(
input_channel, width, expand_ratio_list=self.expand_ratio_list,
kernel_size=3, stride=stride, act_func='relu', downsample_mode='avgpool_conv',
)
blocks.append(bottleneck_block)
input_channel = width
# classifier
classifier = DynamicLinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
super(OFAResNets, self).__init__(input_stem, blocks, classifier)
# set bn param
self.set_bn_param(*bn_param)
# runtime_depth
self.input_stem_skipping = 0
self.runtime_depth = [0] * len(n_block_list)
@property
def ks_list(self):
return [3]
@staticmethod
def name():
return 'OFAResNets'
def forward(self, x):
for layer in self.input_stem:
if self.input_stem_skipping > 0 \
and isinstance(layer, ResidualBlock) and isinstance(layer.shortcut, IdentityLayer):
pass
else:
x = layer(x)
x = self.max_pooling(x)
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[:len(block_idx) - depth_param]
for idx in active_idx:
x = self.blocks[idx](x)
x = self.global_avg_pool(x)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = ''
for layer in self.input_stem:
if self.input_stem_skipping > 0 \
and isinstance(layer, ResidualBlock) and isinstance(layer.shortcut, IdentityLayer):
pass
else:
_str += layer.module_str + '\n'
_str += 'max_pooling(ks=3, stride=2)\n'
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[:len(block_idx) - depth_param]
for idx in active_idx:
_str += self.blocks[idx].module_str + '\n'
_str += self.global_avg_pool.__repr__() + '\n'
_str += self.classifier.module_str
return _str
@property
def config(self):
return {
'name': OFAResNets.__name__,
'bn': self.get_bn_param(),
'input_stem': [
layer.config for layer in self.input_stem
],
'blocks': [
block.config for block in self.blocks
],
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
raise ValueError('do not support this function')
def load_state_dict(self, state_dict, **kwargs):
model_dict = self.state_dict()
for key in state_dict:
new_key = key
if new_key in model_dict:
pass
elif '.linear.' in new_key:
new_key = new_key.replace('.linear.', '.linear.linear.')
elif 'bn.' in new_key:
new_key = new_key.replace('bn.', 'bn.bn.')
elif 'conv.weight' in new_key:
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
else:
raise ValueError(new_key)
assert new_key in model_dict, '%s' % new_key
model_dict[new_key] = state_dict[key]
super(OFAResNets, self).load_state_dict(model_dict)
""" set, sample and get active sub-networks """
def set_max_net(self):
self.set_active_subnet(d=max(self.depth_list), e=max(self.expand_ratio_list), w=len(self.width_mult_list) - 1)
def set_active_subnet(self, d=None, e=None, w=None, **kwargs):
depth = val2list(d, len(ResNets.BASE_DEPTH_LIST) + 1)
expand_ratio = val2list(e, len(self.blocks))
width_mult = val2list(w, len(ResNets.BASE_DEPTH_LIST) + 2)
for block, e in zip(self.blocks, expand_ratio):
if e is not None:
block.active_expand_ratio = e
if width_mult[0] is not None:
self.input_stem[1].conv.active_out_channel = self.input_stem[0].active_out_channel = \
self.input_stem[0].out_channel_list[width_mult[0]]
if width_mult[1] is not None:
self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[width_mult[1]]
if depth[0] is not None:
self.input_stem_skipping = (depth[0] != max(self.depth_list))
for stage_id, (block_idx, d, w) in enumerate(zip(self.grouped_block_index, depth[1:], width_mult[2:])):
if d is not None:
self.runtime_depth[stage_id] = max(self.depth_list) - d
if w is not None:
for idx in block_idx:
self.blocks[idx].active_out_channel = self.blocks[idx].out_channel_list[w]
def sample_active_subnet(self):
# sample expand ratio
expand_setting = []
for block in self.blocks:
expand_setting.append(random.choice(block.expand_ratio_list))
# sample depth
depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])]
for stage_id in range(len(ResNets.BASE_DEPTH_LIST)):
depth_setting.append(random.choice(self.depth_list))
# sample width_mult
width_mult_setting = [
random.choice(list(range(len(self.input_stem[0].out_channel_list)))),
random.choice(list(range(len(self.input_stem[2].out_channel_list)))),
]
for stage_id, block_idx in enumerate(self.grouped_block_index):
stage_first_block = self.blocks[block_idx[0]]
width_mult_setting.append(
random.choice(list(range(len(stage_first_block.out_channel_list))))
)
arch_config = {
'd': depth_setting,
'e': expand_setting,
'w': width_mult_setting
}
self.set_active_subnet(**arch_config)
return arch_config
def get_active_subnet(self, preserve_weight=True):
input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)]
if self.input_stem_skipping <= 0:
input_stem.append(ResidualBlock(
self.input_stem[1].conv.get_active_subnet(self.input_stem[0].active_out_channel, preserve_weight),
IdentityLayer(self.input_stem[0].active_out_channel, self.input_stem[0].active_out_channel)
))
input_stem.append(self.input_stem[2].get_active_subnet(self.input_stem[0].active_out_channel, preserve_weight))
input_channel = self.input_stem[2].active_out_channel
blocks = []
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[:len(block_idx) - depth_param]
for idx in active_idx:
blocks.append(self.blocks[idx].get_active_subnet(input_channel, preserve_weight))
input_channel = self.blocks[idx].active_out_channel
classifier = self.classifier.get_active_subnet(input_channel, preserve_weight)
subnet = ResNets(input_stem, blocks, classifier)
subnet.set_bn_param(**self.get_bn_param())
return subnet
def get_active_net_config(self):
input_stem_config = [self.input_stem[0].get_active_subnet_config(3)]
if self.input_stem_skipping <= 0:
input_stem_config.append({
'name': ResidualBlock.__name__,
'conv': self.input_stem[1].conv.get_active_subnet_config(self.input_stem[0].active_out_channel),
'shortcut': IdentityLayer(self.input_stem[0].active_out_channel, self.input_stem[0].active_out_channel),
})
input_stem_config.append(self.input_stem[2].get_active_subnet_config(self.input_stem[0].active_out_channel))
input_channel = self.input_stem[2].active_out_channel
blocks_config = []
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[:len(block_idx) - depth_param]
for idx in active_idx:
blocks_config.append(self.blocks[idx].get_active_subnet_config(input_channel))
input_channel = self.blocks[idx].active_out_channel
classifier_config = self.classifier.get_active_subnet_config(input_channel)
return {
'name': ResNets.__name__,
'bn': self.get_bn_param(),
'input_stem': input_stem_config,
'blocks': blocks_config,
'classifier': classifier_config,
}
""" Width Related Methods """
def re_organize_middle_weights(self, expand_ratio_stage=0):
for block in self.blocks:
block.re_organize_middle_weights(expand_ratio_stage)

View File

@@ -0,0 +1,5 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .progressive_shrinking import *

View File

@@ -0,0 +1,320 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch.nn as nn
import random
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from ofa.utils import AverageMeter, cross_entropy_loss_with_soft_target
from ofa.utils import DistributedMetric, list_mean, subset_mean, val2list, MyRandomResizedCrop
from ofa.imagenet_classification.run_manager import DistributedRunManager
__all__ = [
'validate', 'train_one_epoch', 'train', 'load_models',
'train_elastic_depth', 'train_elastic_expand', 'train_elastic_width_mult',
]
def validate(run_manager, epoch=0, is_test=False, image_size_list=None,
ks_list=None, expand_ratio_list=None, depth_list=None, width_mult_list=None, additional_setting=None):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
dynamic_net.eval()
if image_size_list is None:
image_size_list = val2list(run_manager.run_config.data_provider.image_size, 1)
if ks_list is None:
ks_list = dynamic_net.ks_list
if expand_ratio_list is None:
expand_ratio_list = dynamic_net.expand_ratio_list
if depth_list is None:
depth_list = dynamic_net.depth_list
if width_mult_list is None:
if 'width_mult_list' in dynamic_net.__dict__:
width_mult_list = list(range(len(dynamic_net.width_mult_list)))
else:
width_mult_list = [0]
subnet_settings = []
for d in depth_list:
for e in expand_ratio_list:
for k in ks_list:
for w in width_mult_list:
for img_size in image_size_list:
subnet_settings.append([{
'image_size': img_size,
'd': d,
'e': e,
'ks': k,
'w': w,
}, 'R%s-D%s-E%s-K%s-W%s' % (img_size, d, e, k, w)])
if additional_setting is not None:
subnet_settings += additional_setting
losses_of_subnets, top1_of_subnets, top5_of_subnets = [], [], []
valid_log = ''
for setting, name in subnet_settings:
run_manager.write_log('-' * 30 + ' Validate %s ' % name + '-' * 30, 'train', should_print=False)
run_manager.run_config.data_provider.assign_active_img_size(setting.pop('image_size'))
dynamic_net.set_active_subnet(**setting)
run_manager.write_log(dynamic_net.module_str, 'train', should_print=False)
run_manager.reset_running_statistics(dynamic_net)
loss, (top1, top5) = run_manager.validate(epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net)
losses_of_subnets.append(loss)
top1_of_subnets.append(top1)
top5_of_subnets.append(top5)
valid_log += '%s (%.3f), ' % (name, top1)
return list_mean(losses_of_subnets), list_mean(top1_of_subnets), list_mean(top5_of_subnets), valid_log
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0):
dynamic_net = run_manager.network
distributed = isinstance(run_manager, DistributedRunManager)
# switch to train mode
dynamic_net.train()
if distributed:
run_manager.run_config.train_loader.sampler.set_epoch(epoch)
MyRandomResizedCrop.EPOCH = epoch
nBatch = len(run_manager.run_config.train_loader)
data_time = AverageMeter()
losses = DistributedMetric('train_loss') if distributed else AverageMeter()
metric_dict = run_manager.get_metric_dict()
with tqdm(total=nBatch,
desc='Train Epoch #{}'.format(epoch + 1),
disable=distributed and not run_manager.is_root) as t:
end = time.time()
for i, (images, labels) in enumerate(run_manager.run_config.train_loader):
MyRandomResizedCrop.BATCH = i
data_time.update(time.time() - end)
if epoch < warmup_epochs:
new_lr = run_manager.run_config.warmup_adjust_learning_rate(
run_manager.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
)
else:
new_lr = run_manager.run_config.adjust_learning_rate(
run_manager.optimizer, epoch - warmup_epochs, i, nBatch
)
images, labels = images.cuda(), labels.cuda()
target = labels
# soft target
if args.kd_ratio > 0:
args.teacher_model.train()
with torch.no_grad():
soft_logits = args.teacher_model(images).detach()
soft_label = F.softmax(soft_logits, dim=1)
# clean gradients
dynamic_net.zero_grad()
loss_of_subnets = []
# compute output
subnet_str = ''
for _ in range(args.dynamic_batch_size):
# set random seed before sampling
subnet_seed = int('%d%.3d%.3d' % (epoch * nBatch + i, _, 0))
random.seed(subnet_seed)
subnet_settings = dynamic_net.sample_active_subnet()
subnet_str += '%d: ' % _ + ','.join(['%s_%s' % (
key, '%.1f' % subset_mean(val, 0) if isinstance(val, list) else val
) for key, val in subnet_settings.items()]) + ' || '
output = run_manager.net(images)
if args.kd_ratio == 0:
loss = run_manager.train_criterion(output, labels)
loss_type = 'ce'
else:
if args.kd_type == 'ce':
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
else:
kd_loss = F.mse_loss(output, soft_logits)
loss = args.kd_ratio * kd_loss + run_manager.train_criterion(output, labels)
loss_type = '%.1fkd-%s & ce' % (args.kd_ratio, args.kd_type)
# measure accuracy and record loss
loss_of_subnets.append(loss)
run_manager.update_metric(metric_dict, output, target)
loss.backward()
run_manager.optimizer.step()
losses.update(list_mean(loss_of_subnets), images.size(0))
t.set_postfix({
'loss': losses.avg.item(),
**run_manager.get_metric_vals(metric_dict, return_dict=True),
'R': images.size(2),
'lr': new_lr,
'loss_type': loss_type,
'seed': str(subnet_seed),
'str': subnet_str,
'data_time': data_time.avg,
})
t.update(1)
end = time.time()
return losses.avg.item(), run_manager.get_metric_vals(metric_dict)
def train(run_manager, args, validate_func=None):
distributed = isinstance(run_manager, DistributedRunManager)
if validate_func is None:
validate_func = validate
for epoch in range(run_manager.start_epoch, run_manager.run_config.n_epochs + args.warmup_epochs):
train_loss, (train_top1, train_top5) = train_one_epoch(
run_manager, args, epoch, args.warmup_epochs, args.warmup_lr)
if (epoch + 1) % args.validation_frequency == 0:
val_loss, val_acc, val_acc5, _val_log = validate_func(run_manager, epoch=epoch, is_test=False)
# best_acc
is_best = val_acc > run_manager.best_acc
run_manager.best_acc = max(run_manager.best_acc, val_acc)
if not distributed or run_manager.is_root:
val_log = 'Valid [{0}/{1}] loss={2:.3f}, top-1={3:.3f} ({4:.3f})'. \
format(epoch + 1 - args.warmup_epochs, run_manager.run_config.n_epochs, val_loss, val_acc,
run_manager.best_acc)
val_log += ', Train top-1 {top1:.3f}, Train loss {loss:.3f}\t'.format(top1=train_top1, loss=train_loss)
val_log += _val_log
run_manager.write_log(val_log, 'valid', should_print=False)
run_manager.save_model({
'epoch': epoch,
'best_acc': run_manager.best_acc,
'optimizer': run_manager.optimizer.state_dict(),
'state_dict': run_manager.network.state_dict(),
}, is_best=is_best)
def load_models(run_manager, dynamic_net, model_path=None):
# specify init path
init = torch.load(model_path, map_location='cpu')['state_dict']
dynamic_net.load_state_dict(init)
run_manager.write_log('Loaded init from %s' % model_path, 'valid')
def train_elastic_depth(train_func, run_manager, args, validate_func_dict):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
depth_stage_list = dynamic_net.depth_list.copy()
depth_stage_list.sort(reverse=True)
n_stages = len(depth_stage_list) - 1
current_stage = n_stages - 1
# load pretrained models
if run_manager.start_epoch == 0 and not args.resume:
validate_func_dict['depth_list'] = sorted(dynamic_net.depth_list)
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
# validate after loading weights
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
else:
assert args.resume
run_manager.write_log(
'-' * 30 + 'Supporting Elastic Depth: %s -> %s' %
(depth_stage_list[:current_stage + 1], depth_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
)
# add depth list constraints
if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.expand_ratio_list)) == 1:
validate_func_dict['depth_list'] = depth_stage_list
else:
validate_func_dict['depth_list'] = sorted({min(depth_stage_list), max(depth_stage_list)})
# train
train_func(
run_manager, args,
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
)
def train_elastic_expand(train_func, run_manager, args, validate_func_dict):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
expand_stage_list = dynamic_net.expand_ratio_list.copy()
expand_stage_list.sort(reverse=True)
n_stages = len(expand_stage_list) - 1
current_stage = n_stages - 1
# load pretrained models
if run_manager.start_epoch == 0 and not args.resume:
validate_func_dict['expand_ratio_list'] = sorted(dynamic_net.expand_ratio_list)
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
dynamic_net.re_organize_middle_weights(expand_ratio_stage=current_stage)
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
else:
assert args.resume
run_manager.write_log(
'-' * 30 + 'Supporting Elastic Expand Ratio: %s -> %s' %
(expand_stage_list[:current_stage + 1], expand_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
)
if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.depth_list)) == 1:
validate_func_dict['expand_ratio_list'] = expand_stage_list
else:
validate_func_dict['expand_ratio_list'] = sorted({min(expand_stage_list), max(expand_stage_list)})
# train
train_func(
run_manager, args,
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
)
def train_elastic_width_mult(train_func, run_manager, args, validate_func_dict):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
width_stage_list = dynamic_net.width_mult_list.copy()
width_stage_list.sort(reverse=True)
n_stages = len(width_stage_list) - 1
current_stage = n_stages - 1
if run_manager.start_epoch == 0 and not args.resume:
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
if current_stage == 0:
dynamic_net.re_organize_middle_weights(expand_ratio_stage=len(dynamic_net.expand_ratio_list) - 1)
run_manager.write_log('reorganize_middle_weights (expand_ratio_stage=%d)'
% (len(dynamic_net.expand_ratio_list) - 1), 'valid')
try:
dynamic_net.re_organize_outer_weights()
run_manager.write_log('reorganize_outer_weights', 'valid')
except Exception:
pass
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
else:
assert args.resume
run_manager.write_log(
'-' * 30 + 'Supporting Elastic Width Mult: %s -> %s' %
(width_stage_list[:current_stage + 1], width_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
)
validate_func_dict['width_mult_list'] = sorted({0, len(width_stage_list) - 1})
# train
train_func(
run_manager, args,
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
)

View File

@@ -0,0 +1,70 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import torch.nn.functional as F
import torch.nn as nn
import torch
from ofa_local.utils import AverageMeter, get_net_device, DistributedTensor
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_op import DynamicBatchNorm2d
__all__ = ['set_running_statistics']
def set_running_statistics(model, data_loader, distributed=False):
bn_mean = {}
bn_var = {}
forward_model = copy.deepcopy(model)
for name, m in forward_model.named_modules():
if isinstance(m, nn.BatchNorm2d):
if distributed:
bn_mean[name] = DistributedTensor(name + '#mean')
bn_var[name] = DistributedTensor(name + '#var')
else:
bn_mean[name] = AverageMeter()
bn_var[name] = AverageMeter()
def new_forward(bn, mean_est, var_est):
def lambda_forward(x):
batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1
batch_var = (x - batch_mean) * (x - batch_mean)
batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
batch_mean = torch.squeeze(batch_mean)
batch_var = torch.squeeze(batch_var)
mean_est.update(batch_mean.data, x.size(0))
var_est.update(batch_var.data, x.size(0))
# bn forward using calculated mean & var
_feature_dim = batch_mean.size(0)
return F.batch_norm(
x, batch_mean, batch_var, bn.weight[:_feature_dim],
bn.bias[:_feature_dim], False,
0.0, bn.eps,
)
return lambda_forward
m.forward = new_forward(m, bn_mean[name], bn_var[name])
if len(bn_mean) == 0:
# skip if there is no batch normalization layers in the network
return
with torch.no_grad():
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = True
for images, labels in data_loader:
images = images.to(get_net_device(forward_model))
forward_model(images)
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = False
for name, m in model.named_modules():
if name in bn_mean and bn_mean[name].count > 0:
feature_dim = bn_mean[name].avg.size(0)
assert isinstance(m, nn.BatchNorm2d)
m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
m.running_var.data[:feature_dim].copy_(bn_var[name].avg)

View File

@@ -0,0 +1,18 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .proxyless_nets import *
from .mobilenet_v3 import *
from .resnets import *
def get_net_by_name(name):
if name == ProxylessNASNets.__name__:
return ProxylessNASNets
elif name == MobileNetV3.__name__:
return MobileNetV3
elif name == ResNets.__name__:
return ResNets
else:
raise ValueError('unrecognized type of network: %s' % name)

View File

@@ -0,0 +1,218 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import torch.nn as nn
from ofa_local.utils.layers import set_layer_from_config, MBConvLayer, ConvLayer, IdentityLayer, LinearLayer, ResidualBlock
from ofa_local.utils import MyNetwork, make_divisible, MyGlobalAvgPool2d
__all__ = ['MobileNetV3', 'MobileNetV3Large']
class MobileNetV3(MyNetwork):
def __init__(self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier):
super(MobileNetV3, self).__init__()
self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.final_expand_layer = final_expand_layer
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=True)
self.feature_mix_layer = feature_mix_layer
self.classifier = classifier
def forward(self, x):
x = self.first_conv(x)
for block in self.blocks:
x = block(x)
x = self.final_expand_layer(x)
x = self.global_avg_pool(x) # global average pooling
x = self.feature_mix_layer(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
for block in self.blocks:
_str += block.module_str + '\n'
_str += self.final_expand_layer.module_str + '\n'
_str += self.global_avg_pool.__repr__() + '\n'
_str += self.feature_mix_layer.module_str + '\n'
_str += self.classifier.module_str
return _str
@property
def config(self):
return {
'name': MobileNetV3.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'final_expand_layer': self.final_expand_layer.config,
'feature_mix_layer': self.feature_mix_layer.config,
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
first_conv = set_layer_from_config(config['first_conv'])
final_expand_layer = set_layer_from_config(config['final_expand_layer'])
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
classifier = set_layer_from_config(config['classifier'])
blocks = []
for block_config in config['blocks']:
blocks.append(ResidualBlock.build_from_config(block_config))
net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
if 'bn' in config:
net.set_bn_param(**config['bn'])
else:
net.set_bn_param(momentum=0.1, eps=1e-5)
return net
def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.conv, MBConvLayer) and isinstance(m.shortcut, IdentityLayer):
m.conv.point_linear.bn.weight.data.zero_()
@property
def grouped_block_index(self):
info_list = []
block_index_list = []
for i, block in enumerate(self.blocks[1:], 1):
if block.shortcut is None and len(block_index_list) > 0:
info_list.append(block_index_list)
block_index_list = []
block_index_list.append(i)
if len(block_index_list) > 0:
info_list.append(block_index_list)
return info_list
@staticmethod
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
# first conv layer
first_conv = ConvLayer(
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act'
)
# build mobile blocks
feature_dim = input_channel
blocks = []
for stage_id, block_config_list in cfg.items():
for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list:
mb_conv = MBConvLayer(
feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se
)
if stride == 1 and out_channel == feature_dim:
shortcut = IdentityLayer(out_channel, out_channel)
else:
shortcut = None
blocks.append(ResidualBlock(mb_conv, shortcut))
feature_dim = out_channel
# final expand layer
final_expand_layer = ConvLayer(
feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act',
)
# feature mix layer
feature_mix_layer = ConvLayer(
feature_dim * 6, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
)
# classifier
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
@staticmethod
def adjust_cfg(cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
for i, (stage_id, block_config_list) in enumerate(cfg.items()):
for block_config in block_config_list:
if ks is not None and stage_id != '0':
block_config[0] = ks
if expand_ratio is not None and stage_id != '0':
block_config[-1] = expand_ratio
block_config[1] = None
if stage_width_list is not None:
block_config[2] = stage_width_list[i]
if depth_param is not None and stage_id != '0':
new_block_config_list = [block_config_list[0]]
new_block_config_list += [copy.deepcopy(block_config_list[-1]) for _ in range(depth_param - 1)]
cfg[stage_id] = new_block_config_list
return cfg
def load_state_dict(self, state_dict, **kwargs):
current_state_dict = self.state_dict()
for key in state_dict:
if key not in current_state_dict:
assert '.mobile_inverted_conv.' in key
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
else:
new_key = key
current_state_dict[new_key] = state_dict[key]
super(MobileNetV3, self).load_state_dict(current_state_dict)
class MobileNetV3Large(MobileNetV3):
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0.2,
ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
input_channel = 16
last_channel = 1280
input_channel = make_divisible(input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
last_channel = make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) \
if width_mult > 1.0 else last_channel
cfg = {
# k, exp, c, se, nl, s, e,
'0': [
[3, 16, 16, False, 'relu', 1, 1],
],
'1': [
[3, 64, 24, False, 'relu', 2, None], # 4
[3, 72, 24, False, 'relu', 1, None], # 3
],
'2': [
[5, 72, 40, True, 'relu', 2, None], # 3
[5, 120, 40, True, 'relu', 1, None], # 3
[5, 120, 40, True, 'relu', 1, None], # 3
],
'3': [
[3, 240, 80, False, 'h_swish', 2, None], # 6
[3, 200, 80, False, 'h_swish', 1, None], # 2.5
[3, 184, 80, False, 'h_swish', 1, None], # 2.3
[3, 184, 80, False, 'h_swish', 1, None], # 2.3
],
'4': [
[3, 480, 112, True, 'h_swish', 1, None], # 6
[3, 672, 112, True, 'h_swish', 1, None], # 6
],
'5': [
[5, 672, 160, True, 'h_swish', 2, None], # 6
[5, 960, 160, True, 'h_swish', 1, None], # 6
[5, 960, 160, True, 'h_swish', 1, None], # 6
]
}
cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list)
# width multiplier on mobile setting, change `exp: 1` and `c: 2`
for stage_id, block_config_list in cfg.items():
for block_config in block_config_list:
if block_config[1] is not None:
block_config[1] = make_divisible(block_config[1] * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
block_config[2] = make_divisible(block_config[2] * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier = self.build_net_via_cfg(
cfg, input_channel, last_channel, n_classes, dropout_rate
)
super(MobileNetV3Large, self).__init__(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
# set bn param
self.set_bn_param(*bn_param)

View File

@@ -0,0 +1,210 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import json
import torch.nn as nn
from ofa_local.utils.layers import set_layer_from_config, MBConvLayer, ConvLayer, IdentityLayer, LinearLayer, ResidualBlock
from ofa_local.utils import download_url, make_divisible, val2list, MyNetwork, MyGlobalAvgPool2d
__all__ = ['proxyless_base', 'ProxylessNASNets', 'MobileNetV2']
def proxyless_base(net_config=None, n_classes=None, bn_param=None, dropout_rate=None,
local_path='~/.torch/proxylessnas/'):
assert net_config is not None, 'Please input a network config'
if 'http' in net_config:
net_config_path = download_url(net_config, local_path)
else:
net_config_path = net_config
net_config_json = json.load(open(net_config_path, 'r'))
if n_classes is not None:
net_config_json['classifier']['out_features'] = n_classes
if dropout_rate is not None:
net_config_json['classifier']['dropout_rate'] = dropout_rate
net = ProxylessNASNets.build_from_config(net_config_json)
if bn_param is not None:
net.set_bn_param(*bn_param)
return net
class ProxylessNASNets(MyNetwork):
def __init__(self, first_conv, blocks, feature_mix_layer, classifier):
super(ProxylessNASNets, self).__init__()
self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.feature_mix_layer = feature_mix_layer
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
self.classifier = classifier
def forward(self, x):
x = self.first_conv(x)
for block in self.blocks:
x = block(x)
if self.feature_mix_layer is not None:
x = self.feature_mix_layer(x)
x = self.global_avg_pool(x)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
for block in self.blocks:
_str += block.module_str + '\n'
_str += self.feature_mix_layer.module_str + '\n'
_str += self.global_avg_pool.__repr__() + '\n'
_str += self.classifier.module_str
return _str
@property
def config(self):
return {
'name': ProxylessNASNets.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'feature_mix_layer': None if self.feature_mix_layer is None else self.feature_mix_layer.config,
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
first_conv = set_layer_from_config(config['first_conv'])
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
classifier = set_layer_from_config(config['classifier'])
blocks = []
for block_config in config['blocks']:
blocks.append(ResidualBlock.build_from_config(block_config))
net = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
if 'bn' in config:
net.set_bn_param(**config['bn'])
else:
net.set_bn_param(momentum=0.1, eps=1e-3)
return net
def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.conv, MBConvLayer) and isinstance(m.shortcut, IdentityLayer):
m.conv.point_linear.bn.weight.data.zero_()
@property
def grouped_block_index(self):
info_list = []
block_index_list = []
for i, block in enumerate(self.blocks[1:], 1):
if block.shortcut is None and len(block_index_list) > 0:
info_list.append(block_index_list)
block_index_list = []
block_index_list.append(i)
if len(block_index_list) > 0:
info_list.append(block_index_list)
return info_list
def load_state_dict(self, state_dict, **kwargs):
current_state_dict = self.state_dict()
for key in state_dict:
if key not in current_state_dict:
assert '.mobile_inverted_conv.' in key
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
else:
new_key = key
current_state_dict[new_key] = state_dict[key]
super(ProxylessNASNets, self).load_state_dict(current_state_dict)
class MobileNetV2(ProxylessNASNets):
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-3), dropout_rate=0.2,
ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
ks = 3 if ks is None else ks
expand_ratio = 6 if expand_ratio is None else expand_ratio
input_channel = 32
last_channel = 1280
input_channel = make_divisible(input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
last_channel = make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) \
if width_mult > 1.0 else last_channel
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[expand_ratio, 24, 2, 2],
[expand_ratio, 32, 3, 2],
[expand_ratio, 64, 4, 2],
[expand_ratio, 96, 3, 1],
[expand_ratio, 160, 3, 2],
[expand_ratio, 320, 1, 1],
]
if depth_param is not None:
assert isinstance(depth_param, int)
for i in range(1, len(inverted_residual_setting) - 1):
inverted_residual_setting[i][2] = depth_param
if stage_width_list is not None:
for i in range(len(inverted_residual_setting)):
inverted_residual_setting[i][1] = stage_width_list[i]
ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
_pt = 0
# first conv layer
first_conv = ConvLayer(
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act'
)
# inverted residual blocks
blocks = []
for t, c, n, s in inverted_residual_setting:
output_channel = make_divisible(c * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
for i in range(n):
if i == 0:
stride = s
else:
stride = 1
if t == 1:
kernel_size = 3
else:
kernel_size = ks[_pt]
_pt += 1
mobile_inverted_conv = MBConvLayer(
in_channels=input_channel, out_channels=output_channel, kernel_size=kernel_size, stride=stride,
expand_ratio=t,
)
if stride == 1:
if input_channel == output_channel:
shortcut = IdentityLayer(input_channel, input_channel)
else:
shortcut = None
else:
shortcut = None
blocks.append(
ResidualBlock(mobile_inverted_conv, shortcut)
)
input_channel = output_channel
# 1x1_conv before global average pooling
feature_mix_layer = ConvLayer(
input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6', ops_order='weight_bn_act',
)
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
super(MobileNetV2, self).__init__(first_conv, blocks, feature_mix_layer, classifier)
# set bn param
self.set_bn_param(*bn_param)

View File

@@ -0,0 +1,192 @@
import torch.nn as nn
from ofa_local.utils.layers import set_layer_from_config, ConvLayer, IdentityLayer, LinearLayer
from ofa_local.utils.layers import ResNetBottleneckBlock, ResidualBlock
from ofa_local.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d
__all__ = ['ResNets', 'ResNet50', 'ResNet50D']
class ResNets(MyNetwork):
BASE_DEPTH_LIST = [2, 2, 4, 2]
STAGE_WIDTH_LIST = [256, 512, 1024, 2048]
def __init__(self, input_stem, blocks, classifier):
super(ResNets, self).__init__()
self.input_stem = nn.ModuleList(input_stem)
self.max_pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.blocks = nn.ModuleList(blocks)
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
self.classifier = classifier
def forward(self, x):
for layer in self.input_stem:
x = layer(x)
x = self.max_pooling(x)
for block in self.blocks:
x = block(x)
x = self.global_avg_pool(x)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = ''
for layer in self.input_stem:
_str += layer.module_str + '\n'
_str += 'max_pooling(ks=3, stride=2)\n'
for block in self.blocks:
_str += block.module_str + '\n'
_str += self.global_avg_pool.__repr__() + '\n'
_str += self.classifier.module_str
return _str
@property
def config(self):
return {
'name': ResNets.__name__,
'bn': self.get_bn_param(),
'input_stem': [
layer.config for layer in self.input_stem
],
'blocks': [
block.config for block in self.blocks
],
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
classifier = set_layer_from_config(config['classifier'])
input_stem = []
for layer_config in config['input_stem']:
input_stem.append(set_layer_from_config(layer_config))
blocks = []
for block_config in config['blocks']:
blocks.append(set_layer_from_config(block_config))
net = ResNets(input_stem, blocks, classifier)
if 'bn' in config:
net.set_bn_param(**config['bn'])
else:
net.set_bn_param(momentum=0.1, eps=1e-5)
return net
def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, ResNetBottleneckBlock) and isinstance(m.downsample, IdentityLayer):
m.conv3.bn.weight.data.zero_()
@property
def grouped_block_index(self):
info_list = []
block_index_list = []
for i, block in enumerate(self.blocks):
if not isinstance(block.downsample, IdentityLayer) and len(block_index_list) > 0:
info_list.append(block_index_list)
block_index_list = []
block_index_list.append(i)
if len(block_index_list) > 0:
info_list.append(block_index_list)
return info_list
def load_state_dict(self, state_dict, **kwargs):
super(ResNets, self).load_state_dict(state_dict)
class ResNet50(ResNets):
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0,
expand_ratio=None, depth_param=None):
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
for i, width in enumerate(stage_width_list):
stage_width_list[i] = make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
depth_list = [3, 4, 6, 3]
if depth_param is not None:
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
depth_list[i] = depth + depth_param
stride_list = [1, 2, 2, 2]
# build input stem
input_stem = [ConvLayer(
3, input_channel, kernel_size=7, stride=2, use_bn=True, act_func='relu', ops_order='weight_bn_act',
)]
# blocks
blocks = []
for d, width, s in zip(depth_list, stage_width_list, stride_list):
for i in range(d):
stride = s if i == 0 else 1
bottleneck_block = ResNetBottleneckBlock(
input_channel, width, kernel_size=3, stride=stride, expand_ratio=expand_ratio,
act_func='relu', downsample_mode='conv',
)
blocks.append(bottleneck_block)
input_channel = width
# classifier
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
super(ResNet50, self).__init__(input_stem, blocks, classifier)
# set bn param
self.set_bn_param(*bn_param)
class ResNet50D(ResNets):
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0,
expand_ratio=None, depth_param=None):
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
mid_input_channel = make_divisible(input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE)
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
for i, width in enumerate(stage_width_list):
stage_width_list[i] = make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
depth_list = [3, 4, 6, 3]
if depth_param is not None:
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
depth_list[i] = depth + depth_param
stride_list = [1, 2, 2, 2]
# build input stem
input_stem = [
ConvLayer(3, mid_input_channel, 3, stride=2, use_bn=True, act_func='relu'),
ResidualBlock(
ConvLayer(mid_input_channel, mid_input_channel, 3, stride=1, use_bn=True, act_func='relu'),
IdentityLayer(mid_input_channel, mid_input_channel)
),
ConvLayer(mid_input_channel, input_channel, 3, stride=1, use_bn=True, act_func='relu')
]
# blocks
blocks = []
for d, width, s in zip(depth_list, stage_width_list, stride_list):
for i in range(d):
stride = s if i == 0 else 1
bottleneck_block = ResNetBottleneckBlock(
input_channel, width, kernel_size=3, stride=stride, expand_ratio=expand_ratio,
act_func='relu', downsample_mode='avgpool_conv',
)
blocks.append(bottleneck_block)
input_channel = width
# classifier
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
super(ResNet50D, self).__init__(input_stem, blocks, classifier)
# set bn param
self.set_bn_param(*bn_param)

View File

@@ -0,0 +1,7 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .run_config import *
from .run_manager import *
from .distributed_run_manager import *

View File

@@ -0,0 +1,381 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import json
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from ofa_local.utils import cross_entropy_with_label_smoothing, cross_entropy_loss_with_soft_target, write_log, init_models
from ofa_local.utils import DistributedMetric, list_mean, get_net_info, accuracy, AverageMeter, mix_labels, mix_images
from ofa_local.utils import MyRandomResizedCrop
__all__ = ['DistributedRunManager']
class DistributedRunManager:
def __init__(self, path, net, run_config, hvd_compression, backward_steps=1, is_root=False, init=True):
import horovod.torch as hvd
self.path = path
self.net = net
self.run_config = run_config
self.is_root = is_root
self.best_acc = 0.0
self.start_epoch = 0
os.makedirs(self.path, exist_ok=True)
self.net.cuda()
cudnn.benchmark = True
if init and self.is_root:
init_models(self.net, self.run_config.model_init)
if self.is_root:
# print net info
net_info = get_net_info(self.net, self.run_config.data_provider.data_shape)
with open('%s/net_info.txt' % self.path, 'w') as fout:
fout.write(json.dumps(net_info, indent=4) + '\n')
try:
fout.write(self.net.module_str + '\n')
except Exception:
fout.write('%s do not support `module_str`' % type(self.net))
fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform)
fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform)
fout.write('%s\n' % self.net)
# criterion
if isinstance(self.run_config.mixup_alpha, float):
self.train_criterion = cross_entropy_loss_with_soft_target
elif self.run_config.label_smoothing > 0:
self.train_criterion = lambda pred, target: \
cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
else:
self.train_criterion = nn.CrossEntropyLoss()
self.test_criterion = nn.CrossEntropyLoss()
# optimizer
if self.run_config.no_decay_keys:
keys = self.run_config.no_decay_keys.split('#')
net_params = [
self.net.get_parameters(keys, mode='exclude'), # parameters with weight decay
self.net.get_parameters(keys, mode='include'), # parameters without weight decay
]
else:
# noinspection PyBroadException
try:
net_params = self.network.weight_parameters()
except Exception:
net_params = []
for param in self.network.parameters():
if param.requires_grad:
net_params.append(param)
self.optimizer = self.run_config.build_optimizer(net_params)
self.optimizer = hvd.DistributedOptimizer(
self.optimizer, named_parameters=self.net.named_parameters(), compression=hvd_compression,
backward_passes_per_step=backward_steps,
)
""" save path and log path """
@property
def save_path(self):
if self.__dict__.get('_save_path', None) is None:
save_path = os.path.join(self.path, 'checkpoint')
os.makedirs(save_path, exist_ok=True)
self.__dict__['_save_path'] = save_path
return self.__dict__['_save_path']
@property
def logs_path(self):
if self.__dict__.get('_logs_path', None) is None:
logs_path = os.path.join(self.path, 'logs')
os.makedirs(logs_path, exist_ok=True)
self.__dict__['_logs_path'] = logs_path
return self.__dict__['_logs_path']
@property
def network(self):
return self.net
@network.setter
def network(self, new_val):
self.net = new_val
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
if self.is_root:
write_log(self.logs_path, log_str, prefix, should_print, mode)
""" save & load model & save_config & broadcast """
def save_config(self, extra_run_config=None, extra_net_config=None):
if self.is_root:
run_save_path = os.path.join(self.path, 'run.config')
if not os.path.isfile(run_save_path):
run_config = self.run_config.config
if extra_run_config is not None:
run_config.update(extra_run_config)
json.dump(run_config, open(run_save_path, 'w'), indent=4)
print('Run configs dump to %s' % run_save_path)
try:
net_save_path = os.path.join(self.path, 'net.config')
net_config = self.net.config
if extra_net_config is not None:
net_config.update(extra_net_config)
json.dump(net_config, open(net_save_path, 'w'), indent=4)
print('Network configs dump to %s' % net_save_path)
except Exception:
print('%s do not support net config' % type(self.net))
def save_model(self, checkpoint=None, is_best=False, model_name=None):
if self.is_root:
if checkpoint is None:
checkpoint = {'state_dict': self.net.state_dict()}
if model_name is None:
model_name = 'checkpoint.pth.tar'
latest_fname = os.path.join(self.save_path, 'latest.txt')
model_path = os.path.join(self.save_path, model_name)
with open(latest_fname, 'w') as _fout:
_fout.write(model_path + '\n')
torch.save(checkpoint, model_path)
if is_best:
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
def load_model(self, model_fname=None):
if self.is_root:
latest_fname = os.path.join(self.save_path, 'latest.txt')
if model_fname is None and os.path.exists(latest_fname):
with open(latest_fname, 'r') as fin:
model_fname = fin.readline()
if model_fname[-1] == '\n':
model_fname = model_fname[:-1]
# noinspection PyBroadException
try:
if model_fname is None or not os.path.exists(model_fname):
model_fname = '%s/checkpoint.pth.tar' % self.save_path
with open(latest_fname, 'w') as fout:
fout.write(model_fname + '\n')
print("=> loading checkpoint '{}'".format(model_fname))
checkpoint = torch.load(model_fname, map_location='cpu')
except Exception:
self.write_log('fail to load checkpoint from %s' % self.save_path, 'valid')
return
self.net.load_state_dict(checkpoint['state_dict'])
if 'epoch' in checkpoint:
self.start_epoch = checkpoint['epoch'] + 1
if 'best_acc' in checkpoint:
self.best_acc = checkpoint['best_acc']
if 'optimizer' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.write_log("=> loaded checkpoint '{}'".format(model_fname), 'valid')
# noinspection PyArgumentList
def broadcast(self):
import horovod.torch as hvd
self.start_epoch = hvd.broadcast(torch.LongTensor(1).fill_(self.start_epoch)[0], 0, name='start_epoch').item()
self.best_acc = hvd.broadcast(torch.Tensor(1).fill_(self.best_acc)[0], 0, name='best_acc').item()
hvd.broadcast_parameters(self.net.state_dict(), 0)
hvd.broadcast_optimizer_state(self.optimizer, 0)
""" metric related """
def get_metric_dict(self):
return {
'top1': DistributedMetric('top1'),
'top5': DistributedMetric('top5'),
}
def update_metric(self, metric_dict, output, labels):
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
metric_dict['top1'].update(acc1[0], output.size(0))
metric_dict['top5'].update(acc5[0], output.size(0))
def get_metric_vals(self, metric_dict, return_dict=False):
if return_dict:
return {
key: metric_dict[key].avg.item() for key in metric_dict
}
else:
return [metric_dict[key].avg.item() for key in metric_dict]
def get_metric_names(self):
return 'top1', 'top5'
""" train & validate """
def validate(self, epoch=0, is_test=False, run_str='', net=None, data_loader=None, no_logs=False):
if net is None:
net = self.net
if data_loader is None:
if is_test:
data_loader = self.run_config.test_loader
else:
data_loader = self.run_config.valid_loader
net.eval()
losses = DistributedMetric('val_loss')
metric_dict = self.get_metric_dict()
with torch.no_grad():
with tqdm(total=len(data_loader),
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str),
disable=no_logs or not self.is_root) as t:
for i, (images, labels) in enumerate(data_loader):
images, labels = images.cuda(), labels.cuda()
# compute output
output = net(images)
loss = self.test_criterion(output, labels)
# measure accuracy and record loss
losses.update(loss, images.size(0))
self.update_metric(metric_dict, output, labels)
t.set_postfix({
'loss': losses.avg.item(),
**self.get_metric_vals(metric_dict, return_dict=True),
'img_size': images.size(2),
})
t.update(1)
return losses.avg.item(), self.get_metric_vals(metric_dict)
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
if net is None:
net = self.net
if isinstance(self.run_config.data_provider.image_size, list):
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
for img_size in self.run_config.data_provider.image_size:
img_size_list.append(img_size)
self.run_config.data_provider.assign_active_img_size(img_size)
self.reset_running_statistics(net=net)
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
loss_list.append(loss)
top1_list.append(top1)
top5_list.append(top5)
return img_size_list, loss_list, top1_list, top5_list
else:
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
def train_one_epoch(self, args, epoch, warmup_epochs=5, warmup_lr=0):
self.net.train()
self.run_config.train_loader.sampler.set_epoch(epoch) # required by distributed sampler
MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
nBatch = len(self.run_config.train_loader)
losses = DistributedMetric('train_loss')
metric_dict = self.get_metric_dict()
data_time = AverageMeter()
with tqdm(total=nBatch,
desc='Train Epoch #{}'.format(epoch + 1),
disable=not self.is_root) as t:
end = time.time()
for i, (images, labels) in enumerate(self.run_config.train_loader):
MyRandomResizedCrop.BATCH = i
data_time.update(time.time() - end)
if epoch < warmup_epochs:
new_lr = self.run_config.warmup_adjust_learning_rate(
self.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
)
else:
new_lr = self.run_config.adjust_learning_rate(self.optimizer, epoch - warmup_epochs, i, nBatch)
images, labels = images.cuda(), labels.cuda()
target = labels
if isinstance(self.run_config.mixup_alpha, float):
# transform data
random.seed(int('%d%.3d' % (i, epoch)))
lam = random.betavariate(self.run_config.mixup_alpha, self.run_config.mixup_alpha)
images = mix_images(images, lam)
labels = mix_labels(
labels, lam, self.run_config.data_provider.n_classes, self.run_config.label_smoothing
)
# soft target
if args.teacher_model is not None:
args.teacher_model.train()
with torch.no_grad():
soft_logits = args.teacher_model(images).detach()
soft_label = F.softmax(soft_logits, dim=1)
# compute output
output = self.net(images)
if args.teacher_model is None:
loss = self.train_criterion(output, labels)
loss_type = 'ce'
else:
if args.kd_type == 'ce':
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
else:
kd_loss = F.mse_loss(output, soft_logits)
loss = args.kd_ratio * kd_loss + self.train_criterion(output, labels)
loss_type = '%.1fkd+ce' % args.kd_ratio
# update
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# measure accuracy and record loss
losses.update(loss, images.size(0))
self.update_metric(metric_dict, output, target)
t.set_postfix({
'loss': losses.avg.item(),
**self.get_metric_vals(metric_dict, return_dict=True),
'img_size': images.size(2),
'lr': new_lr,
'loss_type': loss_type,
'data_time': data_time.avg,
})
t.update(1)
end = time.time()
return losses.avg.item(), self.get_metric_vals(metric_dict)
def train(self, args, warmup_epochs=5, warmup_lr=0):
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epochs):
train_loss, (train_top1, train_top5) = self.train_one_epoch(args, epoch, warmup_epochs, warmup_lr)
img_size, val_loss, val_top1, val_top5 = self.validate_all_resolution(epoch, is_test=False)
is_best = list_mean(val_top1) > self.best_acc
self.best_acc = max(self.best_acc, list_mean(val_top1))
if self.is_root:
val_log = '[{0}/{1}]\tloss {2:.3f}\t{6} acc {3:.3f} ({4:.3f})\t{7} acc {5:.3f}\t' \
'Train {6} {top1:.3f}\tloss {train_loss:.3f}\t'. \
format(epoch + 1 - warmup_epochs, self.run_config.n_epochs, list_mean(val_loss),
list_mean(val_top1), self.best_acc, list_mean(val_top5), *self.get_metric_names(),
top1=train_top1, train_loss=train_loss)
for i_s, v_a in zip(img_size, val_top1):
val_log += '(%d, %.3f), ' % (i_s, v_a)
self.write_log(val_log, prefix='valid', should_print=False)
self.save_model({
'epoch': epoch,
'best_acc': self.best_acc,
'optimizer': self.optimizer.state_dict(),
'state_dict': self.net.state_dict(),
}, is_best=is_best)
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
if net is None:
net = self.net
if data_loader is None:
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
set_running_statistics(net, data_loader)

View File

@@ -0,0 +1,161 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from ofa_local.utils import calc_learning_rate, build_optimizer
from ofa_local.imagenet_classification.data_providers import ImagenetDataProvider
__all__ = ['RunConfig', 'ImagenetRunConfig', 'DistributedImageNetRunConfig']
class RunConfig:
def __init__(self, n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha, model_init, validation_frequency, print_frequency):
self.n_epochs = n_epochs
self.init_lr = init_lr
self.lr_schedule_type = lr_schedule_type
self.lr_schedule_param = lr_schedule_param
self.dataset = dataset
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.valid_size = valid_size
self.opt_type = opt_type
self.opt_param = opt_param
self.weight_decay = weight_decay
self.label_smoothing = label_smoothing
self.no_decay_keys = no_decay_keys
self.mixup_alpha = mixup_alpha
self.model_init = model_init
self.validation_frequency = validation_frequency
self.print_frequency = print_frequency
@property
def config(self):
config = {}
for key in self.__dict__:
if not key.startswith('_'):
config[key] = self.__dict__[key]
return config
def copy(self):
return RunConfig(**self.config)
""" learning rate """
def adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
""" adjust learning of a given optimizer and return the new learning rate """
new_lr = calc_learning_rate(epoch, self.init_lr, self.n_epochs, batch, nBatch, self.lr_schedule_type)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
return new_lr
def warmup_adjust_learning_rate(self, optimizer, T_total, nBatch, epoch, batch=0, warmup_lr=0):
T_cur = epoch * nBatch + batch + 1
new_lr = T_cur / T_total * (self.init_lr - warmup_lr) + warmup_lr
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
return new_lr
""" data provider """
@property
def data_provider(self):
raise NotImplementedError
@property
def train_loader(self):
return self.data_provider.train
@property
def valid_loader(self):
return self.data_provider.valid
@property
def test_loader(self):
return self.data_provider.test
def random_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
return self.data_provider.build_sub_train_loader(n_images, batch_size, num_worker, num_replicas, rank)
""" optimizer """
def build_optimizer(self, net_params):
return build_optimizer(net_params,
self.opt_type, self.opt_param, self.init_lr, self.weight_decay, self.no_decay_keys)
class ImagenetRunConfig(RunConfig):
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='imagenet', train_batch_size=256, test_batch_size=500, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, **kwargs):
super(ImagenetRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == ImagenetDataProvider.name():
DataProviderClass = ImagenetDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class DistributedImageNetRunConfig(ImagenetRunConfig):
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='imagenet', train_batch_size=64, test_batch_size=64, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=8, resize_scale=0.08, distort_color='tf', image_size=224,
**kwargs):
super(DistributedImageNetRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha, model_init, validation_frequency, print_frequency, n_worker, resize_scale, distort_color,
image_size, **kwargs
)
self._num_replicas = kwargs['num_replicas']
self._rank = kwargs['rank']
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == ImagenetDataProvider.name():
DataProviderClass = ImagenetDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
num_replicas=self._num_replicas, rank=self._rank,
)
return self.__dict__['_data_provider']

View File

@@ -0,0 +1,375 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import random
import time
import json
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from tqdm import tqdm
from ofa_local.utils import get_net_info, cross_entropy_loss_with_soft_target, cross_entropy_with_label_smoothing
from ofa_local.utils import AverageMeter, accuracy, write_log, mix_images, mix_labels, init_models
from ofa_local.utils import MyRandomResizedCrop
__all__ = ['RunManager']
class RunManager:
def __init__(self, path, net, run_config, init=True, measure_latency=None, no_gpu=False):
self.path = path
self.net = net
self.run_config = run_config
self.best_acc = 0
self.start_epoch = 0
os.makedirs(self.path, exist_ok=True)
# move network to GPU if available
if torch.cuda.is_available() and (not no_gpu):
self.device = torch.device('cuda:0')
self.net = self.net.to(self.device)
cudnn.benchmark = True
else:
self.device = torch.device('cpu')
# initialize model (default)
if init:
init_models(run_config.model_init)
# net info
net_info = get_net_info(self.net, self.run_config.data_provider.data_shape, measure_latency, True)
with open('%s/net_info.txt' % self.path, 'w') as fout:
fout.write(json.dumps(net_info, indent=4) + '\n')
# noinspection PyBroadException
try:
fout.write(self.network.module_str + '\n')
except Exception:
pass
fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform)
fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform)
fout.write('%s\n' % self.network)
# criterion
if isinstance(self.run_config.mixup_alpha, float):
self.train_criterion = cross_entropy_loss_with_soft_target
elif self.run_config.label_smoothing > 0:
self.train_criterion = \
lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
else:
self.train_criterion = nn.CrossEntropyLoss()
self.test_criterion = nn.CrossEntropyLoss()
# optimizer
if self.run_config.no_decay_keys:
keys = self.run_config.no_decay_keys.split('#')
net_params = [
self.network.get_parameters(keys, mode='exclude'), # parameters with weight decay
self.network.get_parameters(keys, mode='include'), # parameters without weight decay
]
else:
# noinspection PyBroadException
try:
net_params = self.network.weight_parameters()
except Exception:
net_params = []
for param in self.network.parameters():
if param.requires_grad:
net_params.append(param)
self.optimizer = self.run_config.build_optimizer(net_params)
self.net = torch.nn.DataParallel(self.net)
""" save path and log path """
@property
def save_path(self):
if self.__dict__.get('_save_path', None) is None:
save_path = os.path.join(self.path, 'checkpoint')
os.makedirs(save_path, exist_ok=True)
self.__dict__['_save_path'] = save_path
return self.__dict__['_save_path']
@property
def logs_path(self):
if self.__dict__.get('_logs_path', None) is None:
logs_path = os.path.join(self.path, 'logs')
os.makedirs(logs_path, exist_ok=True)
self.__dict__['_logs_path'] = logs_path
return self.__dict__['_logs_path']
@property
def network(self):
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
write_log(self.logs_path, log_str, prefix, should_print, mode)
""" save and load models """
def save_model(self, checkpoint=None, is_best=False, model_name=None):
if checkpoint is None:
checkpoint = {'state_dict': self.network.state_dict()}
if model_name is None:
model_name = 'checkpoint.pth.tar'
checkpoint['dataset'] = self.run_config.dataset # add `dataset` info to the checkpoint
latest_fname = os.path.join(self.save_path, 'latest.txt')
model_path = os.path.join(self.save_path, model_name)
with open(latest_fname, 'w') as fout:
fout.write(model_path + '\n')
torch.save(checkpoint, model_path)
if is_best:
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
def load_model(self, model_fname=None):
latest_fname = os.path.join(self.save_path, 'latest.txt')
if model_fname is None and os.path.exists(latest_fname):
with open(latest_fname, 'r') as fin:
model_fname = fin.readline()
if model_fname[-1] == '\n':
model_fname = model_fname[:-1]
# noinspection PyBroadException
try:
if model_fname is None or not os.path.exists(model_fname):
model_fname = '%s/checkpoint.pth.tar' % self.save_path
with open(latest_fname, 'w') as fout:
fout.write(model_fname + '\n')
print("=> loading checkpoint '{}'".format(model_fname))
checkpoint = torch.load(model_fname, map_location='cpu')
except Exception:
print('fail to load checkpoint from %s' % self.save_path)
return {}
self.network.load_state_dict(checkpoint['state_dict'])
if 'epoch' in checkpoint:
self.start_epoch = checkpoint['epoch'] + 1
if 'best_acc' in checkpoint:
self.best_acc = checkpoint['best_acc']
if 'optimizer' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}'".format(model_fname))
return checkpoint
def save_config(self, extra_run_config=None, extra_net_config=None):
""" dump run_config and net_config to the model_folder """
run_save_path = os.path.join(self.path, 'run.config')
if not os.path.isfile(run_save_path):
run_config = self.run_config.config
if extra_run_config is not None:
run_config.update(extra_run_config)
json.dump(run_config, open(run_save_path, 'w'), indent=4)
print('Run configs dump to %s' % run_save_path)
try:
net_save_path = os.path.join(self.path, 'net.config')
net_config = self.network.config
if extra_net_config is not None:
net_config.update(extra_net_config)
json.dump(net_config, open(net_save_path, 'w'), indent=4)
print('Network configs dump to %s' % net_save_path)
except Exception:
print('%s do not support net config' % type(self.network))
""" metric related """
def get_metric_dict(self):
return {
'top1': AverageMeter(),
'top5': AverageMeter(),
}
def update_metric(self, metric_dict, output, labels):
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
metric_dict['top1'].update(acc1[0].item(), output.size(0))
metric_dict['top5'].update(acc5[0].item(), output.size(0))
def get_metric_vals(self, metric_dict, return_dict=False):
if return_dict:
return {
key: metric_dict[key].avg for key in metric_dict
}
else:
return [metric_dict[key].avg for key in metric_dict]
def get_metric_names(self):
return 'top1', 'top5'
""" train and test """
def validate(self, epoch=0, is_test=False, run_str='', net=None, data_loader=None, no_logs=False, train_mode=False):
if net is None:
net = self.net
if not isinstance(net, nn.DataParallel):
net = nn.DataParallel(net)
if data_loader is None:
data_loader = self.run_config.test_loader if is_test else self.run_config.valid_loader
if train_mode:
net.train()
else:
net.eval()
losses = AverageMeter()
metric_dict = self.get_metric_dict()
with torch.no_grad():
with tqdm(total=len(data_loader),
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str), disable=no_logs) as t:
for i, (images, labels) in enumerate(data_loader):
images, labels = images.to(self.device), labels.to(self.device)
# compute output
output = net(images)
loss = self.test_criterion(output, labels)
# measure accuracy and record loss
self.update_metric(metric_dict, output, labels)
losses.update(loss.item(), images.size(0))
t.set_postfix({
'loss': losses.avg,
**self.get_metric_vals(metric_dict, return_dict=True),
'img_size': images.size(2),
})
t.update(1)
return losses.avg, self.get_metric_vals(metric_dict)
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
if net is None:
net = self.network
if isinstance(self.run_config.data_provider.image_size, list):
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
for img_size in self.run_config.data_provider.image_size:
img_size_list.append(img_size)
self.run_config.data_provider.assign_active_img_size(img_size)
self.reset_running_statistics(net=net)
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
loss_list.append(loss)
top1_list.append(top1)
top5_list.append(top5)
return img_size_list, loss_list, top1_list, top5_list
else:
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
def train_one_epoch(self, args, epoch, warmup_epochs=0, warmup_lr=0):
# switch to train mode
self.net.train()
MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
nBatch = len(self.run_config.train_loader)
losses = AverageMeter()
metric_dict = self.get_metric_dict()
data_time = AverageMeter()
with tqdm(total=nBatch,
desc='{} Train Epoch #{}'.format(self.run_config.dataset, epoch + 1)) as t:
end = time.time()
for i, (images, labels) in enumerate(self.run_config.train_loader):
MyRandomResizedCrop.BATCH = i
data_time.update(time.time() - end)
if epoch < warmup_epochs:
new_lr = self.run_config.warmup_adjust_learning_rate(
self.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
)
else:
new_lr = self.run_config.adjust_learning_rate(self.optimizer, epoch - warmup_epochs, i, nBatch)
images, labels = images.to(self.device), labels.to(self.device)
target = labels
if isinstance(self.run_config.mixup_alpha, float):
# transform data
lam = random.betavariate(self.run_config.mixup_alpha, self.run_config.mixup_alpha)
images = mix_images(images, lam)
labels = mix_labels(
labels, lam, self.run_config.data_provider.n_classes, self.run_config.label_smoothing
)
# soft target
if args.teacher_model is not None:
args.teacher_model.train()
with torch.no_grad():
soft_logits = args.teacher_model(images).detach()
soft_label = F.softmax(soft_logits, dim=1)
# compute output
output = self.net(images)
loss = self.train_criterion(output, labels)
if args.teacher_model is None:
loss_type = 'ce'
else:
if args.kd_type == 'ce':
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
else:
kd_loss = F.mse_loss(output, soft_logits)
loss = args.kd_ratio * kd_loss + loss
loss_type = '%.1fkd+ce' % args.kd_ratio
# compute gradient and do SGD step
self.net.zero_grad() # or self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# measure accuracy and record loss
losses.update(loss.item(), images.size(0))
self.update_metric(metric_dict, output, target)
t.set_postfix({
'loss': losses.avg,
**self.get_metric_vals(metric_dict, return_dict=True),
'img_size': images.size(2),
'lr': new_lr,
'loss_type': loss_type,
'data_time': data_time.avg,
})
t.update(1)
end = time.time()
return losses.avg, self.get_metric_vals(metric_dict)
def train(self, args, warmup_epoch=0, warmup_lr=0):
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epoch):
train_loss, (train_top1, train_top5) = self.train_one_epoch(args, epoch, warmup_epoch, warmup_lr)
if (epoch + 1) % self.run_config.validation_frequency == 0:
img_size, val_loss, val_acc, val_acc5 = self.validate_all_resolution(epoch=epoch, is_test=False)
is_best = np.mean(val_acc) > self.best_acc
self.best_acc = max(self.best_acc, np.mean(val_acc))
val_log = 'Valid [{0}/{1}]\tloss {2:.3f}\t{5} {3:.3f} ({4:.3f})'. \
format(epoch + 1 - warmup_epoch, self.run_config.n_epochs,
np.mean(val_loss), np.mean(val_acc), self.best_acc, self.get_metric_names()[0])
val_log += '\t{2} {0:.3f}\tTrain {1} {top1:.3f}\tloss {train_loss:.3f}\t'. \
format(np.mean(val_acc5), *self.get_metric_names(), top1=train_top1, train_loss=train_loss)
for i_s, v_a in zip(img_size, val_acc):
val_log += '(%d, %.3f), ' % (i_s, v_a)
self.write_log(val_log, prefix='valid', should_print=False)
else:
is_best = False
self.save_model({
'epoch': epoch,
'best_acc': self.best_acc,
'optimizer': self.optimizer.state_dict(),
'state_dict': self.network.state_dict(),
}, is_best=is_best)
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
if net is None:
net = self.network
if data_loader is None:
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
set_running_statistics(net, data_loader)

View File

@@ -0,0 +1,87 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import json
import torch
from ofa_local.utils import download_url
from ofa_local.imagenet_classification.networks import get_net_by_name, proxyless_base
from ofa_local.imagenet_classification.elastic_nn.networks import OFAMobileNetV3, OFAProxylessNASNets, OFAResNets
__all__ = [
'ofa_specialized', 'ofa_net',
'proxylessnas_net', 'proxylessnas_mobile', 'proxylessnas_cpu', 'proxylessnas_gpu',
]
def ofa_specialized(net_id, pretrained=True):
url_base = 'https://hanlab.mit.edu/files/OnceForAll/ofa_specialized/'
net_config = json.load(open(
download_url(url_base + net_id + '/net.config', model_dir='.torch/ofa_specialized/%s/' % net_id)
))
net = get_net_by_name(net_config['name']).build_from_config(net_config)
image_size = json.load(open(
download_url(url_base + net_id + '/run.config', model_dir='.torch/ofa_specialized/%s/' % net_id)
))['image_size']
if pretrained:
init = torch.load(
download_url(url_base + net_id + '/init', model_dir='.torch/ofa_specialized/%s/' % net_id),
map_location='cpu'
)['state_dict']
net.load_state_dict(init)
return net, image_size
def ofa_net(net_id, pretrained=True):
if net_id == 'ofa_proxyless_d234_e346_k357_w1.3':
net = OFAProxylessNASNets(
dropout_rate=0, width_mult=1.3, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
)
elif net_id == 'ofa_mbv3_d234_e346_k357_w1.0':
net = OFAMobileNetV3(
dropout_rate=0, width_mult=1.0, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
)
elif net_id == 'ofa_mbv3_d234_e346_k357_w1.2':
net = OFAMobileNetV3(
dropout_rate=0, width_mult=1.2, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
)
elif net_id == 'ofa_resnet50':
net = OFAResNets(
dropout_rate=0, depth_list=[0, 1, 2], expand_ratio_list=[0.2, 0.25, 0.35], width_mult_list=[0.65, 0.8, 1.0]
)
net_id = 'ofa_resnet50_d=0+1+2_e=0.2+0.25+0.35_w=0.65+0.8+1.0'
else:
raise ValueError('Not supported: %s' % net_id)
if pretrained:
url_base = 'https://hanlab.mit.edu/files/OnceForAll/ofa_nets/'
init = torch.load(
download_url(url_base + net_id, model_dir='.torch/ofa_nets'),
map_location='cpu')['state_dict']
net.load_state_dict(init)
return net
def proxylessnas_net(net_id, pretrained=True):
net = proxyless_base(
net_config='https://hanlab.mit.edu/files/proxylessNAS/%s.config' % net_id,
)
if pretrained:
net.load_state_dict(torch.load(
download_url('https://hanlab.mit.edu/files/proxylessNAS/%s.pth' % net_id), map_location='cpu'
)['state_dict'])
def proxylessnas_mobile(pretrained=True):
return proxylessnas_net('proxyless_mobile', pretrained)
def proxylessnas_cpu(pretrained=True):
return proxylessnas_net('proxyless_cpu', pretrained)
def proxylessnas_gpu(pretrained=True):
return proxylessnas_net('proxyless_gpu', pretrained)

View File

@@ -0,0 +1,7 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .acc_dataset import *
from .acc_predictor import *
from .arch_encoder import *

View File

@@ -0,0 +1,181 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import json
import numpy as np
from tqdm import tqdm
import torch
import torch.utils.data
from ofa.utils import list_mean
__all__ = ['net_setting2id', 'net_id2setting', 'AccuracyDataset']
def net_setting2id(net_setting):
return json.dumps(net_setting)
def net_id2setting(net_id):
return json.loads(net_id)
class RegDataset(torch.utils.data.Dataset):
def __init__(self, inputs, targets):
super(RegDataset, self).__init__()
self.inputs = inputs
self.targets = targets
def __getitem__(self, index):
return self.inputs[index], self.targets[index]
def __len__(self):
return self.inputs.size(0)
class AccuracyDataset:
def __init__(self, path):
self.path = path
os.makedirs(self.path, exist_ok=True)
@property
def net_id_path(self):
return os.path.join(self.path, 'net_id.dict')
@property
def acc_src_folder(self):
return os.path.join(self.path, 'src')
@property
def acc_dict_path(self):
return os.path.join(self.path, 'acc.dict')
# TODO: support parallel building
def build_acc_dataset(self, run_manager, ofa_network, n_arch=1000, image_size_list=None):
# load net_id_list, random sample if not exist
if os.path.isfile(self.net_id_path):
net_id_list = json.load(open(self.net_id_path))
else:
net_id_list = set()
while len(net_id_list) < n_arch:
net_setting = ofa_network.sample_active_subnet()
net_id = net_setting2id(net_setting)
net_id_list.add(net_id)
net_id_list = list(net_id_list)
net_id_list.sort()
json.dump(net_id_list, open(self.net_id_path, 'w'), indent=4)
image_size_list = [128, 160, 192, 224] if image_size_list is None else image_size_list
with tqdm(total=len(net_id_list) * len(image_size_list), desc='Building Acc Dataset') as t:
for image_size in image_size_list:
# load val dataset into memory
val_dataset = []
run_manager.run_config.data_provider.assign_active_img_size(image_size)
for images, labels in run_manager.run_config.valid_loader:
val_dataset.append((images, labels))
# save path
os.makedirs(self.acc_src_folder, exist_ok=True)
acc_save_path = os.path.join(self.acc_src_folder, '%d.dict' % image_size)
acc_dict = {}
# load existing acc dict
if os.path.isfile(acc_save_path):
existing_acc_dict = json.load(open(acc_save_path, 'r'))
else:
existing_acc_dict = {}
for net_id in net_id_list:
net_setting = net_id2setting(net_id)
key = net_setting2id({**net_setting, 'image_size': image_size})
if key in existing_acc_dict:
acc_dict[key] = existing_acc_dict[key]
t.set_postfix({
'net_id': net_id,
'image_size': image_size,
'info_val': acc_dict[key],
'status': 'loading',
})
t.update()
continue
ofa_network.set_active_subnet(**net_setting)
run_manager.reset_running_statistics(ofa_network)
net_setting_str = ','.join(['%s_%s' % (
key, '%.1f' % list_mean(val) if isinstance(val, list) else val
) for key, val in net_setting.items()])
loss, (top1, top5) = run_manager.validate(
run_str=net_setting_str, net=ofa_network, data_loader=val_dataset, no_logs=True,
)
info_val = top1
t.set_postfix({
'net_id': net_id,
'image_size': image_size,
'info_val': info_val,
})
t.update()
acc_dict.update({
key: info_val
})
json.dump(acc_dict, open(acc_save_path, 'w'), indent=4)
def merge_acc_dataset(self, image_size_list=None):
# load existing data
merged_acc_dict = {}
for fname in os.listdir(self.acc_src_folder):
if '.dict' not in fname:
continue
image_size = int(fname.split('.dict')[0])
if image_size_list is not None and image_size not in image_size_list:
print('Skip ', fname)
continue
full_path = os.path.join(self.acc_src_folder, fname)
partial_acc_dict = json.load(open(full_path))
merged_acc_dict.update(partial_acc_dict)
print('loaded %s' % full_path)
json.dump(merged_acc_dict, open(self.acc_dict_path, 'w'), indent=4)
return merged_acc_dict
def build_acc_data_loader(self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16):
# load data
acc_dict = json.load(open(self.acc_dict_path))
X_all = []
Y_all = []
with tqdm(total=len(acc_dict), desc='Loading data') as t:
for k, v in acc_dict.items():
dic = json.loads(k)
X_all.append(arch_encoder.arch2feature(dic))
Y_all.append(v / 100.) # range: 0 - 1
t.update()
base_acc = np.mean(Y_all)
# convert to torch tensor
X_all = torch.tensor(X_all, dtype=torch.float)
Y_all = torch.tensor(Y_all)
# random shuffle
shuffle_idx = torch.randperm(len(X_all))
X_all = X_all[shuffle_idx]
Y_all = Y_all[shuffle_idx]
# split data
idx = X_all.size(0) // 5 * 4 if n_training_sample is None else n_training_sample
val_idx = X_all.size(0) // 5 * 4
X_train, Y_train = X_all[:idx], Y_all[:idx]
X_test, Y_test = X_all[val_idx:], Y_all[val_idx:]
print('Train Size: %d,' % len(X_train), 'Valid Size: %d' % len(X_test))
# build data loader
train_dataset = RegDataset(X_train, Y_train)
val_dataset = RegDataset(X_test, Y_test)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=n_workers
)
valid_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=n_workers
)
return train_loader, valid_loader, base_acc

View File

@@ -0,0 +1,50 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import numpy as np
import torch
import torch.nn as nn
__all__ = ['AccuracyPredictor']
class AccuracyPredictor(nn.Module):
def __init__(self, arch_encoder, hidden_size=400, n_layers=3,
checkpoint_path=None, device='cuda:0'):
super(AccuracyPredictor, self).__init__()
self.arch_encoder = arch_encoder
self.hidden_size = hidden_size
self.n_layers = n_layers
self.device = device
# build layers
layers = []
for i in range(self.n_layers):
layers.append(nn.Sequential(
nn.Linear(self.arch_encoder.n_dim if i == 0 else self.hidden_size, self.hidden_size),
nn.ReLU(inplace=True),
))
layers.append(nn.Linear(self.hidden_size, 1, bias=False))
self.layers = nn.Sequential(*layers)
self.base_acc = nn.Parameter(torch.zeros(1, device=self.device), requires_grad=False)
if checkpoint_path is not None and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
self.load_state_dict(checkpoint)
print('Loaded checkpoint from %s' % checkpoint_path)
self.layers = self.layers.to(self.device)
def forward(self, x):
y = self.layers(x).squeeze()
return y + self.base_acc
def predict_acc(self, arch_dict_list):
X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list]
X = torch.tensor(np.array(X)).float().to(self.device)
return self.forward(X)

View File

@@ -0,0 +1,315 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import random
import numpy as np
from ofa.imagenet_classification.networks import ResNets
__all__ = ['MobileNetArchEncoder', 'ResNetArchEncoder']
class MobileNetArchEncoder:
SPACE_TYPE = 'mbv3'
def __init__(self, image_size_list=None, ks_list=None, expand_list=None, depth_list=None, n_stage=None):
self.image_size_list = [224] if image_size_list is None else image_size_list
self.ks_list = [3, 5, 7] if ks_list is None else ks_list
self.expand_list = [3, 4, 6] if expand_list is None else [int(expand) for expand in expand_list]
self.depth_list = [2, 3, 4] if depth_list is None else depth_list
if n_stage is not None:
self.n_stage = n_stage
elif self.SPACE_TYPE == 'mbv2':
self.n_stage = 6
elif self.SPACE_TYPE == 'mbv3':
self.n_stage = 5
else:
raise NotImplementedError
# build info dict
self.n_dim = 0
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
self._build_info_dict(target='r')
self.k_info = dict(id2val=[], val2id=[], L=[], R=[])
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
self._build_info_dict(target='k')
self._build_info_dict(target='e')
@property
def max_n_blocks(self):
if self.SPACE_TYPE == 'mbv3':
return self.n_stage * max(self.depth_list)
elif self.SPACE_TYPE == 'mbv2':
return (self.n_stage - 1) * max(self.depth_list) + 1
else:
raise NotImplementedError
def _build_info_dict(self, target):
if target == 'r':
target_dict = self.r_info
target_dict['L'].append(self.n_dim)
for img_size in self.image_size_list:
target_dict['val2id'][img_size] = self.n_dim
target_dict['id2val'][self.n_dim] = img_size
self.n_dim += 1
target_dict['R'].append(self.n_dim)
else:
if target == 'k':
target_dict = self.k_info
choices = self.ks_list
elif target == 'e':
target_dict = self.e_info
choices = self.expand_list
else:
raise NotImplementedError
for i in range(self.max_n_blocks):
target_dict['val2id'].append({})
target_dict['id2val'].append({})
target_dict['L'].append(self.n_dim)
for k in choices:
target_dict['val2id'][i][k] = self.n_dim
target_dict['id2val'][i][self.n_dim] = k
self.n_dim += 1
target_dict['R'].append(self.n_dim)
def arch2feature(self, arch_dict):
ks, e, d, r = arch_dict['ks'], arch_dict['e'], arch_dict['d'], arch_dict['image_size']
feature = np.zeros(self.n_dim)
for i in range(self.max_n_blocks):
nowd = i % max(self.depth_list)
stg = i // max(self.depth_list)
if nowd < d[stg]:
feature[self.k_info['val2id'][i][ks[i]]] = 1
feature[self.e_info['val2id'][i][e[i]]] = 1
feature[self.r_info['val2id'][r]] = 1
return feature
def feature2arch(self, feature):
img_sz = self.r_info['id2val'][
int(np.argmax(feature[self.r_info['L'][0]:self.r_info['R'][0]])) + self.r_info['L'][0]
]
assert img_sz in self.image_size_list
arch_dict = {'ks': [], 'e': [], 'd': [], 'image_size': img_sz}
d = 0
for i in range(self.max_n_blocks):
skip = True
for j in range(self.k_info['L'][i], self.k_info['R'][i]):
if feature[j] == 1:
arch_dict['ks'].append(self.k_info['id2val'][i][j])
skip = False
break
for j in range(self.e_info['L'][i], self.e_info['R'][i]):
if feature[j] == 1:
arch_dict['e'].append(self.e_info['id2val'][i][j])
assert not skip
skip = False
break
if skip:
arch_dict['e'].append(0)
arch_dict['ks'].append(0)
else:
d += 1
if (i + 1) % max(self.depth_list) == 0 or (i + 1) == self.max_n_blocks:
arch_dict['d'].append(d)
d = 0
return arch_dict
def random_sample_arch(self):
return {
'ks': random.choices(self.ks_list, k=self.max_n_blocks),
'e': random.choices(self.expand_list, k=self.max_n_blocks),
'd': random.choices(self.depth_list, k=self.n_stage),
'image_size': random.choice(self.image_size_list)
}
def mutate_resolution(self, arch_dict, mutate_prob):
if random.random() < mutate_prob:
arch_dict['image_size'] = random.choice(self.image_size_list)
return arch_dict
def mutate_arch(self, arch_dict, mutate_prob):
for i in range(self.max_n_blocks):
if random.random() < mutate_prob:
arch_dict['ks'][i] = random.choice(self.ks_list)
arch_dict['e'][i] = random.choice(self.expand_list)
for i in range(self.n_stage):
if random.random() < mutate_prob:
arch_dict['d'][i] = random.choice(self.depth_list)
return arch_dict
class ResNetArchEncoder:
def __init__(self, image_size_list=None, depth_list=None, expand_list=None, width_mult_list=None,
base_depth_list=None):
self.image_size_list = [224] if image_size_list is None else image_size_list
self.expand_list = [0.2, 0.25, 0.35] if expand_list is None else expand_list
self.depth_list = [0, 1, 2] if depth_list is None else depth_list
self.width_mult_list = [0.65, 0.8, 1.0] if width_mult_list is None else width_mult_list
self.base_depth_list = ResNets.BASE_DEPTH_LIST if base_depth_list is None else base_depth_list
"""" build info dict """
self.n_dim = 0
# resolution
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
self._build_info_dict(target='r')
# input stem skip
self.input_stem_d_info = dict(id2val={}, val2id={}, L=[], R=[])
self._build_info_dict(target='input_stem_d')
# width_mult
self.width_mult_info = dict(id2val=[], val2id=[], L=[], R=[])
self._build_info_dict(target='width_mult')
# expand ratio
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
self._build_info_dict(target='e')
@property
def n_stage(self):
return len(self.base_depth_list)
@property
def max_n_blocks(self):
return sum(self.base_depth_list) + self.n_stage * max(self.depth_list)
def _build_info_dict(self, target):
if target == 'r':
target_dict = self.r_info
target_dict['L'].append(self.n_dim)
for img_size in self.image_size_list:
target_dict['val2id'][img_size] = self.n_dim
target_dict['id2val'][self.n_dim] = img_size
self.n_dim += 1
target_dict['R'].append(self.n_dim)
elif target == 'input_stem_d':
target_dict = self.input_stem_d_info
target_dict['L'].append(self.n_dim)
for skip in [0, 1]:
target_dict['val2id'][skip] = self.n_dim
target_dict['id2val'][self.n_dim] = skip
self.n_dim += 1
target_dict['R'].append(self.n_dim)
elif target == 'e':
target_dict = self.e_info
choices = self.expand_list
for i in range(self.max_n_blocks):
target_dict['val2id'].append({})
target_dict['id2val'].append({})
target_dict['L'].append(self.n_dim)
for e in choices:
target_dict['val2id'][i][e] = self.n_dim
target_dict['id2val'][i][self.n_dim] = e
self.n_dim += 1
target_dict['R'].append(self.n_dim)
elif target == 'width_mult':
target_dict = self.width_mult_info
choices = list(range(len(self.width_mult_list)))
for i in range(self.n_stage + 2):
target_dict['val2id'].append({})
target_dict['id2val'].append({})
target_dict['L'].append(self.n_dim)
for w in choices:
target_dict['val2id'][i][w] = self.n_dim
target_dict['id2val'][i][self.n_dim] = w
self.n_dim += 1
target_dict['R'].append(self.n_dim)
def arch2feature(self, arch_dict):
d, e, w, r = arch_dict['d'], arch_dict['e'], arch_dict['w'], arch_dict['image_size']
input_stem_skip = 1 if d[0] > 0 else 0
d = d[1:]
feature = np.zeros(self.n_dim)
feature[self.r_info['val2id'][r]] = 1
feature[self.input_stem_d_info['val2id'][input_stem_skip]] = 1
for i in range(self.n_stage + 2):
feature[self.width_mult_info['val2id'][i][w[i]]] = 1
start_pt = 0
for i, base_depth in enumerate(self.base_depth_list):
depth = base_depth + d[i]
for j in range(start_pt, start_pt + depth):
feature[self.e_info['val2id'][j][e[j]]] = 1
start_pt += max(self.depth_list) + base_depth
return feature
def feature2arch(self, feature):
img_sz = self.r_info['id2val'][
int(np.argmax(feature[self.r_info['L'][0]:self.r_info['R'][0]])) + self.r_info['L'][0]
]
input_stem_skip = self.input_stem_d_info['id2val'][
int(np.argmax(feature[self.input_stem_d_info['L'][0]:self.input_stem_d_info['R'][0]])) +
self.input_stem_d_info['L'][0]
] * 2
assert img_sz in self.image_size_list
arch_dict = {'d': [input_stem_skip], 'e': [], 'w': [], 'image_size': img_sz}
for i in range(self.n_stage + 2):
arch_dict['w'].append(
self.width_mult_info['id2val'][i][
int(np.argmax(feature[self.width_mult_info['L'][i]:self.width_mult_info['R'][i]])) +
self.width_mult_info['L'][i]
]
)
d = 0
skipped = 0
stage_id = 0
for i in range(self.max_n_blocks):
skip = True
for j in range(self.e_info['L'][i], self.e_info['R'][i]):
if feature[j] == 1:
arch_dict['e'].append(self.e_info['id2val'][i][j])
skip = False
break
if skip:
arch_dict['e'].append(0)
skipped += 1
else:
d += 1
if i + 1 == self.max_n_blocks or (skipped + d) % \
(max(self.depth_list) + self.base_depth_list[stage_id]) == 0:
arch_dict['d'].append(d - self.base_depth_list[stage_id])
d, skipped = 0, 0
stage_id += 1
return arch_dict
def random_sample_arch(self):
return {
'd': [random.choice([0, 2])] + random.choices(self.depth_list, k=self.n_stage),
'e': random.choices(self.expand_list, k=self.max_n_blocks),
'w': random.choices(list(range(len(self.width_mult_list))), k=self.n_stage + 2),
'image_size': random.choice(self.image_size_list)
}
def mutate_resolution(self, arch_dict, mutate_prob):
if random.random() < mutate_prob:
arch_dict['image_size'] = random.choice(self.image_size_list)
return arch_dict
def mutate_arch(self, arch_dict, mutate_prob):
# input stem skip
if random.random() < mutate_prob:
arch_dict['d'][0] = random.choice([0, 2])
# depth
for i in range(1, len(arch_dict['d'])):
if random.random() < mutate_prob:
arch_dict['d'][i] = random.choice(self.depth_list)
# width_mult
for i in range(len(arch_dict['w'])):
if random.random() < mutate_prob:
arch_dict['w'][i] = random.choice(list(range(len(self.width_mult_list))))
# expand ratio
for i in range(len(arch_dict['e'])):
if random.random() < mutate_prob:
arch_dict['e'][i] = random.choice(self.expand_list)

View File

@@ -0,0 +1,71 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import copy
from .latency_lookup_table import *
class BaseEfficiencyModel:
def __init__(self, ofa_net):
self.ofa_net = ofa_net
def get_active_subnet_config(self, arch_dict):
arch_dict = copy.deepcopy(arch_dict)
image_size = arch_dict.pop('image_size')
self.ofa_net.set_active_subnet(**arch_dict)
active_net_config = self.ofa_net.get_active_net_config()
return active_net_config, image_size
def get_efficiency(self, arch_dict):
raise NotImplementedError
class ProxylessNASFLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return ProxylessNASLatencyTable.count_flops_given_config(active_net_config, image_size)
class Mbv3FLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return MBv3LatencyTable.count_flops_given_config(active_net_config, image_size)
class ResNet50FLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return ResNet50LatencyTable.count_flops_given_config(active_net_config, image_size)
class ProxylessNASLatencyModel(BaseEfficiencyModel):
def __init__(self, ofa_net, lookup_table_path_dict):
super(ProxylessNASLatencyModel, self).__init__(ofa_net)
self.latency_tables = {}
for image_size, path in lookup_table_path_dict.items():
self.latency_tables[image_size] = ProxylessNASLatencyTable(
local_dir='/tmp/.ofa_latency_tools/', url=os.path.join(path, '%d_lookup_table.yaml' % image_size))
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return self.latency_tables[image_size].predict_network_latency_given_config(active_net_config, image_size)
class Mbv3LatencyModel(BaseEfficiencyModel):
def __init__(self, ofa_net, lookup_table_path_dict):
super(Mbv3LatencyModel, self).__init__(ofa_net)
self.latency_tables = {}
for image_size, path in lookup_table_path_dict.items():
self.latency_tables[image_size] = MBv3LatencyTable(
local_dir='/tmp/.ofa_latency_tools/', url=os.path.join(path, '%d_lookup_table.yaml' % image_size))
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return self.latency_tables[image_size].predict_network_latency_given_config(active_net_config, image_size)

View File

@@ -0,0 +1,387 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import yaml
from ofa.utils import download_url, make_divisible, MyNetwork
__all__ = ['count_conv_flop', 'ProxylessNASLatencyTable', 'MBv3LatencyTable', 'ResNet50LatencyTable']
def count_conv_flop(out_size, in_channels, out_channels, kernel_size, groups):
out_h = out_w = out_size
delta_ops = in_channels * out_channels * kernel_size * kernel_size * out_h * out_w / groups
return delta_ops
class LatencyTable(object):
def __init__(self, local_dir='~/.ofa/latency_tools/',
url='https://hanlab.mit.edu/files/proxylessNAS/LatencyTools/mobile_trim.yaml'):
if url.startswith('http'):
fname = download_url(url, local_dir, overwrite=True)
else:
fname = url
with open(fname, 'r') as fp:
self.lut = yaml.load(fp)
@staticmethod
def repr_shape(shape):
if isinstance(shape, (list, tuple)):
return 'x'.join(str(_) for _ in shape)
elif isinstance(shape, str):
return shape
else:
return TypeError
def query(self, **kwargs):
raise NotImplementedError
def predict_network_latency(self, net, image_size):
raise NotImplementedError
def predict_network_latency_given_config(self, net_config, image_size):
raise NotImplementedError
@staticmethod
def count_flops_given_config(net_config, image_size=224):
raise NotImplementedError
class ProxylessNASLatencyTable(LatencyTable):
def query(self, l_type: str, input_shape, output_shape, expand=None, ks=None, stride=None, id_skip=None):
"""
:param l_type:
Layer type must be one of the followings
1. `Conv`: The initial 3x3 conv with stride 2.
2. `Conv_1`: feature_mix_layer
3. `Logits`: All operations after `Conv_1`.
4. `expanded_conv`: MobileInvertedResidual
:param input_shape: input shape (h, w, #channels)
:param output_shape: output shape (h, w, #channels)
:param expand: expansion ratio
:param ks: kernel size
:param stride:
:param id_skip: indicate whether has the residual connection
"""
infos = [l_type, 'input:%s' % self.repr_shape(input_shape), 'output:%s' % self.repr_shape(output_shape), ]
if l_type in ('expanded_conv',):
assert None not in (expand, ks, stride, id_skip)
infos += ['expand:%d' % expand, 'kernel:%d' % ks, 'stride:%d' % stride, 'idskip:%d' % id_skip]
key = '-'.join(infos)
return self.lut[key]['mean']
def predict_network_latency(self, net, image_size=224):
predicted_latency = 0
# first conv
predicted_latency += self.query(
'Conv', [image_size, image_size, 3],
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels]
)
# blocks
fsize = (image_size + 1) // 2
for block in net.blocks:
mb_conv = block.conv
shortcut = block.shortcut
if mb_conv is None:
continue
if shortcut is None:
idskip = 0
else:
idskip = 1
out_fz = int((fsize - 1) / mb_conv.stride + 1) # fsize // mb_conv.stride
block_latency = self.query(
'expanded_conv', [fsize, fsize, mb_conv.in_channels], [out_fz, out_fz, mb_conv.out_channels],
expand=mb_conv.expand_ratio, ks=mb_conv.kernel_size, stride=mb_conv.stride, id_skip=idskip
)
predicted_latency += block_latency
fsize = out_fz
# feature mix layer
predicted_latency += self.query(
'Conv_1', [fsize, fsize, net.feature_mix_layer.in_channels],
[fsize, fsize, net.feature_mix_layer.out_channels]
)
# classifier
predicted_latency += self.query(
'Logits', [fsize, fsize, net.classifier.in_features], [net.classifier.out_features] # 1000
)
return predicted_latency
def predict_network_latency_given_config(self, net_config, image_size=224):
predicted_latency = 0
# first conv
predicted_latency += self.query(
'Conv', [image_size, image_size, 3],
[(image_size + 1) // 2, (image_size + 1) // 2, net_config['first_conv']['out_channels']]
)
# blocks
fsize = (image_size + 1) // 2
for block in net_config['blocks']:
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
shortcut = block['shortcut']
if mb_conv is None:
continue
if shortcut is None:
idskip = 0
else:
idskip = 1
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
block_latency = self.query(
'expanded_conv', [fsize, fsize, mb_conv['in_channels']], [out_fz, out_fz, mb_conv['out_channels']],
expand=mb_conv['expand_ratio'], ks=mb_conv['kernel_size'], stride=mb_conv['stride'], id_skip=idskip
)
predicted_latency += block_latency
fsize = out_fz
# feature mix layer
predicted_latency += self.query(
'Conv_1', [fsize, fsize, net_config['feature_mix_layer']['in_channels']],
[fsize, fsize, net_config['feature_mix_layer']['out_channels']]
)
# classifier
predicted_latency += self.query(
'Logits', [fsize, fsize, net_config['classifier']['in_features']],
[net_config['classifier']['out_features']] # 1000
)
return predicted_latency
@staticmethod
def count_flops_given_config(net_config, image_size=224):
flops = 0
# first conv
flops += count_conv_flop((image_size + 1) // 2, 3, net_config['first_conv']['out_channels'], 3, 1)
# blocks
fsize = (image_size + 1) // 2
for block in net_config['blocks']:
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
if mb_conv is None:
continue
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
if mb_conv['mid_channels'] is None:
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
if mb_conv['expand_ratio'] != 1:
# inverted bottleneck
flops += count_conv_flop(fsize, mb_conv['in_channels'], mb_conv['mid_channels'], 1, 1)
# depth conv
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['mid_channels'],
mb_conv['kernel_size'], mb_conv['mid_channels'])
# point linear
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['out_channels'], 1, 1)
fsize = out_fz
# feature mix layer
flops += count_conv_flop(fsize, net_config['feature_mix_layer']['in_channels'],
net_config['feature_mix_layer']['out_channels'], 1, 1)
# classifier
flops += count_conv_flop(1, net_config['classifier']['in_features'],
net_config['classifier']['out_features'], 1, 1)
return flops / 1e6 # MFLOPs
class MBv3LatencyTable(LatencyTable):
def query(self, l_type: str, input_shape, output_shape, mid=None, ks=None, stride=None, id_skip=None,
se=None, h_swish=None):
infos = [l_type, 'input:%s' % self.repr_shape(input_shape), 'output:%s' % self.repr_shape(output_shape), ]
if l_type in ('expanded_conv',):
assert None not in (mid, ks, stride, id_skip, se, h_swish)
infos += ['expand:%d' % mid, 'kernel:%d' % ks, 'stride:%d' % stride, 'idskip:%d' % id_skip,
'se:%d' % se, 'hs:%d' % h_swish]
key = '-'.join(infos)
return self.lut[key]['mean']
def predict_network_latency(self, net, image_size=224):
predicted_latency = 0
# first conv
predicted_latency += self.query(
'Conv', [image_size, image_size, 3],
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels]
)
# blocks
fsize = (image_size + 1) // 2
for block in net.blocks:
mb_conv = block.conv
shortcut = block.shortcut
if mb_conv is None:
continue
if shortcut is None:
idskip = 0
else:
idskip = 1
out_fz = int((fsize - 1) / mb_conv.stride + 1)
block_latency = self.query(
'expanded_conv', [fsize, fsize, mb_conv.in_channels], [out_fz, out_fz, mb_conv.out_channels],
mid=mb_conv.depth_conv.conv.in_channels, ks=mb_conv.kernel_size, stride=mb_conv.stride, id_skip=idskip,
se=1 if mb_conv.use_se else 0, h_swish=1 if mb_conv.act_func == 'h_swish' else 0,
)
predicted_latency += block_latency
fsize = out_fz
# final expand layer
predicted_latency += self.query(
'Conv_1', [fsize, fsize, net.final_expand_layer.in_channels],
[fsize, fsize, net.final_expand_layer.out_channels],
)
# global average pooling
predicted_latency += self.query(
'AvgPool2D', [fsize, fsize, net.final_expand_layer.out_channels],
[1, 1, net.final_expand_layer.out_channels],
)
# feature mix layer
predicted_latency += self.query(
'Conv_2', [1, 1, net.feature_mix_layer.in_channels],
[1, 1, net.feature_mix_layer.out_channels]
)
# classifier
predicted_latency += self.query(
'Logits', [1, 1, net.classifier.in_features], [net.classifier.out_features]
)
return predicted_latency
def predict_network_latency_given_config(self, net_config, image_size=224):
predicted_latency = 0
# first conv
predicted_latency += self.query(
'Conv', [image_size, image_size, 3],
[(image_size + 1) // 2, (image_size + 1) // 2, net_config['first_conv']['out_channels']]
)
# blocks
fsize = (image_size + 1) // 2
for block in net_config['blocks']:
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
shortcut = block['shortcut']
if mb_conv is None:
continue
if shortcut is None:
idskip = 0
else:
idskip = 1
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
if mb_conv['mid_channels'] is None:
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
block_latency = self.query(
'expanded_conv', [fsize, fsize, mb_conv['in_channels']], [out_fz, out_fz, mb_conv['out_channels']],
mid=mb_conv['mid_channels'], ks=mb_conv['kernel_size'], stride=mb_conv['stride'], id_skip=idskip,
se=1 if mb_conv['use_se'] else 0, h_swish=1 if mb_conv['act_func'] == 'h_swish' else 0,
)
predicted_latency += block_latency
fsize = out_fz
# final expand layer
predicted_latency += self.query(
'Conv_1', [fsize, fsize, net_config['final_expand_layer']['in_channels']],
[fsize, fsize, net_config['final_expand_layer']['out_channels']],
)
# global average pooling
predicted_latency += self.query(
'AvgPool2D', [fsize, fsize, net_config['final_expand_layer']['out_channels']],
[1, 1, net_config['final_expand_layer']['out_channels']],
)
# feature mix layer
predicted_latency += self.query(
'Conv_2', [1, 1, net_config['feature_mix_layer']['in_channels']],
[1, 1, net_config['feature_mix_layer']['out_channels']]
)
# classifier
predicted_latency += self.query(
'Logits', [1, 1, net_config['classifier']['in_features']], [net_config['classifier']['out_features']]
)
return predicted_latency
@staticmethod
def count_flops_given_config(net_config, image_size=224):
flops = 0
# first conv
flops += count_conv_flop((image_size + 1) // 2, 3, net_config['first_conv']['out_channels'], 3, 1)
# blocks
fsize = (image_size + 1) // 2
for block in net_config['blocks']:
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
if mb_conv is None:
continue
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
if mb_conv['mid_channels'] is None:
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
if mb_conv['expand_ratio'] != 1:
# inverted bottleneck
flops += count_conv_flop(fsize, mb_conv['in_channels'], mb_conv['mid_channels'], 1, 1)
# depth conv
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['mid_channels'],
mb_conv['kernel_size'], mb_conv['mid_channels'])
if mb_conv['use_se']:
# SE layer
se_mid = make_divisible(mb_conv['mid_channels'] // 4, divisor=MyNetwork.CHANNEL_DIVISIBLE)
flops += count_conv_flop(1, mb_conv['mid_channels'], se_mid, 1, 1)
flops += count_conv_flop(1, se_mid, mb_conv['mid_channels'], 1, 1)
# point linear
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['out_channels'], 1, 1)
fsize = out_fz
# final expand layer
flops += count_conv_flop(fsize, net_config['final_expand_layer']['in_channels'],
net_config['final_expand_layer']['out_channels'], 1, 1)
# feature mix layer
flops += count_conv_flop(1, net_config['feature_mix_layer']['in_channels'],
net_config['feature_mix_layer']['out_channels'], 1, 1)
# classifier
flops += count_conv_flop(1, net_config['classifier']['in_features'],
net_config['classifier']['out_features'], 1, 1)
return flops / 1e6 # MFLOPs
class ResNet50LatencyTable(LatencyTable):
def query(self, **kwargs):
raise NotImplementedError
def predict_network_latency(self, net, image_size):
raise NotImplementedError
def predict_network_latency_given_config(self, net_config, image_size):
raise NotImplementedError
@staticmethod
def count_flops_given_config(net_config, image_size=224):
flops = 0
# input stem
for layer_config in net_config['input_stem']:
if layer_config['name'] != 'ConvLayer':
layer_config = layer_config['conv']
in_channel = layer_config['in_channels']
out_channel = layer_config['out_channels']
out_image_size = int((image_size - 1) / layer_config['stride'] + 1)
flops += count_conv_flop(out_image_size, in_channel, out_channel,
layer_config['kernel_size'], layer_config.get('groups', 1))
image_size = out_image_size
# max pooling
image_size = int((image_size - 1) / 2 + 1)
# ResNetBottleneckBlocks
for block_config in net_config['blocks']:
in_channel = block_config['in_channels']
out_channel = block_config['out_channels']
out_image_size = int((image_size - 1) / block_config['stride'] + 1)
mid_channel = block_config['mid_channels'] if block_config['mid_channels'] is not None \
else round(out_channel * block_config['expand_ratio'])
mid_channel = make_divisible(mid_channel, MyNetwork.CHANNEL_DIVISIBLE)
# conv1
flops += count_conv_flop(image_size, in_channel, mid_channel, 1, 1)
# conv2
flops += count_conv_flop(out_image_size, mid_channel, mid_channel,
block_config['kernel_size'], block_config['groups'])
# conv3
flops += count_conv_flop(out_image_size, mid_channel, out_channel, 1, 1)
# downsample
if block_config['stride'] == 1 and in_channel == out_channel:
pass
else:
flops += count_conv_flop(out_image_size, in_channel, out_channel, 1, 1)
image_size = out_image_size
# final classifier
flops += count_conv_flop(1, net_config['classifier']['in_features'],
net_config['classifier']['out_features'], 1, 1)
return flops / 1e6 # MFLOPs

View File

@@ -0,0 +1,5 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .evolution import *

View File

@@ -0,0 +1,134 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import random
import numpy as np
from tqdm import tqdm
__all__ = ['EvolutionFinder']
class EvolutionFinder:
def __init__(self, efficiency_predictor, accuracy_predictor, **kwargs):
self.efficiency_predictor = efficiency_predictor
self.accuracy_predictor = accuracy_predictor
# evolution hyper-parameters
self.arch_mutate_prob = kwargs.get('arch_mutate_prob', 0.1)
self.resolution_mutate_prob = kwargs.get('resolution_mutate_prob', 0.5)
self.population_size = kwargs.get('population_size', 100)
self.max_time_budget = kwargs.get('max_time_budget', 500)
self.parent_ratio = kwargs.get('parent_ratio', 0.25)
self.mutation_ratio = kwargs.get('mutation_ratio', 0.5)
@property
def arch_manager(self):
return self.accuracy_predictor.arch_encoder
def update_hyper_params(self, new_param_dict):
self.__dict__.update(new_param_dict)
def random_valid_sample(self, constraint):
while True:
sample = self.arch_manager.random_sample_arch()
efficiency = self.efficiency_predictor.get_efficiency(sample)
if efficiency <= constraint:
return sample, efficiency
def mutate_sample(self, sample, constraint):
while True:
new_sample = copy.deepcopy(sample)
self.arch_manager.mutate_resolution(new_sample, self.resolution_mutate_prob)
self.arch_manager.mutate_arch(new_sample, self.arch_mutate_prob)
efficiency = self.efficiency_predictor.get_efficiency(new_sample)
if efficiency <= constraint:
return new_sample, efficiency
def crossover_sample(self, sample1, sample2, constraint):
while True:
new_sample = copy.deepcopy(sample1)
for key in new_sample.keys():
if not isinstance(new_sample[key], list):
new_sample[key] = random.choice([sample1[key], sample2[key]])
else:
for i in range(len(new_sample[key])):
new_sample[key][i] = random.choice([sample1[key][i], sample2[key][i]])
efficiency = self.efficiency_predictor.get_efficiency(new_sample)
if efficiency <= constraint:
return new_sample, efficiency
def run_evolution_search(self, constraint, verbose=False, **kwargs):
"""Run a single roll-out of regularized evolution to a fixed time budget."""
self.update_hyper_params(kwargs)
mutation_numbers = int(round(self.mutation_ratio * self.population_size))
parents_size = int(round(self.parent_ratio * self.population_size))
best_valids = [-100]
population = [] # (validation, sample, latency) tuples
child_pool = []
efficiency_pool = []
best_info = None
if verbose:
print('Generate random population...')
for _ in range(self.population_size):
sample, efficiency = self.random_valid_sample(constraint)
child_pool.append(sample)
efficiency_pool.append(efficiency)
accs = self.accuracy_predictor.predict_acc(child_pool)
for i in range(mutation_numbers):
population.append((accs[i].item(), child_pool[i], efficiency_pool[i]))
if verbose:
print('Start Evolution...')
# After the population is seeded, proceed with evolving the population.
with tqdm(total=self.max_time_budget, desc='Searching with constraint (%s)' % constraint,
disable=(not verbose)) as t:
for i in range(self.max_time_budget):
parents = sorted(population, key=lambda x: x[0])[::-1][:parents_size]
acc = parents[0][0]
t.set_postfix({
'acc': parents[0][0]
})
if not verbose and (i + 1) % 100 == 0:
print('Iter: {} Acc: {}'.format(i + 1, parents[0][0]))
if acc > best_valids[-1]:
best_valids.append(acc)
best_info = parents[0]
else:
best_valids.append(best_valids[-1])
population = parents
child_pool = []
efficiency_pool = []
for j in range(mutation_numbers):
par_sample = population[np.random.randint(parents_size)][1]
# Mutate
new_sample, efficiency = self.mutate_sample(par_sample, constraint)
child_pool.append(new_sample)
efficiency_pool.append(efficiency)
for j in range(self.population_size - mutation_numbers):
par_sample1 = population[np.random.randint(parents_size)][1]
par_sample2 = population[np.random.randint(parents_size)][1]
# Crossover
new_sample, efficiency = self.crossover_sample(par_sample1, par_sample2, constraint)
child_pool.append(new_sample)
efficiency_pool.append(efficiency)
accs = self.accuracy_predictor.predict_acc(child_pool)
for j in range(self.population_size):
population.append((accs[j].item(), child_pool[j], efficiency_pool[j]))
t.update(1)
return best_valids, best_info

View File

@@ -0,0 +1,5 @@
from .accuracy_predictor import AccuracyPredictor
from .flops_table import FLOPsTable
from .latency_table import LatencyTable
from .evolution_finder import EvolutionFinder, ArchManager
from .imagenet_eval_helper import evaluate_ofa_subnet, evaluate_ofa_specialized

View File

@@ -0,0 +1,85 @@
import torch.nn as nn
import torch
import copy
from ofa.utils import download_url
# Helper for constructing the one-hot vectors.
def construct_maps(keys):
d = dict()
keys = list(set(keys))
for k in keys:
if k not in d:
d[k] = len(list(d.keys()))
return d
ks_map = construct_maps(keys=(3, 5, 7))
ex_map = construct_maps(keys=(3, 4, 6))
dp_map = construct_maps(keys=(2, 3, 4))
class AccuracyPredictor:
def __init__(self, pretrained=True, device='cuda:0'):
self.device = device
self.model = nn.Sequential(
nn.Linear(128, 400),
nn.ReLU(),
nn.Linear(400, 400),
nn.ReLU(),
nn.Linear(400, 400),
nn.ReLU(),
nn.Linear(400, 1),
)
if pretrained:
# load pretrained model
fname = download_url("https://hanlab.mit.edu/files/OnceForAll/tutorial/acc_predictor.pth")
self.model.load_state_dict(
torch.load(fname, map_location=torch.device('cpu'))
)
self.model = self.model.to(self.device)
# TODO: merge it with serialization utils.
@torch.no_grad()
def predict_accuracy(self, population):
all_feats = []
for sample in population:
ks_list = copy.deepcopy(sample['ks'])
ex_list = copy.deepcopy(sample['e'])
d_list = copy.deepcopy(sample['d'])
r = copy.deepcopy(sample['r'])[0]
feats = AccuracyPredictor.spec2feats(ks_list, ex_list, d_list, r).reshape(1, -1).to(self.device)
all_feats.append(feats)
all_feats = torch.cat(all_feats, 0)
pred = self.model(all_feats).cpu()
return pred
@staticmethod
def spec2feats(ks_list, ex_list, d_list, r):
# This function converts a network config to a feature vector (128-D).
start = 0
end = 4
for d in d_list:
for j in range(start+d, end):
ks_list[j] = 0
ex_list[j] = 0
start += 4
end += 4
# convert to onehot
ks_onehot = [0 for _ in range(60)]
ex_onehot = [0 for _ in range(60)]
r_onehot = [0 for _ in range(8)]
for i in range(20):
start = i * 3
if ks_list[i] != 0:
ks_onehot[start + ks_map[ks_list[i]]] = 1
if ex_list[i] != 0:
ex_onehot[start + ex_map[ex_list[i]]] = 1
r_onehot[(r - 112) // 16] = 1
return torch.Tensor(ks_onehot + ex_onehot + r_onehot)

View File

@@ -0,0 +1,213 @@
import copy
import random
from tqdm import tqdm
import numpy as np
__all__ = ['EvolutionFinder']
class ArchManager:
def __init__(self):
self.num_blocks = 20
self.num_stages = 5
self.kernel_sizes = [3, 5, 7]
self.expand_ratios = [3, 4, 6]
self.depths = [2, 3, 4]
self.resolutions = [160, 176, 192, 208, 224]
def random_sample(self):
sample = {}
d = []
e = []
ks = []
for i in range(self.num_stages):
d.append(random.choice(self.depths))
for i in range(self.num_blocks):
e.append(random.choice(self.expand_ratios))
ks.append(random.choice(self.kernel_sizes))
sample = {
'wid': None,
'ks': ks,
'e': e,
'd': d,
'r': [random.choice(self.resolutions)]
}
return sample
def random_resample(self, sample, i):
assert i >= 0 and i < self.num_blocks
sample['ks'][i] = random.choice(self.kernel_sizes)
sample['e'][i] = random.choice(self.expand_ratios)
def random_resample_depth(self, sample, i):
assert i >= 0 and i < self.num_stages
sample['d'][i] = random.choice(self.depths)
def random_resample_resolution(self, sample):
sample['r'][0] = random.choice(self.resolutions)
class EvolutionFinder:
valid_constraint_range = {
'flops': [150, 600],
'note10': [15, 60],
}
def __init__(self, constraint_type, efficiency_constraint,
efficiency_predictor, accuracy_predictor, **kwargs):
self.constraint_type = constraint_type
if not constraint_type in self.valid_constraint_range.keys():
self.invite_reset_constraint_type()
self.efficiency_constraint = efficiency_constraint
if not (efficiency_constraint <= self.valid_constraint_range[constraint_type][1] and
efficiency_constraint >= self.valid_constraint_range[constraint_type][0]):
self.invite_reset_constraint()
self.efficiency_predictor = efficiency_predictor
self.accuracy_predictor = accuracy_predictor
self.arch_manager = ArchManager()
self.num_blocks = self.arch_manager.num_blocks
self.num_stages = self.arch_manager.num_stages
self.mutate_prob = kwargs.get('mutate_prob', 0.1)
self.population_size = kwargs.get('population_size', 100)
self.max_time_budget = kwargs.get('max_time_budget', 500)
self.parent_ratio = kwargs.get('parent_ratio', 0.25)
self.mutation_ratio = kwargs.get('mutation_ratio', 0.5)
def invite_reset_constraint_type(self):
print('Invalid constraint type! Please input one of:', list(self.valid_constraint_range.keys()))
new_type = input()
while new_type not in self.valid_constraint_range.keys():
print('Invalid constraint type! Please input one of:', list(self.valid_constraint_range.keys()))
new_type = input()
self.constraint_type = new_type
def invite_reset_constraint(self):
print('Invalid constraint_value! Please input an integer in interval: [%d, %d]!' % (
self.valid_constraint_range[self.constraint_type][0],
self.valid_constraint_range[self.constraint_type][1])
)
new_cons = input()
while (not new_cons.isdigit()) or (int(new_cons) > self.valid_constraint_range[self.constraint_type][1]) or \
(int(new_cons) < self.valid_constraint_range[self.constraint_type][0]):
print('Invalid constraint_value! Please input an integer in interval: [%d, %d]!' % (
self.valid_constraint_range[self.constraint_type][0],
self.valid_constraint_range[self.constraint_type][1])
)
new_cons = input()
new_cons = int(new_cons)
self.efficiency_constraint = new_cons
def set_efficiency_constraint(self, new_constraint):
self.efficiency_constraint = new_constraint
def random_sample(self):
constraint = self.efficiency_constraint
while True:
sample = self.arch_manager.random_sample()
efficiency = self.efficiency_predictor.predict_efficiency(sample)
if efficiency <= constraint:
return sample, efficiency
def mutate_sample(self, sample):
constraint = self.efficiency_constraint
while True:
new_sample = copy.deepcopy(sample)
if random.random() < self.mutate_prob:
self.arch_manager.random_resample_resolution(new_sample)
for i in range(self.num_blocks):
if random.random() < self.mutate_prob:
self.arch_manager.random_resample(new_sample, i)
for i in range(self.num_stages):
if random.random() < self.mutate_prob:
self.arch_manager.random_resample_depth(new_sample, i)
efficiency = self.efficiency_predictor.predict_efficiency(new_sample)
if efficiency <= constraint:
return new_sample, efficiency
def crossover_sample(self, sample1, sample2):
constraint = self.efficiency_constraint
while True:
new_sample = copy.deepcopy(sample1)
for key in new_sample.keys():
if not isinstance(new_sample[key], list):
continue
for i in range(len(new_sample[key])):
new_sample[key][i] = random.choice([sample1[key][i], sample2[key][i]])
efficiency = self.efficiency_predictor.predict_efficiency(new_sample)
if efficiency <= constraint:
return new_sample, efficiency
def run_evolution_search(self, verbose=False):
"""Run a single roll-out of regularized evolution to a fixed time budget."""
max_time_budget = self.max_time_budget
population_size = self.population_size
mutation_numbers = int(round(self.mutation_ratio * population_size))
parents_size = int(round(self.parent_ratio * population_size))
constraint = self.efficiency_constraint
best_valids = [-100]
population = [] # (validation, sample, latency) tuples
child_pool = []
efficiency_pool = []
best_info = None
if verbose:
print('Generate random population...')
for _ in range(population_size):
sample, efficiency = self.random_sample()
child_pool.append(sample)
efficiency_pool.append(efficiency)
accs = self.accuracy_predictor.predict_accuracy(child_pool)
for i in range(mutation_numbers):
population.append((accs[i].item(), child_pool[i], efficiency_pool[i]))
if verbose:
print('Start Evolution...')
# After the population is seeded, proceed with evolving the population.
for iter in tqdm(range(max_time_budget), desc='Searching with %s constraint (%s)' % (self.constraint_type, self.efficiency_constraint)):
parents = sorted(population, key=lambda x: x[0])[::-1][:parents_size]
acc = parents[0][0]
if verbose:
print('Iter: {} Acc: {}'.format(iter - 1, parents[0][0]))
if acc > best_valids[-1]:
best_valids.append(acc)
best_info = parents[0]
else:
best_valids.append(best_valids[-1])
population = parents
child_pool = []
efficiency_pool = []
for i in range(mutation_numbers):
par_sample = population[np.random.randint(parents_size)][1]
# Mutate
new_sample, efficiency = self.mutate_sample(par_sample)
child_pool.append(new_sample)
efficiency_pool.append(efficiency)
for i in range(population_size - mutation_numbers):
par_sample1 = population[np.random.randint(parents_size)][1]
par_sample2 = population[np.random.randint(parents_size)][1]
# Crossover
new_sample, efficiency = self.crossover_sample(par_sample1, par_sample2)
child_pool.append(new_sample)
efficiency_pool.append(efficiency)
accs = self.accuracy_predictor.predict_accuracy(child_pool)
for i in range(population_size):
population.append((accs[i].item(), child_pool[i], efficiency_pool[i]))
return best_valids, best_info

View File

@@ -0,0 +1,224 @@
import time
import copy
import torch
import torch.nn as nn
import numpy as np
from ofa.utils.layers import *
__all__ = ['FLOPsTable']
def rm_bn_from_net(net):
for m in net.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
m.forward = lambda x: x
class FLOPsTable:
def __init__(self, pred_type='flops', device='cuda:0', multiplier=1.2, batch_size=64, load_efficiency_table=None):
assert pred_type in ['flops', 'latency']
self.multiplier = multiplier
self.pred_type = pred_type
self.device = device
self.batch_size = batch_size
self.efficiency_dict = {}
if load_efficiency_table is not None:
self.efficiency_dict = np.load(load_efficiency_table, allow_pickle=True).item()
else:
self.build_lut(batch_size)
@torch.no_grad()
def measure_single_layer_latency(self, layer: nn.Module, input_size: tuple, warmup_steps=10, measure_steps=50):
total_time = 0
inputs = torch.randn(*input_size, device=self.device)
layer.eval()
rm_bn_from_net(layer)
network = layer.to(self.device)
torch.cuda.synchronize()
for i in range(warmup_steps):
network(inputs)
torch.cuda.synchronize()
torch.cuda.synchronize()
st = time.time()
for i in range(measure_steps):
network(inputs)
torch.cuda.synchronize()
ed = time.time()
total_time += ed - st
latency = total_time / measure_steps * 1000
return latency
@torch.no_grad()
def measure_single_layer_flops(self, layer: nn.Module, input_size: tuple):
import thop
inputs = torch.randn(*input_size, device=self.device)
network = layer.to(self.device)
layer.eval()
rm_bn_from_net(layer)
flops, params = thop.profile(network, (inputs,), verbose=False)
return flops / 1e6
def build_lut(self, batch_size=1, resolutions=[160, 176, 192, 208, 224]):
for resolution in resolutions:
self.build_single_lut(batch_size, resolution)
np.save('local_lut.npy', self.efficiency_dict)
def build_single_lut(self, batch_size=1, base_resolution=224):
print('Building the %s lookup table (resolution=%d)...' % (self.pred_type, base_resolution))
# block, input_size, in_channels, out_channels, expand_ratio, kernel_size, stride, act, se
configurations = [
(ConvLayer, base_resolution, 3, 16, 3, 2, 'relu'),
(ResidualBlock, base_resolution // 2, 16, 16, [1], [3, 5, 7], 1, 'relu', False),
(ResidualBlock, base_resolution // 2, 16, 24, [3, 4, 6], [3, 5, 7], 2, 'relu', False),
(ResidualBlock, base_resolution // 4, 24, 24, [3, 4, 6], [3, 5, 7], 1, 'relu', False),
(ResidualBlock, base_resolution // 4, 24, 24, [3, 4, 6], [3, 5, 7], 1, 'relu', False),
(ResidualBlock, base_resolution // 4, 24, 24, [3, 4, 6], [3, 5, 7], 1, 'relu', False),
(ResidualBlock, base_resolution // 4, 24, 40, [3, 4, 6], [3, 5, 7], 2, 'relu', True),
(ResidualBlock, base_resolution // 8, 40, 40, [3, 4, 6], [3, 5, 7], 1, 'relu', True),
(ResidualBlock, base_resolution // 8, 40, 40, [3, 4, 6], [3, 5, 7], 1, 'relu', True),
(ResidualBlock, base_resolution // 8, 40, 40, [3, 4, 6], [3, 5, 7], 1, 'relu', True),
(ResidualBlock, base_resolution // 8, 40, 80, [3, 4, 6], [3, 5, 7], 2, 'h_swish', False),
(ResidualBlock, base_resolution // 16, 80, 80, [3, 4, 6], [3, 5, 7], 1, 'h_swish', False),
(ResidualBlock, base_resolution // 16, 80, 80, [3, 4, 6], [3, 5, 7], 1, 'h_swish', False),
(ResidualBlock, base_resolution // 16, 80, 80, [3, 4, 6], [3, 5, 7], 1, 'h_swish', False),
(ResidualBlock, base_resolution // 16, 80, 112, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
(ResidualBlock, base_resolution // 16, 112, 112, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
(ResidualBlock, base_resolution // 16, 112, 112, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
(ResidualBlock, base_resolution // 16, 112, 112, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
(ResidualBlock, base_resolution // 16, 112, 160, [3, 4, 6], [3, 5, 7], 2, 'h_swish', True),
(ResidualBlock, base_resolution // 32, 160, 160, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
(ResidualBlock, base_resolution // 32, 160, 160, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
(ResidualBlock, base_resolution // 32, 160, 160, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
(ConvLayer, base_resolution // 32, 160, 960, 1, 1, 'h_swish'),
(ConvLayer, 1, 960, 1280, 1, 1, 'h_swish'),
(LinearLayer, 1, 1280, 1000, 1, 1),
]
efficiency_dict = {
'mobile_inverted_blocks': [],
'other_blocks': {}
}
for layer_idx in range(len(configurations)):
config = configurations[layer_idx]
op_type = config[0]
if op_type == ResidualBlock:
_, input_size, in_channels, out_channels, expand_list, ks_list, stride, act, se = config
in_channels = int(round(in_channels * self.multiplier))
out_channels = int(round(out_channels * self.multiplier))
template_config = {
'name': ResidualBlock.__name__,
'mobile_inverted_conv': {
'name': MBConvLayer.__name__,
'in_channels': in_channels,
'out_channels': out_channels,
'kernel_size': kernel_size,
'stride': stride,
'expand_ratio': 0,
# 'mid_channels': None,
'act_func': act,
'use_se': se,
},
'shortcut': {
'name': IdentityLayer.__name__,
'in_channels': in_channels,
'out_channels': out_channels,
} if (in_channels == out_channels and stride == 1) else None
}
sub_dict = {}
for ks in ks_list:
for e in expand_list:
build_config = copy.deepcopy(template_config)
build_config['mobile_inverted_conv']['expand_ratio'] = e
build_config['mobile_inverted_conv']['kernel_size'] = ks
layer = ResidualBlock.build_from_config(build_config)
input_shape = (batch_size, in_channels, input_size, input_size)
if self.pred_type == 'flops':
measure_result = self.measure_single_layer_flops(layer, input_shape) / batch_size
elif self.pred_type == 'latency':
measure_result = self.measure_single_layer_latency(layer, input_shape)
sub_dict[(ks, e)] = measure_result
efficiency_dict['mobile_inverted_blocks'].append(sub_dict)
elif op_type == ConvLayer:
_, input_size, in_channels, out_channels, kernel_size, stride, activation = config
in_channels = int(round(in_channels * self.multiplier))
out_channels = int(round(out_channels * self.multiplier))
build_config = {
# 'name': ConvLayer.__name__,
'in_channels': in_channels,
'out_channels': out_channels,
'kernel_size': kernel_size,
'stride': stride,
'dilation': 1,
'groups': 1,
'bias': False,
'use_bn': True,
'has_shuffle': False,
'act_func': activation,
}
layer = ConvLayer.build_from_config(build_config)
input_shape = (batch_size, in_channels, input_size, input_size)
if self.pred_type == 'flops':
measure_result = self.measure_single_layer_flops(layer, input_shape) / batch_size
elif self.pred_type == 'latency':
measure_result = self.measure_single_layer_latency(layer, input_shape)
efficiency_dict['other_blocks'][layer_idx] = measure_result
elif op_type == LinearLayer:
_, input_size, in_channels, out_channels, kernel_size, stride = config
in_channels = int(round(in_channels * self.multiplier))
out_channels = int(round(out_channels * self.multiplier))
build_config = {
# 'name': LinearLayer.__name__,
'in_features': in_channels,
'out_features': out_channels
}
layer = LinearLayer.build_from_config(build_config)
input_shape = (batch_size, in_channels)
if self.pred_type == 'flops':
measure_result = self.measure_single_layer_flops(layer, input_shape) / batch_size
elif self.pred_type == 'latency':
measure_result = self.measure_single_layer_latency(layer, input_shape)
efficiency_dict['other_blocks'][layer_idx] = measure_result
else:
raise NotImplementedError
self.efficiency_dict[base_resolution] = efficiency_dict
print('Built the %s lookup table (resolution=%d)!' % (self.pred_type, base_resolution))
return efficiency_dict
def predict_efficiency(self, sample):
input_size = sample.get('r', [224])
input_size = input_size[0]
assert 'ks' in sample and 'e' in sample and 'd' in sample
assert len(sample['ks']) == len(sample['e']) and len(sample['ks']) == 20
assert len(sample['d']) == 5
total_stats = 0.
for i in range(20):
stage = i // 4
depth_max = sample['d'][stage]
depth = i % 4 + 1
if depth > depth_max:
continue
ks, e = sample['ks'][i], sample['e'][i]
total_stats += self.efficiency_dict[input_size]['mobile_inverted_blocks'][i + 1][(ks, e)]
for key in self.efficiency_dict[input_size]['other_blocks']:
total_stats += self.efficiency_dict[input_size]['other_blocks'][key]
total_stats += self.efficiency_dict[input_size]['mobile_inverted_blocks'][0][(3, 1)]
return total_stats

View File

@@ -0,0 +1,241 @@
import os.path as osp
import numpy as np
import math
from tqdm import tqdm
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data
from torchvision import transforms, datasets
from ofa.utils import AverageMeter, accuracy
from ofa.model_zoo import ofa_specialized
from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
def evaluate_ofa_subnet(ofa_net, path, net_config, data_loader, batch_size, device='cuda:0'):
assert 'ks' in net_config and 'd' in net_config and 'e' in net_config
assert len(net_config['ks']) == 20 and len(net_config['e']) == 20 and len(net_config['d']) == 5
ofa_net.set_active_subnet(ks=net_config['ks'], d=net_config['d'], e=net_config['e'])
subnet = ofa_net.get_active_subnet().to(device)
calib_bn(subnet, path, net_config['r'][0], batch_size)
top1 = validate(subnet, path, net_config['r'][0], data_loader, batch_size, device)
return top1
def calib_bn(net, path, image_size, batch_size, num_images=2000):
# print('Creating dataloader for resetting BN running statistics...')
dataset = datasets.ImageFolder(
osp.join(
path,
'train'),
transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=32. / 255., saturation=0.5),
transforms.ToTensor(),
transforms.Normalize(
mean=[
0.485,
0.456,
0.406],
std=[
0.229,
0.224,
0.225]
),
])
)
chosen_indexes = np.random.choice(list(range(len(dataset))), num_images)
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sub_sampler,
batch_size=batch_size,
num_workers=16,
pin_memory=True,
drop_last=False,
)
# print('Resetting BN running statistics (this may take 10-20 seconds)...')
set_running_statistics(net, data_loader)
def validate(net, path, image_size, data_loader, batch_size=100, device='cuda:0'):
if 'cuda' in device:
net = torch.nn.DataParallel(net).to(device)
else:
net = net.to(device)
data_loader.dataset.transform = transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
cudnn.benchmark = True
criterion = nn.CrossEntropyLoss().to(device)
net.eval()
net = net.to(device)
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
with torch.no_grad():
with tqdm(total=len(data_loader), desc='Validate') as t:
for i, (images, labels) in enumerate(data_loader):
images, labels = images.to(device), labels.to(device)
# compute output
output = net(images)
loss = criterion(output, labels)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0].item(), images.size(0))
top5.update(acc5[0].item(), images.size(0))
t.set_postfix({
'loss': losses.avg,
'top1': top1.avg,
'top5': top5.avg,
'img_size': images.size(2),
})
t.update(1)
print('Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (losses.avg, top1.avg, top5.avg))
return top1.avg
def evaluate_ofa_specialized(path, data_loader, batch_size=100, device='cuda:0'):
def select_platform_name():
valid_platform_name = [
'pixel1', 'pixel2', 'note10', 'note8', 's7edge', 'lg-g8', '1080ti', 'v100', 'tx2', 'cpu', 'flops'
]
print("Please select a hardware platform from ('pixel1', 'pixel2', 'note10', 'note8', 's7edge', 'lg-g8', '1080ti', 'v100', 'tx2', 'cpu', 'flops')!\n")
while True:
platform_name = input()
platform_name = platform_name.lower()
if platform_name in valid_platform_name:
return platform_name
print("Platform name is invalid! Please select in ('pixel1', 'pixel2', 'note10', 'note8', 's7edge', 'lg-g8', '1080ti', 'v100', 'tx2', 'cpu', 'flops')!\n")
def select_netid(platform_name):
platform_efficiency_map = {
'pixel1': {
143: 'pixel1_lat@143ms_top1@80.1_finetune@75',
132: 'pixel1_lat@132ms_top1@79.8_finetune@75',
79: 'pixel1_lat@79ms_top1@78.7_finetune@75',
58: 'pixel1_lat@58ms_top1@76.9_finetune@75',
40: 'pixel1_lat@40ms_top1@74.9_finetune@25',
28: 'pixel1_lat@28ms_top1@73.3_finetune@25',
20: 'pixel1_lat@20ms_top1@71.4_finetune@25',
},
'pixel2': {
62: 'pixel2_lat@62ms_top1@75.8_finetune@25',
50: 'pixel2_lat@50ms_top1@74.7_finetune@25',
35: 'pixel2_lat@35ms_top1@73.4_finetune@25',
25: 'pixel2_lat@25ms_top1@71.5_finetune@25',
},
'note10': {
64: 'note10_lat@64ms_top1@80.2_finetune@75',
50: 'note10_lat@50ms_top1@79.7_finetune@75',
41: 'note10_lat@41ms_top1@79.3_finetune@75',
30: 'note10_lat@30ms_top1@78.4_finetune@75',
22: 'note10_lat@22ms_top1@76.6_finetune@25',
16: 'note10_lat@16ms_top1@75.5_finetune@25',
11: 'note10_lat@11ms_top1@73.6_finetune@25',
8: 'note10_lat@8ms_top1@71.4_finetune@25',
},
'note8': {
65: 'note8_lat@65ms_top1@76.1_finetune@25',
49: 'note8_lat@49ms_top1@74.9_finetune@25',
31: 'note8_lat@31ms_top1@72.8_finetune@25',
22: 'note8_lat@22ms_top1@70.4_finetune@25',
},
's7edge': {
88: 's7edge_lat@88ms_top1@76.3_finetune@25',
58: 's7edge_lat@58ms_top1@74.7_finetune@25',
41: 's7edge_lat@41ms_top1@73.1_finetune@25',
29: 's7edge_lat@29ms_top1@70.5_finetune@25',
},
'lg-g8': {
24: 'LG-G8_lat@24ms_top1@76.4_finetune@25',
16: 'LG-G8_lat@16ms_top1@74.7_finetune@25',
11: 'LG-G8_lat@11ms_top1@73.0_finetune@25',
8: 'LG-G8_lat@8ms_top1@71.1_finetune@25',
},
'1080ti': {
27: '1080ti_gpu64@27ms_top1@76.4_finetune@25',
22: '1080ti_gpu64@22ms_top1@75.3_finetune@25',
15: '1080ti_gpu64@15ms_top1@73.8_finetune@25',
12: '1080ti_gpu64@12ms_top1@72.6_finetune@25',
},
'v100': {
11: 'v100_gpu64@11ms_top1@76.1_finetune@25',
9: 'v100_gpu64@9ms_top1@75.3_finetune@25',
6: 'v100_gpu64@6ms_top1@73.0_finetune@25',
5: 'v100_gpu64@5ms_top1@71.6_finetune@25',
},
'tx2': {
96: 'tx2_gpu16@96ms_top1@75.8_finetune@25',
80: 'tx2_gpu16@80ms_top1@75.4_finetune@25',
47: 'tx2_gpu16@47ms_top1@72.9_finetune@25',
35: 'tx2_gpu16@35ms_top1@70.3_finetune@25',
},
'cpu': {
17: 'cpu_lat@17ms_top1@75.7_finetune@25',
15: 'cpu_lat@15ms_top1@74.6_finetune@25',
11: 'cpu_lat@11ms_top1@72.0_finetune@25',
10: 'cpu_lat@10ms_top1@71.1_finetune@25',
},
'flops': {
595: 'flops@595M_top1@80.0_finetune@75',
482: 'flops@482M_top1@79.6_finetune@75',
389: 'flops@389M_top1@79.1_finetune@75',
}
}
sub_efficiency_map = platform_efficiency_map[platform_name]
if not platform_name == 'flops':
print("Now, please specify a latency constraint for model specialization among", sorted(list(sub_efficiency_map.keys())), 'ms. (Please just input the number.) \n')
else:
print("Now, please specify a FLOPs constraint for model specialization among", sorted(list(sub_efficiency_map.keys())), 'MFLOPs. (Please just input the number.) \n')
while True:
efficiency_constraint = input()
if not efficiency_constraint.isdigit():
print('Sorry, please input an integer! \n')
continue
efficiency_constraint = int(efficiency_constraint)
if not efficiency_constraint in sub_efficiency_map.keys():
print('Sorry, please choose a value from: ', sorted(list(sub_efficiency_map.keys())), '.\n')
continue
return sub_efficiency_map[efficiency_constraint]
platform_name = select_platform_name()
net_id = select_netid(platform_name)
net, image_size = ofa_specialized(net_id=net_id, pretrained=True)
validate(net, path, image_size, data_loader, batch_size, device)
return net_id

View File

@@ -0,0 +1,164 @@
import yaml
from ofa.utils import download_url
class LatencyEstimator(object):
def __init__(self, local_dir='~/.hancai/latency_tools/',
url='https://hanlab.mit.edu/files/proxylessNAS/LatencyTools/mobile_trim.yaml'):
if url.startswith('http'):
fname = download_url(url, local_dir, overwrite=True)
else:
fname = url
with open(fname, 'r') as fp:
self.lut = yaml.load(fp)
@staticmethod
def repr_shape(shape):
if isinstance(shape, (list, tuple)):
return 'x'.join(str(_) for _ in shape)
elif isinstance(shape, str):
return shape
else:
return TypeError
def query(self, l_type: str, input_shape, output_shape, mid=None, ks=None, stride=None, id_skip=None,
se=None, h_swish=None):
infos = [l_type, 'input:%s' % self.repr_shape(input_shape), 'output:%s' % self.repr_shape(output_shape), ]
if l_type in ('expanded_conv',):
assert None not in (mid, ks, stride, id_skip, se, h_swish)
infos += ['expand:%d' % mid, 'kernel:%d' % ks, 'stride:%d' % stride, 'idskip:%d' % id_skip,
'se:%d' % se, 'hs:%d' % h_swish]
key = '-'.join(infos)
return self.lut[key]['mean']
def predict_network_latency(self, net, image_size=224):
predicted_latency = 0
# first conv
predicted_latency += self.query(
'Conv', [image_size, image_size, 3],
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels]
)
# blocks
fsize = (image_size + 1) // 2
for block in net.blocks:
mb_conv = block.mobile_inverted_conv
shortcut = block.shortcut
if mb_conv is None:
continue
if shortcut is None:
idskip = 0
else:
idskip = 1
out_fz = int((fsize - 1) / mb_conv.stride + 1)
block_latency = self.query(
'expanded_conv', [fsize, fsize, mb_conv.in_channels], [out_fz, out_fz, mb_conv.out_channels],
mid=mb_conv.depth_conv.conv.in_channels, ks=mb_conv.kernel_size, stride=mb_conv.stride, id_skip=idskip,
se=1 if mb_conv.use_se else 0, h_swish=1 if mb_conv.act_func == 'h_swish' else 0,
)
predicted_latency += block_latency
fsize = out_fz
# final expand layer
predicted_latency += self.query(
'Conv_1', [fsize, fsize, net.final_expand_layer.in_channels],
[fsize, fsize, net.final_expand_layer.out_channels],
)
# global average pooling
predicted_latency += self.query(
'AvgPool2D', [fsize, fsize, net.final_expand_layer.out_channels],
[1, 1, net.final_expand_layer.out_channels],
)
# feature mix layer
predicted_latency += self.query(
'Conv_2', [1, 1, net.feature_mix_layer.in_channels],
[1, 1, net.feature_mix_layer.out_channels]
)
# classifier
predicted_latency += self.query(
'Logits', [1, 1, net.classifier.in_features], [net.classifier.out_features]
)
return predicted_latency
def predict_network_latency_given_spec(self, spec):
image_size = spec['r'][0]
predicted_latency = 0
# first conv
predicted_latency += self.query(
'Conv', [image_size, image_size, 3],
[(image_size + 1) // 2, (image_size + 1) // 2, 24]
)
# blocks
fsize = (image_size + 1) // 2
# first block
predicted_latency += self.query(
'expanded_conv', [fsize, fsize, 24], [fsize, fsize, 24],
mid=24, ks=3, stride=1, id_skip=1, se=0, h_swish=0,
)
in_channel = 24
stride_stages = [2, 2, 2, 1, 2]
width_stages = [32, 48, 96, 136, 192]
act_stages = ['relu', 'relu', 'h_swish', 'h_swish', 'h_swish']
se_stages = [False, True, False, True, True]
for i in range(20):
stage = i // 4
depth_max = spec['d'][stage]
depth = i % 4 + 1
if depth > depth_max:
continue
ks, e = spec['ks'][i], spec['e'][i]
if i % 4 == 0:
stride = stride_stages[stage]
idskip = 0
else:
stride = 1
idskip = 1
out_channel = width_stages[stage]
out_fz = int((fsize - 1) / stride + 1)
mid_channel = round(in_channel * e)
block_latency = self.query(
'expanded_conv', [fsize, fsize, in_channel], [out_fz, out_fz, out_channel],
mid=mid_channel, ks=ks, stride=stride, id_skip=idskip,
se=1 if se_stages[stage] else 0, h_swish=1 if act_stages[stage] == 'h_swish' else 0,
)
predicted_latency += block_latency
fsize = out_fz
in_channel = out_channel
# final expand layer
predicted_latency += self.query(
'Conv_1', [fsize, fsize, 192],
[fsize, fsize, 1152],
)
# global average pooling
predicted_latency += self.query(
'AvgPool2D', [fsize, fsize, 1152],
[1, 1, 1152],
)
# feature mix layer
predicted_latency += self.query(
'Conv_2', [1, 1, 1152],
[1, 1, 1536]
)
# classifier
predicted_latency += self.query(
'Logits', [1, 1, 1536], [1000]
)
return predicted_latency
class LatencyTable:
def __init__(self, device='note10', resolutions=(160, 176, 192, 208, 224)):
self.latency_tables = {}
for image_size in resolutions:
self.latency_tables[image_size] = LatencyEstimator(
url='https://hanlab.mit.edu/files/OnceForAll/tutorial/latency_table@%s/%d_lookup_table.yaml' % (
device, image_size)
)
print('Built latency table for image size: %d.' % image_size)
def predict_efficiency(self, spec: dict):
return self.latency_tables[spec['r'][0]].predict_network_latency_given_spec(spec)

View File

@@ -0,0 +1,10 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .pytorch_modules import *
from .pytorch_utils import *
from .my_modules import *
from .flops_counter import *
from .common_tools import *
from .my_dataloader import *

View File

@@ -0,0 +1,284 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import numpy as np
import os
import sys
import torch
try:
from urllib import urlretrieve
except ImportError:
from urllib.request import urlretrieve
__all__ = [
'sort_dict', 'get_same_padding',
'get_split_list', 'list_sum', 'list_mean', 'list_join',
'subset_mean', 'sub_filter_start_end', 'min_divisible_value', 'val2list',
'download_url',
'write_log', 'pairwise_accuracy', 'accuracy',
'AverageMeter', 'MultiClassAverageMeter',
'DistributedMetric', 'DistributedTensor',
]
def sort_dict(src_dict, reverse=False, return_dict=True):
output = sorted(src_dict.items(), key=lambda x: x[1], reverse=reverse)
if return_dict:
return dict(output)
else:
return output
def get_same_padding(kernel_size):
if isinstance(kernel_size, tuple):
assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
p1 = get_same_padding(kernel_size[0])
p2 = get_same_padding(kernel_size[1])
return p1, p2
assert isinstance(kernel_size, int), 'kernel size should be either `int` or `tuple`'
assert kernel_size % 2 > 0, 'kernel size should be odd number'
return kernel_size // 2
def get_split_list(in_dim, child_num, accumulate=False):
in_dim_list = [in_dim // child_num] * child_num
for _i in range(in_dim % child_num):
in_dim_list[_i] += 1
if accumulate:
for i in range(1, child_num):
in_dim_list[i] += in_dim_list[i - 1]
return in_dim_list
def list_sum(x):
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
def list_mean(x):
return list_sum(x) / len(x)
def list_join(val_list, sep='\t'):
return sep.join([str(val) for val in val_list])
def subset_mean(val_list, sub_indexes):
sub_indexes = val2list(sub_indexes, 1)
return list_mean([val_list[idx] for idx in sub_indexes])
def sub_filter_start_end(kernel_size, sub_kernel_size):
center = kernel_size // 2
dev = sub_kernel_size // 2
start, end = center - dev, center + dev + 1
assert end - start == sub_kernel_size
return start, end
def min_divisible_value(n1, v1):
""" make sure v1 is divisible by n1, otherwise decrease v1 """
if v1 >= n1:
return n1
while n1 % v1 != 0:
v1 -= 1
return v1
def val2list(val, repeat_time=1):
if isinstance(val, list) or isinstance(val, np.ndarray):
return val
elif isinstance(val, tuple):
return list(val)
else:
return [val for _ in range(repeat_time)]
def download_url(url, model_dir='~/.torch/', overwrite=False):
target_dir = url.split('/')[-1]
model_dir = os.path.expanduser(model_dir)
try:
if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_dir = os.path.join(model_dir, target_dir)
cached_file = model_dir
if not os.path.exists(cached_file) or overwrite:
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
urlretrieve(url, cached_file)
return cached_file
except Exception as e:
# remove lock file so download can be executed next time.
os.remove(os.path.join(model_dir, 'download.lock'))
sys.stderr.write('Failed to download from url %s' % url + '\n' + str(e) + '\n')
return None
def write_log(logs_path, log_str, prefix='valid', should_print=True, mode='a'):
if not os.path.exists(logs_path):
os.makedirs(logs_path, exist_ok=True)
""" prefix: valid, train, test """
if prefix in ['valid', 'test']:
with open(os.path.join(logs_path, 'valid_console.txt'), mode) as fout:
fout.write(log_str + '\n')
fout.flush()
if prefix in ['valid', 'test', 'train']:
with open(os.path.join(logs_path, 'train_console.txt'), mode) as fout:
if prefix in ['valid', 'test']:
fout.write('=' * 10)
fout.write(log_str + '\n')
fout.flush()
else:
with open(os.path.join(logs_path, '%s.txt' % prefix), mode) as fout:
fout.write(log_str + '\n')
fout.flush()
if should_print:
print(log_str)
def pairwise_accuracy(la, lb, n_samples=200000):
n = len(la)
assert n == len(lb)
total = 0
count = 0
for _ in range(n_samples):
i = np.random.randint(n)
j = np.random.randint(n)
while i == j:
j = np.random.randint(n)
if la[i] >= la[j] and lb[i] >= lb[j]:
count += 1
if la[i] < la[j] and lb[i] < lb[j]:
count += 1
total += 1
return float(count) / total
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class AverageMeter(object):
"""
Computes and stores the average and current value
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class MultiClassAverageMeter:
""" Multi Binary Classification Tasks """
def __init__(self, num_classes, balanced=False, **kwargs):
super(MultiClassAverageMeter, self).__init__()
self.num_classes = num_classes
self.balanced = balanced
self.counts = []
for k in range(self.num_classes):
self.counts.append(np.ndarray((2, 2), dtype=np.float32))
self.reset()
def reset(self):
for k in range(self.num_classes):
self.counts[k].fill(0)
def add(self, outputs, targets):
outputs = outputs.data.cpu().numpy()
targets = targets.data.cpu().numpy()
for k in range(self.num_classes):
output = np.argmax(outputs[:, k, :], axis=1)
target = targets[:, k]
x = output + 2 * target
bincount = np.bincount(x.astype(np.int32), minlength=2 ** 2)
self.counts[k] += bincount.reshape((2, 2))
def value(self):
mean = 0
for k in range(self.num_classes):
if self.balanced:
value = np.mean((self.counts[k] / np.maximum(np.sum(self.counts[k], axis=1), 1)[:, None]).diagonal())
else:
value = np.sum(self.counts[k].diagonal()) / np.maximum(np.sum(self.counts[k]), 1)
mean += value / self.num_classes * 100.
return mean
class DistributedMetric(object):
"""
Horovod: average metrics from distributed training.
"""
def __init__(self, name):
self.name = name
self.sum = torch.zeros(1)[0]
self.count = torch.zeros(1)[0]
def update(self, val, delta_n=1):
import horovod.torch as hvd
val *= delta_n
self.sum += hvd.allreduce(val.detach().cpu(), name=self.name)
self.count += delta_n
@property
def avg(self):
return self.sum / self.count
class DistributedTensor(object):
def __init__(self, name):
self.name = name
self.sum = None
self.count = torch.zeros(1)[0]
self.synced = False
def update(self, val, delta_n=1):
val *= delta_n
if self.sum is None:
self.sum = val.detach()
else:
self.sum += val.detach()
self.count += delta_n
@property
def avg(self):
import horovod.torch as hvd
if not self.synced:
self.sum = hvd.allreduce(self.sum, name=self.name)
self.synced = True
return self.sum / self.count

View File

@@ -0,0 +1,97 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch
import torch.nn as nn
from .my_modules import MyConv2d
__all__ = ['profile']
def count_convNd(m, _, y):
cin = m.in_channels
kernel_ops = m.weight.size()[2] * m.weight.size()[3]
ops_per_element = kernel_ops
output_elements = y.nelement()
# cout x oW x oH
total_ops = cin * output_elements * ops_per_element // m.groups
m.total_ops = torch.zeros(1).fill_(total_ops)
def count_linear(m, _, __):
total_ops = m.in_features * m.out_features
m.total_ops = torch.zeros(1).fill_(total_ops)
register_hooks = {
nn.Conv1d: count_convNd,
nn.Conv2d: count_convNd,
nn.Conv3d: count_convNd,
MyConv2d: count_convNd,
######################################
nn.Linear: count_linear,
######################################
nn.Dropout: None,
nn.Dropout2d: None,
nn.Dropout3d: None,
nn.BatchNorm2d: None,
}
def profile(model, input_size, custom_ops=None):
handler_collection = []
custom_ops = {} if custom_ops is None else custom_ops
def add_hooks(m_):
if len(list(m_.children())) > 0:
return
m_.register_buffer('total_ops', torch.zeros(1))
m_.register_buffer('total_params', torch.zeros(1))
for p in m_.parameters():
m_.total_params += torch.zeros(1).fill_(p.numel())
m_type = type(m_)
fn = None
if m_type in custom_ops:
fn = custom_ops[m_type]
elif m_type in register_hooks:
fn = register_hooks[m_type]
if fn is not None:
_handler = m_.register_forward_hook(fn)
handler_collection.append(_handler)
original_device = model.parameters().__next__().device
training = model.training
model.eval()
model.apply(add_hooks)
x = torch.zeros(input_size).to(original_device)
with torch.no_grad():
model(x)
total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: # skip for non-leaf module
continue
total_ops += m.total_ops
total_params += m.total_params
total_ops = total_ops.item()
total_params = total_params.item()
model.train(training).to(original_device)
for handler in handler_collection:
handler.remove()
return total_ops, total_params

View File

@@ -0,0 +1,727 @@
######################################################################################
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import torch
import torch.nn as nn
from torch.distributions import Bernoulli
from collections import OrderedDict
from ofa_local.utils import get_same_padding, min_divisible_value, SEModule, ShuffleLayer
from ofa_local.utils import MyNetwork, MyModule
from ofa_local.utils import build_activation, make_divisible
__all__ = [
'set_layer_from_config',
'ConvLayer', 'IdentityLayer', 'LinearLayer', 'MultiHeadLinearLayer', 'ZeroLayer', 'MBConvLayer',
'ResidualBlock', 'ResNetBottleneckBlock',
]
class DropBlock(nn.Module):
def __init__(self, block_size):
super(DropBlock, self).__init__()
self.block_size = block_size
def forward(self, x, gamma):
# shape: (bsize, channels, height, width)
if self.training:
batch_size, channels, height, width = x.shape
bernoulli = Bernoulli(gamma)
mask = bernoulli.sample(
(batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda()
# print((x.sample[-2], x.sample[-1]))
block_mask = self._compute_block_mask(mask)
# print (block_mask.size())
# print (x.size())
countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
count_ones = block_mask.sum()
return block_mask * x * (countM / count_ones)
else:
return x
def _compute_block_mask(self, mask):
left_padding = int((self.block_size - 1) / 2)
right_padding = int(self.block_size / 2)
batch_size, channels, height, width = mask.shape
# print ("mask", mask[0][0])
non_zero_idxs = mask.nonzero()
nr_blocks = non_zero_idxs.shape[0]
offsets = torch.stack(
[
torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1),
# - left_padding,
torch.arange(self.block_size).repeat(self.block_size), # - left_padding
]
).t().cuda()
offsets = torch.cat((torch.zeros(self.block_size ** 2, 2).cuda().long(), offsets.long()), 1)
if nr_blocks > 0:
non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
offsets = offsets.long()
block_idxs = non_zero_idxs + offsets
# block_idxs += left_padding
padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
else:
padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
block_mask = 1 - padded_mask # [:height, :width]
return block_mask
def set_layer_from_config(layer_config):
if layer_config is None:
return None
name2layer = {
ConvLayer.__name__: ConvLayer,
IdentityLayer.__name__: IdentityLayer,
LinearLayer.__name__: LinearLayer,
MultiHeadLinearLayer.__name__: MultiHeadLinearLayer,
ZeroLayer.__name__: ZeroLayer,
MBConvLayer.__name__: MBConvLayer,
'MBInvertedConvLayer': MBConvLayer,
##########################################################
ResidualBlock.__name__: ResidualBlock,
ResNetBottleneckBlock.__name__: ResNetBottleneckBlock,
}
layer_name = layer_config.pop('name')
layer = name2layer[layer_name]
return layer.build_from_config(layer_config)
class My2DLayer(MyModule):
def __init__(self, in_channels, out_channels,
use_bn=True, act_func='relu', dropout_rate=0, ops_order='weight_bn_act'):
super(My2DLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.use_bn = use_bn
self.act_func = act_func
self.dropout_rate = dropout_rate
self.ops_order = ops_order
""" modules """
modules = {}
# batch norm
if self.use_bn:
if self.bn_before_weight:
modules['bn'] = nn.BatchNorm2d(in_channels)
else:
modules['bn'] = nn.BatchNorm2d(out_channels)
else:
modules['bn'] = None
# activation
modules['act'] = build_activation(self.act_func, self.ops_list[0] != 'act' and self.use_bn)
# dropout
if self.dropout_rate > 0:
modules['dropout'] = nn.Dropout2d(self.dropout_rate, inplace=True)
else:
modules['dropout'] = None
# weight
modules['weight'] = self.weight_op()
# add modules
for op in self.ops_list:
if modules[op] is None:
continue
elif op == 'weight':
# dropout before weight operation
if modules['dropout'] is not None:
self.add_module('dropout', modules['dropout'])
for key in modules['weight']:
self.add_module(key, modules['weight'][key])
else:
self.add_module(op, modules[op])
@property
def ops_list(self):
return self.ops_order.split('_')
@property
def bn_before_weight(self):
for op in self.ops_list:
if op == 'bn':
return True
elif op == 'weight':
return False
raise ValueError('Invalid ops_order: %s' % self.ops_order)
def weight_op(self):
raise NotImplementedError
""" Methods defined in MyModule """
def forward(self, x):
# similar to nn.Sequential
for module in self._modules.values():
x = module(x)
return x
@property
def module_str(self):
raise NotImplementedError
@property
def config(self):
return {
'in_channels': self.in_channels,
'out_channels': self.out_channels,
'use_bn': self.use_bn,
'act_func': self.act_func,
'dropout_rate': self.dropout_rate,
'ops_order': self.ops_order,
}
@staticmethod
def build_from_config(config):
raise NotImplementedError
class ConvLayer(My2DLayer):
def __init__(self, in_channels, out_channels,
kernel_size=3, stride=1, dilation=1, groups=1, bias=False, has_shuffle=False, use_se=False,
use_bn=True, act_func='relu', dropout_rate=0, ops_order='weight_bn_act'):
# default normal 3x3_Conv with bn and relu
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.groups = groups
self.bias = bias
self.has_shuffle = has_shuffle
self.use_se = use_se
super(ConvLayer, self).__init__(in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order)
if self.use_se:
self.add_module('se', SEModule(self.out_channels))
def weight_op(self):
padding = get_same_padding(self.kernel_size)
if isinstance(padding, int):
padding *= self.dilation
else:
padding[0] *= self.dilation
padding[1] *= self.dilation
weight_dict = OrderedDict({
'conv': nn.Conv2d(
self.in_channels, self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=padding,
dilation=self.dilation, groups=min_divisible_value(self.in_channels, self.groups), bias=self.bias
)
})
if self.has_shuffle and self.groups > 1:
weight_dict['shuffle'] = ShuffleLayer(self.groups)
return weight_dict
@property
def module_str(self):
if isinstance(self.kernel_size, int):
kernel_size = (self.kernel_size, self.kernel_size)
else:
kernel_size = self.kernel_size
if self.groups == 1:
if self.dilation > 1:
conv_str = '%dx%d_DilatedConv' % (kernel_size[0], kernel_size[1])
else:
conv_str = '%dx%d_Conv' % (kernel_size[0], kernel_size[1])
else:
if self.dilation > 1:
conv_str = '%dx%d_DilatedGroupConv' % (kernel_size[0], kernel_size[1])
else:
conv_str = '%dx%d_GroupConv' % (kernel_size[0], kernel_size[1])
conv_str += '_O%d' % self.out_channels
if self.use_se:
conv_str = 'SE_' + conv_str
conv_str += '_' + self.act_func.upper()
if self.use_bn:
if isinstance(self.bn, nn.GroupNorm):
conv_str += '_GN%d' % self.bn.num_groups
elif isinstance(self.bn, nn.BatchNorm2d):
conv_str += '_BN'
return conv_str
@property
def config(self):
return {
'name': ConvLayer.__name__,
'kernel_size': self.kernel_size,
'stride': self.stride,
'dilation': self.dilation,
'groups': self.groups,
'bias': self.bias,
'has_shuffle': self.has_shuffle,
'use_se': self.use_se,
**super(ConvLayer, self).config
}
@staticmethod
def build_from_config(config):
return ConvLayer(**config)
class IdentityLayer(My2DLayer):
def __init__(self, in_channels, out_channels,
use_bn=False, act_func=None, dropout_rate=0, ops_order='weight_bn_act'):
super(IdentityLayer, self).__init__(in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order)
def weight_op(self):
return None
@property
def module_str(self):
return 'Identity'
@property
def config(self):
return {
'name': IdentityLayer.__name__,
**super(IdentityLayer, self).config,
}
@staticmethod
def build_from_config(config):
return IdentityLayer(**config)
class LinearLayer(MyModule):
def __init__(self, in_features, out_features, bias=True,
use_bn=False, act_func=None, dropout_rate=0, ops_order='weight_bn_act'):
super(LinearLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.bias = bias
self.use_bn = use_bn
self.act_func = act_func
self.dropout_rate = dropout_rate
self.ops_order = ops_order
""" modules """
modules = {}
# batch norm
if self.use_bn:
if self.bn_before_weight:
modules['bn'] = nn.BatchNorm1d(in_features)
else:
modules['bn'] = nn.BatchNorm1d(out_features)
else:
modules['bn'] = None
# activation
modules['act'] = build_activation(self.act_func, self.ops_list[0] != 'act')
# dropout
if self.dropout_rate > 0:
modules['dropout'] = nn.Dropout(self.dropout_rate, inplace=True)
else:
modules['dropout'] = None
# linear
modules['weight'] = {'linear': nn.Linear(self.in_features, self.out_features, self.bias)}
# add modules
for op in self.ops_list:
if modules[op] is None:
continue
elif op == 'weight':
if modules['dropout'] is not None:
self.add_module('dropout', modules['dropout'])
for key in modules['weight']:
self.add_module(key, modules['weight'][key])
else:
self.add_module(op, modules[op])
@property
def ops_list(self):
return self.ops_order.split('_')
@property
def bn_before_weight(self):
for op in self.ops_list:
if op == 'bn':
return True
elif op == 'weight':
return False
raise ValueError('Invalid ops_order: %s' % self.ops_order)
def forward(self, x):
for module in self._modules.values():
x = module(x)
return x
@property
def module_str(self):
return '%dx%d_Linear' % (self.in_features, self.out_features)
@property
def config(self):
return {
'name': LinearLayer.__name__,
'in_features': self.in_features,
'out_features': self.out_features,
'bias': self.bias,
'use_bn': self.use_bn,
'act_func': self.act_func,
'dropout_rate': self.dropout_rate,
'ops_order': self.ops_order,
}
@staticmethod
def build_from_config(config):
return LinearLayer(**config)
class MultiHeadLinearLayer(MyModule):
def __init__(self, in_features, out_features, num_heads=1, bias=True, dropout_rate=0):
super(MultiHeadLinearLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.num_heads = num_heads
self.bias = bias
self.dropout_rate = dropout_rate
if self.dropout_rate > 0:
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
else:
self.dropout = None
self.layers = nn.ModuleList()
for k in range(num_heads):
layer = nn.Linear(in_features, out_features, self.bias)
self.layers.append(layer)
def forward(self, inputs):
if self.dropout is not None:
inputs = self.dropout(inputs)
outputs = []
for layer in self.layers:
output = layer.forward(inputs)
outputs.append(output)
outputs = torch.stack(outputs, dim=1)
return outputs
@property
def module_str(self):
return self.__repr__()
@property
def config(self):
return {
'name': MultiHeadLinearLayer.__name__,
'in_features': self.in_features,
'out_features': self.out_features,
'num_heads': self.num_heads,
'bias': self.bias,
'dropout_rate': self.dropout_rate,
}
@staticmethod
def build_from_config(config):
return MultiHeadLinearLayer(**config)
def __repr__(self):
return 'MultiHeadLinear(in_features=%d, out_features=%d, num_heads=%d, bias=%s, dropout_rate=%s)' % (
self.in_features, self.out_features, self.num_heads, self.bias, self.dropout_rate
)
class ZeroLayer(MyModule):
def __init__(self):
super(ZeroLayer, self).__init__()
def forward(self, x):
raise ValueError
@property
def module_str(self):
return 'Zero'
@property
def config(self):
return {
'name': ZeroLayer.__name__,
}
@staticmethod
def build_from_config(config):
return ZeroLayer()
class MBConvLayer(MyModule):
def __init__(self, in_channels, out_channels,
kernel_size=3, stride=1, expand_ratio=6, mid_channels=None, act_func='relu6', use_se=False,
groups=None):
super(MBConvLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.expand_ratio = expand_ratio
self.mid_channels = mid_channels
self.act_func = act_func
self.use_se = use_se
self.groups = groups
if self.mid_channels is None:
feature_dim = round(self.in_channels * self.expand_ratio)
else:
feature_dim = self.mid_channels
if self.expand_ratio == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
('bn', nn.BatchNorm2d(feature_dim)),
('act', build_activation(self.act_func, inplace=True)),
]))
pad = get_same_padding(self.kernel_size)
groups = feature_dim if self.groups is None else min_divisible_value(feature_dim, self.groups)
depth_conv_modules = [
('conv', nn.Conv2d(feature_dim, feature_dim, kernel_size, stride, pad, groups=groups, bias=False)),
('bn', nn.BatchNorm2d(feature_dim)),
('act', build_activation(self.act_func, inplace=True))
]
if self.use_se:
depth_conv_modules.append(('se', SEModule(feature_dim)))
self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules))
self.point_linear = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
('bn', nn.BatchNorm2d(out_channels)),
]))
def forward(self, x):
if self.inverted_bottleneck:
x = self.inverted_bottleneck(x)
x = self.depth_conv(x)
x = self.point_linear(x)
return x
@property
def module_str(self):
if self.mid_channels is None:
expand_ratio = self.expand_ratio
else:
expand_ratio = self.mid_channels // self.in_channels
layer_str = '%dx%d_MBConv%d_%s' % (self.kernel_size, self.kernel_size, expand_ratio, self.act_func.upper())
if self.use_se:
layer_str = 'SE_' + layer_str
layer_str += '_O%d' % self.out_channels
if self.groups is not None:
layer_str += '_G%d' % self.groups
if isinstance(self.point_linear.bn, nn.GroupNorm):
layer_str += '_GN%d' % self.point_linear.bn.num_groups
elif isinstance(self.point_linear.bn, nn.BatchNorm2d):
layer_str += '_BN'
return layer_str
@property
def config(self):
return {
'name': MBConvLayer.__name__,
'in_channels': self.in_channels,
'out_channels': self.out_channels,
'kernel_size': self.kernel_size,
'stride': self.stride,
'expand_ratio': self.expand_ratio,
'mid_channels': self.mid_channels,
'act_func': self.act_func,
'use_se': self.use_se,
'groups': self.groups,
}
@staticmethod
def build_from_config(config):
return MBConvLayer(**config)
class ResidualBlock(MyModule):
def __init__(self, conv, shortcut, dropout_rate, dropblock, block_size):
super(ResidualBlock, self).__init__()
self.conv = conv
self.shortcut = shortcut
# hayeon
self.num_batches_tracked = 0
self.dropout_rate = dropout_rate
self.dropblock = dropblock
self.block_size = block_size
self.DropBlock = DropBlock(block_size=self.block_size)
def forward(self, x):
# hayeon
self.num_batches_tracked += 1
if self.conv is None or isinstance(self.conv, ZeroLayer):
res = x
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
res = self.conv(x)
else:
res = self.conv(x) + self.shortcut(x)
# hayeon
if self.dropout_rate > 0:
if self.dropblock:
feat_size = res.size()[2]
keep_rate = max(1.0 - self.dropout_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
res = self.DropBlock(res, gamma=gamma)
else:
res = F.dropout(res, p=self.dropout_rate, training=self.training, inplace=True)
return res
@property
def module_str(self):
return '(%s, %s)' % (
self.conv.module_str if self.conv is not None else None,
self.shortcut.module_str if self.shortcut is not None else None
)
@property
def config(self):
return {
'name': ResidualBlock.__name__,
'conv': self.conv.config if self.conv is not None else None,
'shortcut': self.shortcut.config if self.shortcut is not None else None,
}
@staticmethod
def build_from_config(config):
conv_config = config['conv'] if 'conv' in config else config['mobile_inverted_conv']
conv = set_layer_from_config(conv_config)
shortcut = set_layer_from_config(config['shortcut'])
return ResidualBlock(conv, shortcut)
@property
def mobile_inverted_conv(self):
return self.conv
class ResNetBottleneckBlock(MyModule):
def __init__(self, in_channels, out_channels,
kernel_size=3, stride=1, expand_ratio=0.25, mid_channels=None, act_func='relu', groups=1,
downsample_mode='avgpool_conv'):
super(ResNetBottleneckBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.expand_ratio = expand_ratio
self.mid_channels = mid_channels
self.act_func = act_func
self.groups = groups
self.downsample_mode = downsample_mode
if self.mid_channels is None:
feature_dim = round(self.out_channels * self.expand_ratio)
else:
feature_dim = self.mid_channels
feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE)
self.mid_channels = feature_dim
# build modules
self.conv1 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
('bn', nn.BatchNorm2d(feature_dim)),
('act', build_activation(self.act_func, inplace=True)),
]))
pad = get_same_padding(self.kernel_size)
self.conv2 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(feature_dim, feature_dim, kernel_size, stride, pad, groups=groups, bias=False)),
('bn', nn.BatchNorm2d(feature_dim)),
('act', build_activation(self.act_func, inplace=True))
]))
self.conv3 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(feature_dim, self.out_channels, 1, 1, 0, bias=False)),
('bn', nn.BatchNorm2d(self.out_channels)),
]))
if stride == 1 and in_channels == out_channels:
self.downsample = IdentityLayer(in_channels, out_channels)
elif self.downsample_mode == 'conv':
self.downsample = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(in_channels, out_channels, 1, stride, 0, bias=False)),
('bn', nn.BatchNorm2d(out_channels)),
]))
elif self.downsample_mode == 'avgpool_conv':
self.downsample = nn.Sequential(OrderedDict([
('avg_pool', nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0, ceil_mode=True)),
('conv', nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)),
('bn', nn.BatchNorm2d(out_channels)),
]))
else:
raise NotImplementedError
self.final_act = build_activation(self.act_func, inplace=True)
def forward(self, x):
residual = self.downsample(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x + residual
x = self.final_act(x)
return x
@property
def module_str(self):
return '(%s, %s)' % (
'%dx%d_BottleneckConv_%d->%d->%d_S%d_G%d' % (
self.kernel_size, self.kernel_size, self.in_channels, self.mid_channels, self.out_channels,
self.stride, self.groups
),
'Identity' if isinstance(self.downsample, IdentityLayer) else self.downsample_mode,
)
@property
def config(self):
return {
'name': ResNetBottleneckBlock.__name__,
'in_channels': self.in_channels,
'out_channels': self.out_channels,
'kernel_size': self.kernel_size,
'stride': self.stride,
'expand_ratio': self.expand_ratio,
'mid_channels': self.mid_channels,
'act_func': self.act_func,
'groups': self.groups,
'downsample_mode': self.downsample_mode,
}
@staticmethod
def build_from_config(config):
return ResNetBottleneckBlock(**config)

View File

@@ -0,0 +1,4 @@
from .my_data_loader import *
from .my_data_worker import *
from .my_distributed_sampler import *
from .my_random_resize_crop import *

View File

@@ -0,0 +1,962 @@
r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
To support these two classes, in `./_utils` we define many utility methods and
functions to be run in multiprocessing. E.g., the data loading worker loop is
in `./_utils/worker.py`.
"""
import threading
import itertools
import warnings
import multiprocessing as python_multiprocessing
import torch
import torch.multiprocessing as multiprocessing
from torch._utils import ExceptionWrapper
from torch._six import queue, string_classes
from torch.utils.data.dataset import IterableDataset
from torch.utils.data import Sampler, SequentialSampler, RandomSampler, BatchSampler
from torch.utils.data import _utils
from .my_data_worker import worker_loop
__all__ = ['MyDataLoader']
get_worker_info = _utils.worker.get_worker_info
# This function used to be defined in this file. However, it was moved to
# _utils/collate.py. Although it is rather hard to access this from user land
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
# probably is user code out there using it. This aliasing maintains BC in this
# aspect.
default_collate = _utils.collate.default_collate
class _DatasetKind(object):
Map = 0
Iterable = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
class _InfiniteConstantSampler(Sampler):
r"""Analogous to ``itertools.repeat(None, None)``.
Used as sampler for :class:`~torch.utils.data.IterableDataset`.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self):
super(_InfiniteConstantSampler, self).__init__(None)
def __iter__(self):
while True:
yield None
class MyDataLoader(object):
r"""
Data loader. Combines a dataset and a sampler, and provides an iterable over
the given dataset.
The :class:`~torch.utils.data.DataLoader` supports both map-style and
iterable-style datasets with single- or multi-process loading, customizing
loading order and optional automatic batching (collation) and memory pinning.
See :py:mod:`torch.utils.data` documentation page for more details.
Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, :attr:`shuffle` must be ``False``.
batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch of
indices at a time. Mutually exclusive with :attr:`batch_size`,
:attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
into CUDA pinned memory before returning them. If your data elements
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``False``)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: ``0``)
worker_init_fn (callable, optional): If not ``None``, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: ``None``)
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
cannot be an unpicklable object, e.g., a lambda function. See
:ref:`multiprocessing-best-practices` on more details related
to multiprocessing in PyTorch.
.. note:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
``len(dataset)`` (if implemented) is returned instead, regardless
of multi-process loading configurations, because PyTorch trust
user :attr:`dataset` code in correctly handling multi-process
loading to avoid duplicate data. See `Dataset Types`_ for more
details on these two types of datasets and how
:class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_.
"""
__initialized = False
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None):
torch._C._log_api_usage_once("python.data_loader")
if num_workers < 0:
raise ValueError('num_workers option should be non-negative; '
'use num_workers=0 to disable multiprocessing.')
if timeout < 0:
raise ValueError('timeout option should be non-negative')
self.dataset = dataset
self.num_workers = num_workers
self.pin_memory = pin_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context
# Arg-check dataset related before checking samplers because we want to
# tell users that iterable-style datasets are incompatible with custom
# samplers first, so that they don't learn that this combo doesn't work
# after spending time fixing the custom sampler errors.
if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
# NOTE [ Custom Samplers and `IterableDataset` ]
#
# `IterableDataset` does not support custom `batch_sampler` or
# `sampler` since the key is irrelevant (unless we support
# generator-style dataset one day...).
#
# For `sampler`, we always create a dummy sampler. This is an
# infinite sampler even when the dataset may have an implemented
# finite `__len__` because in multi-process data loading, naive
# settings will return duplicated data (which may be desired), and
# thus using a sampler with length matching that of dataset will
# cause data lost (you may have duplicates of the first couple
# batches, but never see anything afterwards). Therefore,
# `Iterabledataset` always uses an infinite sampler, an instance of
# `_InfiniteConstantSampler` defined above.
#
# A custom `batch_sampler` essentially only controls the batch size.
# However, it is unclear how useful it would be since an iterable-style
# dataset can handle that within itself. Moreover, it is pointless
# in multi-process data loading as the assignment order of batches
# to workers is an implementation detail so users can not control
# how to batchify each worker's iterable. Thus, we disable this
# option. If this turns out to be useful in future, we can re-enable
# this, and support custom samplers that specify the assignments to
# specific workers.
if shuffle is not False:
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"shuffle option, but got shuffle={}".format(shuffle))
elif sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"sampler option, but got sampler={}".format(sampler))
elif batch_sampler is not None:
# See NOTE [ Custom Samplers and IterableDataset ]
raise ValueError(
"DataLoader with IterableDataset: expected unspecified "
"batch_sampler option, but got batch_sampler={}".format(batch_sampler))
else:
self._dataset_kind = _DatasetKind.Map
if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with '
'shuffle')
if batch_sampler is not None:
# auto_collation with custom batch_sampler
if batch_size != 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
'drop_last')
batch_size = None
drop_last = False
elif batch_size is None:
# no auto_collation
if shuffle or drop_last:
raise ValueError('batch_size=None option disables auto-batching '
'and is mutually exclusive with '
'shuffle, and drop_last')
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert
self.collate_fn = collate_fn
self.__initialized = True
self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
@property
def multiprocessing_context(self):
return self.__multiprocessing_context
@multiprocessing_context.setter
def multiprocessing_context(self, multiprocessing_context):
if multiprocessing_context is not None:
if self.num_workers > 0:
if not multiprocessing._supports_context:
raise ValueError('multiprocessing_context relies on Python >= 3.4, with '
'support for different start methods')
if isinstance(multiprocessing_context, string_classes):
valid_start_methods = multiprocessing.get_all_start_methods()
if multiprocessing_context not in valid_start_methods:
raise ValueError(
('multiprocessing_context option '
'should specify a valid start method in {}, but got '
'multiprocessing_context={}').format(valid_start_methods, multiprocessing_context))
multiprocessing_context = multiprocessing.get_context(multiprocessing_context)
if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
raise ValueError(('multiprocessing_context option should be a valid context '
'object or a string specifying the start method, but got '
'multiprocessing_context={}').format(multiprocessing_context))
else:
raise ValueError(('multiprocessing_context can only be used with '
'multi-process loading (num_workers > 0), but got '
'num_workers={}').format(self.num_workers))
self.__multiprocessing_context = multiprocessing_context
def __setattr__(self, attr, val):
if self.__initialized and attr in ('batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset'):
raise ValueError('{} attribute should not be set after {} is '
'initialized'.format(attr, self.__class__.__name__))
super(MyDataLoader, self).__setattr__(attr, val)
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
@property
def _auto_collation(self):
return self.batch_sampler is not None
@property
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
def __len__(self):
if self._dataset_kind == _DatasetKind.Iterable:
# NOTE [ IterableDataset and __len__ ]
#
# For `IterableDataset`, `__len__` could be inaccurate when one naively
# does multi-processing data loading, since the samples will be duplicated.
# However, no real use case should be actually using that behavior, so
# it should count as a user error. We should generally trust user
# code to do the proper thing (e.g., configure each replica differently
# in `__iter__`), and give us the correct `__len__` if they choose to
# implement it (this will still throw if the dataset does not implement
# a `__len__`).
#
# To provide a further warning, we track if `__len__` was called on the
# `DataLoader`, save the returned value in `self._len_called`, and warn
# if the iterator ends up yielding more than this number of samples.
length = self._IterableDataset_len_called = len(self.dataset)
return length
else:
return len(self._index_sampler)
class _BaseDataLoaderIter(object):
def __init__(self, loader):
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
self._num_yielded = 0
def __iter__(self):
return self
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError
def __next__(self):
data = self._next_data()
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
self._num_yielded)
if self._num_workers > 0:
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
"IterableDataset replica at each worker. Please see "
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
warnings.warn(warn_msg)
return data
next = __next__ # Python 2 compatibility
def __len__(self):
return len(self._index_sampler)
def __getstate__(self):
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
#
# Preliminary:
#
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`.
#
#
# Terminating multiprocessing logic requires very careful design. In
# particular, we need to make sure that
#
# 1. The iterator gracefully exits the workers when its last reference is
# gone or it is depleted.
#
# In this case, the workers should be gracefully exited because the
# main process may still need to continue to run, and we want cleaning
# up code in the workers to be executed (e.g., releasing GPU memory).
# Naturally, we implement the shutdown logic in `__del__` of
# DataLoaderIterator.
#
# We delay the discussion on the logic in this case until later.
#
# 2. The iterator exits the workers when the loader process and/or worker
# processes exits normally or with error.
#
# We set all workers and `pin_memory_thread` to have `daemon=True`.
#
# You may ask, why can't we make the workers non-daemonic, and
# gracefully exit using the same logic as we have in `__del__` when the
# iterator gets deleted (see 1 above)?
#
# First of all, `__del__` is **not** guaranteed to be called when
# interpreter exits. Even if it is called, by the time it executes,
# many Python core library resources may alreay be freed, and even
# simple things like acquiring an internal lock of a queue may hang.
# Therefore, in this case, we actually need to prevent `__del__` from
# being executed, and rely on the automatic termination of daemonic
# children. Thus, we register an `atexit` hook that sets a global flag
# `_utils.python_exit_status`. Since `atexit` hooks are executed in the
# reverse order of registration, we are guaranteed that this flag is
# set before library resources we use are freed. (Hooks freeing those
# resources are registered at importing the Python core libraries at
# the top of this file.) So in `__del__`, we check if
# `_utils.python_exit_status` is set or `None` (freed), and perform
# no-op if so.
#
# Another problem with `__del__` is also related to the library cleanup
# calls. When a process ends, it shuts the all its daemonic children
# down with a SIGTERM (instead of joining them without a timeout).
# Simiarly for threads, but by a different mechanism. This fact,
# together with a few implementation details of multiprocessing, forces
# us to make workers daemonic. All of our problems arise when a
# DataLoader is used in a subprocess, and are caused by multiprocessing
# code which looks more or less like this:
#
# try:
# your_function_using_a_dataloader()
# finally:
# multiprocessing.util._exit_function()
#
# The joining/termination mentioned above happens inside
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
# throws, the stack trace stored in the exception will prevent the
# frame which uses `DataLoaderIter` to be freed. If the frame has any
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
# its `__del__`, which starts the shutdown procedure, will not be
# called. That, in turn, means that workers aren't notified. Attempting
# to join in `_exit_function` will then result in a hang.
#
# For context, `_exit_function` is also registered as an `atexit` call.
# So it is unclear to me (@ssnl) why this is needed in a finally block.
# The code dates back to 2008 and there is no comment on the original
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
# the finally block and the `atexit` registration) that explains this.
#
# Another choice is to just shutdown workers with logic in 1 above
# whenever we see an error in `next`. This isn't ideal because
# a. It prevents users from using try-catch to resume data loading.
# b. It doesn't prevent hanging if users have references to the
# iterator.
#
# 3. All processes exit if any of them die unexpectedly by fatal signals.
#
# As shown above, the workers are set as daemonic children of the main
# process. However, automatic cleaning-up of such child processes only
# happens if the parent process exits gracefully (e.g., not via fatal
# signals like SIGKILL). So we must ensure that each process will exit
# even the process that should send/receive data to/from it were
# killed, i.e.,
#
# a. A process won't hang when getting from a queue.
#
# Even with carefully designed data dependencies (i.e., a `put()`
# always corresponding to a `get()`), hanging on `get()` can still
# happen when data in queue is corrupted (e.g., due to
# `cancel_join_thread` or unexpected exit).
#
# For child exit, we set a timeout whenever we try to get data
# from `data_queue`, and check the workers' status on each timeout
# and error.
# See `_DataLoaderiter._get_batch()` and
# `_DataLoaderiter._try_get_data()` for details.
#
# Additionally, for child exit on non-Windows platforms, we also
# register a SIGCHLD handler (which is supported on Windows) on
# the main process, which checks if any of the workers fail in the
# (Python) handler. This is more efficient and faster in detecting
# worker failures, compared to only using the above mechanism.
# See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
#
# For `.get()` calls where the sender(s) is not the workers, we
# guard them with timeouts, and check the status of the sender
# when timeout happens:
# + in the workers, the `_utils.worker.ManagerWatchdog` class
# checks the status of the main process.
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
# check `pin_memory_thread` status periodically until `.get()`
# returns or see that `pin_memory_thread` died.
#
# b. A process won't hang when putting into a queue;
#
# We use `mp.Queue` which has a separate background thread to put
# objects from an unbounded buffer array. The background thread is
# daemonic and usually automatically joined when the process
# exits.
#
# However, in case that the receiver has ended abruptly while
# reading from the pipe, the join will hang forever. Therefore,
# for both `worker_result_queue` (worker -> main process/pin_memory_thread)
# and each `index_queue` (main process -> worker), we use
# `q.cancel_join_thread()` in sender process before any `q.put` to
# prevent this automatic join.
#
# Moreover, having all queues called `cancel_join_thread` makes
# implementing graceful shutdown logic in `__del__` much easier.
# It won't need to get from any queue, which would also need to be
# guarded by periodic status checks.
#
# Nonetheless, `cancel_join_thread` must only be called when the
# queue is **not** going to be read from or write into by another
# process, because it may hold onto a lock or leave corrupted data
# in the queue, leading other readers/writers to hang.
#
# `pin_memory_thread`'s `data_queue` is a `queue.Queue` that does
# a blocking `put` if the queue is full. So there is no above
# problem, but we do need to wrap the `put` in a loop that breaks
# not only upon success, but also when the main process stops
# reading, i.e., is shutting down.
#
#
# Now let's get back to 1:
# how we gracefully exit the workers when the last reference to the
# iterator is gone.
#
# To achieve this, we implement the following logic along with the design
# choices mentioned above:
#
# `workers_done_event`:
# A `multiprocessing.Event` shared among the main process and all worker
# processes. This is used to signal the workers that the iterator is
# shutting down. After it is set, they will not send processed data to
# queues anymore, and only wait for the final `None` before exiting.
# `done_event` isn't strictly needed. I.e., we can just check for `None`
# from the input queue, but it allows us to skip wasting resources
# processing data if we are already shutting down.
#
# `pin_memory_thread_done_event`:
# A `threading.Event` for a similar purpose to that of
# `workers_done_event`, but is for the `pin_memory_thread`. The reason
# that separate events are needed is that `pin_memory_thread` reads from
# the output queue of the workers. But the workers, upon seeing that
# `workers_done_event` is set, only wants to see the final `None`, and is
# not required to flush all data in the output queue (e.g., it may call
# `cancel_join_thread` on that queue if its `IterableDataset` iterator
# happens to exhaust coincidentally, which is out of the control of the
# main process). Thus, since we will exit `pin_memory_thread` before the
# workers (see below), two separete events are used.
#
# NOTE: In short, the protocol is that the main process will set these
# `done_event`s and then the corresponding processes/threads a `None`,
# and that they may exit at any time after receiving the `None`.
#
# NOTE: Using `None` as the final signal is valid, since normal data will
# always be a 2-tuple with the 1st element being the index of the data
# transferred (different from dataset index/key), and the 2nd being
# either the dataset key or the data sample (depending on which part
# of the data model the queue is at).
#
# [ worker processes ]
# While loader process is alive:
# Get from `index_queue`.
# If get anything else,
# Check `workers_done_event`.
# If set, continue to next iteration
# i.e., keep getting until see the `None`, then exit.
# Otherwise, process data:
# If is fetching from an `IterableDataset` and the iterator
# is exhausted, send an `_IterableDatasetStopIteration`
# object to signal iteration end. The main process, upon
# receiving such an object, will send `None` to this
# worker and not use the corresponding `index_queue`
# anymore.
# If timed out,
# No matter `workers_done_event` is set (still need to see `None`)
# or not, must continue to next iteration.
# (outside loop)
# If `workers_done_event` is set, (this can be False with `IterableDataset`)
# `data_queue.cancel_join_thread()`. (Everything is ending here:
# main process won't read from it;
# other workers will also call
# `cancel_join_thread`.)
#
# [ pin_memory_thread ]
# # No need to check main thread. If this thread is alive, the main loader
# # thread must be alive, because this thread is set as daemonic.
# While `pin_memory_thread_done_event` is not set:
# Get from `index_queue`.
# If timed out, continue to get in the next iteration.
# Otherwise, process data.
# While `pin_memory_thread_done_event` is not set:
# Put processed data to `data_queue` (a `queue.Queue` with blocking put)
# If timed out, continue to put in the next iteration.
# Otherwise, break, i.e., continuing to the out loop.
#
# NOTE: we don't check the status of the main thread because
# 1. if the process is killed by fatal signal, `pin_memory_thread`
# ends.
# 2. in other cases, either the cleaning-up in __del__ or the
# automatic exit of daemonic thread will take care of it.
# This won't busy-wait either because `.get(timeout)` does not
# busy-wait.
#
# [ main process ]
# In the DataLoader Iter's `__del__`
# b. Exit `pin_memory_thread`
# i. Set `pin_memory_thread_done_event`.
# ii Put `None` in `worker_result_queue`.
# iii. Join the `pin_memory_thread`.
# iv. `worker_result_queue.cancel_join_thread()`.
#
# c. Exit the workers.
# i. Set `workers_done_event`.
# ii. Put `None` in each worker's `index_queue`.
# iii. Join the workers.
# iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
#
# NOTE: (c) is better placed after (b) because it may leave corrupted
# data in `worker_result_queue`, which `pin_memory_thread`
# reads from, in which case the `pin_memory_thread` can only
# happen at timeing out, which is slow. Nonetheless, same thing
# happens if a worker is killed by signal at unfortunate times,
# but in other cases, we are better off having a non-corrupted
# `worker_result_queue` for `pin_memory_thread`.
#
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
# can be omitted
#
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
# `None` from `index_queue`, but it allows us to skip wasting resources
# processing indices already in `index_queue` if we are already shutting
# down.
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
assert self._num_workers > 0
if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
else:
multiprocessing_context = loader.multiprocessing_context
self._worker_init_fn = loader.worker_init_fn
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
self._worker_result_queue = multiprocessing_context.Queue()
self._worker_pids_set = False
self._shutdown = False
self._send_idx = 0 # idx of the next task to be sent to workers
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
# \ (worker_id, data) if data is already fetched (out-of-order)
self._task_info = {}
self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
self._workers_done_event = multiprocessing_context.Event()
self._index_queues = []
self._workers = []
# A list of booleans representing whether each worker still has work to
# do, i.e., not having exhausted its iterable dataset object. It always
# contains all `True`s if not using an iterable-style dataset
# (i.e., if kind != Iterable).
self._workers_status = []
for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue()
# index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=worker_loop,
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed + i, self._worker_init_fn, i, self._num_workers))
w.daemon = True
# NB: Process.start() actually take some time as it needs to
# start a process and pass the arguments over via a pipe.
# Therefore, we only add a worker to self._workers list after
# it started, so that we do not call .join() if program dies
# before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
self._workers_status.append(True)
if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
self._data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
self._pin_memory_thread_done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True
# prime the prefetch loop
for _ in range(2 * self._num_workers):
self._try_put_index()
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
# Tries to fetch data from `self._data_queue` once for a given timeout.
# This can also be used as inner loop of fetching without timeout, with
# the sender status as the loop condition.
#
# This raises a `RuntimeError` if any worker died expectedly. This error
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
# (only for non-Windows platforms), or the manual check below on errors
# and timeouts.
#
# Returns a 2-tuple:
# (bool: whether successfully get data, any: data if successful else None)
try:
data = self._data_queue.get(timeout=timeout)
return (True, data)
except Exception as e:
# At timeout and error, we manually check whether any worker has
# failed. Note that this is the only mechanism for Windows to detect
# worker failures.
failed_workers = []
for worker_id, w in enumerate(self._workers):
if self._workers_status[worker_id] and not w.is_alive():
failed_workers.append(w)
self._shutdown_worker(worker_id)
if len(failed_workers) > 0:
pids_str = ', '.join(str(w.pid) for w in failed_workers)
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
if isinstance(e, queue.Empty):
return (False, None)
raise
def _get_data(self):
# Fetches data from `self._data_queue`.
#
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
# in a loop. This is the only mechanism to detect worker failures for
# Windows. For other platforms, a SIGCHLD handler is also used for
# worker failure detection.
#
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
# died at timeouts.
if self._timeout > 0:
success, data = self._try_get_data(self._timeout)
if success:
return data
else:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
elif self._pin_memory:
while self._pin_memory_thread.is_alive():
success, data = self._try_get_data()
if success:
return data
else:
# while condition is false, i.e., pin_memory_thread died.
raise RuntimeError('Pin memory thread exited unexpectedly')
# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
# need to call `.task_done()` because we don't use `.join()`.
else:
while True:
success, data = self._try_get_data()
if success:
return data
def _next_data(self):
while True:
# If the worker responsible for `self._rcvd_idx` has already ended
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
# we try to advance `self._rcvd_idx` to find the next valid index.
#
# This part needs to run in the loop because both the `self._get_data()`
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead.
while self._rcvd_idx < self._send_idx:
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
self._shutdown_workers()
raise StopIteration
# Now `self._rcvd_idx` is the batch index we want to fetch
# Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data)
assert not self._shutdown and self._tasks_outstanding > 0
idx, data = self._get_data()
self._tasks_outstanding -= 1
if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
self._shutdown_worker(data.worker_id)
self._try_put_index()
continue
if idx != self._rcvd_idx:
# store out-of-order samples
self._task_info[idx] += (data,)
else:
del self._task_info[idx]
return self._process_data(data)
def _try_put_index(self):
assert self._tasks_outstanding < 2 * self._num_workers
try:
index = self._next_index()
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
else:
# not found (i.e., didn't break)
return
self._index_queues[worker_queue_idx].put((self._send_idx, index))
self._task_info[self._send_idx] = (worker_queue_idx,)
self._tasks_outstanding += 1
self._send_idx += 1
def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index()
if isinstance(data, ExceptionWrapper):
data.reraise()
return data
def _shutdown_worker(self, worker_id):
# Mark a worker as having finished its work and dead, e.g., due to
# exhausting an `IterableDataset`. This should be used only when this
# `_MultiProcessingDataLoaderIter` is going to continue running.
assert self._workers_status[worker_id]
# Signal termination to that specific worker.
q = self._index_queues[worker_id]
# Indicate that no more data will be put on this queue by the current
# process.
q.put(None)
# Note that we don't actually join the worker here, nor do we remove the
# worker's pid from C side struct because (1) joining may be slow, and
# (2) since we don't join, the worker may still raise error, and we
# prefer capturing those, rather than ignoring them, even though they
# are raised after the worker has finished its job.
# Joinning is deferred to `_shutdown_workers`, which it is called when
# all workers finish their jobs (e.g., `IterableDataset` replicas) or
# when this iterator is garbage collected.
self._workers_status[worker_id] = False
def _shutdown_workers(self):
# Called when shutting down this `_MultiProcessingDataLoaderIter`.
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
# the logic of this function.
python_exit_status = _utils.python_exit_status
if python_exit_status is True or python_exit_status is None:
# See (2) of the note. If Python is shutting down, do no-op.
return
# Normal exit when last reference is gone / iterator is depleted.
# See (1) and the second half of the note.
if not self._shutdown:
self._shutdown = True
try:
# Exit `pin_memory_thread` first because exiting workers may leave
# corrupted data in `worker_result_queue` which `pin_memory_thread`
# reads from.
if hasattr(self, '_pin_memory_thread'):
# Use hasattr in case error happens before we set the attribute.
self._pin_memory_thread_done_event.set()
# Send something to pin_memory_thread in case it is waiting
# so that it can wake up and check `pin_memory_thread_done_event`
self._worker_result_queue.put((None, None))
self._pin_memory_thread.join()
self._worker_result_queue.close()
# Exit workers now.
self._workers_done_event.set()
for worker_id in range(len(self._workers)):
# Get number of workers from `len(self._workers)` instead of
# `self._num_workers` in case we error before starting all
# workers.
if self._workers_status[worker_id]:
self._shutdown_worker(worker_id)
for w in self._workers:
w.join()
for q in self._index_queues:
q.cancel_join_thread()
q.close()
finally:
# Even though all this function does is putting into queues that
# we have called `cancel_join_thread` on, weird things can
# happen when a worker is killed by a signal, e.g., hanging in
# `Event.set()`. So we need to guard this with SIGCHLD handler,
# and remove pids from the C side data structure only at the
# end.
#
# FIXME: Unfortunately, for Windows, we are missing a worker
# error detection mechanism here in this function, as it
# doesn't provide a SIGCHLD handler.
if self._worker_pids_set:
_utils.signal_handling._remove_worker_pids(id(self))
self._worker_pids_set = False
def __del__(self):
self._shutdown_workers()

View File

@@ -0,0 +1,207 @@
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""
import torch
import random
import os
from collections import namedtuple
from torch._six import queue
from torch._utils import ExceptionWrapper
from torch.utils.data._utils import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS
from .my_random_resize_crop import MyRandomResizedCrop
__all__ = ['worker_loop']
if IS_WINDOWS:
import ctypes
from ctypes.wintypes import DWORD, BOOL, HANDLE
# On Windows, the parent ID of the worker process remains unchanged when the manager process
# is gone, and the only way to check it through OS is to let the worker have a process handle
# of the manager and ask if the process status has changed.
class ManagerWatchdog(object):
def __init__(self):
self.manager_pid = os.getppid()
self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
self.kernel32.OpenProcess.restype = HANDLE
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
self.kernel32.WaitForSingleObject.restype = DWORD
# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
SYNCHRONIZE = 0x00100000
self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
if not self.manager_handle:
raise ctypes.WinError(ctypes.get_last_error())
self.manager_dead = False
def is_alive(self):
if not self.manager_dead:
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
return not self.manager_dead
else:
class ManagerWatchdog(object):
def __init__(self):
self.manager_pid = os.getppid()
self.manager_dead = False
def is_alive(self):
if not self.manager_dead:
self.manager_dead = os.getppid() != self.manager_pid
return not self.manager_dead
_worker_info = None
class WorkerInfo(object):
__initialized = False
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
self.__initialized = True
def __setattr__(self, key, val):
if self.__initialized:
raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__))
return super(WorkerInfo, self).__setattr__(key, val)
def get_worker_info():
r"""Returns the information about the current
:class:`~torch.utils.data.DataLoader` iterator worker process.
When called in a worker, this returns an object guaranteed to have the
following attributes:
* :attr:`id`: the current worker id.
* :attr:`num_workers`: the total number of workers.
* :attr:`seed`: the random seed set for the current worker. This value is
determined by main process RNG and the worker id. See
:class:`~torch.utils.data.DataLoader`'s documentation for more details.
* :attr:`dataset`: the copy of the dataset object in **this** process. Note
that this will be a different object in a different process than the one
in the main process.
When called in the main process, this returns ``None``.
.. note::
When used in a :attr:`worker_init_fn` passed over to
:class:`~torch.utils.data.DataLoader`, this method can be useful to
set up each worker process differently, for instance, using ``worker_id``
to configure the ``dataset`` object to only read a specific fraction of a
sharded dataset, or use ``seed`` to seed other libraries used in dataset
code (e.g., NumPy).
"""
return _worker_info
r"""Dummy class used to signal the end of an IterableDataset"""
_IterableDatasetStopIteration = namedtuple('_IterableDatasetStopIteration', ['worker_id'])
def worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
auto_collation, collate_fn, drop_last, seed, init_fn, worker_id,
num_workers):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
try:
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal had already happened
# again.
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
signal_handling._set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
global _worker_info
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
seed=seed, dataset=dataset)
from torch.utils.data import _DatasetKind
init_exception = None
try:
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
except Exception:
init_exception = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id))
# When using Iterable mode, some worker can exit earlier than others due
# to the IterableDataset behaving differently for different workers.
# When such things happen, an `_IterableDatasetStopIteration` object is
# sent over to the main process with the ID of this worker, so that the
# main process won't send more tasks to this worker, and will send
# `None` to this worker to properly exit it.
#
# Note that we cannot set `done_event` from a worker as it is shared
# among all processes. Instead, we set the `iteration_end` flag to
# signify that the iterator is exhausted. When either `done_event` or
# `iteration_end` is set, we skip all processing step and just wait for
# `None`.
iteration_end = False
watchdog = ManagerWatchdog()
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if r is None:
# Received the final signal
assert done_event.is_set() or iteration_end
break
elif done_event.is_set() or iteration_end:
# `done_event` is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, index = r
""" Added """
MyRandomResizedCrop.sample_image_size(idx)
""" Added """
if init_exception is not None:
data = init_exception
init_exception = None
else:
try:
data = fetcher.fetch(index)
except Exception as e:
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
data = _IterableDatasetStopIteration(worker_id)
# Set `iteration_end`
# (1) to save future `next(...)` calls, and
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
iteration_end = True
else:
# It is important that we don't store exc_info in a variable.
# `ExceptionWrapper` does the correct thing.
# See NOTE [ Python Traceback Reference Cycle Problem ]
data = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id))
data_queue.put((idx, data))
del data, idx, index, r # save memory
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
if done_event.is_set():
data_queue.cancel_join_thread()
data_queue.close()

View File

@@ -0,0 +1,69 @@
import math
import torch
from torch.utils.data.distributed import DistributedSampler
__all__ = ['MyDistributedSampler', 'WeightedDistributedSampler']
class MyDistributedSampler(DistributedSampler):
""" Allow Subset Sampler in Distributed Training """
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True,
sub_index_list=None):
super(MyDistributedSampler, self).__init__(dataset, num_replicas, rank, shuffle)
self.sub_index_list = sub_index_list # numpy
self.num_samples = int(math.ceil(len(self.sub_index_list) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
print('Use MyDistributedSampler: %d, %d' % (self.num_samples, self.total_size))
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.sub_index_list), generator=g).tolist()
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
indices = self.sub_index_list[indices].tolist()
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
class WeightedDistributedSampler(DistributedSampler):
""" Allow Weighted Random Sampling in Distributed Training """
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True,
weights=None, replacement=True):
super(WeightedDistributedSampler, self).__init__(dataset, num_replicas, rank, shuffle)
self.weights = torch.as_tensor(weights, dtype=torch.double) if weights is not None else None
self.replacement = replacement
print('Use WeightedDistributedSampler')
def __iter__(self):
if self.weights is None:
return super(WeightedDistributedSampler, self).__iter__()
else:
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
# original: indices = torch.randperm(len(self.dataset), generator=g).tolist()
indices = torch.multinomial(self.weights, len(self.dataset), self.replacement, generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)

View File

@@ -0,0 +1,136 @@
import time
import random
import math
import os
from PIL import Image
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
__all__ = ['MyRandomResizedCrop', 'MyResizeRandomCrop', 'MyResize']
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
Image.HAMMING: 'PIL.Image.HAMMING',
Image.BOX: 'PIL.Image.BOX',
}
class MyRandomResizedCrop(transforms.RandomResizedCrop):
ACTIVE_SIZE = 224
IMAGE_SIZE_LIST = [224]
IMAGE_SIZE_SEG = 4
CONTINUOUS = False
SYNC_DISTRIBUTED = True
EPOCH = 0
BATCH = 0
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
if not isinstance(size, int):
size = size[0]
super(MyRandomResizedCrop, self).__init__(size, scale, ratio, interpolation)
def __call__(self, img):
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(
img, i, j, h, w, (MyRandomResizedCrop.ACTIVE_SIZE, MyRandomResizedCrop.ACTIVE_SIZE), self.interpolation
)
@staticmethod
def get_candidate_image_size():
if MyRandomResizedCrop.CONTINUOUS:
min_size = min(MyRandomResizedCrop.IMAGE_SIZE_LIST)
max_size = max(MyRandomResizedCrop.IMAGE_SIZE_LIST)
candidate_sizes = []
for i in range(min_size, max_size + 1):
if i % MyRandomResizedCrop.IMAGE_SIZE_SEG == 0:
candidate_sizes.append(i)
else:
candidate_sizes = MyRandomResizedCrop.IMAGE_SIZE_LIST
relative_probs = None
return candidate_sizes, relative_probs
@staticmethod
def sample_image_size(batch_id=None):
if batch_id is None:
batch_id = MyRandomResizedCrop.BATCH
if MyRandomResizedCrop.SYNC_DISTRIBUTED:
_seed = int('%d%.3d' % (batch_id, MyRandomResizedCrop.EPOCH))
else:
_seed = os.getpid() + time.time()
random.seed(_seed)
candidate_sizes, relative_probs = MyRandomResizedCrop.get_candidate_image_size()
MyRandomResizedCrop.ACTIVE_SIZE = random.choices(candidate_sizes, weights=relative_probs)[0]
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(MyRandomResizedCrop.IMAGE_SIZE_LIST)
if MyRandomResizedCrop.CONTINUOUS:
format_string += '@continuous'
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string
class MyResizeRandomCrop(object):
def __init__(self, interpolation=Image.BILINEAR,
use_padding=False, pad_if_needed=False, fill=0, padding_mode='constant'):
# resize
self.interpolation = interpolation
# random crop
self.use_padding = use_padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode
def __call__(self, img):
crop_size = MyRandomResizedCrop.ACTIVE_SIZE
if not self.use_padding:
resize_size = int(math.ceil(crop_size / 0.875))
img = F.resize(img, resize_size, self.interpolation)
else:
img = F.resize(img, crop_size, self.interpolation)
padding_size = crop_size // 8
img = F.pad(img, padding_size, self.fill, self.padding_mode)
# pad the width if needed
if self.pad_if_needed and img.size[0] < crop_size:
img = F.pad(img, (crop_size - img.size[0], 0), self.fill, self.padding_mode)
# pad the height if needed
if self.pad_if_needed and img.size[1] < crop_size:
img = F.pad(img, (0, crop_size - img.size[1]), self.fill, self.padding_mode)
i, j, h, w = transforms.RandomCrop.get_params(img, (crop_size, crop_size))
return F.crop(img, i, j, h, w)
def __repr__(self):
return 'MyResizeRandomCrop(size=%s%s, interpolation=%s, use_padding=%s, fill=%s)' % (
MyRandomResizedCrop.IMAGE_SIZE_LIST, '@continuous' if MyRandomResizedCrop.CONTINUOUS else '',
_pil_interpolation_to_str[self.interpolation], self.use_padding, self.fill,
)
class MyResize(object):
def __init__(self, interpolation=Image.BILINEAR):
self.interpolation = interpolation
def __call__(self, img):
target_size = MyRandomResizedCrop.ACTIVE_SIZE
img = F.resize(img, target_size, self.interpolation)
return img
def __repr__(self):
return 'MyResize(size=%s%s, interpolation=%s)' % (
MyRandomResizedCrop.IMAGE_SIZE_LIST, '@continuous' if MyRandomResizedCrop.CONTINUOUS else '',
_pil_interpolation_to_str[self.interpolation]
)

View File

@@ -0,0 +1,238 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import math
import torch.nn as nn
import torch.nn.functional as F
from .common_tools import min_divisible_value
__all__ = ['MyModule', 'MyNetwork', 'init_models', 'set_bn_param', 'get_bn_param', 'replace_bn_with_gn',
'MyConv2d', 'replace_conv2d_with_my_conv2d']
def set_bn_param(net, momentum, eps, gn_channel_per_group=None, ws_eps=None, **kwargs):
replace_bn_with_gn(net, gn_channel_per_group)
for m in net.modules():
if type(m) in [nn.BatchNorm1d, nn.BatchNorm2d]:
m.momentum = momentum
m.eps = eps
elif isinstance(m, nn.GroupNorm):
m.eps = eps
replace_conv2d_with_my_conv2d(net, ws_eps)
return
def get_bn_param(net):
ws_eps = None
for m in net.modules():
if isinstance(m, MyConv2d):
ws_eps = m.WS_EPS
break
for m in net.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
return {
'momentum': m.momentum,
'eps': m.eps,
'ws_eps': ws_eps,
}
elif isinstance(m, nn.GroupNorm):
return {
'momentum': None,
'eps': m.eps,
'gn_channel_per_group': m.num_channels // m.num_groups,
'ws_eps': ws_eps,
}
return None
def replace_bn_with_gn(model, gn_channel_per_group):
if gn_channel_per_group is None:
return
for m in model.modules():
to_replace_dict = {}
for name, sub_m in m.named_children():
if isinstance(sub_m, nn.BatchNorm2d):
num_groups = sub_m.num_features // min_divisible_value(sub_m.num_features, gn_channel_per_group)
gn_m = nn.GroupNorm(num_groups=num_groups, num_channels=sub_m.num_features, eps=sub_m.eps, affine=True)
# load weight
gn_m.weight.data.copy_(sub_m.weight.data)
gn_m.bias.data.copy_(sub_m.bias.data)
# load requires_grad
gn_m.weight.requires_grad = sub_m.weight.requires_grad
gn_m.bias.requires_grad = sub_m.bias.requires_grad
to_replace_dict[name] = gn_m
m._modules.update(to_replace_dict)
def replace_conv2d_with_my_conv2d(net, ws_eps=None):
if ws_eps is None:
return
for m in net.modules():
to_update_dict = {}
for name, sub_module in m.named_children():
if isinstance(sub_module, nn.Conv2d) and not sub_module.bias:
# only replace conv2d layers that are followed by normalization layers (i.e., no bias)
to_update_dict[name] = sub_module
for name, sub_module in to_update_dict.items():
m._modules[name] = MyConv2d(
sub_module.in_channels, sub_module.out_channels, sub_module.kernel_size, sub_module.stride,
sub_module.padding, sub_module.dilation, sub_module.groups, sub_module.bias,
)
# load weight
m._modules[name].load_state_dict(sub_module.state_dict())
# load requires_grad
m._modules[name].weight.requires_grad = sub_module.weight.requires_grad
if sub_module.bias is not None:
m._modules[name].bias.requires_grad = sub_module.bias.requires_grad
# set ws_eps
for m in net.modules():
if isinstance(m, MyConv2d):
m.WS_EPS = ws_eps
def init_models(net, model_init='he_fout'):
"""
Conv2d,
BatchNorm2d, BatchNorm1d, GroupNorm
Linear,
"""
if isinstance(net, list):
for sub_net in net:
init_models(sub_net, model_init)
return
for m in net.modules():
if isinstance(m, nn.Conv2d):
if model_init == 'he_fout':
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif model_init == 'he_fin':
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
else:
raise NotImplementedError
if m.bias is not None:
m.bias.data.zero_()
elif type(m) in [nn.BatchNorm2d, nn.BatchNorm1d, nn.GroupNorm]:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
stdv = 1. / math.sqrt(m.weight.size(1))
m.weight.data.uniform_(-stdv, stdv)
if m.bias is not None:
m.bias.data.zero_()
class MyConv2d(nn.Conv2d):
"""
Conv2d with Weight Standardization
https://github.com/joe-siyuan-qiao/WeightStandardization
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(MyConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.WS_EPS = None
def weight_standardization(self, weight):
if self.WS_EPS is not None:
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
weight = weight - weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + self.WS_EPS
weight = weight / std.expand_as(weight)
return weight
def forward(self, x):
if self.WS_EPS is None:
return super(MyConv2d, self).forward(x)
else:
return F.conv2d(x, self.weight_standardization(self.weight), self.bias,
self.stride, self.padding, self.dilation, self.groups)
def __repr__(self):
return super(MyConv2d, self).__repr__()[:-1] + ', ws_eps=%s)' % self.WS_EPS
class MyModule(nn.Module):
def forward(self, x):
raise NotImplementedError
@property
def module_str(self):
raise NotImplementedError
@property
def config(self):
raise NotImplementedError
@staticmethod
def build_from_config(config):
raise NotImplementedError
class MyNetwork(MyModule):
CHANNEL_DIVISIBLE = 8
def forward(self, x):
raise NotImplementedError
@property
def module_str(self):
raise NotImplementedError
@property
def config(self):
raise NotImplementedError
@staticmethod
def build_from_config(config):
raise NotImplementedError
def zero_last_gamma(self):
raise NotImplementedError
@property
def grouped_block_index(self):
raise NotImplementedError
""" implemented methods """
def set_bn_param(self, momentum, eps, gn_channel_per_group=None, **kwargs):
set_bn_param(self, momentum, eps, gn_channel_per_group, **kwargs)
def get_bn_param(self):
return get_bn_param(self)
def get_parameters(self, keys=None, mode='include'):
if keys is None:
for name, param in self.named_parameters():
if param.requires_grad: yield param
elif mode == 'include':
for name, param in self.named_parameters():
flag = False
for key in keys:
if key in name:
flag = True
break
if flag and param.requires_grad: yield param
elif mode == 'exclude':
for name, param in self.named_parameters():
flag = True
for key in keys:
if key in name:
flag = False
break
if flag and param.requires_grad: yield param
else:
raise ValueError('do not support: %s' % mode)
def weight_parameters(self):
return self.get_parameters()

View File

@@ -0,0 +1,154 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from .my_modules import MyNetwork
__all__ = [
'make_divisible', 'build_activation', 'ShuffleLayer', 'MyGlobalAvgPool2d', 'Hswish', 'Hsigmoid', 'SEModule',
'MultiHeadCrossEntropyLoss'
]
def make_divisible(v, divisor, min_val=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_val:
:return:
"""
if min_val is None:
min_val = divisor
new_v = max(min_val, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def build_activation(act_func, inplace=True):
if act_func == 'relu':
return nn.ReLU(inplace=inplace)
elif act_func == 'relu6':
return nn.ReLU6(inplace=inplace)
elif act_func == 'tanh':
return nn.Tanh()
elif act_func == 'sigmoid':
return nn.Sigmoid()
elif act_func == 'h_swish':
return Hswish(inplace=inplace)
elif act_func == 'h_sigmoid':
return Hsigmoid(inplace=inplace)
elif act_func is None or act_func == 'none':
return None
else:
raise ValueError('do not support: %s' % act_func)
class ShuffleLayer(nn.Module):
def __init__(self, groups):
super(ShuffleLayer, self).__init__()
self.groups = groups
def forward(self, x):
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // self.groups
# reshape
x = x.view(batch_size, self.groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batch_size, -1, height, width)
return x
def __repr__(self):
return 'ShuffleLayer(groups=%d)' % self.groups
class MyGlobalAvgPool2d(nn.Module):
def __init__(self, keep_dim=True):
super(MyGlobalAvgPool2d, self).__init__()
self.keep_dim = keep_dim
def forward(self, x):
return x.mean(3, keepdim=self.keep_dim).mean(2, keepdim=self.keep_dim)
def __repr__(self):
return 'MyGlobalAvgPool2d(keep_dim=%s)' % self.keep_dim
class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3., inplace=self.inplace) / 6.
def __repr__(self):
return 'Hswish()'
class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return F.relu6(x + 3., inplace=self.inplace) / 6.
def __repr__(self):
return 'Hsigmoid()'
class SEModule(nn.Module):
REDUCTION = 4
def __init__(self, channel, reduction=None):
super(SEModule, self).__init__()
self.channel = channel
self.reduction = SEModule.REDUCTION if reduction is None else reduction
num_mid = make_divisible(self.channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE)
self.fc = nn.Sequential(OrderedDict([
('reduce', nn.Conv2d(self.channel, num_mid, 1, 1, 0, bias=True)),
('relu', nn.ReLU(inplace=True)),
('expand', nn.Conv2d(num_mid, self.channel, 1, 1, 0, bias=True)),
('h_sigmoid', Hsigmoid(inplace=True)),
]))
def forward(self, x):
y = x.mean(3, keepdim=True).mean(2, keepdim=True)
y = self.fc(y)
return x * y
def __repr__(self):
return 'SE(channel=%d, reduction=%d)' % (self.channel, self.reduction)
class MultiHeadCrossEntropyLoss(nn.Module):
def forward(self, outputs, targets):
assert outputs.dim() == 3, outputs
assert targets.dim() == 2, targets
assert outputs.size(1) == targets.size(1), (outputs, targets)
num_heads = targets.size(1)
loss = 0
for k in range(num_heads):
loss += F.cross_entropy(outputs[:, k, :], targets[:, k]) / num_heads
return loss

View File

@@ -0,0 +1,218 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import math
import copy
import time
import torch
import torch.nn as nn
__all__ = [
'mix_images', 'mix_labels',
'label_smooth', 'cross_entropy_loss_with_soft_target', 'cross_entropy_with_label_smoothing',
'clean_num_batch_tracked', 'rm_bn_from_net',
'get_net_device', 'count_parameters', 'count_net_flops', 'measure_net_latency', 'get_net_info',
'build_optimizer', 'calc_learning_rate',
]
""" Mixup """
def mix_images(images, lam):
flipped_images = torch.flip(images, dims=[0]) # flip along the batch dimension
return lam * images + (1 - lam) * flipped_images
def mix_labels(target, lam, n_classes, label_smoothing=0.1):
onehot_target = label_smooth(target, n_classes, label_smoothing)
flipped_target = torch.flip(onehot_target, dims=[0])
return lam * onehot_target + (1 - lam) * flipped_target
""" Label smooth """
def label_smooth(target, n_classes: int, label_smoothing=0.1):
# convert to one-hot
batch_size = target.size(0)
target = torch.unsqueeze(target, 1)
soft_target = torch.zeros((batch_size, n_classes), device=target.device)
soft_target.scatter_(1, target, 1)
# label smoothing
soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
return soft_target
def cross_entropy_loss_with_soft_target(pred, soft_target):
logsoftmax = nn.LogSoftmax()
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
soft_target = label_smooth(target, pred.size(1), label_smoothing)
return cross_entropy_loss_with_soft_target(pred, soft_target)
""" BN related """
def clean_num_batch_tracked(net):
for m in net.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
if m.num_batches_tracked is not None:
m.num_batches_tracked.zero_()
def rm_bn_from_net(net):
for m in net.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
m.forward = lambda x: x
""" Network profiling """
def get_net_device(net):
return net.parameters().__next__().device
def count_parameters(net):
total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
return total_params
def count_net_flops(net, data_shape=(1, 3, 224, 224)):
from .flops_counter import profile
if isinstance(net, nn.DataParallel):
net = net.module
flop, _ = profile(copy.deepcopy(net), data_shape)
return flop
def measure_net_latency(net, l_type='gpu8', fast=True, input_shape=(3, 224, 224), clean=False):
if isinstance(net, nn.DataParallel):
net = net.module
# remove bn from graph
rm_bn_from_net(net)
# return `ms`
if 'gpu' in l_type:
l_type, batch_size = l_type[:3], int(l_type[3:])
else:
batch_size = 1
data_shape = [batch_size] + list(input_shape)
if l_type == 'cpu':
if fast:
n_warmup = 5
n_sample = 10
else:
n_warmup = 50
n_sample = 50
if get_net_device(net) != torch.device('cpu'):
if not clean:
print('move net to cpu for measuring cpu latency')
net = copy.deepcopy(net).cpu()
elif l_type == 'gpu':
if fast:
n_warmup = 5
n_sample = 10
else:
n_warmup = 50
n_sample = 50
else:
raise NotImplementedError
images = torch.zeros(data_shape, device=get_net_device(net))
measured_latency = {'warmup': [], 'sample': []}
net.eval()
with torch.no_grad():
for i in range(n_warmup):
inner_start_time = time.time()
net(images)
used_time = (time.time() - inner_start_time) * 1e3 # ms
measured_latency['warmup'].append(used_time)
if not clean:
print('Warmup %d: %.3f' % (i, used_time))
outer_start_time = time.time()
for i in range(n_sample):
net(images)
total_time = (time.time() - outer_start_time) * 1e3 # ms
measured_latency['sample'].append((total_time, n_sample))
return total_time / n_sample, measured_latency
def get_net_info(net, input_shape=(3, 224, 224), measure_latency=None, print_info=True):
net_info = {}
if isinstance(net, nn.DataParallel):
net = net.module
# parameters
net_info['params'] = count_parameters(net) / 1e6
# flops
net_info['flops'] = count_net_flops(net, [1] + list(input_shape)) / 1e6
# latencies
latency_types = [] if measure_latency is None else measure_latency.split('#')
for l_type in latency_types:
latency, measured_latency = measure_net_latency(net, l_type, fast=False, input_shape=input_shape)
net_info['%s latency' % l_type] = {
'val': latency,
'hist': measured_latency
}
if print_info:
print(net)
print('Total training params: %.2fM' % (net_info['params']))
print('Total FLOPs: %.2fM' % (net_info['flops']))
for l_type in latency_types:
print('Estimated %s latency: %.3fms' % (l_type, net_info['%s latency' % l_type]['val']))
return net_info
""" optimizer """
def build_optimizer(net_params, opt_type, opt_param, init_lr, weight_decay, no_decay_keys, seperate=1.0):
# enc_list, dec_list = [], []
# for name, param in model.named_parameters():
# if ('setenc' in name) or ('fc1' in name) or ('fc2' in name):
# enc_list.append(param)
# else:
# dec_list.append(param)
#optimizer = optim.Adam([{'params': dec_list, 'lr': args.dec_lr},
# {'params': enc_list, 'lr': args.enc_lr}], lr=1e-4)
if no_decay_keys is not None:
assert isinstance(net_params, list) and len(net_params) == 2
net_params = [
{'params': net_params[0], 'weight_decay': weight_decay},
{'params': net_params[1], 'weight_decay': 0},
]
elif seperate != 1.0:
net_params = [{'params': net_params[0], 'weight_decay': weight_decay, 'lr': init_lr * seperate},
{'params': net_params[1], 'weight_decay': weight_decay, 'lr': init_lr}]
else:
net_params = [{'params': net_params, 'weight_decay': weight_decay}]
if opt_type == 'sgd':
opt_param = {} if opt_param is None else opt_param
momentum, nesterov = opt_param.get('momentum', 0.9), opt_param.get('nesterov', True)
optimizer = torch.optim.SGD(net_params, init_lr, momentum=momentum, nesterov=nesterov)
elif opt_type == 'adam':
optimizer = torch.optim.Adam(net_params, init_lr)
else:
raise NotImplementedError
return optimizer
""" learning rate schedule """
def calc_learning_rate(epoch, init_lr, n_epochs, batch=0,
nBatch=None, lr_schedule_type='cosine', optimizer=None):
if lr_schedule_type == 'cosine':
t_total = n_epochs * nBatch
t_cur = epoch * nBatch + batch
lr = 0.5 * init_lr * (1 + math.cos(math.pi * t_cur / t_total))
elif lr_schedule_type == 'reduce':
for param_group in optimizer.param_groups:
lr = param_group['lr']
elif lr_schedule_type is None:
lr = init_lr
else:
raise ValueError('do not support: %s' % lr_schedule_type)
return lr

View File

@@ -0,0 +1,43 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import argparse
def str2bool(v):
return v.lower() in ['t', 'true', True]
def get_parser():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('--seed', type=int, default=333)
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
parser.add_argument('--model_name', type=str, default=None, choices=['generator', 'predictor', 'train_arch'])
parser.add_argument('--save-path', type=str, default='results', help='the path of save directory')
parser.add_argument('--data-path', type=str, default='data', help='the path of save directory')
parser.add_argument('--save-epoch', type=int, default=20, help='how many epochs to wait each time to save model states')
parser.add_argument('--max-epoch', type=int, default=400, help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=32, help='batch size for generator')
parser.add_argument('--graph-data-name', default='ofa_mbv3', help='graph dataset name')
parser.add_argument('--nvt', type=int, default=27, help='number of different node types, 21 for ofa_mbv3 without in/out node')
# set encoder
parser.add_argument('--num-sample', type=int, default=20, help='the number of images as input for set encoder')
# graph encoder
parser.add_argument('--hs', type=int, default=56, help='hidden size of GRUs')
parser.add_argument('--nz', type=int, default=56, help='the number of dimensions of latent vectors z')
# test
parser.add_argument('--test', action='store_true', default=False, help='turn on test mode')
parser.add_argument('--load-epoch', type=int, default=20, help='checkpoint epoch loaded for meta-test')
parser.add_argument('--data-name', type=str, default=None, help='meta-test dataset name')
parser.add_argument('--num-class', type=int, default=None, help='the number of class of dataset')
parser.add_argument('--num-gen-arch', type=int, default=200, help='the number of candidate architectures generated by the generator')
parser.add_argument('--train-arch', type=str2bool, default=True, help='whether to train the searched architecture')
# database
parser.add_argument('--index', type=int, default=None, help='the process number when creating DB')
parser.add_argument('--imgnet', type=str, default=None, help='The path of imagenet')
parser.add_argument('--collect', action='store_true', default=False, help='whether to train the searched architecture')
args = parser.parse_args()
return args

View File

@@ -0,0 +1,6 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from .predictor import Predictor
from .predictor_model import PredictorModel

View File

@@ -0,0 +1,172 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import torch
import os
import random
from tqdm import tqdm
import numpy as np
import time
import os
import shutil
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy.stats import pearsonr
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import load_graph_config, decode_ofa_mbv3_to_igraph
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import Log, get_log
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import load_model, save_model
from transfer_nag_lib.MetaD2A_mobilenetV3.loader import get_meta_train_loader
from .predictor_model import PredictorModel
from all_path import *
class Predictor:
def __init__(self, args):
self.args = args
self.batch_size = args.batch_size
self.data_path = args.data_path
self.num_sample = args.num_sample
self.max_epoch = args.max_epoch
self.save_epoch = args.save_epoch
self.model_path = UNNOISE_META_PREDICTOR_CKPT_PATH #MODEL_METAD2A_PATH_OFA
self.save_path = args.save_path
self.model_name = 'predictor'
self.test = args.test
self.device = torch.device("cuda:0")
self.max_corr_dict = {'corr': -1, 'epoch': -1}
self.train_arch = args.train_arch
graph_config = load_graph_config(
args.graph_data_name, args.nvt, args.data_path)
self.model = PredictorModel(args, graph_config)
self.model.to(self.device)
if self.test:
self.data_name = args.data_name
self.num_class = args.num_class
self.load_epoch = args.load_epoch
load_model(self.model, self.model_path, load_max_pt='ckpt_max_corr.pt')
self.model.to(self.device)
else:
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
self.scheduler = ReduceLROnPlateau(self.optimizer, 'min',
factor=0.1, patience=10, verbose=True)
self.mtrloader = get_meta_train_loader(
self.batch_size, self.data_path, self.num_sample, is_pred=True)
self.acc_mean = self.mtrloader.dataset.mean
self.acc_std = self.mtrloader.dataset.std
self.mtrlog = Log(self.args, open(os.path.join(
self.save_path, self.model_name, 'meta_train_predictor.log'), 'w'))
self.mtrlog.print_args()
def forward(self, x, arch):
D_mu = self.model.set_encode(x.unsqueeze(0).to(self.device)).unsqueeze(0)
G_mu = self.model.graph_encode(arch[0])
y_pred = self.model.predict(D_mu, G_mu)
return y_pred
def meta_train(self):
sttime = time.time()
for epoch in range(1, self.max_epoch + 1):
self.mtrlog.ep_sttime = time.time()
loss, corr = self.meta_train_epoch(epoch)
self.scheduler.step(loss)
self.mtrlog.print_pred_log(loss, corr, 'train', epoch)
valoss, vacorr = self.meta_validation(epoch)
if self.max_corr_dict['corr'] < vacorr:
self.max_corr_dict['corr'] = vacorr
self.max_corr_dict['epoch'] = epoch
self.max_corr_dict['loss'] = valoss
save_model(epoch, self.model, self.model_path, max_corr=True)
self.mtrlog.print_pred_log(
valoss, vacorr, 'valid', max_corr_dict=self.max_corr_dict)
if epoch % self.save_epoch == 0:
save_model(epoch, self.model, self.model_path)
self.mtrlog.save_time_log()
self.mtrlog.max_corr_log(self.max_corr_dict)
def meta_train_epoch(self, epoch):
self.model.to(self.device)
self.model.train()
self.mtrloader.dataset.set_mode('train')
dlen = len(self.mtrloader.dataset)
trloss = 0
y_all, y_pred_all = [], []
pbar = tqdm(self.mtrloader)
for batch in pbar:
batch_loss = 0
y_batch, y_pred_batch = [], []
self.optimizer.zero_grad()
for x, g, acc in batch:
y_pred = self.forward(x, decode_ofa_mbv3_to_igraph(g))
y = acc.to(self.device)
batch_loss += self.model.mseloss(y_pred, y)
y = y.squeeze().tolist()
y_pred = y_pred.squeeze().tolist()
y_batch.append(y)
y_pred_batch.append(y_pred)
y_all.append(y)
y_pred_all.append(y_pred)
batch_loss.backward()
trloss += float(batch_loss)
self.optimizer.step()
pbar.set_description(get_log(
epoch, batch_loss, y_pred_batch, y_batch, self.acc_std, self.acc_mean))
return trloss / dlen, pearsonr(np.array(y_all),
np.array(y_pred_all))[0]
def meta_validation(self, epoch):
self.model.to(self.device)
self.model.eval()
valoss = 0
self.mtrloader.dataset.set_mode('valid')
dlen = len(self.mtrloader.dataset)
y_all, y_pred_all = [], []
pbar = tqdm(self.mtrloader)
with torch.no_grad():
for batch in pbar:
batch_loss = 0
y_batch, y_pred_batch = [], []
for x, g, acc in batch:
y_pred = self.forward(x, decode_ofa_mbv3_to_igraph(g))
y = acc.to(self.device)
batch_loss += self.model.mseloss(y_pred, y)
y = y.squeeze().tolist()
y_pred = y_pred.squeeze().tolist()
y_batch.append(y)
y_pred_batch.append(y_pred)
y_all.append(y)
y_pred_all.append(y_pred)
valoss += float(batch_loss)
pbar.set_description(get_log(
epoch, batch_loss, y_pred_batch, y_batch, self.acc_std, self.acc_mean, tag='val'))
return valoss / dlen, pearsonr(np.array(y_all),
np.array(y_pred_all))[0]

View File

@@ -0,0 +1,241 @@
######################################################################################
# Copyright (c) muhanzhang, D-VAE, NeurIPS 2019 [GitHub D-VAE]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import torch
from torch import nn
from transfer_nag_lib.MetaD2A_mobilenetV3.set_encoder.setenc_models import SetPool
class PredictorModel(nn.Module):
def __init__(self, args, graph_config):
super(PredictorModel, self).__init__()
self.max_n = graph_config['max_n'] # maximum number of vertices
self.nvt = graph_config['num_vertex_type'] # number of vertex types
self.START_TYPE = graph_config['START_TYPE']
self.END_TYPE = graph_config['END_TYPE']
# import pdb; pdb.set_trace()
self.hs = args.hs # hidden state size of each vertex
self.nz = args.nz # size of latent representation z
self.gs = args.hs # size of graph state
self.bidir = True # whether to use bidirectional encoding
self.vid = True
self.device = None
self.input_type = 'DG'
self.num_sample = args.num_sample
if self.vid:
self.vs = self.hs + self.max_n # vertex state size = hidden state + vid
else:
self.vs = self.hs
# 0. encoding-related
self.grue_forward = nn.GRUCell(self.nvt, self.hs) # encoder GRU
self.grue_backward = nn.GRUCell(self.nvt, self.hs) # backward encoder GRU
self.fc1 = nn.Linear(self.gs, self.nz) # latent mean
self.fc2 = nn.Linear(self.gs, self.nz) # latent logvar
# 2. gate-related
self.gate_forward = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.gate_backward = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.mapper_forward = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
) # disable bias to ensure padded zeros also mapped to zeros
self.mapper_backward = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
)
# 3. bidir-related, to unify sizes
if self.bidir:
self.hv_unify = nn.Sequential(
nn.Linear(self.hs * 2, self.hs),
)
self.hg_unify = nn.Sequential(
nn.Linear(self.gs * 2, self.gs),
)
# 4. other
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.logsoftmax1 = nn.LogSoftmax(1)
# 6. predictor
np = self.gs
self.intra_setpool = SetPool(dim_input=512,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.inter_setpool = SetPool(dim_input=self.nz,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.set_fc = nn.Sequential(
nn.Linear(512, self.nz),
nn.ReLU())
input_dim = 0
if 'D' in self.input_type:
input_dim += self.nz
if 'G' in self.input_type:
input_dim += self.nz
self.pred_fc = nn.Sequential(
nn.Linear(input_dim, self.hs),
nn.Tanh(),
nn.Linear(self.hs, 1)
)
self.mseloss = nn.MSELoss(reduction='sum')
def predict(self, D_mu, G_mu):
input_vec = []
if 'D' in self.input_type:
input_vec.append(D_mu)
if 'G' in self.input_type:
input_vec.append(G_mu)
input_vec = torch.cat(input_vec, dim=1)
return self.pred_fc(input_vec)
def get_device(self):
if self.device is None:
self.device = next(self.parameters()).device
return self.device
def _get_zeros(self, n, length):
return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state
def _get_zero_hidden(self, n=1):
return self._get_zeros(n, self.hs) # get a zero hidden state
def _one_hot(self, idx, length):
if type(idx) in [list, range]:
if idx == []:
return None
idx = torch.LongTensor(idx).unsqueeze(0).t()
x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
else:
idx = torch.LongTensor([idx]).unsqueeze(0)
x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
return x
def _gated(self, h, gate, mapper):
return gate(h) * mapper(h)
def _collate_fn(self, G):
return [g.copy() for g in G]
def _propagate_to(self, G, v, propagator, H=None, reverse=False, gate=None, mapper=None):
# propagate messages to vertex index v for all graphs in G
# return the new messages (states) at v
G = [g for g in G if g.vcount() > v]
if len(G) == 0:
return
if H is not None:
idx = [i for i, g in enumerate(G) if g.vcount() > v]
H = H[idx]
v_types = [g.vs[v]['type'] for g in G]
X = self._one_hot(v_types, self.nvt)
if reverse:
H_name = 'H_backward' # name of the hidden states attribute
H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G]
if self.vid:
vids = [self._one_hot(g.successors(v), self.max_n) for g in G]
gate, mapper = self.gate_backward, self.mapper_backward
else:
H_name = 'H_forward' # name of the hidden states attribute
H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
if self.vid:
vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
if gate is None:
gate, mapper = self.gate_forward, self.mapper_forward
if self.vid:
H_pred = [[torch.cat([x[i], y[i:i + 1]], 1) for i in range(len(x))] for x, y in zip(H_pred, vids)]
# if h is not provided, use gated sum of v's predecessors' states as the input hidden state
if H is None:
max_n_pred = max([len(x) for x in H_pred]) # maximum number of predecessors
if max_n_pred == 0:
H = self._get_zero_hidden(len(G))
else:
H_pred = [torch.cat(h_pred +
[self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0)
for h_pred in H_pred] # pad all to same length
H_pred = torch.cat(H_pred, 0) # batch * max_n_pred * vs
H = self._gated(H_pred, gate, mapper).sum(1) # batch * hs
Hv = propagator(X, H)
for i, g in enumerate(G):
g.vs[v][H_name] = Hv[i:i + 1]
return Hv
def _propagate_from(self, G, v, propagator, H0=None, reverse=False):
# perform a series of propagation_to steps starting from v following a topo order
# assume the original vertex indices are in a topological order
if reverse:
prop_order = range(v, -1, -1)
else:
prop_order = range(v, self.max_n)
Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse) # the initial vertex
for v_ in prop_order[1:]:
self._propagate_to(G, v_, propagator, reverse=reverse)
return Hv
def _get_graph_state(self, G, decode=False):
# get the graph states
# when decoding, use the last generated vertex's state as the graph state
# when encoding, use the ending vertex state or unify the starting and ending vertex states
Hg = []
for g in G:
hg = g.vs[g.vcount() - 1]['H_forward']
if self.bidir and not decode: # decoding never uses backward propagation
hg_b = g.vs[0]['H_backward']
hg = torch.cat([hg, hg_b], 1)
Hg.append(hg)
Hg = torch.cat(Hg, 0)
if self.bidir and not decode:
Hg = self.hg_unify(Hg)
return Hg
def set_encode(self, X):
proto_batch = []
for x in X:
cls_protos = self.intra_setpool(
x.view(-1, self.num_sample, 512)).squeeze(1)
proto_batch.append(
self.inter_setpool(cls_protos.unsqueeze(0)))
v = torch.stack(proto_batch).squeeze()
return v
def graph_encode(self, G):
# encode graphs G into latent vectors
if type(G) != list:
G = [G]
self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
reverse=False)
if self.bidir:
self._propagate_from(G, self.max_n - 1, self.grue_backward,
H0=self._get_zero_hidden(len(G)), reverse=True)
Hg = self._get_graph_state(G)
mu = self.fc1(Hg)
#logvar = self.fc2(Hg)
return mu #, logvar
def reparameterize(self, mu, logvar, eps_scale=0.01):
# return z ~ N(mu, std)
if self.training:
std = logvar.mul(0.5).exp_()
eps = torch.randn_like(std) * eps_scale
return eps.mul(std).add_(mu)
else:
return mu

View File

@@ -0,0 +1,158 @@
import numpy as np
import torchvision.models as models
import torchvision.datasets as dset
import os
import torch
import argparse
import random
import torchvision.transforms as transforms
import os, sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
from PIL import Image
parser = argparse.ArgumentParser("sota")
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
parser.add_argument('--data-path', type=str, default='data', help='the path of save directory')
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
parser.add_argument('--seed', type=int, default=-1, help='random seed')
args = parser.parse_args()
if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
np.random.seed(args.seed)
random.seed(args.seed)
# remove last fully-connected layer
model = models.resnet18(pretrained=True).eval()
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])
def get_transform(dataset):
if args.dataset == 'mnist':
mean, std = [0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081]
elif args.dataset == 'svhn':
mean, std = [0.4376821, 0.4437697, 0.47280442], [0.19803012, 0.20101562, 0.19703614]
elif args.dataset == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif args.dataset == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif args.dataset == 'imagenet32':
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [66.22, 64.20, 67.86]]
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
if dataset == 'mnist':
transform.transforms.append(transforms.Lambda(lambda x: x.repeat(3, 1, 1)))
return transform
def process(dataset, n_classes):
data_label = {i: [] for i in range(n_classes)}
for x, y in dataset:
data_label[y].append(x)
for i in range(n_classes):
data_label[i] = torch.stack(data_label[i])
holder = {i: [] for i in range(n_classes)}
for i in range(n_classes):
with torch.no_grad():
data = feature_extractor(data_label[i])
holder[i].append(data.squeeze())
return holder
class ImageNet32(object):
train_list = [
['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
['train_data_batch_10', '8f03f34ac4b42271a294f91bf480f29b'],
]
valid_list = [
['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
]
def __init__(self, root, n_class, transform):
self.transform = transform
downloaded_list = self.train_list
self.n_class = n_class
self.data_label = {i: [] for i in range(n_class)}
self.data = []
self.targets = []
for i, (file_name, checksum) in enumerate(downloaded_list):
file_path = os.path.join(root, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
for j, k in enumerate(entry['labels']):
self.data_label[k - 1].append(entry['data'][j])
for i in range(n_class):
self.data_label[i] = np.vstack(self.data_label[i]).reshape(-1, 3, 32, 32)
self.data_label[i] = self.data_label[i].transpose((0, 2, 3, 1)) # convert to HWC
def get(self, use_num_cls, max_num=None):
assert isinstance(use_num_cls, list) \
and len(use_num_cls) > 0 and len(use_num_cls) < self.n_class, \
'invalid use_num_cls : {:}'.format(use_num_cls)
new_data, new_targets = [], []
for i in use_num_cls:
new_data.append(self.data_label[i][:max_num] if max_num is not None else self.data_label[i])
new_targets.extend([i] * max_num if max_num is not None
else [i] * len(self.data_label[i]))
self.data = np.concatenate(new_data)
self.targets = new_targets
imgs = []
for img in self.data:
img = Image.fromarray(img)
img = self.transform(img)
with torch.no_grad():
imgs.append(feature_extractor(img.unsqueeze(0)).squeeze().unsqueeze(0))
return torch.cat(imgs)
if __name__ == '__main__':
ncls = {'mnist': 10, 'svhn': 10, 'cifar10': 10, 'cifar100': 100, 'imagenet32': 1000}
transform = get_transform(args.dataset)
if args.dataset == 'imagenet32':
imgnet32 = ImageNet32(args.data, ncls[args.dataset], transform)
data_label = {i: [] for i in range(1000)}
for i in range(1000):
m = imgnet32.get([i])
data_label[i].append(m)
if i % 10 == 0:
print(f'Currently saving features of {i}-th class')
torch.save(data_label, f'{args.save_path}/{args.dataset}bylabel.pt')
else:
if args.dataset == 'mnist':
data = dset.MNIST(args.data_path, train=True, transform=transform, download=True)
elif args.dataset == 'svhn':
data = dset.SVHN(args.data_path, split='train', transform=transform, download=True)
elif args.dataset == 'cifar10':
data = dset.CIFAR10(args.data_path, train=True, transform=transform, download=True)
elif args.dataset == 'cifar100':
data = dset.CIFAR100(args.data_path, train=True, transform=transform, download=True)
dataset = process(data, ncls[args.dataset])
torch.save(dataset, f'{args.save_path}/{args.dataset}bylabel.pt')

View File

@@ -0,0 +1,37 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from transfer_nag_lib.MetaD2A_mobilenetV3.set_encoder.setenc_modules import *
class SetPool(nn.Module):
def __init__(self, dim_input, num_outputs, dim_output,
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
super(SetPool, self).__init__()
if 'sab' in mode: # [32, 400, 128]
self.enc = nn.Sequential(
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
else: # [32, 400, 128]
self.enc = nn.Sequential(
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
if 'PF' in mode: #[32, 1, 501]
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
nn.Linear(dim_hidden, dim_output))
elif 'P' in mode:
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
else: #torch.Size([32, 1, 501])
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
nn.Linear(dim_hidden, dim_output))
# "", sm, sab, sabsm
def forward(self, X):
x1 = self.enc(X)
x2 = self.dec(x1)
return x2

Some files were not shown because too many files have changed in this diff Show More