update scripts-cluster

This commit is contained in:
Xuanyi Dong
2019-03-31 22:49:43 +08:00
parent 280c9f3099
commit 4bac459bf9
20 changed files with 118 additions and 1248 deletions

View File

@@ -7,6 +7,7 @@ import torchvision.transforms as transforms
from utils import Cutout
from .TieredImageNet import TieredImageNet
Dataset2Class = {'cifar10' : 10,
'cifar100': 100,
'tiered' : -1,
@@ -59,11 +60,11 @@ def get_datasets(name, root, cutout):
else: raise TypeError("Unknow dataset : {:}".format(name))
if name == 'cifar10':
train_data = dset.CIFAR10(root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR10(root, train=True, transform=test_transform , download=True)
train_data = dset.CIFAR10(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR10(root, train=False, transform=test_transform , download=True)
elif name == 'cifar100':
train_data = dset.CIFAR100(root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=True, transform=test_transform , download=True)
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
elif name == 'imagenet-1k' or name == 'imagenet-100':
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), train_transform)