Fix small bugs

This commit is contained in:
Xuanyi Dong
2019-04-08 11:04:08 +08:00
parent 3b1d8f1e4b
commit 36bb07ef1a
4 changed files with 13 additions and 6 deletions

View File

@@ -60,14 +60,14 @@ def get_datasets(name, root, cutout):
else: raise TypeError("Unknow dataset : {:}".format(name))
if name == 'cifar10':
train_data = dset.CIFAR10(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR10(root, train=False, transform=test_transform , download=True)
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
elif name == 'cifar100':
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
elif name == 'imagenet-1k' or name == 'imagenet-100':
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
else: raise TypeError("Unknow dataset : {:}".format(name))
class_num = Dataset2Class[name]