This commit is contained in:
Jack Turner
2021-02-26 16:12:51 +00:00
parent c895924c99
commit b74255e1f3
74 changed files with 11326 additions and 537 deletions

View File

@@ -16,7 +16,9 @@ from config_utils import load_config
Dataset2Class = {'cifar10' : 10,
'cifar100': 100,
'fake':10,
'imagenet-1k-s':1000,
'imagenette2' : 10,
'imagenet-1k' : 1000,
'ImageNet16' : 1000,
'ImageNet16-150': 150,
@@ -98,8 +100,13 @@ def get_datasets(name, root, cutout):
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name == 'fake':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name.startswith('imagenet-1k'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name.startswith('imagenette'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name.startswith('ImageNet16'):
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
@@ -113,6 +120,12 @@ def get_datasets(name, root, cutout):
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name == 'fake':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name.startswith('ImageNet16'):
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [CUTOUT(cutout)]
@@ -125,6 +138,15 @@ def get_datasets(name, root, cutout):
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name.startswith('imagenette'):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
xlists = []
xlists.append( transforms.ToTensor() )
xlists.append( normalize )
#train_transform = transforms.Compose(xlists)
train_transform = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
test_transform = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
xshape = (1, 3, 224, 224)
elif name.startswith('imagenet-1k'):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if name == 'imagenet-1k':
@@ -156,6 +178,12 @@ def get_datasets(name, root, cutout):
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'fake':
train_data = dset.FakeData(size=50000, image_size=(3, 32, 32), transform=train_transform)
test_data = dset.FakeData(size=10000, image_size=(3, 32, 32), transform=test_transform)
elif name.startswith('imagenette2'):
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
elif name.startswith('imagenet-1k'):
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)