Fix the potential memory leak in NAS-Bench-201 clear_param
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
|
||||
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
|
||||
###############################################################################################
|
||||
import os, gc, sys, time, glob, random, argparse
|
||||
import os, gc, sys, argparse, psutil
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
@@ -33,7 +33,7 @@ def tostr(accdict, norms):
|
||||
|
||||
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
print('\nEvaluate dataset={:}'.format(data))
|
||||
norms = []
|
||||
norms, process = [], psutil.Process(os.getpid())
|
||||
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||
for idx in range(len(api)):
|
||||
@@ -56,16 +56,17 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
with torch.no_grad():
|
||||
net.load_state_dict(param)
|
||||
_, summary = weight_watcher.analyze(net, alphas=False)
|
||||
cur_norms.append( summary['lognorm'] )
|
||||
cur_norms.append(summary['lognorm'])
|
||||
norms.append( float(np.mean(cur_norms)) )
|
||||
api.clear_params(idx, use_12epochs_result)
|
||||
api.clear_params(idx, None)
|
||||
if idx % 200 == 199 or idx + 1 == len(api):
|
||||
head = '{:05d}/{:05d}'.format(idx, len(api))
|
||||
stem_val = tostr(final_val_accs, norms)
|
||||
stem_test = tostr(final_test_accs, norms)
|
||||
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200))
|
||||
print(' -->> {:} || {:}'.format(stem_val, stem_test))
|
||||
torch.cuda.empty_cache() ; gc.collect()
|
||||
print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6))
|
||||
print(' [Valid] -->> {:}'.format(stem_val))
|
||||
print(' [Test.] -->> {:}'.format(stem_test))
|
||||
gc.collect()
|
||||
|
||||
|
||||
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result):
|
||||
|
Reference in New Issue
Block a user