update affines for NAS

This commit is contained in:
D-X-Y
2019-12-02 18:03:40 +11:00
parent 487fec21bf
commit d175a361bd
9 changed files with 78 additions and 41 deletions

View File

@@ -15,10 +15,10 @@ from procedures import get_machine_info
from datasets import get_datasets
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
from models import CellStructure, CellArchitectures, get_search_spaces
from AA_functions import evaluate_for_seed
from AA_functions_v2 import evaluate_for_seed
def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, workers, logger):
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger):
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
all_infos = {'info': machine_info}
all_dataset_keys = []
@@ -28,10 +28,12 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configurature
if dataset == 'cifar10' or dataset == 'cifar100':
config_path = 'configs/nas-benchmark/CIFAR.config'
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
else : config_path = 'configs/nas-benchmark/CIFAR.config'
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
elif dataset.startswith('ImageNet16'):
config_path = 'configs/nas-benchmark/ImageNet-16.config'
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
else : config_path = 'configs/nas-benchmark/ImageNet-16.config'
split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None)
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
@@ -41,6 +43,8 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor
logger)
# check whether use splited validation set
if bool(split):
assert dataset == 'cifar10'
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid))
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
@@ -48,23 +52,42 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True)
ValLoaders['x-valid'] = valid_loader
else:
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
if dataset == 'cifar10':
ValLoaders = {'ori-test': valid_loader}
elif dataset == 'cifar100':
cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
ValLoaders = {'ori-test': valid_loader,
'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True),
'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True)
}
elif dataset == 'ImageNet16-120':
imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
ValLoaders = {'ori-test': valid_loader,
'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True),
'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True)
}
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
dataset_key = '{:}'.format(dataset)
if bool(split): dataset_key = dataset_key + '-valid'
logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config))
results = evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, seed, logger)
for key, value in ValLoaders.items():
logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger)
all_infos[dataset_key] = results
all_dataset_keys.append( dataset_key )
all_infos['all_dataset_keys'] = all_dataset_keys
return all_infos
def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds, cover_mode, meta_info, arch_config):
def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
#torch.backends.cudnn.benchmark = True
@@ -73,7 +96,10 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds,
assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange)
sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
if use_less:
sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
else:
sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
logger = Logger(str(sub_dir), 0, False)
all_archs = meta_info['archs']
@@ -114,7 +140,7 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds,
has_continue = True
continue
results = evaluate_all_datasets(CellStructure.str2structure(arch), \
datasets, xpaths, splits, seed, \
datasets, xpaths, splits, use_less, seed, \
arch_config, workers, logger)
torch.save(results, to_save_name)
logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name))
@@ -130,7 +156,7 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds,
logger.close()
def train_single_model(save_dir, workers, datasets, xpaths, splits, seeds, model_str, arch_config):
def train_single_model(save_dir, workers, datasets, xpaths, use_less, splits, seeds, model_str, arch_config):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
@@ -160,7 +186,7 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, seeds, model
checkpoint = torch.load(to_save_name)
else:
logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, workers, logger)
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger)
torch.save(checkpoint, to_save_name)
# log information
logger.log('{:}'.format(checkpoint['info']))
@@ -252,6 +278,7 @@ if __name__ == '__main__':
parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.')
parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.')
parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.')
parser.add_argument('--use_less', type=int, default=0, help='Using the less-training-epoch config.')
parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
@@ -264,7 +291,7 @@ if __name__ == '__main__':
elif args.mode.startswith('specific'):
assert len(args.mode.split('-')) == 2, 'invalid mode : {:}'.format(args.mode)
model_str = args.mode.split('-')[1]
train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, \
train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \
tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells})
else:
meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node)
@@ -276,7 +303,7 @@ if __name__ == '__main__':
assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))
assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers)
main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, \
main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \
tuple(args.srange), args.arch_index, tuple(args.seeds), \
args.mode == 'cover', meta_info, \
{'channel': args.channel, 'num_cells': args.num_cells})

View File

@@ -47,6 +47,7 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode):
elif mode == 'valid': network.eval()
else: raise ValueError("The mode is not right : {:}".format(mode))
batch_time, end = AverageMeter(), time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
@@ -64,7 +65,10 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode):
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
return losses.avg, top1.avg, top5.avg
# count time
batch_time.update(time.time() - end)
end = time.time()
return losses.avg, top1.avg, top5.avg, batch_time.sum
@@ -87,18 +91,21 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see
# start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
train_times , valid_times = {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
train_loss, train_acc1, train_acc5 = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
with torch.no_grad():
valid_loss, valid_acc1, valid_acc5 = procedure(valid_loader, network, criterion, None, None, 'valid')
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(valid_loader, network, criterion, None, None, 'valid')
train_losses[epoch] = train_loss
train_acc1es[epoch] = train_acc1
train_acc5es[epoch] = train_acc5
valid_losses[epoch] = valid_loss
valid_acc1es[epoch] = valid_acc1
valid_acc5es[epoch] = valid_acc5
train_times [epoch] = train_tm
valid_times [epoch] = valid_tm
# measure elapsed time
epoch_time.update(time.time() - start_time)
@@ -114,9 +121,11 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see
'train_losses': train_losses,
'train_acc1es': train_acc1es,
'train_acc5es': train_acc5es,
'train_times' : train_times,
'valid_losses': valid_losses,
'valid_acc1es': valid_acc1es,
'valid_acc5es': valid_acc5es,
'valid_times' : valid_times,
'net_state_dict': net.state_dict(),
'net_string' : '{:}'.format(net),
'finish-train': True