first commit

This commit is contained in:
CownowAn
2024-03-15 14:38:51 +00:00
commit bc2ed1304f
321 changed files with 44802 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
######################################################################################
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################

View File

@@ -0,0 +1,401 @@
from __future__ import print_function
import os
import math
import warnings
import numpy as np
# from timm.data.transforms import _pil_interp
from timm.data.auto_augment import rand_augment_transform
import torch.utils.data
import torchvision.transforms as transforms
from torchvision.datasets.folder import default_loader
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
def make_dataset(dir, image_ids, targets):
assert(len(image_ids) == len(targets))
images = []
dir = os.path.expanduser(dir)
for i in range(len(image_ids)):
item = (os.path.join(dir, 'data', 'images',
'%s.jpg' % image_ids[i]), targets[i])
images.append(item)
return images
def find_classes(classes_file):
# read classes file, separating out image IDs and class names
image_ids = []
targets = []
f = open(classes_file, 'r')
for line in f:
split_line = line.split(' ')
image_ids.append(split_line[0])
targets.append(' '.join(split_line[1:]))
f.close()
# index class names
classes = np.unique(targets)
class_to_idx = {classes[i]: i for i in range(len(classes))}
targets = [class_to_idx[c] for c in targets]
return (image_ids, targets, classes, class_to_idx)
class FGVCAircraft(torch.utils.data.Dataset):
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
Args:
root (string): Root directory path to dataset.
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g. ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in the root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
class_types = ('variant', 'family', 'manufacturer')
splits = ('train', 'val', 'trainval', 'test')
def __init__(self, root, class_type='variant', split='train', transform=None,
target_transform=None, loader=default_loader, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
if class_type not in self.class_types:
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
class_type, ', '.join(self.class_types),
))
self.root = os.path.expanduser(root)
self.class_type = class_type
self.split = split
self.classes_file = os.path.join(self.root, 'data',
'images_%s_%s.txt' % (self.class_type, self.split))
if download:
self.download()
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
samples = make_dataset(self.root, image_ids, targets)
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.samples = samples
self.classes = classes
self.class_to_idx = class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def _check_exists(self):
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
os.path.exists(self.classes_file)
def download(self):
"""Download the FGVC-Aircraft data if it doesn't exist already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
print('Downloading %s ... (may take a few minutes)' % self.url)
parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
tar_name = self.url.rpartition('/')[-1]
tar_path = os.path.join(parent_dir, tar_name)
data = urllib.request.urlopen(self.url)
# download .tar.gz file
with open(tar_path, 'wb') as f:
f.write(data.read())
# extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
data_folder = tar_path.strip('.tar.gz')
print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
tar = tarfile.open(tar_path)
tar.extractall(parent_dir)
# if necessary, rename data folder to self.root
if not os.path.samefile(data_folder, self.root):
print('Renaming %s to %s ...' % (data_folder, self.root))
os.rename(data_folder, self.root)
# delete .tar.gz file
print('Deleting %s ...' % tar_path)
os.remove(tar_path)
print('Done!')
class FGVCAircraftDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'aircraft'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 100
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/Aircraft' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/Aircraft' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = FGVCAircraft(
root=self.train_path, split='trainval', download=True, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = FGVCAircraft(
root=self.valid_path, split='test', download=True, transform=_transforms)
return dataset
@property
def train_path(self):
return self.save_path
@property
def valid_path(self):
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.48933587508932375, 0.5183537408957618, 0.5387914411673883],
std=[0.22388883112804625, 0.21641635409388751, 0.24615605842636115])
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
if image_size is None:
image_size = self.image_size
# if print_log:
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
# (self.distort_color, self.resize_scale, image_size))
# if self.distort_color == 'torch':
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
# elif self.distort_color == 'tf':
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
# else:
# color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
img_size_min = min(image_size)
else:
resize_transform_class = transforms.RandomResizedCrop
img_size_min = image_size
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in [0.48933587508932375, 0.5183537408957618,
0.5387914411673883]]),
)
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
# if color_transform is not None:
# train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]
if __name__ == '__main__':
data = FGVCAircraft(root='/mnt/datastore/Aircraft',
split='trainval', download=True)
print(len(data.classes))
print(len(data.samples))

View File

@@ -0,0 +1,238 @@
"""
Taken from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
"""
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
class ImageNetPolicy(object):
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
Example:
>>> policy = ImageNetPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment ImageNet Policy"
class CIFAR10Policy(object):
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
Example:
>>> policy = CIFAR10Policy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> CIFAR10Policy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment CIFAR10 Policy"
class SVHNPolicy(object):
""" Randomly choose one of the best 25 Sub-policies on SVHN.
Example:
>>> policy = SVHNPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> SVHNPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment SVHN Policy"
class SubPolicy(object):
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10
}
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
func = {
"shearX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
return img

View File

@@ -0,0 +1,657 @@
import os
import math
import numpy as np
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class CIFAR10DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'cifar10'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 10
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/CIFAR' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/CIFAR' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = torchvision.datasets.CIFAR10(
root=self.valid_path, train=True, download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = torchvision.datasets.CIFAR10(
root=self.valid_path, train=False, download=False, transform=_transforms)
return dataset
@property
def train_path(self):
# return os.path.join(self.save_path, 'train')
return self.save_path
@property
def valid_path(self):
# return os.path.join(self.save_path, 'val')
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.data)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]
class CIFAR100DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'cifar100'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 100
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/CIFAR' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/CIFAR' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = torchvision.datasets.CIFAR100(
root=self.valid_path, train=True, download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = torchvision.datasets.CIFAR100(
root=self.valid_path, train=False, download=False, transform=_transforms)
return dataset
@property
def train_path(self):
# return os.path.join(self.save_path, 'train')
return self.save_path
@property
def valid_path(self):
# return os.path.join(self.save_path, 'val')
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.data)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]
class CINIC10DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'cinic10'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 10
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/CINIC10' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/CINIC10' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = torchvision.datasets.ImageFolder(self.train_path, transform=_transforms)
# dataset = torchvision.datasets.CIFAR10(
# root=self.valid_path, train=True, download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
dataset = torchvision.datasets.ImageFolder(self.valid_path, transform=_transforms)
# dataset = torchvision.datasets.CIFAR10(
# root=self.valid_path, train=False, download=False, transform=_transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train_and_valid')
# return self.save_path
@property
def valid_path(self):
return os.path.join(self.save_path, 'test')
# return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.47889522, 0.47227842, 0.43047404], std=[0.24205776, 0.23828046, 0.25874835])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,237 @@
import os
import warnings
import numpy as np
from timm.data.transforms import _pil_interp
from timm.data.auto_augment import rand_augment_transform
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class DTDDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'dtd'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 47
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/dtd' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/dtd' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.valid_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'valid')
@property
def normalize(self):
return transforms.Normalize(
mean=[0.5329876098715876, 0.474260843249454, 0.42627281899380676],
std=[0.26549755708788914, 0.25473554309855373, 0.2631728035662832])
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
if image_size is None:
image_size = self.image_size
# if print_log:
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
# (self.distort_color, self.resize_scale, image_size))
# if self.distort_color == 'torch':
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
# elif self.distort_color == 'tf':
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
# else:
# color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
img_size_min = min(image_size)
else:
resize_transform_class = transforms.RandomResizedCrop
img_size_min = image_size
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in [0.5329876098715876, 0.474260843249454,
0.42627281899380676]]),
)
aa_params['interpolation'] = _pil_interp('bicubic')
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
# if color_transform is not None:
# train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
# transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.Resize((image_size, image_size), interpolation=3),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,241 @@
import warnings
import os
import math
import numpy as np
import PIL
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class Flowers102DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=512, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
# warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
weights = self.make_weights_for_balanced_classes(
train_dataset.imgs, self.n_classes)
weights = torch.DoubleTensor(weights)
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
if valid_size is not None:
raise NotImplementedError("validation dataset not yet implemented")
# valid_dataset = self.valid_dataset(valid_transforms)
# self.train = train_loader_class(
# train_dataset, batch_size=train_batch_size, sampler=train_sampler,
# num_workers=n_worker, pin_memory=True)
# self.valid = torch.utils.data.DataLoader(
# valid_dataset, batch_size=test_batch_size,
# num_workers=n_worker, pin_memory=True)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'flowers102'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 102
@property
def save_path(self):
if self._save_path is None:
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
self._save_path = '/mnt/datastore/Flowers102' # home server
if not os.path.exists(self._save_path):
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
self._save_path = '/mnt/datastore/Flowers102' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
# def valid_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
# return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.test_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
# @property
# def valid_path(self):
# return os.path.join(self.save_path, 'train')
@property
def test_path(self):
return os.path.join(self.save_path, 'test')
@property
def normalize(self):
return transforms.Normalize(
mean=[0.5178361839861569, 0.4106749456881299, 0.32864167836880803],
std=[0.2972239085211309, 0.24976049135203868, 0.28533308036347665])
@staticmethod
def make_weights_for_balanced_classes(images, nclasses):
count = [0] * nclasses
# Counts per label
for item in images:
count[item[1]] += 1
weight_per_class = [0.] * nclasses
# Total number of images.
N = float(sum(count))
# super-sample the smaller classes.
for i in range(nclasses):
weight_per_class[i] = N / float(count[i])
weight = [0] * len(images)
# Calculate a weight per image.
for idx, val in enumerate(images):
weight[idx] = weight_per_class[val[1]]
return weight
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
transforms.RandomAffine(
45, translate=(0.4, 0.4), scale=(0.75, 1.5), shear=None, resample=PIL.Image.BILINEAR, fillcolor=0),
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
# transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,225 @@
import warnings
import os
import math
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class ImagenetDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'imagenet'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 1000
@property
def save_path(self):
if self._save_path is None:
# self._save_path = '/dataset/imagenet'
# self._save_path = '/usr/local/soft/temp-datastore/ILSVRC2012' # servers
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
if not os.path.exists(self._save_path):
# self._save_path = os.path.expanduser('~/dataset/imagenet')
# self._save_path = os.path.expanduser('/usr/local/soft/temp-datastore/ILSVRC2012')
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.valid_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'val')
@property
def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,237 @@
import os
import math
import warnings
import numpy as np
# from timm.data.transforms import _pil_interp
from timm.data.auto_augment import rand_augment_transform
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class OxfordIIITPetsDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'pets'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 37
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.valid_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'valid')
@property
def normalize(self):
return transforms.Normalize(
mean=[0.4828895122298728, 0.4448394893850807, 0.39566558230789783],
std=[0.25925664613996574, 0.2532760018681693, 0.25981017205097917])
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
if image_size is None:
image_size = self.image_size
# if print_log:
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
# (self.distort_color, self.resize_scale, image_size))
# if self.distort_color == 'torch':
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
# elif self.distort_color == 'tf':
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
# else:
# color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
img_size_min = min(image_size)
else:
resize_transform_class = transforms.RandomResizedCrop
img_size_min = image_size
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in [0.4828895122298728, 0.4448394893850807,
0.39566558230789783]]),
)
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
# if color_transform is not None:
# train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,69 @@
import torch
from glob import glob
from torch.utils.data.dataset import Dataset
import os
from PIL import Image
def load_image(filename):
img = Image.open(filename)
img = img.convert('RGB')
return img
class PetDataset(Dataset):
def __init__(self, root, train=True, num_cl=37, val_split=0.15, transforms=None):
pt_name = os.path.join(root, '{}{}.pth'.format('train' if train else 'test',
int(100 * (1 - val_split)) if train else int(
100 * val_split)))
if not os.path.exists(pt_name):
filenames = glob(os.path.join(root, 'images') + '/*.jpg')
classes = set()
data = []
labels = []
for image in filenames:
class_name = image.rsplit("/", 1)[1].rsplit('_', 1)[0]
classes.add(class_name)
img = load_image(image)
data.append(img)
labels.append(class_name)
# convert classnames to indices
class2idx = {cl: idx for idx, cl in enumerate(classes)}
labels = torch.Tensor(list(map(lambda x: class2idx[x], labels))).long()
data = list(zip(data, labels))
class_values = [[] for x in range(num_cl)]
# create arrays for each class type
for d in data:
class_values[d[1].item()].append(d)
train_data = []
val_data = []
for class_dp in class_values:
split_idx = int(len(class_dp) * (1 - val_split))
train_data += class_dp[:split_idx]
val_data += class_dp[split_idx:]
torch.save(train_data, os.path.join(root, 'train{}.pth'.format(int(100 * (1 - val_split)))))
torch.save(val_data, os.path.join(root, 'test{}.pth'.format(int(100 * val_split))))
self.data = torch.load(pt_name)
self.len = len(self.data)
self.transform = transforms
def __getitem__(self, index):
img, label = self.data[index]
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return self.len

