diff --git a/src/datasets/utilities.py b/src/datasets/utilities.py index 4803ffc..b553727 100644 --- a/src/datasets/utilities.py +++ b/src/datasets/utilities.py @@ -13,7 +13,8 @@ Dataset2Class = {'cifar10': 10, 'ImageNet16' : 1000, 'ImageNet16-120': 120, 'ImageNet16-150': 150, - 'ImageNet16-200': 200} + 'ImageNet16-200': 200, + 'aircraft': 100} class RandChannel(object): # randomly pick channels from input @@ -46,6 +47,10 @@ def get_datasets(name, root, input_size, cutout=-1): elif name.startswith('ImageNet16'): mean = [0.481098, 0.45749, 0.407882] std = [0.247922, 0.240235, 0.255255] + elif name == 'aircraft': + mean = [0.4785, 0.5100, 0.5338] + std = [0.1845, 0.1830, 0.2060] + else: raise TypeError("Unknow dataset : {:}".format(name)) @@ -55,6 +60,12 @@ def get_datasets(name, root, input_size, cutout=-1): if cutout > 0 : lists += [CUTOUT(cutout)] train_transform = transforms.Compose(lists) test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) + elif name == 'aircraft': + lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)] + if cutout > 0 : lists += [CUTOUT(cutout)] + train_transform = transforms.Compose(lists) + test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) + elif name.startswith('ImageNet16'): lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])] if cutout > 0 : lists += [CUTOUT(cutout)] @@ -86,9 +97,12 @@ def get_datasets(name, root, input_size, cutout=-1): train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True) assert len(train_data) == 50000 and len(test_data) == 10000 + elif name == 'aircraft': + train_data = dset.ImageFolder(osp.join(root, 'train_sorted_images'), train_transform) + test_data = dset.ImageFolder(osp.join(root, 'test_sorted_images'), test_transform) elif name.startswith('imagenet-1k'): train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) - test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform) + test_data = dset.ImageFolder(osp.join(root, 'test'), test_transform) elif name == 'ImageNet16': root = osp.join(root, 'ImageNet16') train_data = ImageNet16(root, True , train_transform)