update_name

This commit is contained in:
gang liu
2024-05-25 15:32:36 -04:00
parent a6bd0117d4
commit 2c00828630
28 changed files with 178 additions and 19 deletions

View File

View File

@@ -0,0 +1,138 @@
import torch
from torch import Tensor
from torch.nn import functional as F
from torchmetrics import Metric, MeanSquaredError
class TrainAbstractMetricsDiscrete(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
pass
def reset(self):
pass
def log_epoch_metrics(self, current_epoch):
pass
class TrainAbstractMetrics(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log):
pass
def reset(self):
pass
def log_epoch_metrics(self, current_epoch):
pass
class SumExceptBatchMetric(Metric):
def __init__(self):
super().__init__()
self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
def update(self, values) -> None:
self.total_value += torch.sum(values)
self.total_samples += values.shape[0]
def compute(self):
return self.total_value / self.total_samples
class SumExceptBatchMSE(MeanSquaredError):
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
assert preds.shape == target.shape
sum_squared_error, n_obs = self._mean_squared_error_update(preds, target)
self.sum_squared_error += sum_squared_error
self.total += n_obs
def _mean_squared_error_update(self, preds: Tensor, target: Tensor):
""" Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input
tensors.
preds: Predicted tensor
target: Ground truth tensor
"""
diff = preds - target
sum_squared_error = torch.sum(diff * diff)
n_obs = preds.shape[0]
return sum_squared_error, n_obs
class SumExceptBatchKL(Metric):
def __init__(self):
super().__init__()
self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
def update(self, p, q) -> None:
self.total_value += F.kl_div(q, p, reduction='sum')
self.total_samples += p.size(0)
def compute(self):
return self.total_value / self.total_samples
class CrossEntropyMetric(Metric):
def __init__(self):
super().__init__()
self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor, weight=None) -> None:
""" Update state with predictions and targets.
preds: Predictions from model (bs * n, d) or (bs * n * n, d)
target: Ground truth values (bs * n, d) or (bs * n * n, d). """
target = torch.argmax(target, dim=-1)
if weight is not None:
weight = weight.type_as(preds)
output = F.cross_entropy(preds, target, weight = weight, reduction='sum')
else:
output = F.cross_entropy(preds, target, reduction='sum')
self.total_ce += output
self.total_samples += preds.size(0)
def compute(self):
return self.total_ce / self.total_samples
class ProbabilityMetric(Metric):
def __init__(self):
""" This metric is used to track the marginal predicted probability of a class during training. """
super().__init__()
self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum")
def update(self, preds: Tensor) -> None:
self.prob += preds.sum()
self.total += preds.numel()
def compute(self):
return self.prob / self.total
class NLL(Metric):
def __init__(self):
super().__init__()
self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
def update(self, batch_nll) -> None:
self.total_nll += torch.sum(batch_nll)
self.total_samples += batch_nll.numel()
def compute(self):
return self.total_nll / self.total_samples

Binary file not shown.

View File