View File

@@ -0,0 +1,226 @@
import os
import math
import numpy as np
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class STL10DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'stl10'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 10
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/STL10' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/STL10' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = torchvision.datasets.STL10(
root=self.valid_path, split='train', download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = torchvision.datasets.STL10(
root=self.valid_path, split='test', download=False, transform=_transforms)
return dataset
@property
def train_path(self):
# return os.path.join(self.save_path, 'train')
return self.save_path
@property
def valid_path(self):
# return os.path.join(self.save_path, 'val')
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.44671097, 0.4398105, 0.4066468],
std=[0.2603405, 0.25657743, 0.27126738])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.data)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,4 @@
from ofa.imagenet_codebase.networks.proxyless_nets import ProxylessNASNets, proxyless_base, MobileNetV2
from ofa.imagenet_codebase.networks.mobilenet_v3 import MobileNetV3, MobileNetV3Large
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks.nsganetv2 import NSGANetV2

View File

@@ -0,0 +1,126 @@
from timm.models.layers import drop_path
from ofa.imagenet_codebase.modules.layers import *
from ofa.imagenet_codebase.networks import MobileNetV3
class MobileInvertedResidualBlock(MyModule):
"""
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
imagenet_codebase/networks/proxyless_nets.py to include drop path in training
"""
def __init__(self, mobile_inverted_conv, shortcut, drop_connect_rate=0.0):
super(MobileInvertedResidualBlock, self).__init__()
self.mobile_inverted_conv = mobile_inverted_conv
self.shortcut = shortcut
self.drop_connect_rate = drop_connect_rate
def forward(self, x):
if self.mobile_inverted_conv is None or isinstance(self.mobile_inverted_conv, ZeroLayer):
res = x
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
res = self.mobile_inverted_conv(x)
else:
# res = self.mobile_inverted_conv(x) + self.shortcut(x)
res = self.mobile_inverted_conv(x)
if self.drop_connect_rate > 0.:
res = drop_path(res, drop_prob=self.drop_connect_rate, training=self.training)
res += self.shortcut(x)
return res
@property
def module_str(self):
return '(%s, %s)' % (
self.mobile_inverted_conv.module_str if self.mobile_inverted_conv is not None else None,
self.shortcut.module_str if self.shortcut is not None else None
)
@property
def config(self):
return {
'name': MobileInvertedResidualBlock.__name__,
'mobile_inverted_conv': self.mobile_inverted_conv.config if self.mobile_inverted_conv is not None else None,
'shortcut': self.shortcut.config if self.shortcut is not None else None,
}
@staticmethod
def build_from_config(config):
mobile_inverted_conv = set_layer_from_config(config['mobile_inverted_conv'])
shortcut = set_layer_from_config(config['shortcut'])
return MobileInvertedResidualBlock(
mobile_inverted_conv, shortcut, drop_connect_rate=config['drop_connect_rate'])
class NSGANetV2(MobileNetV3):
"""
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
imagenet_codebase/networks/mobilenet_v3.py to include drop path in training
and option to reset classification layer
"""
@staticmethod
def build_from_config(config, drop_connect_rate=0.0):
first_conv = set_layer_from_config(config['first_conv'])
final_expand_layer = set_layer_from_config(config['final_expand_layer'])
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
classifier = set_layer_from_config(config['classifier'])
blocks = []
for block_idx, block_config in enumerate(config['blocks']):
block_config['drop_connect_rate'] = drop_connect_rate * block_idx / len(config['blocks'])
blocks.append(MobileInvertedResidualBlock.build_from_config(block_config))
net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
if 'bn' in config:
net.set_bn_param(**config['bn'])
else:
net.set_bn_param(momentum=0.1, eps=1e-3)
return net
def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, MobileInvertedResidualBlock):
if isinstance(m.mobile_inverted_conv, MBInvertedConvLayer) and isinstance(m.shortcut, IdentityLayer):
m.mobile_inverted_conv.point_linear.bn.weight.data.zero_()
@staticmethod
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
# first conv layer
first_conv = ConvLayer(
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act'
)
# build mobile blocks
feature_dim = input_channel
blocks = []
for stage_id, block_config_list in cfg.items():
for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list:
mb_conv = MBInvertedConvLayer(
feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se
)
if stride == 1 and out_channel == feature_dim:
shortcut = IdentityLayer(out_channel, out_channel)
else:
shortcut = None
blocks.append(MobileInvertedResidualBlock(mb_conv, shortcut))
feature_dim = out_channel
# final expand layer
final_expand_layer = ConvLayer(
feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act',
)
feature_dim = feature_dim * 6
# feature mix layer
feature_mix_layer = ConvLayer(
feature_dim, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
)
# classifier
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
@staticmethod
def reset_classifier(model, last_channel, n_classes, dropout_rate=0.0):
model.classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)

