upload
This commit is contained in:
19
zero-cost-nas/foresight/pruners/__init__.py
Normal file
19
zero-cost-nas/foresight/pruners/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
from os.path import dirname, basename, isfile, join
|
||||
import glob
|
||||
modules = glob.glob(join(dirname(__file__), "*.py"))
|
||||
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
|
66
zero-cost-nas/foresight/pruners/measures/__init__.py
Normal file
66
zero-cost-nas/foresight/pruners/measures/__init__.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
|
||||
available_measures = []
|
||||
_measure_impls = {}
|
||||
|
||||
|
||||
def measure(name, bn=True, copy_net=True, force_clean=True, **impl_args):
|
||||
def make_impl(func):
|
||||
def measure_impl(net_orig, device, *args, **kwargs):
|
||||
if copy_net:
|
||||
net = net_orig.get_prunable_copy(bn=bn).to(device)
|
||||
else:
|
||||
net = net_orig
|
||||
ret = func(net, *args, **kwargs, **impl_args)
|
||||
if copy_net and force_clean:
|
||||
import gc
|
||||
import torch
|
||||
del net
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
return ret
|
||||
|
||||
global _measure_impls
|
||||
if name in _measure_impls:
|
||||
raise KeyError(f'Duplicated measure! {name}')
|
||||
available_measures.append(name)
|
||||
_measure_impls[name] = measure_impl
|
||||
return func
|
||||
return make_impl
|
||||
|
||||
|
||||
def calc_measure(name, net, device, *args, **kwargs):
|
||||
return _measure_impls[name](net, device, *args, **kwargs)
|
||||
|
||||
|
||||
def load_all():
|
||||
from . import grad_norm
|
||||
from . import snip
|
||||
from . import grasp
|
||||
from . import fisher
|
||||
from . import jacob_cov
|
||||
from . import plain
|
||||
from . import synflow
|
||||
from . import var
|
||||
from . import cor
|
||||
from . import norm
|
||||
from . import meco
|
||||
from . import zico
|
||||
|
||||
|
||||
# TODO: should we do that by default?
|
||||
load_all()
|
53
zero-cost-nas/foresight/pruners/measures/cor.py
Normal file
53
zero-cost-nas/foresight/pruners/measures/cor.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
import time
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
|
||||
|
||||
def get_score(net, x, target, device, split_data):
|
||||
result_list = []
|
||||
def forward_hook(module, data_input, data_output):
|
||||
corr = np.mean(np.corrcoef(data_input[0].detach().cpu().numpy()))
|
||||
result_list.append(corr)
|
||||
net.classifier.register_forward_hook(forward_hook)
|
||||
|
||||
N = x.shape[0]
|
||||
for sp in range(split_data):
|
||||
st = sp * N // split_data
|
||||
en = (sp + 1) * N // split_data
|
||||
y = net(x[st:en])
|
||||
cor = result_list[0].item()
|
||||
result_list.clear()
|
||||
return cor
|
||||
|
||||
|
||||
|
||||
@measure('cor', bn=True)
|
||||
def compute_norm(net, inputs, targets, split_data=1, loss_fn=None):
|
||||
device = inputs.device
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
try:
|
||||
cor= get_score(net, inputs, targets, device, split_data=split_data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
cor= np.nan
|
||||
|
||||
return cor
|
107
zero-cost-nas/foresight/pruners/measures/fisher.py
Normal file
107
zero-cost-nas/foresight/pruners/measures/fisher.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import types
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array, reshape_elements
|
||||
|
||||
|
||||
def fisher_forward_conv2d(self, x):
|
||||
x = F.conv2d(x, self.weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
#intercept and store the activations after passing through 'hooked' identity op
|
||||
self.act = self.dummy(x)
|
||||
return self.act
|
||||
|
||||
def fisher_forward_linear(self, x):
|
||||
x = F.linear(x, self.weight, self.bias)
|
||||
self.act = self.dummy(x)
|
||||
return self.act
|
||||
|
||||
@measure('fisher', bn=True, mode='channel')
|
||||
def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1):
|
||||
|
||||
device = inputs.device
|
||||
|
||||
if mode == 'param':
|
||||
raise ValueError('Fisher pruning does not support parameter pruning.')
|
||||
|
||||
net.train()
|
||||
all_hooks = []
|
||||
for layer in net.modules():
|
||||
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
|
||||
#variables/op needed for fisher computation
|
||||
layer.fisher = None
|
||||
layer.act = 0.
|
||||
layer.dummy = nn.Identity()
|
||||
|
||||
#replace forward method of conv/linear
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
layer.forward = types.MethodType(fisher_forward_conv2d, layer)
|
||||
if isinstance(layer, nn.Linear):
|
||||
layer.forward = types.MethodType(fisher_forward_linear, layer)
|
||||
|
||||
#function to call during backward pass (hooked on identity op at output of layer)
|
||||
def hook_factory(layer):
|
||||
def hook(module, grad_input, grad_output):
|
||||
act = layer.act.detach()
|
||||
grad = grad_output[0].detach()
|
||||
if len(act.shape) > 2:
|
||||
g_nk = torch.sum((act * grad), list(range(2,len(act.shape))))
|
||||
else:
|
||||
g_nk = act * grad
|
||||
del_k = g_nk.pow(2).mean(0).mul(0.5)
|
||||
if layer.fisher is None:
|
||||
layer.fisher = del_k
|
||||
else:
|
||||
layer.fisher += del_k
|
||||
del layer.act #without deleting this, a nasty memory leak occurs! related: https://discuss.pytorch.org/t/memory-leak-when-using-forward-hook-and-backward-hook-simultaneously/27555
|
||||
return hook
|
||||
|
||||
#register backward hook on identity fcn to compute fisher info
|
||||
layer.dummy.register_backward_hook(hook_factory(layer))
|
||||
|
||||
N = inputs.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
net.zero_grad()
|
||||
outputs = net(inputs[st:en])
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
# retrieve fisher info
|
||||
def fisher(layer):
|
||||
if layer.fisher is not None:
|
||||
return torch.abs(layer.fisher.detach())
|
||||
else:
|
||||
return torch.zeros(layer.weight.shape[0]) #size=ch
|
||||
|
||||
grads_abs_ch = get_layer_metric_array(net, fisher, mode)
|
||||
|
||||
#broadcast channel value here to all parameters in that channel
|
||||
#to be compatible with stuff downstream (which expects per-parameter metrics)
|
||||
#TODO cleanup on the selectors/apply_prune_mask side (?)
|
||||
shapes = get_layer_metric_array(net, lambda l : l.weight.shape[1:], mode)
|
||||
|
||||
grads_abs = reshape_elements(grads_abs_ch, shapes, device)
|
||||
|
||||
return grads_abs
|
38
zero-cost-nas/foresight/pruners/measures/grad_norm.py
Normal file
38
zero-cost-nas/foresight/pruners/measures/grad_norm.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import copy
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
@measure('grad_norm', bn=True)
|
||||
def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=False):
|
||||
net.zero_grad()
|
||||
N = inputs.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
grad_norm_arr = get_layer_metric_array(net, lambda l: l.weight.grad.norm() if l.weight.grad is not None else torch.zeros_like(l.weight), mode='param')
|
||||
|
||||
return grad_norm_arr
|
87
zero-cost-nas/foresight/pruners/measures/grasp.py
Normal file
87
zero-cost-nas/foresight/pruners/measures/grasp.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.autograd as autograd
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
@measure('grasp', bn=True, mode='param')
|
||||
def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1):
|
||||
|
||||
# get all applicable weights
|
||||
weights = []
|
||||
for layer in net.modules():
|
||||
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
|
||||
weights.append(layer.weight)
|
||||
layer.weight.requires_grad_(True) # TODO isn't this already true?
|
||||
|
||||
# NOTE original code had some input/target splitting into 2
|
||||
# I am guessing this was because of GPU mem limit
|
||||
net.zero_grad()
|
||||
N = inputs.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
#forward/grad pass #1
|
||||
grad_w = None
|
||||
for _ in range(num_iters):
|
||||
#TODO get new data, otherwise num_iters is useless!
|
||||
outputs = net.forward(inputs[st:en])/T
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
grad_w_p = autograd.grad(loss, weights, allow_unused=True)
|
||||
if grad_w is None:
|
||||
grad_w = list(grad_w_p)
|
||||
else:
|
||||
for idx in range(len(grad_w)):
|
||||
grad_w[idx] += grad_w_p[idx]
|
||||
|
||||
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
# forward/grad pass #2
|
||||
outputs = net.forward(inputs[st:en])/T
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
grad_f = autograd.grad(loss, weights, create_graph=True, allow_unused=True)
|
||||
|
||||
# accumulate gradients computed in previous step and call backwards
|
||||
z, count = 0,0
|
||||
for layer in net.modules():
|
||||
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
|
||||
if grad_w[count] is not None:
|
||||
z += (grad_w[count].data * grad_f[count]).sum()
|
||||
count += 1
|
||||
z.backward()
|
||||
|
||||
# compute final sensitivity metric and put in grads
|
||||
def grasp(layer):
|
||||
if layer.weight.grad is not None:
|
||||
return -layer.weight.data * layer.weight.grad # -theta_q Hg
|
||||
#NOTE in the grasp code they take the *bottom* (1-p)% of values
|
||||
#but we take the *top* (1-p)%, therefore we remove the -ve sign
|
||||
#EDIT accuracy seems to be negatively correlated with this metric, so we add -ve sign here!
|
||||
else:
|
||||
return torch.zeros_like(layer.weight)
|
||||
|
||||
grads = get_layer_metric_array(net, grasp, mode)
|
||||
|
||||
return grads
|
57
zero-cost-nas/foresight/pruners/measures/jacob_cov.py
Normal file
57
zero-cost-nas/foresight/pruners/measures/jacob_cov.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from . import measure
|
||||
|
||||
|
||||
def get_batch_jacobian(net, x, target, device, split_data):
|
||||
x.requires_grad_(True)
|
||||
|
||||
N = x.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
y = net(x[st:en])
|
||||
y.backward(torch.ones_like(y))
|
||||
|
||||
jacob = x.grad.detach()
|
||||
x.requires_grad_(False)
|
||||
return jacob, target.detach()
|
||||
|
||||
def eval_score(jacob, labels=None):
|
||||
corrs = np.corrcoef(jacob)
|
||||
v, _ = np.linalg.eig(corrs)
|
||||
k = 1e-5
|
||||
return -np.sum(np.log(v + k) + 1./(v + k))
|
||||
|
||||
@measure('jacob_cov', bn=True)
|
||||
def compute_jacob_cov(net, inputs, targets, split_data=1, loss_fn=None):
|
||||
device = inputs.device
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
jacobs, labels = get_batch_jacobian(net, inputs, targets, device, split_data=split_data)
|
||||
jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy()
|
||||
|
||||
try:
|
||||
jc = eval_score(jacobs, labels)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
jc = np.nan
|
||||
|
||||
return jc
|
22
zero-cost-nas/foresight/pruners/measures/l2_norm.py
Normal file
22
zero-cost-nas/foresight/pruners/measures/l2_norm.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
@measure('l2_norm', copy_net=False, mode='param')
|
||||
def get_l2_norm_array(net, inputs, targets, mode, split_data=1):
|
||||
return get_layer_metric_array(net, lambda l: l.weight.norm(), mode=mode)
|
69
zero-cost-nas/foresight/pruners/measures/meco.py
Normal file
69
zero-cost-nas/foresight/pruners/measures/meco.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
import copy
|
||||
import time
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from . import measure
|
||||
|
||||
|
||||
def get_score(net, x, target, device, split_data):
|
||||
result_list = []
|
||||
|
||||
def forward_hook(module, data_input, data_output):
|
||||
|
||||
fea = data_output[0].detach()
|
||||
fea = fea.reshape(fea.shape[0], -1)
|
||||
corr = torch.corrcoef(fea)
|
||||
corr[torch.isnan(corr)] = 0
|
||||
corr[torch.isinf(corr)] = 0
|
||||
values = torch.linalg.eig(corr)[0]
|
||||
# result = np.real(np.min(values)) / np.real(np.max(values))
|
||||
result = torch.min(torch.real(values))
|
||||
result_list.append(result)
|
||||
|
||||
for name, modules in net.named_modules():
|
||||
modules.register_forward_hook(forward_hook)
|
||||
|
||||
|
||||
|
||||
N = x.shape[0]
|
||||
for sp in range(split_data):
|
||||
st = sp * N // split_data
|
||||
en = (sp + 1) * N // split_data
|
||||
y = net(x[st:en])
|
||||
results = torch.tensor(result_list)
|
||||
results = results[torch.logical_not(torch.isnan(results))]
|
||||
v = torch.sum(results)
|
||||
result_list.clear()
|
||||
return v.item()
|
||||
|
||||
|
||||
|
||||
@measure('meco', bn=True)
|
||||
def compute_meco(net, inputs, targets, split_data=1, loss_fn=None):
|
||||
device = inputs.device
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
try:
|
||||
meco = get_score(net, inputs, targets, device, split_data=split_data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
meco = np.nan, None
|
||||
return meco
|
55
zero-cost-nas/foresight/pruners/measures/norm.py
Normal file
55
zero-cost-nas/foresight/pruners/measures/norm.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
import time
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
|
||||
|
||||
def get_score(net, x, target, device, split_data):
|
||||
result_list = []
|
||||
def forward_hook(module, data_input, data_output):
|
||||
norm = torch.norm(data_input[0])
|
||||
result_list.append(norm)
|
||||
net.classifier.register_forward_hook(forward_hook)
|
||||
|
||||
N = x.shape[0]
|
||||
for sp in range(split_data):
|
||||
st = sp * N // split_data
|
||||
en = (sp + 1) * N // split_data
|
||||
y = net(x[st:en])
|
||||
n = result_list[0].item()
|
||||
result_list.clear()
|
||||
return n
|
||||
|
||||
|
||||
|
||||
@measure('norm', bn=True)
|
||||
def compute_norm(net, inputs, targets, split_data=1, loss_fn=None):
|
||||
device = inputs.device
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
# print('var:', feature.shape)
|
||||
try:
|
||||
norm, t = get_score(net, inputs, targets, device, split_data=split_data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
norm, t = np.nan, None
|
||||
# print(jc)
|
||||
# print(f'norm time: {t} s')
|
||||
return norm, t
|
16
zero-cost-nas/foresight/pruners/measures/param_count.py
Normal file
16
zero-cost-nas/foresight/pruners/measures/param_count.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import time
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
|
||||
@measure('param_count', copy_net=False, mode='param')
|
||||
def get_param_count_array(net, inputs, targets, mode, loss_fn, split_data=1):
|
||||
s = time.time()
|
||||
count = get_layer_metric_array(net, lambda l: torch.tensor(sum(p.numel() for p in l.parameters() if p.requires_grad)), mode=mode)
|
||||
e = time.time()
|
||||
t = e - s
|
||||
# print(f'param_count time: {t} s')
|
||||
return count, t
|
44
zero-cost-nas/foresight/pruners/measures/plain.py
Normal file
44
zero-cost-nas/foresight/pruners/measures/plain.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
@measure('plain', bn=True, mode='param')
|
||||
def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
||||
|
||||
net.zero_grad()
|
||||
N = inputs.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
# select the gradients that we want to use for search/prune
|
||||
def plain(layer):
|
||||
if layer.weight.grad is not None:
|
||||
return layer.weight.grad * layer.weight
|
||||
else:
|
||||
return torch.zeros_like(layer.weight)
|
||||
|
||||
grads_abs = get_layer_metric_array(net, plain, mode)
|
||||
return grads_abs
|
69
zero-cost-nas/foresight/pruners/measures/snip.py
Normal file
69
zero-cost-nas/foresight/pruners/measures/snip.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import copy
|
||||
import types
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
def snip_forward_conv2d(self, x):
|
||||
return F.conv2d(x, self.weight * self.weight_mask, self.bias,
|
||||
self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
def snip_forward_linear(self, x):
|
||||
return F.linear(x, self.weight * self.weight_mask, self.bias)
|
||||
|
||||
@measure('snip', bn=True, mode='param')
|
||||
def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
||||
for layer in net.modules():
|
||||
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
|
||||
layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))
|
||||
layer.weight.requires_grad = False
|
||||
|
||||
# Override the forward methods:
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
layer.forward = types.MethodType(snip_forward_conv2d, layer)
|
||||
|
||||
if isinstance(layer, nn.Linear):
|
||||
layer.forward = types.MethodType(snip_forward_linear, layer)
|
||||
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
N = inputs.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
# select the gradients that we want to use for search/prune
|
||||
def snip(layer):
|
||||
if layer.weight_mask.grad is not None:
|
||||
return torch.abs(layer.weight_mask.grad)
|
||||
else:
|
||||
return torch.zeros_like(layer.weight)
|
||||
|
||||
grads_abs = get_layer_metric_array(net, snip, mode)
|
||||
|
||||
return grads_abs
|
69
zero-cost-nas/foresight/pruners/measures/synflow.py
Normal file
69
zero-cost-nas/foresight/pruners/measures/synflow.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
@measure('synflow', bn=False, mode='param')
|
||||
@measure('synflow_bn', bn=True, mode='param')
|
||||
def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn=None):
|
||||
|
||||
device = inputs.device
|
||||
|
||||
#convert params to their abs. Keep sign for converting it back.
|
||||
@torch.no_grad()
|
||||
def linearize(net):
|
||||
signs = {}
|
||||
for name, param in net.state_dict().items():
|
||||
signs[name] = torch.sign(param)
|
||||
param.abs_()
|
||||
return signs
|
||||
|
||||
#convert to orig values
|
||||
@torch.no_grad()
|
||||
def nonlinearize(net, signs):
|
||||
for name, param in net.state_dict().items():
|
||||
if 'weight_mask' not in name:
|
||||
param.mul_(signs[name])
|
||||
|
||||
# keep signs of all params
|
||||
signs = linearize(net)
|
||||
|
||||
# Compute gradients with input of 1s
|
||||
net.zero_grad()
|
||||
net.double()
|
||||
input_dim = list(inputs[0,:].shape)
|
||||
inputs = torch.ones([1] + input_dim).double().to(device)
|
||||
output = net.forward(inputs)
|
||||
torch.sum(output).backward()
|
||||
|
||||
# select the gradients that we want to use for search/prune
|
||||
def synflow(layer):
|
||||
if layer.weight.grad is not None:
|
||||
return torch.abs(layer.weight * layer.weight.grad)
|
||||
else:
|
||||
return torch.zeros_like(layer.weight)
|
||||
|
||||
grads_abs = get_layer_metric_array(net, synflow, mode)
|
||||
|
||||
# apply signs of all params
|
||||
nonlinearize(net, signs)
|
||||
|
||||
return grads_abs
|
||||
|
||||
|
55
zero-cost-nas/foresight/pruners/measures/var.py
Normal file
55
zero-cost-nas/foresight/pruners/measures/var.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
import time
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
|
||||
|
||||
def get_score(net, x, target, device, split_data):
|
||||
result_list = []
|
||||
def forward_hook(module, data_input, data_output):
|
||||
var = torch.var(data_input[0])
|
||||
result_list.append(var)
|
||||
net.classifier.register_forward_hook(forward_hook)
|
||||
|
||||
N = x.shape[0]
|
||||
for sp in range(split_data):
|
||||
st = sp * N // split_data
|
||||
en = (sp + 1) * N // split_data
|
||||
y = net(x[st:en])
|
||||
v = result_list[0].item()
|
||||
result_list.clear()
|
||||
return v
|
||||
|
||||
|
||||
|
||||
@measure('var', bn=True)
|
||||
def compute_var(net, inputs, targets, split_data=1, loss_fn=None):
|
||||
device = inputs.device
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
# print('var:', feature.shape)
|
||||
try:
|
||||
var= get_score(net, inputs, targets, device, split_data=split_data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
var= np.nan
|
||||
# print(jc)
|
||||
# print(f'var time: {t} s')
|
||||
return var
|
106
zero-cost-nas/foresight/pruners/measures/zico.py
Normal file
106
zero-cost-nas/foresight/pruners/measures/zico.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
import time
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
from torch import nn
|
||||
|
||||
from ...dataset import get_cifar_dataloaders
|
||||
|
||||
|
||||
def getgrad(model: torch.nn.Module, grad_dict: dict, step_iter=0):
|
||||
if step_iter == 0:
|
||||
for name, mod in model.named_modules():
|
||||
if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
|
||||
# print(mod.weight.grad.data.size())
|
||||
# print(mod.weight.data.size())
|
||||
try:
|
||||
grad_dict[name] = [mod.weight.grad.data.cpu().reshape(-1).numpy()]
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
for name, mod in model.named_modules():
|
||||
if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
|
||||
try:
|
||||
grad_dict[name].append(mod.weight.grad.data.cpu().reshape(-1).numpy())
|
||||
except:
|
||||
continue
|
||||
return grad_dict
|
||||
|
||||
|
||||
def caculate_zico(grad_dict):
|
||||
allgrad_array = None
|
||||
for i, modname in enumerate(grad_dict.keys()):
|
||||
grad_dict[modname] = np.array(grad_dict[modname])
|
||||
nsr_mean_sum = 0
|
||||
nsr_mean_sum_abs = 0
|
||||
nsr_mean_avg = 0
|
||||
nsr_mean_avg_abs = 0
|
||||
for j, modname in enumerate(grad_dict.keys()):
|
||||
nsr_std = np.std(grad_dict[modname], axis=0)
|
||||
# print(grad_dict[modname].shape)
|
||||
# print(grad_dict[modname].shape, nsr_std.shape)
|
||||
nonzero_idx = np.nonzero(nsr_std)[0]
|
||||
nsr_mean_abs = np.mean(np.abs(grad_dict[modname]), axis=0)
|
||||
tmpsum = np.sum(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx])
|
||||
if tmpsum == 0:
|
||||
pass
|
||||
else:
|
||||
nsr_mean_sum_abs += np.log(tmpsum)
|
||||
nsr_mean_avg_abs += np.log(np.mean(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx]))
|
||||
return nsr_mean_sum_abs
|
||||
|
||||
|
||||
def getzico(network, inputs, targets, loss_fn, split_data=2):
|
||||
grad_dict = {}
|
||||
network.train()
|
||||
device = inputs.device
|
||||
network.to(device)
|
||||
N = inputs.shape[0]
|
||||
split_data = 2
|
||||
|
||||
for sp in range(split_data):
|
||||
st = sp * N // split_data
|
||||
en = (sp + 1) * N // split_data
|
||||
outputs = network.forward(inputs[st:en])
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
grad_dict = getgrad(network, grad_dict, sp)
|
||||
# print(grad_dict)
|
||||
res = caculate_zico(grad_dict)
|
||||
return res
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@measure('zico', bn=True)
|
||||
def compute_zico(net, inputs, targets, split_data=2, loss_fn=None):
|
||||
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
# print('var:', feature.shape)
|
||||
try:
|
||||
zico = getzico(net, inputs, targets, loss_fn, split_data=split_data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
zico= np.nan
|
||||
# print(jc)
|
||||
# print(f'var time: {t} s')
|
||||
return zico
|
83
zero-cost-nas/foresight/pruners/p_utils.py
Normal file
83
zero-cost-nas/foresight/pruners/p_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..models import *
|
||||
|
||||
def get_some_data(train_dataloader, num_batches, device):
|
||||
traindata = []
|
||||
dataloader_iter = iter(train_dataloader)
|
||||
for _ in range(num_batches):
|
||||
traindata.append(next(dataloader_iter))
|
||||
inputs = torch.cat([a for a,_ in traindata])
|
||||
targets = torch.cat([b for _,b in traindata])
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
return inputs, targets
|
||||
|
||||
def get_some_data_grasp(train_dataloader, num_classes, samples_per_class, device):
|
||||
datas = [[] for _ in range(num_classes)]
|
||||
labels = [[] for _ in range(num_classes)]
|
||||
mark = dict()
|
||||
dataloader_iter = iter(train_dataloader)
|
||||
while True:
|
||||
inputs, targets = next(dataloader_iter)
|
||||
for idx in range(inputs.shape[0]):
|
||||
x, y = inputs[idx:idx+1], targets[idx:idx+1]
|
||||
category = y.item()
|
||||
if len(datas[category]) == samples_per_class:
|
||||
mark[category] = True
|
||||
continue
|
||||
datas[category].append(x)
|
||||
labels[category].append(y)
|
||||
if len(mark) == num_classes:
|
||||
break
|
||||
|
||||
x = torch.cat([torch.cat(_, 0) for _ in datas]).to(device)
|
||||
y = torch.cat([torch.cat(_) for _ in labels]).view(-1).to(device)
|
||||
return x, y
|
||||
|
||||
def get_layer_metric_array(net, metric, mode):
|
||||
metric_array = []
|
||||
|
||||
for layer in net.modules():
|
||||
if mode=='channel' and hasattr(layer,'dont_ch_prune'):
|
||||
continue
|
||||
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
|
||||
metric_array.append(metric(layer))
|
||||
|
||||
return metric_array
|
||||
|
||||
def reshape_elements(elements, shapes, device):
|
||||
def broadcast_val(elements, shapes):
|
||||
ret_grads = []
|
||||
for e,sh in zip(elements, shapes):
|
||||
ret_grads.append(torch.stack([torch.Tensor(sh).fill_(v) for v in e], dim=0).to(device))
|
||||
return ret_grads
|
||||
if type(elements[0]) == list:
|
||||
outer = []
|
||||
for e,sh in zip(elements, shapes):
|
||||
outer.append(broadcast_val(e,sh))
|
||||
return outer
|
||||
else:
|
||||
return broadcast_val(elements, shapes)
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
116
zero-cost-nas/foresight/pruners/predictive.py
Normal file
116
zero-cost-nas/foresight/pruners/predictive.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .p_utils import *
|
||||
from . import measures
|
||||
|
||||
import types
|
||||
import copy
|
||||
|
||||
|
||||
def no_op(self,x):
|
||||
return x
|
||||
|
||||
def copynet(self, bn):
|
||||
net = copy.deepcopy(self)
|
||||
if bn==False:
|
||||
for l in net.modules():
|
||||
if isinstance(l,nn.BatchNorm2d) or isinstance(l,nn.BatchNorm1d) :
|
||||
l.forward = types.MethodType(no_op, l)
|
||||
return net
|
||||
|
||||
def find_measures_arrays(net_orig, trainloader, dataload_info, device, measure_names=None, loss_fn=F.cross_entropy):
|
||||
if measure_names is None:
|
||||
measure_names = measures.available_measures
|
||||
|
||||
dataload, num_imgs_or_batches, num_classes = dataload_info
|
||||
|
||||
if not hasattr(net_orig,'get_prunable_copy'):
|
||||
net_orig.get_prunable_copy = types.MethodType(copynet, net_orig)
|
||||
|
||||
#move to cpu to free up mem
|
||||
torch.cuda.empty_cache()
|
||||
net_orig = net_orig.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
#given 1 minibatch of data
|
||||
if dataload == 'random':
|
||||
inputs, targets = get_some_data(trainloader, num_batches=num_imgs_or_batches, device=device)
|
||||
elif dataload == 'grasp':
|
||||
inputs, targets = get_some_data_grasp(trainloader, num_classes, samples_per_class=num_imgs_or_batches, device=device)
|
||||
else:
|
||||
raise NotImplementedError(f'dataload {dataload} is not supported')
|
||||
|
||||
done, ds = False, 1
|
||||
measure_values = {}
|
||||
|
||||
while not done:
|
||||
try:
|
||||
for measure_name in measure_names:
|
||||
if measure_name not in measure_values:
|
||||
val = measures.calc_measure(measure_name, net_orig, device, inputs, targets, loss_fn=loss_fn, split_data=ds)
|
||||
measure_values[measure_name] = val
|
||||
|
||||
done = True
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
done=False
|
||||
if ds == inputs.shape[0]//2:
|
||||
raise ValueError(f'Can\'t split data anymore, but still unable to run. Something is wrong')
|
||||
ds += 1
|
||||
while inputs.shape[0] % ds != 0:
|
||||
ds += 1
|
||||
torch.cuda.empty_cache()
|
||||
print(f'Caught CUDA OOM, retrying with data split into {ds} parts')
|
||||
else:
|
||||
raise e
|
||||
|
||||
net_orig = net_orig.to(device).train()
|
||||
return measure_values
|
||||
|
||||
def find_measures(net_orig, # neural network
|
||||
dataloader, # a data loader (typically for training data)
|
||||
dataload_info, # a tuple with (dataload_type = {random, grasp}, number_of_batches_for_random_or_images_per_class_for_grasp, number of classes)
|
||||
device, # GPU/CPU device used
|
||||
loss_fn=F.cross_entropy, # loss function to use within the zero-cost metrics
|
||||
measure_names=None, # an array of measure names to compute, if left blank, all measures are computed by default
|
||||
measures_arr=None): # [not used] if the measures are already computed but need to be summarized, pass them here
|
||||
|
||||
#Given a neural net
|
||||
#and some information about the input data (dataloader)
|
||||
#and loss function (loss_fn)
|
||||
#this function returns an array of zero-cost proxy metrics.
|
||||
|
||||
def sum_arr(arr):
|
||||
sum = 0.
|
||||
for i in range(len(arr)):
|
||||
sum += torch.sum(arr[i])
|
||||
return sum.item()
|
||||
|
||||
if measures_arr is None:
|
||||
measures_arr = find_measures_arrays(net_orig, dataloader, dataload_info, device, loss_fn=loss_fn, measure_names=measure_names)
|
||||
|
||||
measures = {}
|
||||
for k,v in measures_arr.items():
|
||||
if k in ['jacob_cov', 'var', 'cor', 'norm', 'meco', 'zico']:
|
||||
measures[k] = v
|
||||
else:
|
||||
measures[k] = sum_arr(v)
|
||||
|
||||
return measures
|
Reference in New Issue
Block a user