update_name
This commit is contained in:
138
graph_dit/metrics/abstract_metrics.py
Normal file
138
graph_dit/metrics/abstract_metrics.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics import Metric, MeanSquaredError
|
||||
|
||||
|
||||
class TrainAbstractMetricsDiscrete(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def log_epoch_metrics(self, current_epoch):
|
||||
pass
|
||||
|
||||
|
||||
class TrainAbstractMetrics(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def log_epoch_metrics(self, current_epoch):
|
||||
pass
|
||||
|
||||
|
||||
class SumExceptBatchMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, values) -> None:
|
||||
self.total_value += torch.sum(values)
|
||||
self.total_samples += values.shape[0]
|
||||
|
||||
def compute(self):
|
||||
return self.total_value / self.total_samples
|
||||
|
||||
|
||||
class SumExceptBatchMSE(MeanSquaredError):
|
||||
def update(self, preds: Tensor, target: Tensor) -> None:
|
||||
"""Update state with predictions and targets.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model
|
||||
target: Ground truth values
|
||||
"""
|
||||
assert preds.shape == target.shape
|
||||
sum_squared_error, n_obs = self._mean_squared_error_update(preds, target)
|
||||
|
||||
self.sum_squared_error += sum_squared_error
|
||||
self.total += n_obs
|
||||
|
||||
def _mean_squared_error_update(self, preds: Tensor, target: Tensor):
|
||||
""" Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input
|
||||
tensors.
|
||||
preds: Predicted tensor
|
||||
target: Ground truth tensor
|
||||
"""
|
||||
diff = preds - target
|
||||
sum_squared_error = torch.sum(diff * diff)
|
||||
n_obs = preds.shape[0]
|
||||
return sum_squared_error, n_obs
|
||||
|
||||
|
||||
class SumExceptBatchKL(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, p, q) -> None:
|
||||
self.total_value += F.kl_div(q, p, reduction='sum')
|
||||
self.total_samples += p.size(0)
|
||||
|
||||
def compute(self):
|
||||
return self.total_value / self.total_samples
|
||||
|
||||
|
||||
class CrossEntropyMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: Tensor, target: Tensor, weight=None) -> None:
|
||||
""" Update state with predictions and targets.
|
||||
preds: Predictions from model (bs * n, d) or (bs * n * n, d)
|
||||
target: Ground truth values (bs * n, d) or (bs * n * n, d). """
|
||||
target = torch.argmax(target, dim=-1)
|
||||
if weight is not None:
|
||||
weight = weight.type_as(preds)
|
||||
output = F.cross_entropy(preds, target, weight = weight, reduction='sum')
|
||||
else:
|
||||
output = F.cross_entropy(preds, target, reduction='sum')
|
||||
self.total_ce += output
|
||||
self.total_samples += preds.size(0)
|
||||
|
||||
def compute(self):
|
||||
return self.total_ce / self.total_samples
|
||||
|
||||
|
||||
class ProbabilityMetric(Metric):
|
||||
def __init__(self):
|
||||
""" This metric is used to track the marginal predicted probability of a class during training. """
|
||||
super().__init__()
|
||||
self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: Tensor) -> None:
|
||||
self.prob += preds.sum()
|
||||
self.total += preds.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.prob / self.total
|
||||
|
||||
|
||||
class NLL(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, batch_nll) -> None:
|
||||
self.total_nll += torch.sum(batch_nll)
|
||||
self.total_samples += batch_nll.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.total_nll / self.total_samples
|
Reference in New Issue
Block a user