update EdgeMetricsCE class

This commit is contained in:
mhz
2024-06-30 17:37:18 +02:00
parent d57575586d
commit 0fc6f6e686
2 changed files with 156 additions and 15 deletions

View File

@@ -77,6 +77,15 @@ class NodeMetricsCE(MetricCollection):
for i, node_type in enumerate(active_nodes) :
metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i))
super().__init__(metrics_list)
class EdgeMetricsCE(MetricCollection):
def __init__(self):
ce_no_bond = NoBondCE(0)
ce_SI = SingleCE(1)
ce_DO = DoubleCE(2)
ce_TR = TripleCE(3)
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
class AtomMetricsCE(MetricCollection):
def __init__(self, active_atoms):
@@ -101,6 +110,41 @@ class BondMetricsCE(MetricCollection):
class TrainGraphMetricsDiscrete(nn.Module):
def __init__(self, dataset_infos):
super().__init__()
active_nodes = dataset_infos.active_nodes
self.train_node_metrics = NodeMetricsCE(active_nodes=active_nodes)
self.train_edge_metrics = EdgeMetricsCE()
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
self.train_node_metrics(masked_pred_X, true_X)
self.train_edge_metrics(masked_pred_E, true_E)
if log:
to_log = {}
for key, val in self.train_node_metrics.compute().items():
to_log['train/' + key] = val.item()
for key, val in self.train_edge_metrics.compute().items():
to_log['train/' + key] = val.item()
def reset(self):
for metric in [self.train_node_metrics, self.train_edge_metrics]:
metric.reset()
def log_epoch_metrics(self, current_epoch, log=True):
epoch_node_metrics = self.train_node_metrics.compute()
epoch_edge_metrics = self.train_edge_metrics.compute()
to_log = {}
for key, val in epoch_node_metrics.items():
to_log['train_epoch/' + key] = val.item()
for key, val in epoch_edge_metrics.items():
to_log['train_epoch/' + key] = val.item()
for key, val in epoch_node_metrics.items():
epoch_node_metrics[key] = round(val.item(),4)
for key, val in epoch_edge_metrics.items():
epoch_edge_metrics[key] = round(val.item(),4)
if log:
print(f"Epoch {current_epoch}: {epoch_node_metrics} -- {epoch_edge_metrics}")
class TrainMolecularMetricsDiscrete(nn.Module):
def __init__(self, dataset_infos):