View File

@@ -0,0 +1,309 @@
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.imagenet import *
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.cifar import *
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets import *
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import *
from ofa.imagenet_codebase.run_manager.run_manager import *
class ImagenetRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=1e-4, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='imagenet', train_batch_size=128, test_batch_size=512, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/ILSVRC2012',
**kwargs):
super(ImagenetRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.imagenet_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == ImagenetDataProvider.name():
DataProviderClass = ImagenetDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.imagenet_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class CIFARRunConfig(RunConfig):
def __init__(self, n_epochs=5, init_lr=0.01, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='cifar10', train_batch_size=96, test_batch_size=256, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224,
data_path='/mnt/datastore/CIFAR',
**kwargs):
super(CIFARRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.cifar_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == CIFAR10DataProvider.name():
DataProviderClass = CIFAR10DataProvider
elif self.dataset == CIFAR100DataProvider.name():
DataProviderClass = CIFAR100DataProvider
elif self.dataset == CINIC10DataProvider.name():
DataProviderClass = CINIC10DataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.cifar_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class Flowers102RunConfig(RunConfig):
def __init__(self, n_epochs=3, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='flowers102', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
data_path='/mnt/datastore/Flowers102',
**kwargs):
super(Flowers102RunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.flowers102_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == Flowers102DataProvider.name():
DataProviderClass = Flowers102DataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.flowers102_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class STL10RunConfig(RunConfig):
def __init__(self, n_epochs=5, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='stl10', train_batch_size=96, test_batch_size=256, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
data_path='/mnt/datastore/STL10',
**kwargs):
super(STL10RunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.stl10_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == STL10DataProvider.name():
DataProviderClass = STL10DataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.stl10_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class DTDRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='dtd', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/dtd',
**kwargs):
super(DTDRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == DTDDataProvider.name():
DataProviderClass = DTDDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class PetsRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='pets', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/Oxford-IIITPets',
**kwargs):
super(PetsRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.imagenet_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == OxfordIIITPetsDataProvider.name():
DataProviderClass = OxfordIIITPetsDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.imagenet_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class AircraftRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='aircraft', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/Aircraft',
**kwargs):
super(AircraftRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == FGVCAircraftDataProvider.name():
DataProviderClass = FGVCAircraftDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
def get_run_config(**kwargs):
if kwargs['dataset'] == 'imagenet':
run_config = ImagenetRunConfig(**kwargs)
elif kwargs['dataset'].startswith('cifar') or kwargs['dataset'].startswith('cinic'):
run_config = CIFARRunConfig(**kwargs)
elif kwargs['dataset'] == 'flowers102':
run_config = Flowers102RunConfig(**kwargs)
elif kwargs['dataset'] == 'stl10':
run_config = STL10RunConfig(**kwargs)
elif kwargs['dataset'] == 'dtd':
run_config = DTDRunConfig(**kwargs)
elif kwargs['dataset'] == 'pets':
run_config = PetsRunConfig(**kwargs)
elif kwargs['dataset'] == 'aircraft':
run_config = AircraftRunConfig(**kwargs)
elif kwargs['dataset'] == 'aircraft100':
run_config = AircraftRunConfig(**kwargs)
else:
raise NotImplementedError
return run_config

View File

@@ -0,0 +1,122 @@
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
import torchvision.utils
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import FGVCAircraft
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets2 import PetDataset
import torch.utils.data as Data
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.autoaugment import CIFAR10Policy
def get_dataset(data_name, batch_size, data_path, num_workers,
img_size, autoaugment, cutout, cutout_length):
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'mnist': 10,
'aircraft': 100,
'svhn': 10,
'pets': 37
}
# 'aircraft30': 30,
# 'aircraft100': 100,
train_transform, valid_transform = _data_transforms(
data_name, img_size, autoaugment, cutout, cutout_length)
if data_name == 'cifar100':
train_data = torchvision.datasets.CIFAR100(
root=data_path, train=True, download=True, transform=train_transform)
valid_data = torchvision.datasets.CIFAR100(
root=data_path, train=False, download=True, transform=valid_transform)
elif data_name == 'cifar10':
train_data = torchvision.datasets.CIFAR10(
root=data_path, train=True, download=True, transform=train_transform)
valid_data = torchvision.datasets.CIFAR10(
root=data_path, train=False, download=True, transform=valid_transform)
elif data_name.startswith('aircraft'):
print(data_path)
if 'aircraft100' in data_path:
data_path = data_path.replace('aircraft100', 'aircraft/fgvc-aircraft-2013b')
else:
data_path = data_path.replace('aircraft', 'aircraft/fgvc-aircraft-2013b')
train_data = FGVCAircraft(data_path, class_type='variant', split='trainval',
transform=train_transform, download=True)
valid_data = FGVCAircraft(data_path, class_type='variant', split='test',
transform=valid_transform, download=True)
elif data_name.startswith('pets'):
train_data = PetDataset(data_path, train=True, num_cl=37,
val_split=0.15, transforms=train_transform)
valid_data = PetDataset(data_path, train=False, num_cl=37,
val_split=0.15, transforms=valid_transform)
else:
raise KeyError
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=batch_size, shuffle=True, pin_memory=True,
num_workers=num_workers)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=200, shuffle=False, pin_memory=True,
num_workers=num_workers)
return train_queue, valid_queue, num_class_dict[data_name]
class Cutout(object):
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def _data_transforms(data_name, img_size, autoaugment, cutout, cutout_length):
if 'cifar' in data_name:
norm_mean = [0.49139968, 0.48215827, 0.44653124]
norm_std = [0.24703233, 0.24348505, 0.26158768]
elif 'aircraft' in data_name:
norm_mean = [0.48933587508932375, 0.5183537408957618, 0.5387914411673883]
norm_std = [0.22388883112804625, 0.21641635409388751, 0.24615605842636115]
elif 'pets' in data_name:
norm_mean = [0.4828895122298728, 0.4448394893850807, 0.39566558230789783]
norm_std = [0.25925664613996574, 0.2532760018681693, 0.25981017205097917]
else:
raise KeyError
train_transform = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
transforms.RandomHorizontalFlip(),
])
if autoaugment:
train_transform.transforms.append(CIFAR10Policy())
train_transform.transforms.append(transforms.ToTensor())
if cutout:
train_transform.transforms.append(Cutout(cutout_length))
train_transform.transforms.append(transforms.Normalize(norm_mean, norm_std))
valid_transform = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
return train_transform, valid_transform

