diff --git a/exps/NAS-Bench-201/main.py b/exps/NAS-Bench-201/main.py index 61be861..7eb3c1c 100644 --- a/exps/NAS-Bench-201/main.py +++ b/exps/NAS-Bench-201/main.py @@ -20,7 +20,92 @@ from functions import evaluate_for_seed from torchvision import datasets, transforms -def evaluate_all_datasets( +# NASBENCH201_CONFIG_PATH = os.path.join( os.getcwd(), 'main_exp', 'transfer_nag') + +NASBENCH201_CONFIG_PATH = '/lustre/hpe/ws11/ws11.1/ws/xmuhanma-nbdit/autodl-projects/configs/nas-benchmark' + + +def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, + arch_config, workers, logger): + machine_info, arch_config = get_machine_info(), deepcopy(arch_config) + all_infos = {'info': machine_info} + all_dataset_keys = [] + # look all the datasets + for dataset, xpath, split in zip(datasets, xpaths, splits): + # train valid data + task = None + train_data, valid_data, xshape, class_num = get_datasets( + dataset, xpath, -1, task) + + # load the configuration + if dataset in ['mnist', 'svhn', 'aircraft', 'oxford']: + if use_less: + # config_path = os.path.join( + # NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/LESS.config') + config_path = os.path.join( + NASBENCH201_CONFIG_PATH, 'LESS.config') + else: + # config_path = os.path.join( + # NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}.config'.format(dataset)) + config_path = os.path.join( + NASBENCH201_CONFIG_PATH, '{}.config'.format(dataset)) + + + p = os.path.join( + NASBENCH201_CONFIG_PATH, '{:}-split.txt'.format(dataset)) + if not os.path.exists(p): + import json + label_list = list(range(len(train_data))) + random.shuffle(label_list) + strlist = [str(label_list[i]) for i in range(len(label_list))] + splited = {'train': ["int", strlist[:len(train_data) // 2]], + 'valid': ["int", strlist[len(train_data) // 2:]]} + with open(p, 'w') as f: + f.write(json.dumps(splited)) + split_info = load_config(os.path.join( + NASBENCH201_CONFIG_PATH, '{:}-split.txt'.format(dataset)), None, None) + else: + raise ValueError('invalid dataset : {:}'.format(dataset)) + + config = load_config( + config_path, {'class_num': class_num, 'xshape': xshape}, logger) + # data loader + train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, + shuffle=True, num_workers=workers, pin_memory=True) + valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, + shuffle=False, num_workers=workers, pin_memory=True) + splits = load_config(os.path.join( + NASBENCH201_CONFIG_PATH, '{}-test-split.txt'.format(dataset)), None, None) + ValLoaders = {'ori-test': valid_loader, + 'x-valid': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler( + splits.xvalid), + num_workers=workers, pin_memory=True), + 'x-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler( + splits.xtest), + num_workers=workers, pin_memory=True) + } + dataset_key = '{:}'.format(dataset) + if bool(split): + dataset_key = dataset_key + '-valid' + logger.log( + 'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'. + format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) + logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format( + dataset_key, config)) + for key, value in ValLoaders.items(): + logger.log( + 'Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) + + results = evaluate_for_seed( + arch_config, config, arch, train_loader, ValLoaders, seed, logger) + all_infos[dataset_key] = results + all_dataset_keys.append(dataset_key) + all_infos['all_dataset_keys'] = all_dataset_keys + return all_infos + +def evaluate_all_datasets1( arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger ): machine_info, arch_config = get_machine_info(), deepcopy(arch_config) @@ -55,7 +140,14 @@ def evaluate_all_datasets( split_info = load_config( "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None ) - + elif dataset.startswith("oxford"): + if use_less: + config_path = "configs/nas-benchmark/LESS.config" + else: + config_path = "configs/nas-benchmark/oxford.config" + split_info = load_config( + "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None + ) else: raise ValueError("invalid dataset : {:}".format(dataset)) config = load_config( @@ -126,6 +218,31 @@ def evaluate_all_datasets( sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True) + elif dataset == "oxford": + ValLoaders = { + "ori-test": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True + ) + } + # train_data_v2 = deepcopy(train_data) + # train_data_v2.transform = valid_data.transform + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), + num_workers=workers, + pin_memory=True) + valid_loader = torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), + num_workers=workers, + pin_memory=True) + else: # data loader train_loader = torch.utils.data.DataLoader( @@ -142,7 +259,7 @@ def evaluate_all_datasets( num_workers=workers, pin_memory=True, ) - if dataset == "cifar10" or dataset == "aircraft": + if dataset == "cifar10" or dataset == "aircraft" or dataset == "oxford": ValLoaders = {"ori-test": valid_loader} elif dataset == "cifar100": cifar100_splits = load_config( diff --git a/scripts-search/NAS-Bench-201/train-models.sh b/scripts-search/NAS-Bench-201/train-models.sh index 1accb37..0a26dbf 100644 --- a/scripts-search/NAS-Bench-201/train-models.sh +++ b/scripts-search/NAS-Bench-201/train-models.sh @@ -46,7 +46,7 @@ OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \ --mode ${mode} --save_dir ${save_dir} --max_node 4 \ --use_less ${use_less} \ --datasets aircraft \ - --xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/ \ + --xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/ \ --channel 16 \ --splits 1 \ --num_cells 5 \ @@ -54,4 +54,15 @@ OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \ --srange ${xstart} ${xend} --arch_index ${arch_index} \ --seeds ${all_seeds} +# OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \ +# --mode ${mode} --save_dir ${save_dir} --max_node 4 \ +# --use_less ${use_less} \ +# --datasets oxford\ +# --xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/ \ +# --channel 16 \ +# --splits 1 \ +# --num_cells 5 \ +# --workers 4 \ +# --srange ${xstart} ${xend} --arch_index ${arch_index} \ +# --seeds ${all_seeds} diff --git a/xautodl/datasets/get_dataset_with_transform.py b/xautodl/datasets/get_dataset_with_transform.py index 491c96c..31f578e 100644 --- a/xautodl/datasets/get_dataset_with_transform.py +++ b/xautodl/datasets/get_dataset_with_transform.py @@ -1,42 +1,39 @@ ################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +# Modified by Hayeon Lee, Eunyoung Hyung 2021. 03. ################################################## -import os, sys, torch +import os +import sys +import torch import os.path as osp import numpy as np import torchvision.datasets as dset import torchvision.transforms as transforms from copy import deepcopy -from PIL import Image - -from xautodl.config_utils import load_config - -from .DownsampledImageNet import ImageNet16 from .SearchDatasetWrap import SearchDataset +# from PIL import Image +import random +import pdb +from .aircraft import FGVCAircraft +from .pets import PetDataset +from config_utils import load_config -Dataset2Class = { - "cifar10": 10, - "cifar100": 100, - "imagenet-1k-s": 1000, - "imagenet-1k": 1000, - "ImageNet16": 1000, - "ImageNet16-150": 150, - "ImageNet16-120": 120, - "ImageNet16-200": 200, - "aircraft": 100, - "oxford": 102 -} +Dataset2Class = {'cifar10': 10, + 'cifar100': 100, + 'mnist': 10, + 'svhn': 10, + 'aircraft': 30, + 'oxford': 37} class CUTOUT(object): + def __init__(self, length): self.length = length def __repr__(self): - return "{name}(length={length})".format( - name=self.__class__.__name__, **self.__dict__ - ) + return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) def __call__(self, img): h, w = img.size(1), img.size(2) @@ -49,7 +46,7 @@ class CUTOUT(object): x1 = np.clip(x - self.length // 2, 0, w) x2 = np.clip(x + self.length // 2, 0, w) - mask[y1:y2, x1:x2] = 0.0 + mask[y1: y2, x1: x2] = 0. mask = torch.from_numpy(mask) mask = mask.expand_as(img) img *= mask @@ -57,21 +54,19 @@ class CUTOUT(object): imagenet_pca = { - "eigval": np.asarray([0.2175, 0.0188, 0.0045]), - "eigvec": np.asarray( - [ - [-0.5675, 0.7192, 0.4009], - [-0.5808, -0.0045, -0.8140], - [-0.5836, -0.6948, 0.4203], - ] - ), + 'eigval': np.asarray([0.2175, 0.0188, 0.0045]), + 'eigvec': np.asarray([ + [-0.5675, 0.7192, 0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, 0.4203], + ]) } class Lighting(object): - def __init__( - self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"] - ): + def __init__(self, alphastd, + eigval=imagenet_pca['eigval'], + eigvec=imagenet_pca['eigvec']): self.alphastd = alphastd assert eigval.shape == (3,) assert eigvec.shape == (3, 3) @@ -79,10 +74,10 @@ class Lighting(object): self.eigvec = eigvec def __call__(self, img): - if self.alphastd == 0.0: + if self.alphastd == 0.: return img rnd = np.random.randn(3) * self.alphastd - rnd = rnd.astype("float32") + rnd = rnd.astype('float32') v = rnd old_dtype = np.asarray(img).dtype v = v * self.eigval @@ -91,292 +86,222 @@ class Lighting(object): img = np.add(img, inc) if old_dtype == np.uint8: img = np.clip(img, 0, 255) - img = Image.fromarray(img.astype(old_dtype), "RGB") + img = Image.fromarray(img.astype(old_dtype), 'RGB') return img def __repr__(self): - return self.__class__.__name__ + "()" + return self.__class__.__name__ + '()' -def get_datasets(name, root, cutout): - - if name == "cifar10": +def get_datasets(name, root, cutout, use_num_cls=None): + if name == 'cifar10': mean = [x / 255 for x in [125.3, 123.0, 113.9]] std = [x / 255 for x in [63.0, 62.1, 66.7]] - elif name == "cifar100": + elif name == 'cifar100': mean = [x / 255 for x in [129.3, 124.1, 112.4]] std = [x / 255 for x in [68.2, 65.4, 70.4]] - elif name.startswith("imagenet-1k"): - mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] - elif name.startswith("ImageNet16"): - mean = [x / 255 for x in [122.68, 116.66, 104.01]] - std = [x / 255 for x in [63.22, 61.26, 65.09]] - elif name == 'aircraft': - mean = [0.4785, 0.5100, 0.5338] - std = [0.1845, 0.1830, 0.2060] - elif name == 'oxford': - mean = [0.4811, 0.4492, 0.3957] - std = [0.2260, 0.2231, 0.2249] + elif name.startswith('mnist'): + mean, std = [0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081] + elif name.startswith('svhn'): + mean, std = [0.4376821, 0.4437697, 0.47280442], [ 0.19803012, 0.20101562, 0.19703614] + elif name.startswith('aircraft'): + mean = [0.48933587508932375, 0.5183537408957618, 0.5387914411673883] + std = [0.22388883112804625, 0.21641635409388751, 0.24615605842636115] + elif name.startswith('oxford'): + mean = [0.4828895122298728, 0.4448394893850807, 0.39566558230789783] + std = [0.25925664613996574, 0.2532760018681693, 0.25981017205097917] else: raise TypeError("Unknow dataset : {:}".format(name)) # Data Argumentation - if name == "cifar10" or name == "cifar100": - lists = [ - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32, padding=4), - transforms.ToTensor(), - transforms.Normalize(mean, std), - ] + if name == 'cifar10' or name == 'cifar100': + lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), + transforms.Normalize(mean, std)] if cutout > 0: lists += [CUTOUT(cutout)] train_transform = transforms.Compose(lists) test_transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize(mean, std)] - ) + [transforms.ToTensor(), transforms.Normalize(mean, std)]) xshape = (1, 3, 32, 32) - elif name.startswith("aircraft") or name.startswith("oxford"): - lists = [transforms.RandomCrop(16, padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)] - if cutout > 0: - lists += [CUTOUT(cutout)] - train_transform = transforms.Compose(lists) - test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)]) - xshape = (1, 3, 16, 16) - elif name.startswith("ImageNet16"): - lists = [ - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(16, padding=2), + elif name.startswith('cub200'): + train_transform = transforms.Compose([ + transforms.Resize((32, 32)), transforms.ToTensor(), - transforms.Normalize(mean, std), - ] - if cutout > 0: - lists += [CUTOUT(cutout)] - train_transform = transforms.Compose(lists) - test_transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize(mean, std)] - ) - xshape = (1, 3, 16, 16) - elif name == "tiered": - lists = [ - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(80, padding=4), + transforms.Normalize(mean=mean, std=std) + ]) + test_transform = transforms.Compose([ + transforms.Resize((32, 32)), transforms.ToTensor(), - transforms.Normalize(mean, std), - ] - if cutout > 0: - lists += [CUTOUT(cutout)] - train_transform = transforms.Compose(lists) - test_transform = transforms.Compose( - [ - transforms.CenterCrop(80), - transforms.ToTensor(), - transforms.Normalize(mean, std), - ] - ) + transforms.Normalize(mean=mean, std=std) + ]) + xshape = (1, 3, 32, 32) + elif name.startswith('mnist'): + train_transform = transforms.Compose([ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Lambda(lambda x: x.repeat(3, 1, 1)), + transforms.Normalize(mean, std), + ]) + test_transform = transforms.Compose([ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Lambda(lambda x: x.repeat(3, 1, 1)), + transforms.Normalize(mean, std) + ]) + xshape = (1, 3, 32, 32) + elif name.startswith('svhn'): + train_transform = transforms.Compose([ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + test_transform = transforms.Compose([ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + xshape = (1, 3, 32, 32) + elif name.startswith('aircraft'): + train_transform = transforms.Compose([ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + test_transform = transforms.Compose([ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + xshape = (1, 3, 32, 32) + elif name.startswith('oxford'): + train_transform = transforms.Compose([ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + test_transform = transforms.Compose([ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) xshape = (1, 3, 32, 32) - elif name.startswith("imagenet-1k"): - normalize = transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) - if name == "imagenet-1k": - xlists = [transforms.RandomResizedCrop(224)] - xlists.append( - transforms.ColorJitter( - brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2 - ) - ) - xlists.append(Lighting(0.1)) - elif name == "imagenet-1k-s": - xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))] - else: - raise ValueError("invalid name : {:}".format(name)) - xlists.append(transforms.RandomHorizontalFlip(p=0.5)) - xlists.append(transforms.ToTensor()) - xlists.append(normalize) - train_transform = transforms.Compose(xlists) - test_transform = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ] - ) - xshape = (1, 3, 224, 224) else: raise TypeError("Unknow dataset : {:}".format(name)) - if name == "cifar10": + if name == 'cifar10': train_data = dset.CIFAR10( - root, train=True, transform=train_transform, download=True - ) + root, train=True, transform=train_transform, download=True) test_data = dset.CIFAR10( - root, train=False, transform=test_transform, download=True - ) + root, train=False, transform=test_transform, download=True) assert len(train_data) == 50000 and len(test_data) == 10000 - elif name == "cifar100": + elif name == 'cifar100': train_data = dset.CIFAR100( - root, train=True, transform=train_transform, download=True - ) + root, train=True, transform=train_transform, download=True) test_data = dset.CIFAR100( - root, train=False, transform=test_transform, download=True - ) + root, train=False, transform=test_transform, download=True) assert len(train_data) == 50000 and len(test_data) == 10000 - elif name == "aircraft": - train_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=train_transform) - test_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=test_transform) - - elif name.startswith("imagenet-1k"): - train_data = dset.ImageFolder(osp.join(root, "train"), train_transform) - test_data = dset.ImageFolder(osp.join(root, "val"), test_transform) - assert ( - len(train_data) == 1281167 and len(test_data) == 50000 - ), "invalid number of images : {:} & {:} vs {:} & {:}".format( - len(train_data), len(test_data), 1281167, 50000 - ) - elif name == "ImageNet16": - train_data = ImageNet16(root, True, train_transform) - test_data = ImageNet16(root, False, test_transform) - assert len(train_data) == 1281167 and len(test_data) == 50000 - elif name == "ImageNet16-120": - train_data = ImageNet16(root, True, train_transform, 120) - test_data = ImageNet16(root, False, test_transform, 120) - assert len(train_data) == 151700 and len(test_data) == 6000 - elif name == "ImageNet16-150": - train_data = ImageNet16(root, True, train_transform, 150) - test_data = ImageNet16(root, False, test_transform, 150) - assert len(train_data) == 190272 and len(test_data) == 7500 - elif name == "ImageNet16-200": - train_data = ImageNet16(root, True, train_transform, 200) - test_data = ImageNet16(root, False, test_transform, 200) - assert len(train_data) == 254775 and len(test_data) == 10000 + elif name == 'mnist': + train_data = dset.MNIST( + root, train=True, transform=train_transform, download=True) + test_data = dset.MNIST( + root, train=False, transform=test_transform, download=True) + assert len(train_data) == 60000 and len(test_data) == 10000 + elif name == 'svhn': + train_data = dset.SVHN(root, split='train', + transform=train_transform, download=True) + test_data = dset.SVHN(root, split='test', + transform=test_transform, download=True) + assert len(train_data) == 73257 and len(test_data) == 26032 + elif name == 'aircraft': + train_data = FGVCAircraft(root, class_type='manufacturer', split='trainval', + transform=train_transform, download=False) + test_data = FGVCAircraft(root, class_type='manufacturer', split='test', + transform=test_transform, download=False) + assert len(train_data) == 6667 and len(test_data) == 3333 + elif name == 'oxford': + train_data = PetDataset(root, train=True, num_cl=37, + val_split=0.15, transforms=train_transform) + test_data = PetDataset(root, train=False, num_cl=37, + val_split=0.15, transforms=test_transform) else: raise TypeError("Unknow dataset : {:}".format(name)) - class_num = Dataset2Class[name] + class_num = Dataset2Class[name] if use_num_cls is None else len( + use_num_cls) return train_data, test_data, xshape, class_num -def get_nas_search_loaders( - train_data, valid_data, dataset, config_root, batch_size, workers -): +def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, num_cls=None): if isinstance(batch_size, (list, tuple)): batch, test_batch = batch_size else: batch, test_batch = batch_size, batch_size - if dataset == "cifar10": + if dataset == 'cifar10': # split_Fpath = 'configs/nas-benchmark/cifar-split.txt' - cifar_split = load_config("{:}/cifar-split.txt".format(config_root), None, None) - train_split, valid_split = ( - cifar_split.train, - cifar_split.valid, - ) # search over the proposed training and validation set + cifar_split = load_config( + '{:}/cifar-split.txt'.format(config_root), None, None) + # search over the proposed training and validation set + train_split, valid_split = cifar_split.train, cifar_split.valid # logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set # To split data xvalid_data = deepcopy(train_data) - if hasattr(xvalid_data, "transforms"): # to avoid a print issue + if hasattr(xvalid_data, 'transforms'): # to avoid a print issue xvalid_data.transforms = valid_data.transform xvalid_data.transform = deepcopy(valid_data.transform) - search_data = SearchDataset(dataset, train_data, train_split, valid_split) + search_data = SearchDataset( + dataset, train_data, train_split, valid_split) # data loader - search_loader = torch.utils.data.DataLoader( - search_data, - batch_size=batch, - shuffle=True, - num_workers=workers, - pin_memory=True, - ) - train_loader = torch.utils.data.DataLoader( - train_data, - batch_size=batch, - sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), - num_workers=workers, - pin_memory=True, - ) - valid_loader = torch.utils.data.DataLoader( - xvalid_data, - batch_size=test_batch, - sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), - num_workers=workers, - pin_memory=True, - ) - elif dataset == "cifar100": + search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers, + pin_memory=True) + train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, + sampler=torch.utils.data.sampler.SubsetRandomSampler( + train_split), + num_workers=workers, pin_memory=True) + valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, + sampler=torch.utils.data.sampler.SubsetRandomSampler( + valid_split), + num_workers=workers, pin_memory=True) + elif dataset == 'cifar100': cifar100_test_split = load_config( - "{:}/cifar100-test-split.txt".format(config_root), None, None - ) + '{:}/cifar100-test-split.txt'.format(config_root), None, None) search_train_data = train_data search_valid_data = deepcopy(valid_data) search_valid_data.transform = train_data.transform - search_data = SearchDataset( - dataset, - [search_train_data, search_valid_data], - list(range(len(search_train_data))), - cifar100_test_split.xvalid, - ) - search_loader = torch.utils.data.DataLoader( - search_data, - batch_size=batch, - shuffle=True, - num_workers=workers, - pin_memory=True, - ) - train_loader = torch.utils.data.DataLoader( - train_data, - batch_size=batch, - shuffle=True, - num_workers=workers, - pin_memory=True, - ) - valid_loader = torch.utils.data.DataLoader( - valid_data, - batch_size=test_batch, - sampler=torch.utils.data.sampler.SubsetRandomSampler( - cifar100_test_split.xvalid - ), - num_workers=workers, - pin_memory=True, - ) - elif dataset == "ImageNet16-120": - imagenet_test_split = load_config( - "{:}/imagenet-16-120-test-split.txt".format(config_root), None, None - ) + search_data = SearchDataset(dataset, [search_train_data, search_valid_data], + list(range(len(search_train_data))), + cifar100_test_split.xvalid) + search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers, + pin_memory=True) + train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, num_workers=workers, + pin_memory=True) + valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch, + sampler=torch.utils.data.sampler.SubsetRandomSampler( + cifar100_test_split.xvalid), num_workers=workers, pin_memory=True) + elif dataset in ['mnist', 'svhn', 'aircraft', 'oxford']: + if not os.path.exists('{:}/{}-test-split.txt'.format(config_root, dataset)): + import json + label_list = list(range(len(valid_data))) + random.shuffle(label_list) + strlist = [str(label_list[i]) for i in range(len(label_list))] + split = {'xvalid': ["int", strlist[:len(valid_data) // 2]], + 'xtest': ["int", strlist[len(valid_data) // 2:]]} + with open('{:}/{}-test-split.txt'.format(config_root, dataset), 'w') as f: + f.write(json.dumps(split)) + test_split = load_config( + '{:}/{}-test-split.txt'.format(config_root, dataset), None, None) + search_train_data = train_data search_valid_data = deepcopy(valid_data) search_valid_data.transform = train_data.transform - search_data = SearchDataset( - dataset, - [search_train_data, search_valid_data], - list(range(len(search_train_data))), - imagenet_test_split.xvalid, - ) - search_loader = torch.utils.data.DataLoader( - search_data, - batch_size=batch, - shuffle=True, - num_workers=workers, - pin_memory=True, - ) - train_loader = torch.utils.data.DataLoader( - train_data, - batch_size=batch, - shuffle=True, - num_workers=workers, - pin_memory=True, - ) - valid_loader = torch.utils.data.DataLoader( - valid_data, - batch_size=test_batch, - sampler=torch.utils.data.sampler.SubsetRandomSampler( - imagenet_test_split.xvalid - ), - num_workers=workers, - pin_memory=True, - ) + search_data = SearchDataset(dataset, [search_train_data, search_valid_data], + list(range(len(search_train_data))), test_split.xvalid) + search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, + num_workers=workers, pin_memory=True) + train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, + num_workers=workers, pin_memory=True) + valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch, + sampler=torch.utils.data.sampler.SubsetRandomSampler( + test_split.xvalid), num_workers=workers, pin_memory=True) else: - raise ValueError("invalid dataset : {:}".format(dataset)) + raise ValueError('invalid dataset : {:}'.format(dataset)) return search_loader, train_loader, valid_loader - - -# if __name__ == '__main__': -# train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1) -# import pdb; pdb.set_trace()