This commit is contained in:
HamsterMimi
2023-05-04 13:23:56 +08:00
parent 189df25fd3
commit fd43e67da1
8 changed files with 1 additions and 211 deletions

View File

@@ -96,20 +96,6 @@ def project_op(model, input, target, args, cell_type, proj_queue=None, selected_
model.candidate_flags[cell_type][selected_eid] = False
# print(model.get_projected_weights())
if proj_crit == 'comb':
synflow = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=['synflow'])
var = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=['var'])
# print(synflow, var)
comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1)
measures = {'comb': comb}
else:
measures = predictive.find_measures(model,
proj_queue,