first commit
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
@@ -0,0 +1,401 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import math
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
# from timm.data.transforms import _pil_interp
|
||||
from timm.data.auto_augment import rand_augment_transform
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision.datasets.folder import default_loader
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
def make_dataset(dir, image_ids, targets):
|
||||
assert(len(image_ids) == len(targets))
|
||||
images = []
|
||||
dir = os.path.expanduser(dir)
|
||||
for i in range(len(image_ids)):
|
||||
item = (os.path.join(dir, 'data', 'images',
|
||||
'%s.jpg' % image_ids[i]), targets[i])
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
|
||||
def find_classes(classes_file):
|
||||
# read classes file, separating out image IDs and class names
|
||||
image_ids = []
|
||||
targets = []
|
||||
f = open(classes_file, 'r')
|
||||
for line in f:
|
||||
split_line = line.split(' ')
|
||||
image_ids.append(split_line[0])
|
||||
targets.append(' '.join(split_line[1:]))
|
||||
f.close()
|
||||
|
||||
# index class names
|
||||
classes = np.unique(targets)
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
targets = [class_to_idx[c] for c in targets]
|
||||
|
||||
return (image_ids, targets, classes, class_to_idx)
|
||||
|
||||
|
||||
class FGVCAircraft(torch.utils.data.Dataset):
|
||||
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory path to dataset.
|
||||
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
|
||||
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
|
||||
transform (callable, optional): A function/transform that takes in a PIL image
|
||||
and returns a transformed version. E.g. ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
loader (callable, optional): A function to load an image given its path.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in the root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
|
||||
class_types = ('variant', 'family', 'manufacturer')
|
||||
splits = ('train', 'val', 'trainval', 'test')
|
||||
|
||||
def __init__(self, root, class_type='variant', split='train', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False):
|
||||
if split not in self.splits:
|
||||
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
|
||||
split, ', '.join(self.splits),
|
||||
))
|
||||
if class_type not in self.class_types:
|
||||
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
|
||||
class_type, ', '.join(self.class_types),
|
||||
))
|
||||
self.root = os.path.expanduser(root)
|
||||
self.class_type = class_type
|
||||
self.split = split
|
||||
self.classes_file = os.path.join(self.root, 'data',
|
||||
'images_%s_%s.txt' % (self.class_type, self.split))
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
|
||||
samples = make_dataset(self.root, image_ids, targets)
|
||||
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self.loader = loader
|
||||
|
||||
self.samples = samples
|
||||
self.classes = classes
|
||||
self.class_to_idx = class_to_idx
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
tuple: (sample, target) where target is class_index of the target class.
|
||||
"""
|
||||
|
||||
path, target = self.samples[index]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return sample, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
||||
fmt_str += ' Root Location: {}\n'.format(self.root)
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
tmp = ' Target Transforms (if any): '
|
||||
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
||||
|
||||
def _check_exists(self):
|
||||
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
|
||||
os.path.exists(self.classes_file)
|
||||
|
||||
def download(self):
|
||||
"""Download the FGVC-Aircraft data if it doesn't exist already."""
|
||||
from six.moves import urllib
|
||||
import tarfile
|
||||
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
|
||||
print('Downloading %s ... (may take a few minutes)' % self.url)
|
||||
|
||||
parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
|
||||
tar_name = self.url.rpartition('/')[-1]
|
||||
tar_path = os.path.join(parent_dir, tar_name)
|
||||
data = urllib.request.urlopen(self.url)
|
||||
|
||||
# download .tar.gz file
|
||||
with open(tar_path, 'wb') as f:
|
||||
f.write(data.read())
|
||||
|
||||
# extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
|
||||
data_folder = tar_path.strip('.tar.gz')
|
||||
print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
|
||||
tar = tarfile.open(tar_path)
|
||||
tar.extractall(parent_dir)
|
||||
|
||||
# if necessary, rename data folder to self.root
|
||||
if not os.path.samefile(data_folder, self.root):
|
||||
print('Renaming %s to %s ...' % (data_folder, self.root))
|
||||
os.rename(data_folder, self.root)
|
||||
|
||||
# delete .tar.gz file
|
||||
print('Deleting %s ...' % tar_path)
|
||||
os.remove(tar_path)
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
class FGVCAircraftDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'aircraft'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 100
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/Aircraft' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/Aircraft' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = FGVCAircraft(
|
||||
root=self.train_path, split='trainval', download=True, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = FGVCAircraft(
|
||||
root=self.valid_path, split='test', download=True, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.48933587508932375, 0.5183537408957618, 0.5387914411673883],
|
||||
std=[0.22388883112804625, 0.21641635409388751, 0.24615605842636115])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
# if print_log:
|
||||
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
# (self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
# if self.distort_color == 'torch':
|
||||
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
# elif self.distort_color == 'tf':
|
||||
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
# else:
|
||||
# color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
img_size_min = min(image_size)
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
img_size_min = image_size
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in [0.48933587508932375, 0.5183537408957618,
|
||||
0.5387914411673883]]),
|
||||
)
|
||||
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
|
||||
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
|
||||
|
||||
# if color_transform is not None:
|
||||
# train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = FGVCAircraft(root='/mnt/datastore/Aircraft',
|
||||
split='trainval', download=True)
|
||||
print(len(data.classes))
|
||||
print(len(data.samples))
|
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Taken from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
|
||||
"""
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
class ImageNetPolicy(object):
|
||||
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
|
||||
|
||||
Example:
|
||||
>>> policy = ImageNetPolicy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> ImageNetPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
|
||||
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
|
||||
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
|
||||
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
|
||||
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
|
||||
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment ImageNet Policy"
|
||||
|
||||
|
||||
class CIFAR10Policy(object):
|
||||
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
|
||||
|
||||
Example:
|
||||
>>> policy = CIFAR10Policy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> CIFAR10Policy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
|
||||
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
|
||||
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
|
||||
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
|
||||
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
|
||||
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment CIFAR10 Policy"
|
||||
|
||||
|
||||
class SVHNPolicy(object):
|
||||
""" Randomly choose one of the best 25 Sub-policies on SVHN.
|
||||
|
||||
Example:
|
||||
>>> policy = SVHNPolicy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> SVHNPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
|
||||
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
|
||||
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
|
||||
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
|
||||
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment SVHN Policy"
|
||||
|
||||
|
||||
class SubPolicy(object):
|
||||
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
|
||||
ranges = {
|
||||
"shearX": np.linspace(0, 0.3, 10),
|
||||
"shearY": np.linspace(0, 0.3, 10),
|
||||
"translateX": np.linspace(0, 150 / 331, 10),
|
||||
"translateY": np.linspace(0, 150 / 331, 10),
|
||||
"rotate": np.linspace(0, 30, 10),
|
||||
"color": np.linspace(0.0, 0.9, 10),
|
||||
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||
"solarize": np.linspace(256, 0, 10),
|
||||
"contrast": np.linspace(0.0, 0.9, 10),
|
||||
"sharpness": np.linspace(0.0, 0.9, 10),
|
||||
"brightness": np.linspace(0.0, 0.9, 10),
|
||||
"autocontrast": [0] * 10,
|
||||
"equalize": [0] * 10,
|
||||
"invert": [0] * 10
|
||||
}
|
||||
|
||||
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
|
||||
|
||||
func = {
|
||||
"shearX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"shearY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"translateX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
|
||||
fillcolor=fillcolor),
|
||||
"translateY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
|
||||
fillcolor=fillcolor),
|
||||
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
|
||||
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
|
||||
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||
}
|
||||
|
||||
self.p1 = p1
|
||||
self.operation1 = func[operation1]
|
||||
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||
self.p2 = p2
|
||||
self.operation2 = func[operation2]
|
||||
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
|
||||
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
|
||||
return img
|
@@ -0,0 +1,657 @@
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torchvision
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class CIFAR10DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'cifar10'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 10
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR10(
|
||||
root=self.valid_path, train=True, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR10(
|
||||
root=self.valid_path, train=False, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'val')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.data)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
|
||||
|
||||
class CIFAR100DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'cifar100'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 100
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR100(
|
||||
root=self.valid_path, train=True, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR100(
|
||||
root=self.valid_path, train=False, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'val')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.data)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
|
||||
|
||||
class CINIC10DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'cinic10'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 10
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/CINIC10' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/CINIC10' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = torchvision.datasets.ImageFolder(self.train_path, transform=_transforms)
|
||||
# dataset = torchvision.datasets.CIFAR10(
|
||||
# root=self.valid_path, train=True, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = torchvision.datasets.ImageFolder(self.valid_path, transform=_transforms)
|
||||
# dataset = torchvision.datasets.CIFAR10(
|
||||
# root=self.valid_path, train=False, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train_and_valid')
|
||||
# return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'test')
|
||||
# return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.47889522, 0.47227842, 0.43047404], std=[0.24205776, 0.23828046, 0.25874835])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
@@ -0,0 +1,237 @@
|
||||
import os
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
from timm.data.transforms import _pil_interp
|
||||
from timm.data.auto_augment import rand_augment_transform
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class DTDDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'dtd'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 47
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/dtd' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/dtd' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'valid')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.5329876098715876, 0.474260843249454, 0.42627281899380676],
|
||||
std=[0.26549755708788914, 0.25473554309855373, 0.2631728035662832])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
# if print_log:
|
||||
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
# (self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
# if self.distort_color == 'torch':
|
||||
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
# elif self.distort_color == 'tf':
|
||||
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
# else:
|
||||
# color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
img_size_min = min(image_size)
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
img_size_min = image_size
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in [0.5329876098715876, 0.474260843249454,
|
||||
0.42627281899380676]]),
|
||||
)
|
||||
aa_params['interpolation'] = _pil_interp('bicubic')
|
||||
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
|
||||
|
||||
# if color_transform is not None:
|
||||
# train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
# transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.Resize((image_size, image_size), interpolation=3),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
@@ -0,0 +1,241 @@
|
||||
import warnings
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import PIL
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class Flowers102DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=512, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
# warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
weights = self.make_weights_for_balanced_classes(
|
||||
train_dataset.imgs, self.n_classes)
|
||||
weights = torch.DoubleTensor(weights)
|
||||
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
|
||||
|
||||
if valid_size is not None:
|
||||
raise NotImplementedError("validation dataset not yet implemented")
|
||||
# valid_dataset = self.valid_dataset(valid_transforms)
|
||||
|
||||
# self.train = train_loader_class(
|
||||
# train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
# num_workers=n_worker, pin_memory=True)
|
||||
# self.valid = torch.utils.data.DataLoader(
|
||||
# valid_dataset, batch_size=test_batch_size,
|
||||
# num_workers=n_worker, pin_memory=True)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'flowers102'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 102
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
|
||||
self._save_path = '/mnt/datastore/Flowers102' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
|
||||
self._save_path = '/mnt/datastore/Flowers102' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
# def valid_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
# return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.test_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
# @property
|
||||
# def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def test_path(self):
|
||||
return os.path.join(self.save_path, 'test')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.5178361839861569, 0.4106749456881299, 0.32864167836880803],
|
||||
std=[0.2972239085211309, 0.24976049135203868, 0.28533308036347665])
|
||||
|
||||
@staticmethod
|
||||
def make_weights_for_balanced_classes(images, nclasses):
|
||||
count = [0] * nclasses
|
||||
|
||||
# Counts per label
|
||||
for item in images:
|
||||
count[item[1]] += 1
|
||||
|
||||
weight_per_class = [0.] * nclasses
|
||||
|
||||
# Total number of images.
|
||||
N = float(sum(count))
|
||||
|
||||
# super-sample the smaller classes.
|
||||
for i in range(nclasses):
|
||||
weight_per_class[i] = N / float(count[i])
|
||||
|
||||
weight = [0] * len(images)
|
||||
|
||||
# Calculate a weight per image.
|
||||
for idx, val in enumerate(images):
|
||||
weight[idx] = weight_per_class[val[1]]
|
||||
|
||||
return weight
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
transforms.RandomAffine(
|
||||
45, translate=(0.4, 0.4), scale=(0.75, 1.5), shear=None, resample=PIL.Image.BILINEAR, fillcolor=0),
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
@@ -0,0 +1,225 @@
|
||||
import warnings
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class ImagenetDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'imagenet'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 1000
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
# self._save_path = '/dataset/imagenet'
|
||||
# self._save_path = '/usr/local/soft/temp-datastore/ILSVRC2012' # servers
|
||||
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
# self._save_path = os.path.expanduser('~/dataset/imagenet')
|
||||
# self._save_path = os.path.expanduser('/usr/local/soft/temp-datastore/ILSVRC2012')
|
||||
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'val')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
@@ -0,0 +1,237 @@
|
||||
import os
|
||||
import math
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
# from timm.data.transforms import _pil_interp
|
||||
from timm.data.auto_augment import rand_augment_transform
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class OxfordIIITPetsDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'pets'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 37
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'valid')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.4828895122298728, 0.4448394893850807, 0.39566558230789783],
|
||||
std=[0.25925664613996574, 0.2532760018681693, 0.25981017205097917])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
# if print_log:
|
||||
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
# (self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
# if self.distort_color == 'torch':
|
||||
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
# elif self.distort_color == 'tf':
|
||||
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
# else:
|
||||
# color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
img_size_min = min(image_size)
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
img_size_min = image_size
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in [0.4828895122298728, 0.4448394893850807,
|
||||
0.39566558230789783]]),
|
||||
)
|
||||
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
|
||||
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
|
||||
|
||||
# if color_transform is not None:
|
||||
# train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
from glob import glob
|
||||
from torch.utils.data.dataset import Dataset
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def load_image(filename):
|
||||
img = Image.open(filename)
|
||||
img = img.convert('RGB')
|
||||
return img
|
||||
|
||||
|
||||
class PetDataset(Dataset):
|
||||
def __init__(self, root, train=True, num_cl=37, val_split=0.15, transforms=None):
|
||||
pt_name = os.path.join(root, '{}{}.pth'.format('train' if train else 'test',
|
||||
int(100 * (1 - val_split)) if train else int(
|
||||
100 * val_split)))
|
||||
if not os.path.exists(pt_name):
|
||||
filenames = glob(os.path.join(root, 'images') + '/*.jpg')
|
||||
classes = set()
|
||||
|
||||
data = []
|
||||
labels = []
|
||||
|
||||
for image in filenames:
|
||||
class_name = image.rsplit("/", 1)[1].rsplit('_', 1)[0]
|
||||
classes.add(class_name)
|
||||
img = load_image(image)
|
||||
|
||||
data.append(img)
|
||||
labels.append(class_name)
|
||||
|
||||
# convert classnames to indices
|
||||
class2idx = {cl: idx for idx, cl in enumerate(classes)}
|
||||
labels = torch.Tensor(list(map(lambda x: class2idx[x], labels))).long()
|
||||
data = list(zip(data, labels))
|
||||
|
||||
class_values = [[] for x in range(num_cl)]
|
||||
|
||||
# create arrays for each class type
|
||||
for d in data:
|
||||
class_values[d[1].item()].append(d)
|
||||
|
||||
train_data = []
|
||||
val_data = []
|
||||
|
||||
for class_dp in class_values:
|
||||
split_idx = int(len(class_dp) * (1 - val_split))
|
||||
train_data += class_dp[:split_idx]
|
||||
val_data += class_dp[split_idx:]
|
||||
torch.save(train_data, os.path.join(root, 'train{}.pth'.format(int(100 * (1 - val_split)))))
|
||||
torch.save(val_data, os.path.join(root, 'test{}.pth'.format(int(100 * val_split))))
|
||||
|
||||
self.data = torch.load(pt_name)
|
||||
self.len = len(self.data)
|
||||
self.transform = transforms
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, label = self.data[index]
|
||||
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, label
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
@@ -0,0 +1,226 @@
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torchvision
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class STL10DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'stl10'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 10
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/STL10' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/STL10' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = torchvision.datasets.STL10(
|
||||
root=self.valid_path, split='train', download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = torchvision.datasets.STL10(
|
||||
root=self.valid_path, split='test', download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'val')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.44671097, 0.4398105, 0.4066468],
|
||||
std=[0.2603405, 0.25657743, 0.27126738])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.data)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
@@ -0,0 +1,4 @@
|
||||
from ofa.imagenet_codebase.networks.proxyless_nets import ProxylessNASNets, proxyless_base, MobileNetV2
|
||||
from ofa.imagenet_codebase.networks.mobilenet_v3 import MobileNetV3, MobileNetV3Large
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks.nsganetv2 import NSGANetV2
|
||||
|
@@ -0,0 +1,126 @@
|
||||
from timm.models.layers import drop_path
|
||||
from ofa.imagenet_codebase.modules.layers import *
|
||||
from ofa.imagenet_codebase.networks import MobileNetV3
|
||||
|
||||
|
||||
class MobileInvertedResidualBlock(MyModule):
|
||||
"""
|
||||
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
|
||||
imagenet_codebase/networks/proxyless_nets.py to include drop path in training
|
||||
|
||||
"""
|
||||
def __init__(self, mobile_inverted_conv, shortcut, drop_connect_rate=0.0):
|
||||
super(MobileInvertedResidualBlock, self).__init__()
|
||||
|
||||
self.mobile_inverted_conv = mobile_inverted_conv
|
||||
self.shortcut = shortcut
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
def forward(self, x):
|
||||
if self.mobile_inverted_conv is None or isinstance(self.mobile_inverted_conv, ZeroLayer):
|
||||
res = x
|
||||
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
|
||||
res = self.mobile_inverted_conv(x)
|
||||
else:
|
||||
# res = self.mobile_inverted_conv(x) + self.shortcut(x)
|
||||
res = self.mobile_inverted_conv(x)
|
||||
|
||||
if self.drop_connect_rate > 0.:
|
||||
res = drop_path(res, drop_prob=self.drop_connect_rate, training=self.training)
|
||||
|
||||
res += self.shortcut(x)
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return '(%s, %s)' % (
|
||||
self.mobile_inverted_conv.module_str if self.mobile_inverted_conv is not None else None,
|
||||
self.shortcut.module_str if self.shortcut is not None else None
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': MobileInvertedResidualBlock.__name__,
|
||||
'mobile_inverted_conv': self.mobile_inverted_conv.config if self.mobile_inverted_conv is not None else None,
|
||||
'shortcut': self.shortcut.config if self.shortcut is not None else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
mobile_inverted_conv = set_layer_from_config(config['mobile_inverted_conv'])
|
||||
shortcut = set_layer_from_config(config['shortcut'])
|
||||
return MobileInvertedResidualBlock(
|
||||
mobile_inverted_conv, shortcut, drop_connect_rate=config['drop_connect_rate'])
|
||||
|
||||
|
||||
class NSGANetV2(MobileNetV3):
|
||||
"""
|
||||
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
|
||||
imagenet_codebase/networks/mobilenet_v3.py to include drop path in training
|
||||
and option to reset classification layer
|
||||
"""
|
||||
@staticmethod
|
||||
def build_from_config(config, drop_connect_rate=0.0):
|
||||
first_conv = set_layer_from_config(config['first_conv'])
|
||||
final_expand_layer = set_layer_from_config(config['final_expand_layer'])
|
||||
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
|
||||
classifier = set_layer_from_config(config['classifier'])
|
||||
|
||||
blocks = []
|
||||
for block_idx, block_config in enumerate(config['blocks']):
|
||||
block_config['drop_connect_rate'] = drop_connect_rate * block_idx / len(config['blocks'])
|
||||
blocks.append(MobileInvertedResidualBlock.build_from_config(block_config))
|
||||
|
||||
net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
|
||||
if 'bn' in config:
|
||||
net.set_bn_param(**config['bn'])
|
||||
else:
|
||||
net.set_bn_param(momentum=0.1, eps=1e-3)
|
||||
|
||||
return net
|
||||
|
||||
def zero_last_gamma(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, MobileInvertedResidualBlock):
|
||||
if isinstance(m.mobile_inverted_conv, MBInvertedConvLayer) and isinstance(m.shortcut, IdentityLayer):
|
||||
m.mobile_inverted_conv.point_linear.bn.weight.data.zero_()
|
||||
|
||||
@staticmethod
|
||||
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
|
||||
# first conv layer
|
||||
first_conv = ConvLayer(
|
||||
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act'
|
||||
)
|
||||
# build mobile blocks
|
||||
feature_dim = input_channel
|
||||
blocks = []
|
||||
for stage_id, block_config_list in cfg.items():
|
||||
for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list:
|
||||
mb_conv = MBInvertedConvLayer(
|
||||
feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se
|
||||
)
|
||||
if stride == 1 and out_channel == feature_dim:
|
||||
shortcut = IdentityLayer(out_channel, out_channel)
|
||||
else:
|
||||
shortcut = None
|
||||
blocks.append(MobileInvertedResidualBlock(mb_conv, shortcut))
|
||||
feature_dim = out_channel
|
||||
# final expand layer
|
||||
final_expand_layer = ConvLayer(
|
||||
feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act',
|
||||
)
|
||||
feature_dim = feature_dim * 6
|
||||
# feature mix layer
|
||||
feature_mix_layer = ConvLayer(
|
||||
feature_dim, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
|
||||
)
|
||||
# classifier
|
||||
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
||||
|
||||
@staticmethod
|
||||
def reset_classifier(model, last_channel, n_classes, dropout_rate=0.0):
|
||||
model.classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
@@ -0,0 +1,309 @@
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.imagenet import *
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.cifar import *
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets import *
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import *
|
||||
|
||||
from ofa.imagenet_codebase.run_manager.run_manager import *
|
||||
|
||||
|
||||
class ImagenetRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=1e-4, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='imagenet', train_batch_size=128, test_batch_size=512, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/ILSVRC2012',
|
||||
**kwargs):
|
||||
super(ImagenetRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.imagenet_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == ImagenetDataProvider.name():
|
||||
DataProviderClass = ImagenetDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.imagenet_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class CIFARRunConfig(RunConfig):
|
||||
def __init__(self, n_epochs=5, init_lr=0.01, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='cifar10', train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224,
|
||||
data_path='/mnt/datastore/CIFAR',
|
||||
**kwargs):
|
||||
super(CIFARRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.cifar_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == CIFAR10DataProvider.name():
|
||||
DataProviderClass = CIFAR10DataProvider
|
||||
elif self.dataset == CIFAR100DataProvider.name():
|
||||
DataProviderClass = CIFAR100DataProvider
|
||||
elif self.dataset == CINIC10DataProvider.name():
|
||||
DataProviderClass = CINIC10DataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.cifar_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class Flowers102RunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=3, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='flowers102', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
|
||||
data_path='/mnt/datastore/Flowers102',
|
||||
**kwargs):
|
||||
super(Flowers102RunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.flowers102_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == Flowers102DataProvider.name():
|
||||
DataProviderClass = Flowers102DataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.flowers102_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class STL10RunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=5, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='stl10', train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
|
||||
data_path='/mnt/datastore/STL10',
|
||||
**kwargs):
|
||||
super(STL10RunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.stl10_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == STL10DataProvider.name():
|
||||
DataProviderClass = STL10DataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.stl10_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class DTDRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='dtd', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/dtd',
|
||||
**kwargs):
|
||||
super(DTDRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == DTDDataProvider.name():
|
||||
DataProviderClass = DTDDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class PetsRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='pets', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/Oxford-IIITPets',
|
||||
**kwargs):
|
||||
super(PetsRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.imagenet_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == OxfordIIITPetsDataProvider.name():
|
||||
DataProviderClass = OxfordIIITPetsDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.imagenet_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class AircraftRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='aircraft', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/Aircraft',
|
||||
**kwargs):
|
||||
super(AircraftRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == FGVCAircraftDataProvider.name():
|
||||
DataProviderClass = FGVCAircraftDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
def get_run_config(**kwargs):
|
||||
if kwargs['dataset'] == 'imagenet':
|
||||
run_config = ImagenetRunConfig(**kwargs)
|
||||
elif kwargs['dataset'].startswith('cifar') or kwargs['dataset'].startswith('cinic'):
|
||||
run_config = CIFARRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'flowers102':
|
||||
run_config = Flowers102RunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'stl10':
|
||||
run_config = STL10RunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'dtd':
|
||||
run_config = DTDRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'pets':
|
||||
run_config = PetsRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'aircraft':
|
||||
run_config = AircraftRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'aircraft100':
|
||||
run_config = AircraftRunConfig(**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return run_config
|
||||
|
||||
|
@@ -0,0 +1,122 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
import torchvision.utils
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import FGVCAircraft
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets2 import PetDataset
|
||||
import torch.utils.data as Data
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.autoaugment import CIFAR10Policy
|
||||
|
||||
|
||||
def get_dataset(data_name, batch_size, data_path, num_workers,
|
||||
img_size, autoaugment, cutout, cutout_length):
|
||||
num_class_dict = {
|
||||
'cifar100': 100,
|
||||
'cifar10': 10,
|
||||
'mnist': 10,
|
||||
'aircraft': 100,
|
||||
'svhn': 10,
|
||||
'pets': 37
|
||||
}
|
||||
# 'aircraft30': 30,
|
||||
# 'aircraft100': 100,
|
||||
|
||||
train_transform, valid_transform = _data_transforms(
|
||||
data_name, img_size, autoaugment, cutout, cutout_length)
|
||||
if data_name == 'cifar100':
|
||||
train_data = torchvision.datasets.CIFAR100(
|
||||
root=data_path, train=True, download=True, transform=train_transform)
|
||||
valid_data = torchvision.datasets.CIFAR100(
|
||||
root=data_path, train=False, download=True, transform=valid_transform)
|
||||
elif data_name == 'cifar10':
|
||||
train_data = torchvision.datasets.CIFAR10(
|
||||
root=data_path, train=True, download=True, transform=train_transform)
|
||||
valid_data = torchvision.datasets.CIFAR10(
|
||||
root=data_path, train=False, download=True, transform=valid_transform)
|
||||
elif data_name.startswith('aircraft'):
|
||||
print(data_path)
|
||||
if 'aircraft100' in data_path:
|
||||
data_path = data_path.replace('aircraft100', 'aircraft/fgvc-aircraft-2013b')
|
||||
else:
|
||||
data_path = data_path.replace('aircraft', 'aircraft/fgvc-aircraft-2013b')
|
||||
train_data = FGVCAircraft(data_path, class_type='variant', split='trainval',
|
||||
transform=train_transform, download=True)
|
||||
valid_data = FGVCAircraft(data_path, class_type='variant', split='test',
|
||||
transform=valid_transform, download=True)
|
||||
elif data_name.startswith('pets'):
|
||||
train_data = PetDataset(data_path, train=True, num_cl=37,
|
||||
val_split=0.15, transforms=train_transform)
|
||||
valid_data = PetDataset(data_path, train=False, num_cl=37,
|
||||
val_split=0.15, transforms=valid_transform)
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=batch_size, shuffle=True, pin_memory=True,
|
||||
num_workers=num_workers)
|
||||
|
||||
valid_queue = torch.utils.data.DataLoader(
|
||||
valid_data, batch_size=200, shuffle=False, pin_memory=True,
|
||||
num_workers=num_workers)
|
||||
|
||||
return train_queue, valid_queue, num_class_dict[data_name]
|
||||
|
||||
|
||||
|
||||
class Cutout(object):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1: y2, x1: x2] = 0.
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
|
||||
def _data_transforms(data_name, img_size, autoaugment, cutout, cutout_length):
|
||||
if 'cifar' in data_name:
|
||||
norm_mean = [0.49139968, 0.48215827, 0.44653124]
|
||||
norm_std = [0.24703233, 0.24348505, 0.26158768]
|
||||
elif 'aircraft' in data_name:
|
||||
norm_mean = [0.48933587508932375, 0.5183537408957618, 0.5387914411673883]
|
||||
norm_std = [0.22388883112804625, 0.21641635409388751, 0.24615605842636115]
|
||||
elif 'pets' in data_name:
|
||||
norm_mean = [0.4828895122298728, 0.4448394893850807, 0.39566558230789783]
|
||||
norm_std = [0.25925664613996574, 0.2532760018681693, 0.25981017205097917]
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
|
||||
transforms.RandomHorizontalFlip(),
|
||||
])
|
||||
|
||||
if autoaugment:
|
||||
train_transform.transforms.append(CIFAR10Policy())
|
||||
|
||||
train_transform.transforms.append(transforms.ToTensor())
|
||||
|
||||
if cutout:
|
||||
train_transform.transforms.append(Cutout(cutout_length))
|
||||
|
||||
train_transform.transforms.append(transforms.Normalize(norm_mean, norm_std))
|
||||
|
||||
valid_transform = transforms.Compose([
|
||||
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(norm_mean, norm_std),
|
||||
])
|
||||
return train_transform, valid_transform
|
@@ -0,0 +1,233 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import sys
|
||||
import transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.run_manager import get_run_config
|
||||
from ofa.elastic_nn.networks import OFAMobileNetV3
|
||||
from ofa.imagenet_codebase.run_manager import RunManager
|
||||
from ofa.elastic_nn.modules.dynamic_op import DynamicSeparableConv2d
|
||||
from torchprofile import profile_macs
|
||||
import copy
|
||||
import json
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = 1
|
||||
|
||||
|
||||
class ArchManager:
|
||||
def __init__(self):
|
||||
self.num_blocks = 20
|
||||
self.num_stages = 5
|
||||
self.kernel_sizes = [3, 5, 7]
|
||||
self.expand_ratios = [3, 4, 6]
|
||||
self.depths = [2, 3, 4]
|
||||
self.resolutions = [160, 176, 192, 208, 224]
|
||||
|
||||
def random_sample(self):
|
||||
sample = {}
|
||||
d = []
|
||||
e = []
|
||||
ks = []
|
||||
for i in range(self.num_stages):
|
||||
d.append(random.choice(self.depths))
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
e.append(random.choice(self.expand_ratios))
|
||||
ks.append(random.choice(self.kernel_sizes))
|
||||
|
||||
sample = {
|
||||
'wid': None,
|
||||
'ks': ks,
|
||||
'e': e,
|
||||
'd': d,
|
||||
'r': [random.choice(self.resolutions)]
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def random_resample(self, sample, i):
|
||||
assert i >= 0 and i < self.num_blocks
|
||||
sample['ks'][i] = random.choice(self.kernel_sizes)
|
||||
sample['e'][i] = random.choice(self.expand_ratios)
|
||||
|
||||
def random_resample_depth(self, sample, i):
|
||||
assert i >= 0 and i < self.num_stages
|
||||
sample['d'][i] = random.choice(self.depths)
|
||||
|
||||
def random_resample_resolution(self, sample):
|
||||
sample['r'][0] = random.choice(self.resolutions)
|
||||
|
||||
|
||||
def parse_string_list(string):
|
||||
if isinstance(string, str):
|
||||
# convert '[5 5 5 7 7 7 3 3 7 7 7 3 3]' to [5, 5, 5, 7, 7, 7, 3, 3, 7, 7, 7, 3, 3]
|
||||
return list(map(int, string[1:-1].split()))
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def pad_none(x, depth, max_depth):
|
||||
new_x, counter = [], 0
|
||||
for d in depth:
|
||||
for _ in range(d):
|
||||
new_x.append(x[counter])
|
||||
counter += 1
|
||||
if d < max_depth:
|
||||
new_x += [None] * (max_depth - d)
|
||||
return new_x
|
||||
|
||||
|
||||
def get_net_info(net, data_shape, measure_latency=None, print_info=True, clean=False, lut=None):
|
||||
net_info = eval_utils.get_net_info(
|
||||
net, data_shape, measure_latency, print_info=print_info, clean=clean, lut=lut)
|
||||
|
||||
gpu_latency, cpu_latency = None, None
|
||||
for k in net_info.keys():
|
||||
if 'gpu' in k:
|
||||
gpu_latency = np.round(net_info[k]['val'], 2)
|
||||
if 'cpu' in k:
|
||||
cpu_latency = np.round(net_info[k]['val'], 2)
|
||||
|
||||
return {
|
||||
'params': np.round(net_info['params'] / 1e6, 2),
|
||||
'flops': np.round(net_info['flops'] / 1e6, 2),
|
||||
'gpu': gpu_latency, 'cpu': cpu_latency
|
||||
}
|
||||
|
||||
|
||||
def validate_config(config, max_depth=4):
|
||||
kernel_size, exp_ratio, depth = config['ks'], config['e'], config['d']
|
||||
|
||||
if isinstance(kernel_size, str): kernel_size = parse_string_list(kernel_size)
|
||||
if isinstance(exp_ratio, str): exp_ratio = parse_string_list(exp_ratio)
|
||||
if isinstance(depth, str): depth = parse_string_list(depth)
|
||||
|
||||
assert (isinstance(kernel_size, list) or isinstance(kernel_size, int))
|
||||
assert (isinstance(exp_ratio, list) or isinstance(exp_ratio, int))
|
||||
assert isinstance(depth, list)
|
||||
|
||||
if len(kernel_size) < len(depth) * max_depth:
|
||||
kernel_size = pad_none(kernel_size, depth, max_depth)
|
||||
if len(exp_ratio) < len(depth) * max_depth:
|
||||
exp_ratio = pad_none(exp_ratio, depth, max_depth)
|
||||
|
||||
# return {'ks': kernel_size, 'e': exp_ratio, 'd': depth, 'w': config['w']}
|
||||
return {'ks': kernel_size, 'e': exp_ratio, 'd': depth}
|
||||
|
||||
|
||||
def set_nas_test_dataset(path, test_data_name, max_img):
|
||||
if not test_data_name in ['mnist', 'svhn', 'cifar10',
|
||||
'cifar100', 'aircraft', 'pets']: raise ValueError(test_data_name)
|
||||
|
||||
dpath = path
|
||||
num_cls = 10 # mnist, svhn, cifar10
|
||||
if test_data_name in ['cifar100', 'aircraft']:
|
||||
num_cls = 100
|
||||
elif test_data_name == 'pets':
|
||||
num_cls = 37
|
||||
|
||||
x = torch.load(dpath + f'/{test_data_name}bylabel')
|
||||
img_per_cls = min(int(max_img / num_cls), 20)
|
||||
return x, img_per_cls, num_cls
|
||||
|
||||
|
||||
class OFAEvaluator:
|
||||
""" based on OnceForAll supernet taken from https://github.com/mit-han-lab/once-for-all """
|
||||
|
||||
def __init__(self, num_gen_arch, img_size, drop_path,
|
||||
n_classes=1000,
|
||||
model_path=None,
|
||||
kernel_size=None, exp_ratio=None, depth=None):
|
||||
# default configurations
|
||||
self.kernel_size = [3, 5, 7] if kernel_size is None else kernel_size # depth-wise conv kernel size
|
||||
self.exp_ratio = [3, 4, 6] if exp_ratio is None else exp_ratio # expansion rate
|
||||
self.depth = [2, 3, 4] if depth is None else depth # number of MB block repetition
|
||||
|
||||
if 'w1.0' in model_path:
|
||||
self.width_mult = 1.0
|
||||
elif 'w1.2' in model_path:
|
||||
self.width_mult = 1.2
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.engine = OFAMobileNetV3(
|
||||
n_classes=n_classes,
|
||||
dropout_rate=0, width_mult_list=self.width_mult, ks_list=self.kernel_size,
|
||||
expand_ratio_list=self.exp_ratio, depth_list=self.depth)
|
||||
|
||||
|
||||
init = torch.load(model_path, map_location='cpu')['state_dict']
|
||||
self.engine.load_weights_from_net(init)
|
||||
print(f'load {model_path}...')
|
||||
|
||||
## metad2a
|
||||
self.arch_manager = ArchManager()
|
||||
self.num_gen_arch = num_gen_arch
|
||||
|
||||
|
||||
def sample_random_architecture(self):
|
||||
sampled_architecture = self.arch_manager.random_sample()
|
||||
return sampled_architecture
|
||||
|
||||
def get_architecture(self, bound=None):
|
||||
g_lst, pred_acc_lst, x_lst = [], [], []
|
||||
searched_g, max_pred_acc = None, 0
|
||||
|
||||
with torch.no_grad():
|
||||
for n in range(self.num_gen_arch):
|
||||
file_acc = self.lines[n].split()[0]
|
||||
g_dict = ' '.join(self.lines[n].split())
|
||||
g = json.loads(g_dict.replace("'", "\""))
|
||||
|
||||
if bound is not None:
|
||||
subnet, config = self.sample(config=g)
|
||||
net = NSGANetV2.build_from_config(subnet.config,
|
||||
drop_connect_rate=self.drop_path)
|
||||
inputs = torch.randn(1, 3, self.img_size, self.img_size)
|
||||
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
|
||||
if flops <= bound:
|
||||
searched_g = g
|
||||
break
|
||||
else:
|
||||
searched_g = g
|
||||
pred_acc_lst.append(file_acc)
|
||||
break
|
||||
|
||||
if searched_g is None:
|
||||
raise ValueError(searched_g)
|
||||
return searched_g, pred_acc_lst
|
||||
|
||||
|
||||
def sample(self, config=None):
|
||||
""" randomly sample a sub-network """
|
||||
if config is not None:
|
||||
config = validate_config(config)
|
||||
self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d'])
|
||||
else:
|
||||
config = self.engine.sample_active_subnet()
|
||||
|
||||
subnet = self.engine.get_active_subnet(preserve_weight=True)
|
||||
return subnet, config
|
||||
|
||||
@staticmethod
|
||||
def save_net_config(path, net, config_name='net.config'):
|
||||
""" dump run_config and net_config to the model_folder """
|
||||
net_save_path = os.path.join(path, config_name)
|
||||
json.dump(net.config, open(net_save_path, 'w'), indent=4)
|
||||
print('Network configs dump to %s' % net_save_path)
|
||||
|
||||
@staticmethod
|
||||
def save_net(path, net, model_name):
|
||||
""" dump net weight as checkpoint """
|
||||
if isinstance(net, torch.nn.DataParallel):
|
||||
checkpoint = {'state_dict': net.module.state_dict()}
|
||||
else:
|
||||
checkpoint = {'state_dict': net.state_dict()}
|
||||
model_path = os.path.join(path, model_name)
|
||||
torch.save(checkpoint, model_path)
|
||||
print('Network model dump to %s' % model_path)
|
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import torch.optim as optim
|
||||
from evaluator import OFAEvaluator
|
||||
from torchprofile import profile_macs
|
||||
from codebase.networks import NSGANetV2
|
||||
from parser import get_parse
|
||||
from eval_utils import get_dataset
|
||||
|
||||
|
||||
args = get_parse()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
device_list = [int(_) for _ in args.gpu.split(',')]
|
||||
args.n_gpus = len(device_list)
|
||||
args.device = torch.device("cuda:0")
|
||||
|
||||
if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
|
||||
evaluator = OFAEvaluator(args,
|
||||
model_path='../.torch/ofa_nets/ofa_mbv3_d234_e346_k357_w1.0')
|
||||
|
||||
args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
|
||||
if args.model_config.startswith('flops@'):
|
||||
args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
|
||||
else:
|
||||
args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
|
||||
if not os.path.exists(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
|
||||
args.data_path = os.path.join(args.data_path, args.data_name)
|
||||
|
||||
log_format = '%(asctime)s %(message)s'
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
||||
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
|
||||
fh.setFormatter(logging.Formatter(log_format))
|
||||
logging.getLogger().addHandler(fh)
|
||||
if not torch.cuda.is_available():
|
||||
logging.info('no gpu self.args.device available')
|
||||
sys.exit(1)
|
||||
logging.info("args = %s", args)
|
||||
|
||||
|
||||
|
||||
def set_architecture(n_cls):
|
||||
if args.model_config.startswith('flops@'):
|
||||
names = {'cifar10': 'CIFAR-10', 'cifar100': 'CIFAR-100',
|
||||
'aircraft100': 'Aircraft', 'pets': 'Pets'}
|
||||
p = os.path.join('./searched-architectures/{}/net-{}/net.subnet'.
|
||||
format(names[args.data_name], args.model_config))
|
||||
g = json.load(open(p))
|
||||
else:
|
||||
g, acc = evaluator.get_architecture(args)
|
||||
|
||||
subnet, config = evaluator.sample(g)
|
||||
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=args.drop_path)
|
||||
net.load_state_dict(subnet.state_dict())
|
||||
|
||||
NSGANetV2.reset_classifier(
|
||||
net, last_channel=net.classifier.in_features,
|
||||
n_classes=n_cls, dropout_rate=args.drop)
|
||||
# calculate #Paramaters and #FLOPS
|
||||
inputs = torch.randn(1, 3, args.img_size, args.img_size)
|
||||
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
|
||||
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
|
||||
net_name = "net_flops@{:.0f}".format(flops)
|
||||
logging.info('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
|
||||
OFAEvaluator.save_net_config(args.save_path, net, net_name + '.config')
|
||||
if args.n_gpus > 1:
|
||||
net = nn.DataParallel(net) # data parallel in case more than 1 gpu available
|
||||
net = net.to(args.device)
|
||||
|
||||
return net, net_name
|
||||
|
||||
|
||||
def train(train_queue, net, criterion, optimizer):
|
||||
net.train()
|
||||
train_loss, correct, total = 0, 0, 0
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
# upsample by bicubic to match imagenet training size
|
||||
inputs, targets = inputs.to(args.device), targets.to(args.device)
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
train_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('train %03d %e %f', step, train_loss / total, 100. * correct / total)
|
||||
logging.info('train acc %f', 100. * correct / total)
|
||||
return train_loss / total, 100. * correct / total
|
||||
|
||||
|
||||
def infer(valid_queue, net, criterion, early_stop=False):
|
||||
net.eval()
|
||||
test_loss, correct, total = 0, 0, 0
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
inputs, targets = inputs.to(args.device), targets.to(args.device)
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
test_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('valid %03d %e %f', step, test_loss / total, 100. * correct / total)
|
||||
if early_stop and step == 10:
|
||||
break
|
||||
acc = 100. * correct / total
|
||||
logging.info('valid acc %f', 100. * correct / total)
|
||||
|
||||
return test_loss / total, acc
|
||||
|
||||
|
||||
def main():
|
||||
best_acc, top_checkpoints = 0, []
|
||||
|
||||
train_queue, valid_queue, n_cls = get_dataset(args)
|
||||
net, net_name = set_architecture(n_cls)
|
||||
parameters = filter(lambda p: p.requires_grad, net.parameters())
|
||||
optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
criterion = nn.CrossEntropyLoss().to(args.device)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
|
||||
|
||||
train(train_queue, net, criterion, optimizer)
|
||||
_, valid_acc = infer(valid_queue, net, criterion)
|
||||
# checkpoint saving
|
||||
|
||||
if len(top_checkpoints) < args.topk:
|
||||
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
|
||||
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
|
||||
else:
|
||||
idx = np.argmin([x[1] for x in top_checkpoints])
|
||||
if valid_acc > top_checkpoints[idx][1]:
|
||||
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
|
||||
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
|
||||
# remove the idx
|
||||
os.remove(top_checkpoints[idx][0])
|
||||
top_checkpoints.pop(idx)
|
||||
print(top_checkpoints)
|
||||
if valid_acc > best_acc:
|
||||
OFAEvaluator.save_net(args.save_path, net, net_name + '.best')
|
||||
best_acc = valid_acc
|
||||
scheduler.step()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@@ -0,0 +1,43 @@
|
||||
import argparse
|
||||
|
||||
def get_parse():
|
||||
parser = argparse.ArgumentParser(description='MetaD2A vs NSGANETv2')
|
||||
parser.add_argument('--save-path', type=str, default='../results', help='the path of save directory')
|
||||
parser.add_argument('--data-path', type=str, default='../data', help='the path of save directory')
|
||||
parser.add_argument('--data-name', type=str, default=None, help='meta-test dataset name')
|
||||
parser.add_argument('--num-gen-arch', type=int, default=200,
|
||||
help='the number of candidate architectures generated by the generator')
|
||||
parser.add_argument('--bound', type=int, default=None)
|
||||
|
||||
# original setting
|
||||
parser.add_argument('--seed', type=int, default=-1, help='random seed')
|
||||
parser.add_argument('--batch-size', type=int, default=96, help='batch size')
|
||||
parser.add_argument('--num_workers', type=int, default=2, help='number of workers for data loading')
|
||||
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
|
||||
parser.add_argument('--lr', type=float, default=0.01, help='init learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, default=4e-5, help='weight decay')
|
||||
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
|
||||
parser.add_argument('--epochs', type=int, default=150, help='num of training epochs')
|
||||
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
|
||||
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
|
||||
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
|
||||
parser.add_argument('--autoaugment', action='store_true', default=True, help='use auto augmentation')
|
||||
|
||||
parser.add_argument('--topk', type=int, default=10, help='top k checkpoints to save')
|
||||
parser.add_argument('--evaluate', action='store_true', default=False, help='evaluate a pretrained model')
|
||||
# model related
|
||||
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
|
||||
help='Name of model to train (default: "countception"')
|
||||
parser.add_argument('--model-config', type=str, default='search',
|
||||
help='location of a json file of specific model declaration')
|
||||
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
|
||||
help='Initialize model from this checkpoint (default: none)')
|
||||
parser.add_argument('--drop', type=float, default=0.2,
|
||||
help='dropout rate')
|
||||
parser.add_argument('--drop-path', type=float, default=0.2, metavar='PCT',
|
||||
help='Drop path rate (default: None)')
|
||||
parser.add_argument('--img-size', type=int, default=224,
|
||||
help='input resolution (192 -> 256)')
|
||||
args = parser.parse_args()
|
||||
return args
|
@@ -0,0 +1,261 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import torch.optim as optim
|
||||
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.evaluator import OFAEvaluator
|
||||
from torchprofile import profile_macs
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.parser import get_parse
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils import get_dataset
|
||||
from transfer_nag_lib.MetaD2A_nas_bench_201.metad2a_utils import reset_seed
|
||||
from transfer_nag_lib.ofa_net import OFASubNet
|
||||
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
# device_list = [int(_) for _ in args.gpu.split(',')]
|
||||
# args.n_gpus = len(device_list)
|
||||
# args.device = torch.device("cuda:0")
|
||||
|
||||
# if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
|
||||
# torch.cuda.manual_seed(args.seed)
|
||||
# torch.manual_seed(args.seed)
|
||||
# np.random.seed(args.seed)
|
||||
# random.seed(args.seed)
|
||||
|
||||
|
||||
|
||||
# args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
|
||||
# if args.model_config.startswith('flops@'):
|
||||
# args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
|
||||
# else:
|
||||
# args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
|
||||
# if not os.path.exists(args.save_path):
|
||||
# os.makedirs(args.save_path)
|
||||
|
||||
# args.data_path = os.path.join(args.data_path, args.data_name)
|
||||
|
||||
# log_format = '%(asctime)s %(message)s'
|
||||
# logging.basicConfig(stream=sys.stdout, level=print,
|
||||
# format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
# fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
|
||||
# fh.setFormatter(logging.Formatter(log_format))
|
||||
# logging.getLogger().addHandler(fh)
|
||||
# if not torch.cuda.is_available():
|
||||
# print('no gpu self.args.device available')
|
||||
# sys.exit(1)
|
||||
# print("args = %s", args)
|
||||
|
||||
|
||||
|
||||
def set_architecture(n_cls, evaluator, drop_path, drop, img_size, n_gpus, device, save_path, model_str):
|
||||
# g, acc = evaluator.get_architecture(model_str)
|
||||
g = OFASubNet(model_str).get_op_dict()
|
||||
subnet, config = evaluator.sample(g)
|
||||
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=drop_path)
|
||||
net.load_state_dict(subnet.state_dict())
|
||||
|
||||
NSGANetV2.reset_classifier(
|
||||
net, last_channel=net.classifier.in_features,
|
||||
n_classes=n_cls, dropout_rate=drop)
|
||||
# calculate #Paramaters and #FLOPS
|
||||
inputs = torch.randn(1, 3, img_size, img_size)
|
||||
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
|
||||
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
|
||||
net_name = "net_flops@{:.0f}".format(flops)
|
||||
print('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
|
||||
# OFAEvaluator.save_net_config(save_path, net, net_name + '.config')
|
||||
if torch.cuda.device_count() > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
net = nn.DataParallel(net)
|
||||
net = net.to(device)
|
||||
|
||||
return net, net_name, params, flops
|
||||
|
||||
|
||||
def train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq):
|
||||
net.train()
|
||||
train_loss, correct, total = 0, 0, 0
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
# upsample by bicubic to match imagenet training size
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
|
||||
optimizer.step()
|
||||
train_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % report_freq == 0:
|
||||
print(f'train step {step:03d} loss {train_loss / total:.4f} train acc {100. * correct / total:.4f}')
|
||||
print(f'train acc {100. * correct / total:.4f}')
|
||||
return train_loss / total, 100. * correct / total
|
||||
|
||||
|
||||
def infer(valid_queue, net, criterion, device, report_freq, early_stop=False):
|
||||
net.eval()
|
||||
test_loss, correct, total = 0, 0, 0
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
test_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % report_freq == 0:
|
||||
print(f'valid {step:03d} {test_loss / total:.4f} {100. * correct / total:.4f}')
|
||||
if early_stop and step == 10:
|
||||
break
|
||||
acc = 100. * correct / total
|
||||
print('valid acc {:.4f}'.format(100. * correct / total))
|
||||
|
||||
return test_loss / total, acc
|
||||
|
||||
|
||||
def train_single_model(save_path, workers, datasets, xpaths, splits, use_less,
|
||||
seed, model_str, device,
|
||||
lr=0.01,
|
||||
momentum=0.9,
|
||||
weight_decay=4e-5,
|
||||
report_freq=50,
|
||||
epochs=150,
|
||||
grad_clip=5,
|
||||
cutout=True,
|
||||
cutout_length=16,
|
||||
autoaugment=True,
|
||||
drop=0.2,
|
||||
drop_path=0.2,
|
||||
img_size=224,
|
||||
batch_size=96,
|
||||
):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
reset_seed(seed)
|
||||
# save_dir = Path(save_dir)
|
||||
# logger = Logger(str(save_dir), 0, False)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
to_save_name = save_path + '/seed-{:04d}.pth'.format(seed)
|
||||
print(to_save_name)
|
||||
# args = get_parse()
|
||||
num_gen_arch = None
|
||||
evaluator = OFAEvaluator(num_gen_arch, img_size, drop_path,
|
||||
model_path='/home/data/GTAD/checkpoints/ofa/ofa_net/ofa_mbv3_d234_e346_k357_w1.0')
|
||||
|
||||
train_queue, valid_queue, n_cls = get_dataset(datasets, batch_size,
|
||||
xpaths, workers, img_size, autoaugment, cutout, cutout_length)
|
||||
net, net_name, params, flops = set_architecture(n_cls, evaluator,
|
||||
drop_path, drop, img_size, n_gpus=1, device=device, save_path=save_path, model_str=model_str)
|
||||
|
||||
|
||||
# net.to(device)
|
||||
|
||||
parameters = filter(lambda p: p.requires_grad, net.parameters())
|
||||
optimizer = optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
|
||||
criterion = nn.CrossEntropyLoss().to(device)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
|
||||
|
||||
# assert epochs == 1
|
||||
max_valid_acc = 0
|
||||
max_epoch = 0
|
||||
for epoch in range(epochs):
|
||||
print('epoch {:d} lr {:.4f}'.format(epoch, scheduler.get_lr()[0]))
|
||||
|
||||
train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq)
|
||||
_, valid_acc = infer(valid_queue, net, criterion, device, report_freq)
|
||||
torch.save(valid_acc, to_save_name)
|
||||
print(f'seed {seed:04d} last acc {valid_acc:.4f} max acc {max_valid_acc:.4f}')
|
||||
if max_valid_acc < valid_acc:
|
||||
max_valid_acc = valid_acc
|
||||
max_epoch = epoch
|
||||
# parent_path = os.path.abspath(os.path.join(save_path, os.pardir))
|
||||
# with open(parent_path + '/accuracy.txt', 'a+') as f:
|
||||
# f.write(f'{model_str} seed {seed:04d} {valid_acc:.4f}\n')
|
||||
|
||||
return valid_acc, max_valid_acc, params, flops
|
||||
|
||||
|
||||
################ NAS BENCH 201 #####################
|
||||
# def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
|
||||
# seeds, model_str, arch_config):
|
||||
# assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
# torch.backends.cudnn.enabled = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.set_num_threads(workers)
|
||||
|
||||
# save_dir = Path(save_dir)
|
||||
# logger = Logger(str(save_dir), 0, False)
|
||||
|
||||
# if model_str in CellArchitectures:
|
||||
# arch = CellArchitectures[model_str]
|
||||
# logger.log(
|
||||
# 'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
|
||||
# else:
|
||||
# try:
|
||||
# arch = CellStructure.str2structure(model_str)
|
||||
# except:
|
||||
# raise ValueError(
|
||||
# 'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
|
||||
|
||||
# assert arch.check_valid_op(get_search_spaces(
|
||||
# 'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
|
||||
# # assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
|
||||
# logger.log('Start train-evaluate {:}'.format(arch.tostr()))
|
||||
# logger.log('arch_config : {:}'.format(arch_config))
|
||||
|
||||
# start_time, seed_time = time.time(), AverageMeter()
|
||||
# for _is, seed in enumerate(seeds):
|
||||
# logger.log(
|
||||
# '\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
|
||||
# seed))
|
||||
# to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
|
||||
# if to_save_name.exists():
|
||||
# logger.log(
|
||||
# 'Find the existing file {:}, directly load!'.format(to_save_name))
|
||||
# checkpoint = torch.load(to_save_name)
|
||||
# else:
|
||||
# logger.log(
|
||||
# 'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
|
||||
# checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
|
||||
# seed, arch_config, workers, logger)
|
||||
# torch.save(checkpoint, to_save_name)
|
||||
# # log information
|
||||
# logger.log('{:}'.format(checkpoint['info']))
|
||||
# all_dataset_keys = checkpoint['all_dataset_keys']
|
||||
# for dataset_key in all_dataset_keys:
|
||||
# logger.log('\n{:} dataset : {:} {:}'.format(
|
||||
# '-' * 15, dataset_key, '-' * 15))
|
||||
# dataset_info = checkpoint[dataset_key]
|
||||
# # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
|
||||
# logger.log('Flops = {:} MB, Params = {:} MB'.format(
|
||||
# dataset_info['flop'], dataset_info['param']))
|
||||
# logger.log('config : {:}'.format(dataset_info['config']))
|
||||
# logger.log('Training State (finish) = {:}'.format(
|
||||
# dataset_info['finish-train']))
|
||||
# last_epoch = dataset_info['total_epoch'] - 1
|
||||
# train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
|
||||
# valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
|
||||
# # measure elapsed time
|
||||
# seed_time.update(time.time() - start_time)
|
||||
# start_time = time.time()
|
||||
# need_time = 'Time Left: {:}'.format(convert_secs2time(
|
||||
# seed_time.avg * (len(seeds) - _is - 1), True))
|
||||
# logger.log(
|
||||
# '\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
|
||||
# need_time))
|
||||
# logger.close()
|
||||
# ###################
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_single_model()
|
Reference in New Issue
Block a user