v2
This commit is contained in:
21
scores.py
Normal file
21
scores.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
|
||||
def hooklogdet(K, labels=None):
|
||||
s, ld = np.linalg.slogdet(K)
|
||||
return ld
|
||||
|
||||
def random_score(jacob, label=None):
|
||||
return np.random.normal()
|
||||
|
||||
|
||||
_scores = {
|
||||
'hook_logdet': hooklogdet,
|
||||
'random': random_score
|
||||
}
|
||||
|
||||
def get_score_func(score_name):
|
||||
return _scores[score_name]
|
Reference in New Issue
Block a user