View File

@@ -0,0 +1,233 @@
import os
import torch
import numpy as np
import random
import sys
import transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.run_manager import get_run_config
from ofa.elastic_nn.networks import OFAMobileNetV3
from ofa.imagenet_codebase.run_manager import RunManager
from ofa.elastic_nn.modules.dynamic_op import DynamicSeparableConv2d
from torchprofile import profile_macs
import copy
import json
import warnings
warnings.simplefilter("ignore")
DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = 1
class ArchManager:
def __init__(self):
self.num_blocks = 20
self.num_stages = 5
self.kernel_sizes = [3, 5, 7]
self.expand_ratios = [3, 4, 6]
self.depths = [2, 3, 4]
self.resolutions = [160, 176, 192, 208, 224]
def random_sample(self):
sample = {}
d = []
e = []
ks = []
for i in range(self.num_stages):
d.append(random.choice(self.depths))
for i in range(self.num_blocks):
e.append(random.choice(self.expand_ratios))
ks.append(random.choice(self.kernel_sizes))
sample = {
'wid': None,
'ks': ks,
'e': e,
'd': d,
'r': [random.choice(self.resolutions)]
}
return sample
def random_resample(self, sample, i):
assert i >= 0 and i < self.num_blocks
sample['ks'][i] = random.choice(self.kernel_sizes)
sample['e'][i] = random.choice(self.expand_ratios)
def random_resample_depth(self, sample, i):
assert i >= 0 and i < self.num_stages
sample['d'][i] = random.choice(self.depths)
def random_resample_resolution(self, sample):
sample['r'][0] = random.choice(self.resolutions)
def parse_string_list(string):
if isinstance(string, str):
# convert '[5 5 5 7 7 7 3 3 7 7 7 3 3]' to [5, 5, 5, 7, 7, 7, 3, 3, 7, 7, 7, 3, 3]
return list(map(int, string[1:-1].split()))
else:
return string
def pad_none(x, depth, max_depth):
new_x, counter = [], 0
for d in depth:
for _ in range(d):
new_x.append(x[counter])
counter += 1
if d < max_depth:
new_x += [None] * (max_depth - d)
return new_x
def get_net_info(net, data_shape, measure_latency=None, print_info=True, clean=False, lut=None):
net_info = eval_utils.get_net_info(
net, data_shape, measure_latency, print_info=print_info, clean=clean, lut=lut)
gpu_latency, cpu_latency = None, None
for k in net_info.keys():
if 'gpu' in k:
gpu_latency = np.round(net_info[k]['val'], 2)
if 'cpu' in k:
cpu_latency = np.round(net_info[k]['val'], 2)
return {
'params': np.round(net_info['params'] / 1e6, 2),
'flops': np.round(net_info['flops'] / 1e6, 2),
'gpu': gpu_latency, 'cpu': cpu_latency
}
def validate_config(config, max_depth=4):
kernel_size, exp_ratio, depth = config['ks'], config['e'], config['d']
if isinstance(kernel_size, str): kernel_size = parse_string_list(kernel_size)
if isinstance(exp_ratio, str): exp_ratio = parse_string_list(exp_ratio)
if isinstance(depth, str): depth = parse_string_list(depth)
assert (isinstance(kernel_size, list) or isinstance(kernel_size, int))
assert (isinstance(exp_ratio, list) or isinstance(exp_ratio, int))
assert isinstance(depth, list)
if len(kernel_size) < len(depth) * max_depth:
kernel_size = pad_none(kernel_size, depth, max_depth)
if len(exp_ratio) < len(depth) * max_depth:
exp_ratio = pad_none(exp_ratio, depth, max_depth)
# return {'ks': kernel_size, 'e': exp_ratio, 'd': depth, 'w': config['w']}
return {'ks': kernel_size, 'e': exp_ratio, 'd': depth}
def set_nas_test_dataset(path, test_data_name, max_img):
if not test_data_name in ['mnist', 'svhn', 'cifar10',
'cifar100', 'aircraft', 'pets']: raise ValueError(test_data_name)
dpath = path
num_cls = 10 # mnist, svhn, cifar10
if test_data_name in ['cifar100', 'aircraft']:
num_cls = 100
elif test_data_name == 'pets':
num_cls = 37
x = torch.load(dpath + f'/{test_data_name}bylabel')
img_per_cls = min(int(max_img / num_cls), 20)
return x, img_per_cls, num_cls
class OFAEvaluator:
""" based on OnceForAll supernet taken from https://github.com/mit-han-lab/once-for-all """
def __init__(self, num_gen_arch, img_size, drop_path,
n_classes=1000,
model_path=None,
kernel_size=None, exp_ratio=None, depth=None):
# default configurations
self.kernel_size = [3, 5, 7] if kernel_size is None else kernel_size # depth-wise conv kernel size
self.exp_ratio = [3, 4, 6] if exp_ratio is None else exp_ratio # expansion rate
self.depth = [2, 3, 4] if depth is None else depth # number of MB block repetition
if 'w1.0' in model_path:
self.width_mult = 1.0
elif 'w1.2' in model_path:
self.width_mult = 1.2
else:
raise ValueError
self.engine = OFAMobileNetV3(
n_classes=n_classes,
dropout_rate=0, width_mult_list=self.width_mult, ks_list=self.kernel_size,
expand_ratio_list=self.exp_ratio, depth_list=self.depth)
init = torch.load(model_path, map_location='cpu')['state_dict']
self.engine.load_weights_from_net(init)
print(f'load {model_path}...')
## metad2a
self.arch_manager = ArchManager()
self.num_gen_arch = num_gen_arch
def sample_random_architecture(self):
sampled_architecture = self.arch_manager.random_sample()
return sampled_architecture
def get_architecture(self, bound=None):
g_lst, pred_acc_lst, x_lst = [], [], []
searched_g, max_pred_acc = None, 0
with torch.no_grad():
for n in range(self.num_gen_arch):
file_acc = self.lines[n].split()[0]
g_dict = ' '.join(self.lines[n].split())
g = json.loads(g_dict.replace("'", "\""))
if bound is not None:
subnet, config = self.sample(config=g)
net = NSGANetV2.build_from_config(subnet.config,
drop_connect_rate=self.drop_path)
inputs = torch.randn(1, 3, self.img_size, self.img_size)
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
if flops <= bound:
searched_g = g
break
else:
searched_g = g
pred_acc_lst.append(file_acc)
break
if searched_g is None:
raise ValueError(searched_g)
return searched_g, pred_acc_lst
def sample(self, config=None):
""" randomly sample a sub-network """
if config is not None:
config = validate_config(config)
self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d'])
else:
config = self.engine.sample_active_subnet()
subnet = self.engine.get_active_subnet(preserve_weight=True)
return subnet, config
@staticmethod
def save_net_config(path, net, config_name='net.config'):
""" dump run_config and net_config to the model_folder """
net_save_path = os.path.join(path, config_name)
json.dump(net.config, open(net_save_path, 'w'), indent=4)
print('Network configs dump to %s' % net_save_path)
@staticmethod
def save_net(path, net, model_name):
""" dump net weight as checkpoint """
if isinstance(net, torch.nn.DataParallel):
checkpoint = {'state_dict': net.module.state_dict()}
else:
checkpoint = {'state_dict': net.state_dict()}
model_path = os.path.join(path, model_name)
torch.save(checkpoint, model_path)
print('Network model dump to %s' % model_path)

