Upgrade API of NAS-Bench-201
This commit is contained in:
@@ -1,14 +1,17 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
import time, torch
|
||||
import os, time, copy, torch, pathlib
|
||||
|
||||
import datasets
|
||||
from config_utils import load_config
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
__all__ = ['evaluate_for_seed', 'pure_evaluate']
|
||||
__all__ = ['evaluate_for_seed', 'pure_evaluate', 'get_nas_bench_loaders']
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
@@ -127,3 +130,72 @@ def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed
|
||||
'finish-train': True
|
||||
}
|
||||
return info_seed
|
||||
|
||||
|
||||
def get_nas_bench_loaders(workers):
|
||||
|
||||
torch.set_num_threads(workers)
|
||||
|
||||
root_dir = (pathlib.Path(__file__).parent / '..' / '..').resolve()
|
||||
torch_dir = pathlib.Path(os.environ['TORCH_HOME'])
|
||||
# cifar
|
||||
cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config'
|
||||
cifar_config = load_config(cifar_config_path, None, None)
|
||||
get_datasets = datasets.get_datasets # a function to return the dataset
|
||||
break_line = '-' * 150
|
||||
print ('{:} Create data-loader for all datasets'.format(time_string()))
|
||||
print (break_line)
|
||||
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1)
|
||||
print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num))
|
||||
cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None)
|
||||
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14]
|
||||
temp_dataset = copy.deepcopy(TRAIN_CIFAR10)
|
||||
temp_dataset.transform = VALID_CIFAR10.transform
|
||||
# data loader
|
||||
trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
train_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True)
|
||||
valid_cifar10_loader = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True)
|
||||
test__cifar10_loader = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
|
||||
print ('CIFAR-10 : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size))
|
||||
print (break_line)
|
||||
# CIFAR-100
|
||||
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1)
|
||||
print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num))
|
||||
cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None)
|
||||
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24]
|
||||
train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True)
|
||||
test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True)
|
||||
print ('CIFAR-100 : train-loader has {:3d} batch'.format(len(train_cifar100_loader)))
|
||||
print ('CIFAR-100 : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader)))
|
||||
print ('CIFAR-100 : test--loader has {:3d} batch'.format(len(test__cifar100_loader)))
|
||||
print (break_line)
|
||||
|
||||
imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config'
|
||||
imagenet16_config = load_config(imagenet16_config_path, None, None)
|
||||
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1)
|
||||
print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num))
|
||||
imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None)
|
||||
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20]
|
||||
train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True)
|
||||
test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True)
|
||||
print ('ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size))
|
||||
print ('ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size))
|
||||
print ('ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size))
|
||||
|
||||
# 'cifar10', 'cifar100', 'ImageNet16-120'
|
||||
loaders = {'cifar10@trainval': trainval_cifar10_loader,
|
||||
'cifar10@train' : train_cifar10_loader,
|
||||
'cifar10@valid' : valid_cifar10_loader,
|
||||
'cifar10@test' : test__cifar10_loader,
|
||||
'cifar100@train' : train_cifar100_loader,
|
||||
'cifar100@valid' : valid_cifar100_loader,
|
||||
'cifar100@test' : test__cifar100_loader,
|
||||
'ImageNet16-120@train': train_imagenet_loader,
|
||||
'ImageNet16-120@valid': valid_imagenet_loader,
|
||||
'ImageNet16-120@test' : test__imagenet_loader}
|
||||
return loaders
|
Reference in New Issue
Block a user