NAS-sharing-parameters support 3 datasets / update ops / update pypi

This commit is contained in:
D-X-Y
2020-01-11 00:19:58 +11:00
parent 96152a9904
commit c66afa4df8
17 changed files with 192 additions and 153 deletions

View File

@@ -10,7 +10,7 @@ from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
@@ -184,29 +184,14 @@ def main(xargs):
logger = prepare_logger(args)
train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
elif xargs.dataset.startswith('ImageNet16'):
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
imagenet16_split = load_config(split_Fpath, None, None)
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
else:
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
logger.log('use config from : {:}'.format(xargs.config_path))
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
logger.log('config: {:}'.format(config))
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = test_data.transform
valid_data = train_data_v2
_, train_loader, valid_loader = get_nas_search_loaders(train_data, test_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
# since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader
valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform)
if hasattr(valid_loader.dataset, 'transforms'):
valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))