update_name
This commit is contained in:
126
graph_dit/metrics/molecular_metrics_train.py
Normal file
126
graph_dit/metrics/molecular_metrics_train.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
from torchmetrics import Metric, MetricCollection
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
class CEPerClass(Metric):
|
||||
full_state_update = False
|
||||
def __init__(self, class_id):
|
||||
super().__init__()
|
||||
self.class_id = class_id
|
||||
self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
self.binary_cross_entropy = torch.nn.BCELoss(reduction='sum')
|
||||
|
||||
def update(self, preds: Tensor, target: Tensor) -> None:
|
||||
"""Update state with predictions and targets.
|
||||
Args:
|
||||
preds: Predictions from model (bs, n, d) or (bs, n, n, d)
|
||||
target: Ground truth values (bs, n, d) or (bs, n, n, d)
|
||||
"""
|
||||
target = target.reshape(-1, target.shape[-1])
|
||||
mask = (target != 0.).any(dim=-1)
|
||||
|
||||
prob = self.softmax(preds)[..., self.class_id]
|
||||
prob = prob.flatten()[mask]
|
||||
|
||||
target = target[:, self.class_id]
|
||||
target = target[mask]
|
||||
|
||||
output = self.binary_cross_entropy(prob, target)
|
||||
|
||||
self.total_ce += output
|
||||
self.total_samples += prob.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.total_ce / self.total_samples
|
||||
|
||||
|
||||
class AtomCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
class NoBondCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class SingleCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class DoubleCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class TripleCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class AromaticCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class AtomMetricsCE(MetricCollection):
|
||||
def __init__(self, active_atoms):
|
||||
metrics_list = []
|
||||
|
||||
for i, atom_type in enumerate(active_atoms):
|
||||
metrics_list.append(type(f'{atom_type}_CE', (AtomCE,), {})(i))
|
||||
|
||||
super().__init__(metrics_list)
|
||||
|
||||
|
||||
class BondMetricsCE(MetricCollection):
|
||||
def __init__(self):
|
||||
ce_no_bond = NoBondCE(0)
|
||||
ce_SI = SingleCE(1)
|
||||
ce_DO = DoubleCE(2)
|
||||
ce_TR = TripleCE(3)
|
||||
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
|
||||
|
||||
|
||||
class TrainMolecularMetricsDiscrete(nn.Module):
|
||||
def __init__(self, dataset_infos):
|
||||
super().__init__()
|
||||
active_atoms = dataset_infos.active_atoms
|
||||
self.train_atom_metrics = AtomMetricsCE(active_atoms=active_atoms)
|
||||
self.train_bond_metrics = BondMetricsCE()
|
||||
|
||||
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
|
||||
self.train_atom_metrics(masked_pred_X, true_X)
|
||||
self.train_bond_metrics(masked_pred_E, true_E)
|
||||
if log:
|
||||
to_log = {}
|
||||
for key, val in self.train_atom_metrics.compute().items():
|
||||
to_log['train/' + key] = val.item()
|
||||
for key, val in self.train_bond_metrics.compute().items():
|
||||
to_log['train/' + key] = val.item()
|
||||
|
||||
def reset(self):
|
||||
for metric in [self.train_atom_metrics, self.train_bond_metrics]:
|
||||
metric.reset()
|
||||
|
||||
def log_epoch_metrics(self, current_epoch, log=True):
|
||||
epoch_atom_metrics = self.train_atom_metrics.compute()
|
||||
epoch_bond_metrics = self.train_bond_metrics.compute()
|
||||
|
||||
to_log = {}
|
||||
for key, val in epoch_atom_metrics.items():
|
||||
to_log['train_epoch/' + key] = val.item()
|
||||
for key, val in epoch_bond_metrics.items():
|
||||
to_log['train_epoch/' + key] = val.item()
|
||||
|
||||
for key, val in epoch_atom_metrics.items():
|
||||
epoch_atom_metrics[key] = round(val.item(),4)
|
||||
for key, val in epoch_bond_metrics.items():
|
||||
epoch_bond_metrics[key] = round(val.item(),4)
|
||||
|
||||
if log:
|
||||
print(f"Epoch {current_epoch}: {epoch_atom_metrics} -- {epoch_bond_metrics}")
|
||||
|
Reference in New Issue
Block a user