Fix the potential memory leak in NAS-Bench-201 clear_param

This commit is contained in:
D-X-Y
2020-03-21 01:33:07 -07:00
parent b702ddf5a2
commit 22025887f1
9 changed files with 40 additions and 38 deletions

View File

@@ -114,15 +114,27 @@ class NASBench201API(object):
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path, map_location='cpu')
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
if index in self.arch2infos_less: del self.arch2infos_less[index]
if index in self.arch2infos_full: del self.arch2infos_full[index]
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
def clear_params(self, index: int, use_12epochs_result: bool):
"""Remove the architecture's weights to save memory."""
if use_12epochs_result: arch2infos = self.arch2infos_less
else : arch2infos = self.arch2infos_full
archresult = arch2infos[index]
archresult.clear_params()
def clear_params(self, index: int, use_12epochs_result: Union[bool, None]):
"""Remove the architecture's weights to save memory.
:arg
index: the index of the target architecture
use_12epochs_result: a flag to controll how to clear the parameters.
-- None: clear all the weights in both `less` and `full`, which indicates the training hyper-parameters.
-- True: clear all the weights in arch2infos_less, which by default is 12-epoch-training result.
-- False: clear all the weights in arch2infos_full, which by default is 200-epoch-training result.
"""
if use_12epochs_result is None:
self.arch2infos_less[index].clear_params()
self.arch2infos_full[index].clear_params()
else:
if use_12epochs_result: arch2infos = self.arch2infos_less
else : arch2infos = self.arch2infos_full
arch2infos[index].clear_params()
# This function is used to query the information of a specific archiitecture
# 'arch' can be an architecture index or an architecture string
@@ -193,7 +205,6 @@ class NASBench201API(object):
best_index, highest_accuracy = idx, accuracy
return best_index, highest_accuracy
def arch(self, index: int):
"""Return the topology structure of the `index`-th architecture."""
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
@@ -213,7 +224,6 @@ class NASBench201API(object):
else: arch2infos = self.arch2infos_full
arch_result = arch2infos[index]
return arch_result.get_net_param(dataset, seed)
def get_net_config(self, index: int, dataset: Text):
"""
@@ -235,7 +245,6 @@ class NASBench201API(object):
#print ('SEED [{:}] : {:}'.format(seed, result))
raise ValueError('Impossible to reach here!')
def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]:
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
if use_12epochs_result: arch2infos = self.arch2infos_less
@@ -243,7 +252,6 @@ class NASBench201API(object):
arch_result = arch2infos[index]
return arch_result.get_compute_costs(dataset)
def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float:
"""
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
@@ -254,7 +262,6 @@ class NASBench201API(object):
cost_dict = self.get_cost_info(index, dataset, use_12epochs_result)
return cost_dict['latency']
# obtain the metric for the `index`-th architecture
# `dataset` indicates the dataset:
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
@@ -388,7 +395,6 @@ class NASBench201API(object):
return xifo
"""
def show(self, index: int = -1) -> None:
"""
This function will print the information of a specific (or all) architecture(s).
@@ -423,7 +429,6 @@ class NASBench201API(object):
else:
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]:
"""
This function will count the number of total trials.
@@ -443,7 +448,6 @@ class NASBench201API(object):
nums[len(dataset_seed[dataset])] += 1
return dict(nums)
@staticmethod
def str2lists(arch_str: Text) -> List[tuple]:
"""
@@ -471,7 +475,6 @@ class NASBench201API(object):
genotypes.append( input_infos )
return genotypes
@staticmethod
def str2matrix(arch_str: Text,
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
@@ -511,7 +514,6 @@ class NASBench201API(object):
return matrix
class ArchResults(object):
def __init__(self, arch_index, arch_str):
@@ -752,7 +754,6 @@ class ArchResults(object):
def __repr__(self):
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
"""
@@ -872,8 +873,8 @@ class ResultsCount(object):
'cur_time': xtime,
'all_time': atime}
# get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).
def get_eval(self, name, iepoch=None):
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
@@ -890,8 +891,8 @@ class ResultsCount(object):
if clone: return copy.deepcopy(self.net_state_dict)
else: return self.net_state_dict
# This function is used to obtain the config dict for this architecture.
def get_config(self, str2structure):
"""This function is used to obtain the config dict for this architecture."""
if str2structure is None:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
'N' : self.arch_config['num_cells'],