upload
This commit is contained in:
47
Scorers/scorer.py
Normal file
47
Scorers/scorer.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class Jocab_Scorer:
|
||||
def __init__(self, gpu):
|
||||
self.gpu = gpu
|
||||
print('Jacob score init')
|
||||
|
||||
def score(self, model, input, target):
|
||||
batch_size = input.shape[0]
|
||||
model.K = torch.zeros(batch_size, batch_size).cuda()
|
||||
|
||||
input = input.cuda()
|
||||
with torch.no_grad():
|
||||
model(input)
|
||||
score = self.hooklogdet(model.K.cpu().numpy())
|
||||
|
||||
#print(score)
|
||||
return score
|
||||
|
||||
def setup_hooks(self, model, batch_size):
|
||||
#initalize score
|
||||
model = model.to(torch.device('cuda', self.gpu))
|
||||
model.eval()
|
||||
model.K = torch.zeros(batch_size, batch_size).cuda()
|
||||
def counting_forward_hook(module, inp, out):
|
||||
try:
|
||||
# if not module.visited_backwards:
|
||||
# return
|
||||
if isinstance(inp, tuple):
|
||||
inp = inp[0]
|
||||
inp = inp.view(inp.size(0), -1)
|
||||
x = (inp > 0).float()
|
||||
K = x @ x.t()
|
||||
K2 = (1.-x) @ (1.-x.t())
|
||||
model.K = model.K + K + K2
|
||||
except:
|
||||
pass
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if 'ReLU' in str(type(module)):
|
||||
module.register_forward_hook(counting_forward_hook)
|
||||
#module.register_backward_hook(counting_backward_hook)
|
||||
|
||||
def hooklogdet(self, K, labels=None):
|
||||
s, ld = np.linalg.slogdet(K)
|
||||
return ld
|
Reference in New Issue
Block a user