Fix the potential memory leak in NAS-Bench-201 clear_param

This commit is contained in:
D-X-Y
2020-03-21 01:33:07 -07:00
parent b702ddf5a2
commit 22025887f1
9 changed files with 40 additions and 38 deletions

View File

@@ -1,7 +1,7 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
import os, sys, time, torch
import time, torch
from procedures import prepare_seed, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from config_utils import dict2config
@@ -9,11 +9,9 @@ from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net
__all__ = ['evaluate_for_seed', 'pure_evaluate']
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()

View File

@@ -28,7 +28,7 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c
for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configurature
# load the configuration
if dataset == 'cifar10' or dataset == 'cifar100':
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
else : config_path = 'configs/nas-benchmark/CIFAR.config'

View File

@@ -3,7 +3,7 @@
################################################################################################
# python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth #
################################################################################################
import os, sys, time, glob, random, argparse
import sys, argparse
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))

View File

@@ -6,7 +6,7 @@
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
###############################################################################################
import os, gc, sys, time, glob, random, argparse
import os, gc, sys, argparse, psutil
import numpy as np
import torch
from pathlib import Path
@@ -33,7 +33,7 @@ def tostr(accdict, norms):
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
print('\nEvaluate dataset={:}'.format(data))
norms = []
norms, process = [], psutil.Process(os.getpid())
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
for idx in range(len(api)):
@@ -56,16 +56,17 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
with torch.no_grad():
net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append( summary['lognorm'] )
cur_norms.append(summary['lognorm'])
norms.append( float(np.mean(cur_norms)) )
api.clear_params(idx, use_12epochs_result)
api.clear_params(idx, None)
if idx % 200 == 199 or idx + 1 == len(api):
head = '{:05d}/{:05d}'.format(idx, len(api))
stem_val = tostr(final_val_accs, norms)
stem_test = tostr(final_test_accs, norms)
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200))
print(' -->> {:} || {:}'.format(stem_val, stem_test))
torch.cuda.empty_cache() ; gc.collect()
print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6))
print(' [Valid] -->> {:}'.format(stem_val))
print(' [Test.] -->> {:}'.format(stem_test))
gc.collect()
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result):

View File

@@ -3,7 +3,7 @@
#####################################################
# python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
#####################################################
import os, sys, time, argparse, collections
import sys, argparse
from tqdm import tqdm
from collections import OrderedDict
import numpy as np

View File

@@ -24,11 +24,11 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
machine_info = get_machine_info()
all_infos = {'info': machine_info}
all_dataset_keys = []
# look all the datasets
# look all the dataset
for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data
# the train and valid data
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configurature
# load the configuration
if dataset == 'cifar10' or dataset == 'cifar100':
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
elif dataset.startswith('ImageNet16'):
@@ -36,7 +36,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
# check whether use splited validation set
# check whether use the splitted validation set
if bool(split):
assert dataset == 'cifar10'
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
@@ -92,7 +92,7 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
log_dir = save_dir / 'logs'
log_dir.mkdir(parents=True, exist_ok=True)
logger = Logger(str(log_dir), 0, False)
logger = Logger(str(log_dir), os.getpid(), False)
logger.log('xargs : seeds = {:}'.format(seeds))
logger.log('xargs : cover_mode = {:}'.format(cover_mode))