Fix the potential memory leak in NAS-Bench-201 clear_param
This commit is contained in:
@@ -114,15 +114,27 @@ class NASBench201API(object):
|
||||
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
||||
xdata = torch.load(xfile_path, map_location='cpu')
|
||||
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
||||
if index in self.arch2infos_less: del self.arch2infos_less[index]
|
||||
if index in self.arch2infos_full: del self.arch2infos_full[index]
|
||||
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
|
||||
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
|
||||
|
||||
def clear_params(self, index: int, use_12epochs_result: bool):
|
||||
"""Remove the architecture's weights to save memory."""
|
||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||
else : arch2infos = self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
archresult.clear_params()
|
||||
def clear_params(self, index: int, use_12epochs_result: Union[bool, None]):
|
||||
"""Remove the architecture's weights to save memory.
|
||||
:arg
|
||||
index: the index of the target architecture
|
||||
use_12epochs_result: a flag to controll how to clear the parameters.
|
||||
-- None: clear all the weights in both `less` and `full`, which indicates the training hyper-parameters.
|
||||
-- True: clear all the weights in arch2infos_less, which by default is 12-epoch-training result.
|
||||
-- False: clear all the weights in arch2infos_full, which by default is 200-epoch-training result.
|
||||
"""
|
||||
if use_12epochs_result is None:
|
||||
self.arch2infos_less[index].clear_params()
|
||||
self.arch2infos_full[index].clear_params()
|
||||
else:
|
||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||
else : arch2infos = self.arch2infos_full
|
||||
arch2infos[index].clear_params()
|
||||
|
||||
# This function is used to query the information of a specific archiitecture
|
||||
# 'arch' can be an architecture index or an architecture string
|
||||
@@ -193,7 +205,6 @@ class NASBench201API(object):
|
||||
best_index, highest_accuracy = idx, accuracy
|
||||
return best_index, highest_accuracy
|
||||
|
||||
|
||||
def arch(self, index: int):
|
||||
"""Return the topology structure of the `index`-th architecture."""
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
@@ -213,7 +224,6 @@ class NASBench201API(object):
|
||||
else: arch2infos = self.arch2infos_full
|
||||
arch_result = arch2infos[index]
|
||||
return arch_result.get_net_param(dataset, seed)
|
||||
|
||||
|
||||
def get_net_config(self, index: int, dataset: Text):
|
||||
"""
|
||||
@@ -235,7 +245,6 @@ class NASBench201API(object):
|
||||
#print ('SEED [{:}] : {:}'.format(seed, result))
|
||||
raise ValueError('Impossible to reach here!')
|
||||
|
||||
|
||||
def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]:
|
||||
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||
@@ -243,7 +252,6 @@ class NASBench201API(object):
|
||||
arch_result = arch2infos[index]
|
||||
return arch_result.get_compute_costs(dataset)
|
||||
|
||||
|
||||
def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float:
|
||||
"""
|
||||
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
|
||||
@@ -254,7 +262,6 @@ class NASBench201API(object):
|
||||
cost_dict = self.get_cost_info(index, dataset, use_12epochs_result)
|
||||
return cost_dict['latency']
|
||||
|
||||
|
||||
# obtain the metric for the `index`-th architecture
|
||||
# `dataset` indicates the dataset:
|
||||
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||
@@ -388,7 +395,6 @@ class NASBench201API(object):
|
||||
return xifo
|
||||
"""
|
||||
|
||||
|
||||
def show(self, index: int = -1) -> None:
|
||||
"""
|
||||
This function will print the information of a specific (or all) architecture(s).
|
||||
@@ -423,7 +429,6 @@ class NASBench201API(object):
|
||||
else:
|
||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||
|
||||
|
||||
def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]:
|
||||
"""
|
||||
This function will count the number of total trials.
|
||||
@@ -443,7 +448,6 @@ class NASBench201API(object):
|
||||
nums[len(dataset_seed[dataset])] += 1
|
||||
return dict(nums)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def str2lists(arch_str: Text) -> List[tuple]:
|
||||
"""
|
||||
@@ -471,7 +475,6 @@ class NASBench201API(object):
|
||||
genotypes.append( input_infos )
|
||||
return genotypes
|
||||
|
||||
|
||||
@staticmethod
|
||||
def str2matrix(arch_str: Text,
|
||||
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
|
||||
@@ -511,7 +514,6 @@ class NASBench201API(object):
|
||||
return matrix
|
||||
|
||||
|
||||
|
||||
class ArchResults(object):
|
||||
|
||||
def __init__(self, arch_index, arch_str):
|
||||
@@ -752,7 +754,6 @@ class ArchResults(object):
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
|
||||
|
||||
|
||||
|
||||
"""
|
||||
@@ -872,8 +873,8 @@ class ResultsCount(object):
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
|
||||
# get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).
|
||||
def get_eval(self, name, iepoch=None):
|
||||
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
||||
@@ -890,8 +891,8 @@ class ResultsCount(object):
|
||||
if clone: return copy.deepcopy(self.net_state_dict)
|
||||
else: return self.net_state_dict
|
||||
|
||||
# This function is used to obtain the config dict for this architecture.
|
||||
def get_config(self, str2structure):
|
||||
"""This function is used to obtain the config dict for this architecture."""
|
||||
if str2structure is None:
|
||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
||||
'N' : self.arch_config['num_cells'],
|
||||
|
Reference in New Issue
Block a user