Fix the potential memory leak in NAS-Bench-201 clear_param
This commit is contained in:
@@ -24,11 +24,11 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
|
||||
machine_info = get_machine_info()
|
||||
all_infos = {'info': machine_info}
|
||||
all_dataset_keys = []
|
||||
# look all the datasets
|
||||
# look all the dataset
|
||||
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||
# train valid data
|
||||
# the train and valid data
|
||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||||
# load the configurature
|
||||
# load the configuration
|
||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
||||
elif dataset.startswith('ImageNet16'):
|
||||
@@ -36,7 +36,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
|
||||
# check whether use splited validation set
|
||||
# check whether use the splitted validation set
|
||||
if bool(split):
|
||||
assert dataset == 'cifar10'
|
||||
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
|
||||
@@ -92,7 +92,7 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
|
||||
|
||||
log_dir = save_dir / 'logs'
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger = Logger(str(log_dir), 0, False)
|
||||
logger = Logger(str(log_dir), os.getpid(), False)
|
||||
|
||||
logger.log('xargs : seeds = {:}'.format(seeds))
|
||||
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
|
||||
|
Reference in New Issue
Block a user