Fix the potential memory leak in NAS-Bench-201 clear_param

This commit is contained in:
D-X-Y
2020-03-21 01:33:07 -07:00
parent b702ddf5a2
commit 22025887f1
9 changed files with 40 additions and 38 deletions

View File

@@ -24,11 +24,11 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
machine_info = get_machine_info()
all_infos = {'info': machine_info}
all_dataset_keys = []
# look all the datasets
# look all the dataset
for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data
# the train and valid data
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configurature
# load the configuration
if dataset == 'cifar10' or dataset == 'cifar100':
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
elif dataset.startswith('ImageNet16'):
@@ -36,7 +36,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
# check whether use splited validation set
# check whether use the splitted validation set
if bool(split):
assert dataset == 'cifar10'
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
@@ -92,7 +92,7 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
log_dir = save_dir / 'logs'
log_dir.mkdir(parents=True, exist_ok=True)
logger = Logger(str(log_dir), 0, False)
logger = Logger(str(log_dir), os.getpid(), False)
logger.log('xargs : seeds = {:}'.format(seeds))
logger.log('xargs : cover_mode = {:}'.format(cover_mode))