NAS-sharing-parameters support 3 datasets / update ops / update pypi
This commit is contained in:
@@ -12,7 +12,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@@ -107,35 +107,7 @@ def main(xargs):
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
#config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
if xargs.dataset == 'cifar10':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set
|
||||
logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
elif xargs.dataset == 'cifar100':
|
||||
cifar100_test_split = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
|
||||
elif xargs.dataset == 'ImageNet16-120':
|
||||
imagenet_test_split = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
|
@@ -12,7 +12,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@@ -169,28 +169,8 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
imagenet16_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
#config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
|
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@@ -184,29 +184,14 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
imagenet16_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
logger.log('use config from : {:}'.format(xargs.config_path))
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
logger.log('config: {:}'.format(config))
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = test_data.transform
|
||||
valid_data = train_data_v2
|
||||
_, train_loader, valid_loader = get_nas_search_loaders(train_data, test_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
|
||||
# since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader
|
||||
valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform)
|
||||
if hasattr(valid_loader.dataset, 'transforms'):
|
||||
valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms)
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
|
@@ -12,7 +12,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@@ -80,25 +80,10 @@ def main(xargs):
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, _, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
imagenet16_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
#config_path = 'configs/nas-benchmark/algos/GDAS.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
@@ -143,7 +128,7 @@ def main(xargs):
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
|
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@@ -117,32 +117,9 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
#elif xargs.dataset.startswith('ImageNet16'):
|
||||
# # all_indexes = list(range(len(train_data))) ; random.seed(111) ; random.shuffle(all_indexes)
|
||||
# # train_split, valid_split = sorted(all_indexes[: len(train_data)//2]), sorted(all_indexes[len(train_data)//2 :])
|
||||
# # imagenet16_split = dict2config({'train': train_split, 'valid': valid_split}, None)
|
||||
# # _ = configure2str(imagenet16_split, 'temp.txt')
|
||||
# split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
# imagenet16_split = load_config(split_Fpath, None, None)
|
||||
# train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
# logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
logger.log('config : {:}'.format(config))
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
|
||||
(config.batch_size, config.test_batch_size), xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
|
@@ -12,7 +12,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@@ -135,29 +135,9 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
imagenet16_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
#config_path = 'configs/nas-benchmark/algos/SETN.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
|
||||
(config.batch_size, config.test_batch_size), xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
@@ -202,7 +182,8 @@ def main(xargs):
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num)
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: init_genotype}
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
|
Reference in New Issue
Block a user