update codes

This commit is contained in:
D-X-Y
2019-02-01 04:03:35 +11:00
parent 3f9b54d99e
commit 65d9c1c57f
11 changed files with 103 additions and 55 deletions

View File

@@ -10,7 +10,7 @@ from .TieredImageNet import TieredImageNet
Dataset2Class = {'cifar10' : 10,
'cifar100': 100,
'tiered' : -1,
'imagnet-1k' : 1000,
'imagenet-1k' : 1000,
'imagenet-100': 100}
@@ -25,8 +25,8 @@ def get_datasets(name, root, cutout):
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name == 'tiered':
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name == 'imagnet-1k' or name == 'imagenet-100':
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
elif name == 'imagenet-1k' or name == 'imagenet-100':
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
else: raise TypeError("Unknow dataset : {:}".format(name))
@@ -42,7 +42,7 @@ def get_datasets(name, root, cutout):
if cutout > 0 : lists += [Cutout(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name == 'imagnet-1k' or name == 'imagenet-100':
elif name == 'imagenet-1k' or name == 'imagenet-100':
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
@@ -57,15 +57,14 @@ def get_datasets(name, root, cutout):
])
test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
else: raise TypeError("Unknow dataset : {:}".format(name))
train_data = TieredImageNet(root, 'train-val', train_transform)
test_data = None
if name == 'cifar10':
train_data = dset.CIFAR10(root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR10(root, train=True, transform=test_transform , download=True)
elif name == 'cifar100':
train_data = dset.CIFAR100(root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=True, transform=test_transform , download=True)
elif name == 'imagnet-1k' or name == 'imagenet-100':
elif name == 'imagenet-1k' or name == 'imagenet-100':
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), train_transform)
else: raise TypeError("Unknow dataset : {:}".format(name))