v2
This commit is contained in:
@@ -3,3 +3,4 @@
|
||||
##################################################
|
||||
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
|
||||
from .SearchDatasetWrap import SearchDataset
|
||||
from .data import get_data
|
||||
|
69
datasets/data.py
Normal file
69
datasets/data.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from datasets import get_datasets
|
||||
from config_utils import load_config
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
class AddGaussianNoise(object):
|
||||
def __init__(self, mean=0., std=0.001):
|
||||
self.std = std
|
||||
self.mean = mean
|
||||
|
||||
def __call__(self, tensor):
|
||||
return tensor + torch.randn(tensor.size()) * self.std + self.mean
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
||||
|
||||
|
||||
|
||||
|
||||
class RepeatSampler(torch.utils.data.sampler.Sampler):
|
||||
def __init__(self, samp, repeat):
|
||||
self.samp = samp
|
||||
self.repeat = repeat
|
||||
def __iter__(self):
|
||||
for i in self.samp:
|
||||
for j in range(self.repeat):
|
||||
yield i
|
||||
def __len__(self):
|
||||
return self.repeat*len(self.samp)
|
||||
|
||||
|
||||
def get_data(dataset, data_loc, trainval, batch_size, augtype, repeat, args, pin_memory=True):
|
||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, data_loc, cutout=0)
|
||||
if augtype == 'gaussnoise':
|
||||
train_data.transform.transforms = train_data.transform.transforms[2:]
|
||||
train_data.transform.transforms.append(AddGaussianNoise(std=args.sigma))
|
||||
elif augtype == 'cutout':
|
||||
train_data.transform.transforms = train_data.transform.transforms[2:]
|
||||
train_data.transform.transforms.append(torchvision.transforms.RandomErasing(p=0.9, scale=(0.02, 0.04)))
|
||||
elif augtype == 'none':
|
||||
train_data.transform.transforms = train_data.transform.transforms[2:]
|
||||
|
||||
if dataset == 'cifar10':
|
||||
acc_type = 'ori-test'
|
||||
val_acc_type = 'x-valid'
|
||||
|
||||
else:
|
||||
acc_type = 'x-test'
|
||||
val_acc_type = 'x-valid'
|
||||
|
||||
if trainval and 'cifar10' in dataset:
|
||||
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
if repeat > 0:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
|
||||
num_workers=0, pin_memory=pin_memory, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(train_split), repeat))
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
|
||||
num_workers=0, pin_memory=pin_memory, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
|
||||
|
||||
|
||||
else:
|
||||
if repeat > 0:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, #shuffle=True,
|
||||
num_workers=0, pin_memory=pin_memory, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(range(len(train_data))), repeat))
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
|
||||
num_workers=0, pin_memory=pin_memory)
|
||||
return train_loader
|
@@ -16,7 +16,9 @@ from config_utils import load_config
|
||||
|
||||
Dataset2Class = {'cifar10' : 10,
|
||||
'cifar100': 100,
|
||||
'fake':10,
|
||||
'imagenet-1k-s':1000,
|
||||
'imagenette2' : 10,
|
||||
'imagenet-1k' : 1000,
|
||||
'ImageNet16' : 1000,
|
||||
'ImageNet16-150': 150,
|
||||
@@ -98,8 +100,13 @@ def get_datasets(name, root, cutout):
|
||||
elif name == 'cifar100':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
elif name == 'fake':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
elif name.startswith('imagenet-1k'):
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
elif name.startswith('imagenette'):
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
elif name.startswith('ImageNet16'):
|
||||
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
||||
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
|
||||
@@ -113,6 +120,12 @@ def get_datasets(name, root, cutout):
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name == 'fake':
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('ImageNet16'):
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||
@@ -125,6 +138,15 @@ def get_datasets(name, root, cutout):
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('imagenette'):
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
xlists = []
|
||||
xlists.append( transforms.ToTensor() )
|
||||
xlists.append( normalize )
|
||||
#train_transform = transforms.Compose(xlists)
|
||||
train_transform = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
|
||||
test_transform = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
|
||||
xshape = (1, 3, 224, 224)
|
||||
elif name.startswith('imagenet-1k'):
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
if name == 'imagenet-1k':
|
||||
@@ -156,6 +178,12 @@ def get_datasets(name, root, cutout):
|
||||
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name == 'fake':
|
||||
train_data = dset.FakeData(size=50000, image_size=(3, 32, 32), transform=train_transform)
|
||||
test_data = dset.FakeData(size=10000, image_size=(3, 32, 32), transform=test_transform)
|
||||
elif name.startswith('imagenette2'):
|
||||
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
||||
elif name.startswith('imagenet-1k'):
|
||||
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
||||
|
Reference in New Issue
Block a user