|
|
|
@@ -105,28 +105,27 @@ def main(xargs):
|
|
|
|
|
logger = prepare_logger(args)
|
|
|
|
|
|
|
|
|
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
|
|
|
|
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
|
|
|
|
#config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
|
|
|
|
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
|
|
|
|
if xargs.dataset == 'cifar10':
|
|
|
|
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
|
|
|
|
cifar_split = load_config(split_Fpath, None, None)
|
|
|
|
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
|
|
|
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
|
|
|
|
train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set
|
|
|
|
|
logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
|
|
|
|
# To split data
|
|
|
|
|
train_data_v2 = deepcopy(train_data)
|
|
|
|
|
train_data_v2.transform = valid_data.transform
|
|
|
|
|
valid_data = train_data_v2
|
|
|
|
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
|
|
|
|
# data loader
|
|
|
|
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
|
|
|
|
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
|
|
|
|
elif xargs.dataset == 'cifar100':
|
|
|
|
|
raise ValueError('not support yet : {:}'.format(xargs.dataset))
|
|
|
|
|
elif xargs.dataset.startswith('ImageNet16'):
|
|
|
|
|
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
|
|
|
|
imagenet16_split = load_config(split_Fpath, None, None)
|
|
|
|
|
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
|
|
|
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
|
|
|
|
raise ValueError('not support yet : {:}'.format(xargs.dataset))
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
|
|
|
|
config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
|
|
|
|
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
|
|
|
|
# To split data
|
|
|
|
|
train_data_v2 = deepcopy(train_data)
|
|
|
|
|
train_data_v2.transform = valid_data.transform
|
|
|
|
|
valid_data = train_data_v2
|
|
|
|
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
|
|
|
|
# data loader
|
|
|
|
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
|
|
|
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
|
|
|
|
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
|
|
|
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
|
|
|
|
|
|
|
|
@@ -231,6 +230,7 @@ if __name__ == '__main__':
|
|
|
|
|
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
|
|
|
|
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
|
|
|
|
# channels and number-of-cells
|
|
|
|
|
parser.add_argument('--config_path', type=str, help='The config paths.')
|
|
|
|
|
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
|
|
|
|
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
|
|
|
|
parser.add_argument('--channel', type=int, help='The number of channels.')
|
|
|
|
|