simplify baselines
This commit is contained in:
@@ -99,24 +99,31 @@ def main(xargs, nas_bench):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
|
||||
if xargs.data_path is not None:
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
|
||||
else:
|
||||
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
||||
config = load_config(config_path, None, logger)
|
||||
extra_info = {'config': config, 'train_loader': None, 'valid_loader': None}
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
policy = Policy(xargs.max_nodes, search_space)
|
||||
|
Reference in New Issue
Block a user