View File

@@ -0,0 +1,169 @@
import os
import sys
import json
import logging
import numpy as np
import copy
import torch
import torch.nn as nn
import random
import torch.optim as optim
from evaluator import OFAEvaluator
from torchprofile import profile_macs
from codebase.networks import NSGANetV2
from parser import get_parse
from eval_utils import get_dataset
args = get_parse()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device_list = [int(_) for _ in args.gpu.split(',')]
args.n_gpus = len(device_list)
args.device = torch.device("cuda:0")
if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
evaluator = OFAEvaluator(args,
model_path='../.torch/ofa_nets/ofa_mbv3_d234_e346_k357_w1.0')
args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
if args.model_config.startswith('flops@'):
args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
else:
args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
args.data_path = os.path.join(args.data_path, args.data_name)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
if not torch.cuda.is_available():
logging.info('no gpu self.args.device available')
sys.exit(1)
logging.info("args = %s", args)
def set_architecture(n_cls):
if args.model_config.startswith('flops@'):
names = {'cifar10': 'CIFAR-10', 'cifar100': 'CIFAR-100',
'aircraft100': 'Aircraft', 'pets': 'Pets'}
p = os.path.join('./searched-architectures/{}/net-{}/net.subnet'.
format(names[args.data_name], args.model_config))
g = json.load(open(p))
else:
g, acc = evaluator.get_architecture(args)
subnet, config = evaluator.sample(g)
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=args.drop_path)
net.load_state_dict(subnet.state_dict())
NSGANetV2.reset_classifier(
net, last_channel=net.classifier.in_features,
n_classes=n_cls, dropout_rate=args.drop)
# calculate #Paramaters and #FLOPS
inputs = torch.randn(1, 3, args.img_size, args.img_size)
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
net_name = "net_flops@{:.0f}".format(flops)
logging.info('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
OFAEvaluator.save_net_config(args.save_path, net, net_name + '.config')
if args.n_gpus > 1:
net = nn.DataParallel(net) # data parallel in case more than 1 gpu available
net = net.to(args.device)
return net, net_name
def train(train_queue, net, criterion, optimizer):
net.train()
train_loss, correct, total = 0, 0, 0
for step, (inputs, targets) in enumerate(train_queue):
# upsample by bicubic to match imagenet training size
inputs, targets = inputs.to(args.device), targets.to(args.device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip)
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % args.report_freq == 0:
logging.info('train %03d %e %f', step, train_loss / total, 100. * correct / total)
logging.info('train acc %f', 100. * correct / total)
return train_loss / total, 100. * correct / total
def infer(valid_queue, net, criterion, early_stop=False):
net.eval()
test_loss, correct, total = 0, 0, 0
with torch.no_grad():
for step, (inputs, targets) in enumerate(valid_queue):
inputs, targets = inputs.to(args.device), targets.to(args.device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % args.report_freq == 0:
logging.info('valid %03d %e %f', step, test_loss / total, 100. * correct / total)
if early_stop and step == 10:
break
acc = 100. * correct / total
logging.info('valid acc %f', 100. * correct / total)
return test_loss / total, acc
def main():
best_acc, top_checkpoints = 0, []
train_queue, valid_queue, n_cls = get_dataset(args)
net, net_name = set_architecture(n_cls)
parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
criterion = nn.CrossEntropyLoss().to(args.device)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
for epoch in range(args.epochs):
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
train(train_queue, net, criterion, optimizer)
_, valid_acc = infer(valid_queue, net, criterion)
# checkpoint saving
if len(top_checkpoints) < args.topk:
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
else:
idx = np.argmin([x[1] for x in top_checkpoints])
if valid_acc > top_checkpoints[idx][1]:
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
# remove the idx
os.remove(top_checkpoints[idx][0])
top_checkpoints.pop(idx)
print(top_checkpoints)
if valid_acc > best_acc:
OFAEvaluator.save_net(args.save_path, net, net_name + '.best')
best_acc = valid_acc
scheduler.step()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,43 @@
import argparse
def get_parse():
parser = argparse.ArgumentParser(description='MetaD2A vs NSGANETv2')
parser.add_argument('--save-path', type=str, default='../results', help='the path of save directory')
parser.add_argument('--data-path', type=str, default='../data', help='the path of save directory')
parser.add_argument('--data-name', type=str, default=None, help='meta-test dataset name')
parser.add_argument('--num-gen-arch', type=int, default=200,
help='the number of candidate architectures generated by the generator')
parser.add_argument('--bound', type=int, default=None)
# original setting
parser.add_argument('--seed', type=int, default=-1, help='random seed')
parser.add_argument('--batch-size', type=int, default=96, help='batch size')
parser.add_argument('--num_workers', type=int, default=2, help='number of workers for data loading')
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
parser.add_argument('--lr', type=float, default=0.01, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=4e-5, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--epochs', type=int, default=150, help='num of training epochs')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--autoaugment', action='store_true', default=True, help='use auto augmentation')
parser.add_argument('--topk', type=int, default=10, help='top k checkpoints to save')
parser.add_argument('--evaluate', action='store_true', default=False, help='evaluate a pretrained model')
# model related
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--model-config', type=str, default='search',
help='location of a json file of specific model declaration')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--drop', type=float, default=0.2,
help='dropout rate')
parser.add_argument('--drop-path', type=float, default=0.2, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--img-size', type=int, default=224,
help='input resolution (192 -> 256)')
args = parser.parse_args()
return args

