NAS-sharing-parameters support 3 datasets / update ops / update pypi

This commit is contained in:
D-X-Y
2020-01-11 00:19:58 +11:00
parent 96152a9904
commit c66afa4df8
17 changed files with 192 additions and 153 deletions

View File

@@ -10,7 +10,7 @@ from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
@@ -117,32 +117,9 @@ def main(xargs):
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
#elif xargs.dataset.startswith('ImageNet16'):
# # all_indexes = list(range(len(train_data))) ; random.seed(111) ; random.shuffle(all_indexes)
# # train_split, valid_split = sorted(all_indexes[: len(train_data)//2]), sorted(all_indexes[len(train_data)//2 :])
# # imagenet16_split = dict2config({'train': train_split, 'valid': valid_split}, None)
# # _ = configure2str(imagenet16_split, 'temp.txt')
# split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
# imagenet16_split = load_config(split_Fpath, None, None)
# train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
# logger.log('Load split file from {:}'.format(split_Fpath))
else:
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
logger.log('config : {:}'.format(config))
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
(config.batch_size, config.test_batch_size), xargs.workers)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))