Fix the potential memory leak in NAS-Bench-201 clear_param

This commit is contained in:
D-X-Y
2020-03-21 01:33:07 -07:00
parent b702ddf5a2
commit 22025887f1
9 changed files with 40 additions and 38 deletions

View File

@@ -6,7 +6,7 @@
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
###############################################################################################
import os, gc, sys, time, glob, random, argparse
import os, gc, sys, argparse, psutil
import numpy as np
import torch
from pathlib import Path
@@ -33,7 +33,7 @@ def tostr(accdict, norms):
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
print('\nEvaluate dataset={:}'.format(data))
norms = []
norms, process = [], psutil.Process(os.getpid())
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
for idx in range(len(api)):
@@ -56,16 +56,17 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
with torch.no_grad():
net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append( summary['lognorm'] )
cur_norms.append(summary['lognorm'])
norms.append( float(np.mean(cur_norms)) )
api.clear_params(idx, use_12epochs_result)
api.clear_params(idx, None)
if idx % 200 == 199 or idx + 1 == len(api):
head = '{:05d}/{:05d}'.format(idx, len(api))
stem_val = tostr(final_val_accs, norms)
stem_test = tostr(final_test_accs, norms)
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200))
print(' -->> {:} || {:}'.format(stem_val, stem_test))
torch.cuda.empty_cache() ; gc.collect()
print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6))
print(' [Valid] -->> {:}'.format(stem_val))
print(' [Test.] -->> {:}'.format(stem_test))
gc.collect()
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result):