Fix the potential memory leak in NAS-Bench-201 clear_param
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
import os, sys, time, torch
|
||||
import time, torch
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from config_utils import dict2config
|
||||
@@ -9,11 +9,9 @@ from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
|
||||
__all__ = ['evaluate_for_seed', 'pure_evaluate']
|
||||
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
|
@@ -28,7 +28,7 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c
|
||||
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||
# train valid data
|
||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||||
# load the configurature
|
||||
# load the configuration
|
||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
|
||||
else : config_path = 'configs/nas-benchmark/CIFAR.config'
|
||||
|
@@ -3,7 +3,7 @@
|
||||
################################################################################################
|
||||
# python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth #
|
||||
################################################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import sys, argparse
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
|
@@ -6,7 +6,7 @@
|
||||
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
|
||||
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
|
||||
###############################################################################################
|
||||
import os, gc, sys, time, glob, random, argparse
|
||||
import os, gc, sys, argparse, psutil
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
@@ -33,7 +33,7 @@ def tostr(accdict, norms):
|
||||
|
||||
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
print('\nEvaluate dataset={:}'.format(data))
|
||||
norms = []
|
||||
norms, process = [], psutil.Process(os.getpid())
|
||||
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||
for idx in range(len(api)):
|
||||
@@ -56,16 +56,17 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
with torch.no_grad():
|
||||
net.load_state_dict(param)
|
||||
_, summary = weight_watcher.analyze(net, alphas=False)
|
||||
cur_norms.append( summary['lognorm'] )
|
||||
cur_norms.append(summary['lognorm'])
|
||||
norms.append( float(np.mean(cur_norms)) )
|
||||
api.clear_params(idx, use_12epochs_result)
|
||||
api.clear_params(idx, None)
|
||||
if idx % 200 == 199 or idx + 1 == len(api):
|
||||
head = '{:05d}/{:05d}'.format(idx, len(api))
|
||||
stem_val = tostr(final_val_accs, norms)
|
||||
stem_test = tostr(final_test_accs, norms)
|
||||
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200))
|
||||
print(' -->> {:} || {:}'.format(stem_val, stem_test))
|
||||
torch.cuda.empty_cache() ; gc.collect()
|
||||
print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6))
|
||||
print(' [Valid] -->> {:}'.format(stem_val))
|
||||
print(' [Test.] -->> {:}'.format(stem_test))
|
||||
gc.collect()
|
||||
|
||||
|
||||
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result):
|
||||
|
@@ -3,7 +3,7 @@
|
||||
#####################################################
|
||||
# python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
||||
#####################################################
|
||||
import os, sys, time, argparse, collections
|
||||
import sys, argparse
|
||||
from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
|
@@ -24,11 +24,11 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
|
||||
machine_info = get_machine_info()
|
||||
all_infos = {'info': machine_info}
|
||||
all_dataset_keys = []
|
||||
# look all the datasets
|
||||
# look all the dataset
|
||||
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||
# train valid data
|
||||
# the train and valid data
|
||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||||
# load the configurature
|
||||
# load the configuration
|
||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
||||
elif dataset.startswith('ImageNet16'):
|
||||
@@ -36,7 +36,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
|
||||
# check whether use splited validation set
|
||||
# check whether use the splitted validation set
|
||||
if bool(split):
|
||||
assert dataset == 'cifar10'
|
||||
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
|
||||
@@ -92,7 +92,7 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
|
||||
|
||||
log_dir = save_dir / 'logs'
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger = Logger(str(log_dir), 0, False)
|
||||
logger = Logger(str(log_dir), os.getpid(), False)
|
||||
|
||||
logger.log('xargs : seeds = {:}'.format(seeds))
|
||||
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
|
||||
|
Reference in New Issue
Block a user