@@ -0,0 +1,138 @@
### packages for visualization
from analysis.rdkit_functions import compute_molecular_metrics
from mini_moses.metrics.metrics import compute_intermediate_statistics
from metrics.property_metric import TaskModel
import torch
import torch.nn as nn
import os
import csv
import time
def result_to_csv(path, dict_data):
file_exists = os.path.exists(path)
log_name = dict_data.pop("log_name", None)
if log_name is None:
raise ValueError("The provided dictionary must contain a 'log_name' key.")
field_names = ["log_name"] + list(dict_data.keys())
dict_data["log_name"] = log_name
with open(path, "a", newline="") as file:
writer = csv.DictWriter(file, fieldnames=field_names)
if not file_exists:
writer.writeheader()
writer.writerow(dict_data)
class SamplingMolecularMetrics(nn.Module):
def __init__(
self,
dataset_infos,
train_smiles,
reference_smiles,
n_jobs=1,
device="cpu",
batch_size=512,
):
super().__init__()
self.task_name = dataset_infos.task
self.dataset_infos = dataset_infos
self.active_atoms = dataset_infos.active_atoms
self.train_smiles = train_smiles
if reference_smiles is not None:
print(
f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
)
start_time = time.time()
self.stat_ref = compute_intermediate_statistics(
reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
)
end_time = time.time()
elapsed_time = end_time - start_time
print(
f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
)
else:
self.stat_ref = None
self.comput_config = {
"n_jobs": n_jobs,
"device": device,
"batch_size": batch_size,
}
self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None}
for cur_task in dataset_infos.task.split("-")[:]:
# print('loading evaluator for task', cur_task)
model_path = os.path.join(
dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib"
)
os.makedirs(os.path.dirname(model_path), exist_ok=True)
evaluator = TaskModel(model_path, cur_task)
self.task_evaluator[cur_task] = evaluator
def forward(self, molecules, targets, name, current_epoch, val_counter, test=False):
if isinstance(targets, list):
targets_cat = torch.cat(targets, dim=0)
targets_np = targets_cat.detach().cpu().numpy()
else:
targets_np = targets.detach().cpu().numpy()
unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics(
molecules,
targets_np,
self.train_smiles,
self.stat_ref,
self.dataset_infos,
self.task_evaluator,
self.comput_config,
)
if test:
file_name = "final_smiles.txt"
with open(file_name, "w") as fp:
all_tasks_name = list(self.task_evaluator.keys())
all_tasks_name = all_tasks_name.copy()
if 'meta_taskname' in all_tasks_name:
all_tasks_name.remove('meta_taskname')
if 'scs' in all_tasks_name:
all_tasks_name.remove('scs')
all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name])
fp.write(all_tasks_str + "\n")
for i, smiles in enumerate(all_smiles):
if targets_log is not None:
all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name])
fp.write(all_result_str + "\n")
else:
fp.write("%s\n" % smiles)
print("All smiles saved")
else:
result_path = os.path.join(os.getcwd(), f"graphs/{name}")
os.makedirs(result_path, exist_ok=True)
text_path = os.path.join(
result_path,
f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt",
)
textfile = open(text_path, "w")
for smiles in unique_smiles:
textfile.write(smiles + "\n")
textfile.close()
all_logs = all_metrics
if test:
all_logs["log_name"] = "test"
else:
all_logs["log_name"] = (
"epoch" + str(current_epoch) + "_batch" + str(val_counter)
)
result_to_csv("output.csv", all_logs)
return all_smiles
def reset(self):
pass
if __name__ == "__main__":
pass

View File