View File

@@ -0,0 +1,261 @@
import os
import sys
import json
import logging
import numpy as np
import copy
import torch
import torch.nn as nn
import random
import torch.optim as optim
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.evaluator import OFAEvaluator
from torchprofile import profile_macs
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.parser import get_parse
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils import get_dataset
from transfer_nag_lib.MetaD2A_nas_bench_201.metad2a_utils import reset_seed
from transfer_nag_lib.ofa_net import OFASubNet
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
# device_list = [int(_) for _ in args.gpu.split(',')]
# args.n_gpus = len(device_list)
# args.device = torch.device("cuda:0")
# if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
# torch.cuda.manual_seed(args.seed)
# torch.manual_seed(args.seed)
# np.random.seed(args.seed)
# random.seed(args.seed)
# args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
# if args.model_config.startswith('flops@'):
# args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
# else:
# args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
# if not os.path.exists(args.save_path):
# os.makedirs(args.save_path)
# args.data_path = os.path.join(args.data_path, args.data_name)
# log_format = '%(asctime)s %(message)s'
# logging.basicConfig(stream=sys.stdout, level=print,
# format=log_format, datefmt='%m/%d %I:%M:%S %p')
# fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
# fh.setFormatter(logging.Formatter(log_format))
# logging.getLogger().addHandler(fh)
# if not torch.cuda.is_available():
# print('no gpu self.args.device available')
# sys.exit(1)
# print("args = %s", args)
def set_architecture(n_cls, evaluator, drop_path, drop, img_size, n_gpus, device, save_path, model_str):
# g, acc = evaluator.get_architecture(model_str)
g = OFASubNet(model_str).get_op_dict()
subnet, config = evaluator.sample(g)
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=drop_path)
net.load_state_dict(subnet.state_dict())
NSGANetV2.reset_classifier(
net, last_channel=net.classifier.in_features,
n_classes=n_cls, dropout_rate=drop)
# calculate #Paramaters and #FLOPS
inputs = torch.randn(1, 3, img_size, img_size)
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
net_name = "net_flops@{:.0f}".format(flops)
print('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
# OFAEvaluator.save_net_config(save_path, net, net_name + '.config')
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
net = nn.DataParallel(net)
net = net.to(device)
return net, net_name, params, flops
def train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq):
net.train()
train_loss, correct, total = 0, 0, 0
for step, (inputs, targets) in enumerate(train_queue):
# upsample by bicubic to match imagenet training size
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % report_freq == 0:
print(f'train step {step:03d} loss {train_loss / total:.4f} train acc {100. * correct / total:.4f}')
print(f'train acc {100. * correct / total:.4f}')
return train_loss / total, 100. * correct / total
def infer(valid_queue, net, criterion, device, report_freq, early_stop=False):
net.eval()
test_loss, correct, total = 0, 0, 0
with torch.no_grad():
for step, (inputs, targets) in enumerate(valid_queue):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % report_freq == 0:
print(f'valid {step:03d} {test_loss / total:.4f} {100. * correct / total:.4f}')
if early_stop and step == 10:
break
acc = 100. * correct / total
print('valid acc {:.4f}'.format(100. * correct / total))
return test_loss / total, acc
def train_single_model(save_path, workers, datasets, xpaths, splits, use_less,
seed, model_str, device,
lr=0.01,
momentum=0.9,
weight_decay=4e-5,
report_freq=50,
epochs=150,
grad_clip=5,
cutout=True,
cutout_length=16,
autoaugment=True,
drop=0.2,
drop_path=0.2,
img_size=224,
batch_size=96,
):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
reset_seed(seed)
# save_dir = Path(save_dir)
# logger = Logger(str(save_dir), 0, False)
os.makedirs(save_path, exist_ok=True)
to_save_name = save_path + '/seed-{:04d}.pth'.format(seed)
print(to_save_name)
# args = get_parse()
num_gen_arch = None
evaluator = OFAEvaluator(num_gen_arch, img_size, drop_path,
model_path='/home/data/GTAD/checkpoints/ofa/ofa_net/ofa_mbv3_d234_e346_k357_w1.0')
train_queue, valid_queue, n_cls = get_dataset(datasets, batch_size,
xpaths, workers, img_size, autoaugment, cutout, cutout_length)
net, net_name, params, flops = set_architecture(n_cls, evaluator,
drop_path, drop, img_size, n_gpus=1, device=device, save_path=save_path, model_str=model_str)
# net.to(device)
parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss().to(device)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
# assert epochs == 1
max_valid_acc = 0
max_epoch = 0
for epoch in range(epochs):
print('epoch {:d} lr {:.4f}'.format(epoch, scheduler.get_lr()[0]))
train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq)
_, valid_acc = infer(valid_queue, net, criterion, device, report_freq)
torch.save(valid_acc, to_save_name)
print(f'seed {seed:04d} last acc {valid_acc:.4f} max acc {max_valid_acc:.4f}')
if max_valid_acc < valid_acc:
max_valid_acc = valid_acc
max_epoch = epoch
# parent_path = os.path.abspath(os.path.join(save_path, os.pardir))
# with open(parent_path + '/accuracy.txt', 'a+') as f:
# f.write(f'{model_str} seed {seed:04d} {valid_acc:.4f}\n')
return valid_acc, max_valid_acc, params, flops
################ NAS BENCH 201 #####################
# def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
# seeds, model_str, arch_config):
# assert torch.cuda.is_available(), 'CUDA is not available.'
# torch.backends.cudnn.enabled = True
# torch.backends.cudnn.deterministic = True
# torch.set_num_threads(workers)
# save_dir = Path(save_dir)
# logger = Logger(str(save_dir), 0, False)
# if model_str in CellArchitectures:
# arch = CellArchitectures[model_str]
# logger.log(
# 'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
# else:
# try:
# arch = CellStructure.str2structure(model_str)
# except:
# raise ValueError(
# 'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
# assert arch.check_valid_op(get_search_spaces(
# 'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
# # assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
# logger.log('Start train-evaluate {:}'.format(arch.tostr()))
# logger.log('arch_config : {:}'.format(arch_config))
# start_time, seed_time = time.time(), AverageMeter()
# for _is, seed in enumerate(seeds):
# logger.log(
# '\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
# seed))
# to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
# if to_save_name.exists():
# logger.log(
# 'Find the existing file {:}, directly load!'.format(to_save_name))
# checkpoint = torch.load(to_save_name)
# else:
# logger.log(
# 'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
# checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
# seed, arch_config, workers, logger)
# torch.save(checkpoint, to_save_name)
# # log information
# logger.log('{:}'.format(checkpoint['info']))
# all_dataset_keys = checkpoint['all_dataset_keys']
# for dataset_key in all_dataset_keys:
# logger.log('\n{:} dataset : {:} {:}'.format(
# '-' * 15, dataset_key, '-' * 15))
# dataset_info = checkpoint[dataset_key]
# # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
# logger.log('Flops = {:} MB, Params = {:} MB'.format(
# dataset_info['flop'], dataset_info['param']))
# logger.log('config : {:}'.format(dataset_info['config']))
# logger.log('Training State (finish) = {:}'.format(
# dataset_info['finish-train']))
# last_epoch = dataset_info['total_epoch'] - 1
# train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
# valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
# # measure elapsed time
# seed_time.update(time.time() - start_time)
# start_time = time.time()
# need_time = 'Time Left: {:}'.format(convert_secs2time(
# seed_time.avg * (len(seeds) - _is - 1), True))
# logger.log(
# '\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
# need_time))
# logger.close()
# ###################
if __name__ == '__main__':
train_single_model()