update EdgeMetricsCE class
This commit is contained in:
@@ -77,6 +77,15 @@ class NodeMetricsCE(MetricCollection):
|
||||
|
||||
for i, node_type in enumerate(active_nodes) :
|
||||
metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i))
|
||||
super().__init__(metrics_list)
|
||||
|
||||
class EdgeMetricsCE(MetricCollection):
|
||||
def __init__(self):
|
||||
ce_no_bond = NoBondCE(0)
|
||||
ce_SI = SingleCE(1)
|
||||
ce_DO = DoubleCE(2)
|
||||
ce_TR = TripleCE(3)
|
||||
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
|
||||
|
||||
class AtomMetricsCE(MetricCollection):
|
||||
def __init__(self, active_atoms):
|
||||
@@ -101,6 +110,41 @@ class BondMetricsCE(MetricCollection):
|
||||
class TrainGraphMetricsDiscrete(nn.Module):
|
||||
def __init__(self, dataset_infos):
|
||||
super().__init__()
|
||||
active_nodes = dataset_infos.active_nodes
|
||||
self.train_node_metrics = NodeMetricsCE(active_nodes=active_nodes)
|
||||
self.train_edge_metrics = EdgeMetricsCE()
|
||||
|
||||
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
|
||||
self.train_node_metrics(masked_pred_X, true_X)
|
||||
self.train_edge_metrics(masked_pred_E, true_E)
|
||||
if log:
|
||||
to_log = {}
|
||||
for key, val in self.train_node_metrics.compute().items():
|
||||
to_log['train/' + key] = val.item()
|
||||
for key, val in self.train_edge_metrics.compute().items():
|
||||
to_log['train/' + key] = val.item()
|
||||
|
||||
def reset(self):
|
||||
for metric in [self.train_node_metrics, self.train_edge_metrics]:
|
||||
metric.reset()
|
||||
|
||||
def log_epoch_metrics(self, current_epoch, log=True):
|
||||
epoch_node_metrics = self.train_node_metrics.compute()
|
||||
epoch_edge_metrics = self.train_edge_metrics.compute()
|
||||
|
||||
to_log = {}
|
||||
for key, val in epoch_node_metrics.items():
|
||||
to_log['train_epoch/' + key] = val.item()
|
||||
for key, val in epoch_edge_metrics.items():
|
||||
to_log['train_epoch/' + key] = val.item()
|
||||
|
||||
for key, val in epoch_node_metrics.items():
|
||||
epoch_node_metrics[key] = round(val.item(),4)
|
||||
for key, val in epoch_edge_metrics.items():
|
||||
epoch_edge_metrics[key] = round(val.item(),4)
|
||||
|
||||
if log:
|
||||
print(f"Epoch {current_epoch}: {epoch_node_metrics} -- {epoch_edge_metrics}")
|
||||
|
||||
class TrainMolecularMetricsDiscrete(nn.Module):
|
||||
def __init__(self, dataset_infos):
|
||||
|
Reference in New Issue
Block a user