@@ -0,0 +1,126 @@
import torch
from torchmetrics import Metric, MetricCollection
from torch import Tensor
import torch.nn as nn
class CEPerClass(Metric):
full_state_update = False
def __init__(self, class_id):
super().__init__()
self.class_id = class_id
self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
self.softmax = torch.nn.Softmax(dim=-1)
self.binary_cross_entropy = torch.nn.BCELoss(reduction='sum')
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets.
Args:
preds: Predictions from model (bs, n, d) or (bs, n, n, d)
target: Ground truth values (bs, n, d) or (bs, n, n, d)
"""
target = target.reshape(-1, target.shape[-1])
mask = (target != 0.).any(dim=-1)
prob = self.softmax(preds)[..., self.class_id]
prob = prob.flatten()[mask]
target = target[:, self.class_id]
target = target[mask]
output = self.binary_cross_entropy(prob, target)
self.total_ce += output
self.total_samples += prob.numel()
def compute(self):
return self.total_ce / self.total_samples
class AtomCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class NoBondCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class SingleCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class DoubleCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class TripleCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class AromaticCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class AtomMetricsCE(MetricCollection):
def __init__(self, active_atoms):
metrics_list = []
for i, atom_type in enumerate(active_atoms):
metrics_list.append(type(f'{atom_type}_CE', (AtomCE,), {})(i))
super().__init__(metrics_list)
class BondMetricsCE(MetricCollection):
def __init__(self):
ce_no_bond = NoBondCE(0)
ce_SI = SingleCE(1)
ce_DO = DoubleCE(2)
ce_TR = TripleCE(3)
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
class TrainMolecularMetricsDiscrete(nn.Module):
def __init__(self, dataset_infos):
super().__init__()
active_atoms = dataset_infos.active_atoms
self.train_atom_metrics = AtomMetricsCE(active_atoms=active_atoms)
self.train_bond_metrics = BondMetricsCE()
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
self.train_atom_metrics(masked_pred_X, true_X)
self.train_bond_metrics(masked_pred_E, true_E)
if log:
to_log = {}
for key, val in self.train_atom_metrics.compute().items():
to_log['train/' + key] = val.item()
for key, val in self.train_bond_metrics.compute().items():
to_log['train/' + key] = val.item()
def reset(self):
for metric in [self.train_atom_metrics, self.train_bond_metrics]:
metric.reset()
def log_epoch_metrics(self, current_epoch, log=True):
epoch_atom_metrics = self.train_atom_metrics.compute()
epoch_bond_metrics = self.train_bond_metrics.compute()
to_log = {}
for key, val in epoch_atom_metrics.items():
to_log['train_epoch/' + key] = val.item()
for key, val in epoch_bond_metrics.items():
to_log['train_epoch/' + key] = val.item()
for key, val in epoch_atom_metrics.items():
epoch_atom_metrics[key] = round(val.item(),4)
for key, val in epoch_bond_metrics.items():
epoch_bond_metrics[key] = round(val.item(),4)
if log:
print(f"Epoch {current_epoch}: {epoch_atom_metrics} -- {epoch_bond_metrics}")

View File

@@ -0,0 +1,201 @@
import math, os
import pickle
import os.path as op
import numpy as np
import pandas as pd
from joblib import dump, load
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import mean_absolute_error, roc_auc_score
from rdkit import Chem
from rdkit import rdBase
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem import rdMolDescriptors
rdBase.DisableLog('rdApp.error')
task_to_colname = {
'hiv_b': 'HIV_active',
'bace_b': 'Class',
'bbbp_b': 'p_np',
'O2': 'O2',
'N2': 'N2',
'CO2': 'CO2',
}
tasktype_name = {
'hiv_b': 'classification',
'bace_b': 'classification',
'bbbp_b': 'classification',
'O2': 'regression',
'N2': 'regression',
'CO2': 'regression',
}
class TaskModel():
"""Scores based on an ECFP classifier."""
def __init__(self, model_path, task_name):
task_type = tasktype_name[task_name]
self.task_name = task_name
self.task_type = task_type
self.model_path = model_path
self.metric_func = roc_auc_score if 'classification' in self.task_type else mean_absolute_error
try:
self.model = load(model_path)
print(self.task_name, ' evaluator loaded')
except:
print(self.task_name, ' evaluator not found, training new one...')
if 'classification' in task_type:
self.model = RandomForestClassifier(random_state=0)
elif 'regression' in task_type:
self.model = RandomForestRegressor(random_state=0)
perfermance = self.train()
dump(self.model, model_path)
print('Oracle peformance: ', perfermance)
def train(self):
data_path = os.path.dirname(self.model_path)
data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz')
df = pd.read_csv(data_path)
col_name = task_to_colname[self.task_name]
y = df[col_name].to_numpy()
x_smiles = df['smiles'].to_numpy()
mask = ~np.isnan(y)
y = y[mask]
if 'classification' in self.task_type:
y = y.astype(int)
x_smiles = x_smiles[mask]
x_fps = []
mask = []
for i,smiles in enumerate(x_smiles):
mol = Chem.MolFromSmiles(smiles)
mask.append( int(mol is not None) )
fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048))
x_fps.append(fp)
x_fps = np.concatenate(x_fps, axis=0)
self.model.fit(x_fps, y)
y_pred = self.model.predict(x_fps)
perf = self.metric_func(y, y_pred)
print(f'{self.task_name} performance: {perf}')
return perf
def __call__(self, smiles_list):
fps = []
mask = []
for i,smiles in enumerate(smiles_list):
mol = Chem.MolFromSmiles(smiles)
mask.append( int(mol is not None) )
fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048))
fps.append(fp)
fps = np.concatenate(fps, axis=0)
if 'classification' in self.task_type:
scores = self.model.predict_proba(fps)[:, 1]
else:
scores = self.model.predict(fps)
scores = scores * np.array(mask)
return np.float32(scores)
@classmethod
def fingerprints_from_mol(cls, mol): # use ECFP4
features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
features = np.zeros((1,))
DataStructs.ConvertToNumpyArray(features_vec, features)
return features.reshape(1, -1)
###### SAS Score ######
_fscores = None
def readFragmentScores(name='fpscores'):
import gzip
global _fscores
# generate the full path filename:
if name == "fpscores":
name = op.join(op.dirname(__file__), name)
data = pickle.load(gzip.open('%s.pkl.gz' % name))
outDict = {}
for i in data:
for j in range(1, len(i)):
outDict[i[j]] = float(i[0])
_fscores = outDict
def numBridgeheadsAndSpiro(mol, ri=None):
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
return nBridgehead, nSpiro
def calculateSAS(smiles_list):
scores = []
for i, smiles in enumerate(smiles_list):
mol = Chem.MolFromSmiles(smiles)
score = calculateScore(mol)
scores.append(score)
return np.float32(scores)
def calculateScore(m):
if _fscores is None:
readFragmentScores()
# fragment score
fp = rdMolDescriptors.GetMorganFingerprint(m,
2) # <- 2 is the *radius* of the circular fingerprint
fps = fp.GetNonzeroElements()
score1 = 0.
nf = 0
for bitId, v in fps.items():
nf += v
sfp = bitId
score1 += _fscores.get(sfp, -4) * v
score1 /= nf
# features score
nAtoms = m.GetNumAtoms()
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
ri = m.GetRingInfo()
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
nMacrocycles = 0
for x in ri.AtomRings():
if len(x) > 8:
nMacrocycles += 1
sizePenalty = nAtoms**1.005 - nAtoms
stereoPenalty = math.log10(nChiralCenters + 1)
spiroPenalty = math.log10(nSpiro + 1)
bridgePenalty = math.log10(nBridgeheads + 1)
macrocyclePenalty = 0.
# ---------------------------------------
# This differs from the paper, which defines:
# macrocyclePenalty = math.log10(nMacrocycles+1)
# This form generates better results when 2 or more macrocycles are present
if nMacrocycles > 0:
macrocyclePenalty = math.log10(2)
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
# correction for the fingerprint density
# not in the original publication, added in version 1.1
# to make highly symmetrical molecules easier to synthetise
score3 = 0.
if nAtoms > len(fps):
score3 = math.log(float(nAtoms) / len(fps)) * .5
sascore = score1 + score2 + score3
# need to transform "raw" value into scale between 1 and 10
min = -4.0
max = 2.5
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
# smooth the 10-end
if sascore > 8.:
sascore = 8. + math.log(sascore + 1. - 9.)
if sascore > 10.:
sascore = 10.0
elif sascore < 1.:
sascore = 1.0
return sascore

View File

@@ -0,0 +1,94 @@
import time
import torch
import torch.nn as nn
from metrics.abstract_metrics import CrossEntropyMetric
from torchmetrics import Metric, MeanSquaredError
# from 2:He to 119:*
valencies_check = [0, 1, 2, 3, 4, 3, 2, 1, 0, 1, 2, 6, 6, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 6, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
valencies_check = torch.tensor(valencies_check)
weight_check = [4.003, 6.941, 9.012, 10.812, 12.011, 14.007, 15.999, 18.998, 20.18, 22.99, 24.305, 26.982, 28.086, 30.974, 32.067, 35.453, 39.948, 39.098, 40.078, 44.956, 47.867, 50.942, 51.996, 54.938, 55.845, 58.933, 58.693, 63.546, 65.39, 69.723, 72.61, 74.922, 78.96, 79.904, 83.8, 85.468, 87.62, 88.906, 91.224, 92.906, 95.94, 98.0, 101.07, 102.906, 106.42, 107.868, 112.412, 114.818, 118.711, 121.76, 127.6, 126.904, 131.29, 132.905, 137.328, 138.906, 140.116, 140.908, 144.24, 145.0, 150.36, 151.964, 157.25, 158.925, 162.5, 164.93, 167.26, 168.934, 173.04, 174.967, 178.49, 180.948, 183.84, 186.207, 190.23, 192.217, 195.078, 196.967, 200.59, 204.383, 207.2, 208.98, 209.0, 210.0, 222.0, 223.0, 226.0, 227.0, 232.038, 231.036, 238.029, 237.0, 244.0, 243.0, 247.0, 247.0, 251.0, 252.0, 257.0, 258.0, 259.0, 262.0, 267.0, 268.0, 269.0, 270.0, 269.0, 278.0, 281.0, 281.0, 285.0, 284.0, 289.0, 288.0, 293.0, 292.0, 294.0, 294.0]
weight_check = torch.tensor(weight_check)
class AtomWeightMetric(Metric):
def __init__(self):
super().__init__()
self.add_state('total_loss', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
global weight_check
self.weight_check = weight_check
def update(self, X, Y):
atom_pred_num = X.argmax(dim=-1)
atom_real_num = Y.argmax(dim=-1)
self.weight_check = self.weight_check.type_as(X)
pred_weight = self.weight_check[atom_pred_num]
real_weight = self.weight_check[atom_real_num]
lss = 0
lss += torch.abs(pred_weight.sum(dim=-1) - real_weight.sum(dim=-1)).sum()
self.total_loss += lss
self.total_samples += X.size(0)
def compute(self):
return self.total_loss / self.total_samples
class TrainLossDiscrete(nn.Module):
""" Train with Cross entropy"""
def __init__(self, lambda_train, weight_node=None, weight_edge=None):
super().__init__()
self.node_loss = CrossEntropyMetric()
self.edge_loss = CrossEntropyMetric()
self.weight_loss = AtomWeightMetric()
self.y_loss = MeanSquaredError()
self.lambda_train = lambda_train
def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, node_mask, log: bool):
""" Compute train metrics
masked_pred_X : tensor -- (bs, n, dx)
masked_pred_E : tensor -- (bs, n, n, de)
pred_y : tensor -- (bs, )
true_X : tensor -- (bs, n, dx)
true_E : tensor -- (bs, n, n, de)
true_y : tensor -- (bs, )
log : boolean. """
loss_weight = self.weight_loss(masked_pred_X, true_X)
true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx)
true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de)
masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1))) # (bs * n, dx)
masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1))) # (bs * n * n, de)
# Remove masked rows
mask_X = (true_X != 0.).any(dim=-1)
mask_E = (true_E != 0.).any(dim=-1)
flat_true_X = true_X[mask_X, :]
flat_pred_X = masked_pred_X[mask_X, :]
flat_true_E = true_E[mask_E, :]
flat_pred_E = masked_pred_E[mask_E, :]
loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0
loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0
return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + loss_weight
def reset(self):
for metric in [self.node_loss, self.edge_loss, self.y_loss]:
metric.reset()
def log_epoch_metrics(self, current_epoch, start_epoch_time, log=True):
epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1
epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1
epoch_weight_loss = self.weight_loss.compute() if self.weight_loss.total_samples > 0 else -1
if log:
print(f"Epoch {current_epoch} finished: X_CE: {epoch_node_loss :.4f} -- E_CE: {epoch_edge_loss :.4f} "
f"Weight: {epoch_weight_loss :.4f} "
f"-- Time taken {time.time() - start_epoch_time:.1f}s ")