update_name
This commit is contained in:
0
graph_dit/metrics/__init__.py
Normal file
0
graph_dit/metrics/__init__.py
Normal file
138
graph_dit/metrics/abstract_metrics.py
Normal file
138
graph_dit/metrics/abstract_metrics.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics import Metric, MeanSquaredError
|
||||
|
||||
|
||||
class TrainAbstractMetricsDiscrete(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def log_epoch_metrics(self, current_epoch):
|
||||
pass
|
||||
|
||||
|
||||
class TrainAbstractMetrics(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def log_epoch_metrics(self, current_epoch):
|
||||
pass
|
||||
|
||||
|
||||
class SumExceptBatchMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, values) -> None:
|
||||
self.total_value += torch.sum(values)
|
||||
self.total_samples += values.shape[0]
|
||||
|
||||
def compute(self):
|
||||
return self.total_value / self.total_samples
|
||||
|
||||
|
||||
class SumExceptBatchMSE(MeanSquaredError):
|
||||
def update(self, preds: Tensor, target: Tensor) -> None:
|
||||
"""Update state with predictions and targets.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model
|
||||
target: Ground truth values
|
||||
"""
|
||||
assert preds.shape == target.shape
|
||||
sum_squared_error, n_obs = self._mean_squared_error_update(preds, target)
|
||||
|
||||
self.sum_squared_error += sum_squared_error
|
||||
self.total += n_obs
|
||||
|
||||
def _mean_squared_error_update(self, preds: Tensor, target: Tensor):
|
||||
""" Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input
|
||||
tensors.
|
||||
preds: Predicted tensor
|
||||
target: Ground truth tensor
|
||||
"""
|
||||
diff = preds - target
|
||||
sum_squared_error = torch.sum(diff * diff)
|
||||
n_obs = preds.shape[0]
|
||||
return sum_squared_error, n_obs
|
||||
|
||||
|
||||
class SumExceptBatchKL(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, p, q) -> None:
|
||||
self.total_value += F.kl_div(q, p, reduction='sum')
|
||||
self.total_samples += p.size(0)
|
||||
|
||||
def compute(self):
|
||||
return self.total_value / self.total_samples
|
||||
|
||||
|
||||
class CrossEntropyMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: Tensor, target: Tensor, weight=None) -> None:
|
||||
""" Update state with predictions and targets.
|
||||
preds: Predictions from model (bs * n, d) or (bs * n * n, d)
|
||||
target: Ground truth values (bs * n, d) or (bs * n * n, d). """
|
||||
target = torch.argmax(target, dim=-1)
|
||||
if weight is not None:
|
||||
weight = weight.type_as(preds)
|
||||
output = F.cross_entropy(preds, target, weight = weight, reduction='sum')
|
||||
else:
|
||||
output = F.cross_entropy(preds, target, reduction='sum')
|
||||
self.total_ce += output
|
||||
self.total_samples += preds.size(0)
|
||||
|
||||
def compute(self):
|
||||
return self.total_ce / self.total_samples
|
||||
|
||||
|
||||
class ProbabilityMetric(Metric):
|
||||
def __init__(self):
|
||||
""" This metric is used to track the marginal predicted probability of a class during training. """
|
||||
super().__init__()
|
||||
self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: Tensor) -> None:
|
||||
self.prob += preds.sum()
|
||||
self.total += preds.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.prob / self.total
|
||||
|
||||
|
||||
class NLL(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, batch_nll) -> None:
|
||||
self.total_nll += torch.sum(batch_nll)
|
||||
self.total_samples += batch_nll.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.total_nll / self.total_samples
|
BIN
graph_dit/metrics/fpscores.pkl.gz
Normal file
BIN
graph_dit/metrics/fpscores.pkl.gz
Normal file
Binary file not shown.
138
graph_dit/metrics/molecular_metrics_sampling.py
Normal file
138
graph_dit/metrics/molecular_metrics_sampling.py
Normal file
@@ -0,0 +1,138 @@
|
||||
### packages for visualization
|
||||
from analysis.rdkit_functions import compute_molecular_metrics
|
||||
from mini_moses.metrics.metrics import compute_intermediate_statistics
|
||||
from metrics.property_metric import TaskModel
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import os
|
||||
import csv
|
||||
import time
|
||||
|
||||
def result_to_csv(path, dict_data):
|
||||
file_exists = os.path.exists(path)
|
||||
log_name = dict_data.pop("log_name", None)
|
||||
if log_name is None:
|
||||
raise ValueError("The provided dictionary must contain a 'log_name' key.")
|
||||
field_names = ["log_name"] + list(dict_data.keys())
|
||||
dict_data["log_name"] = log_name
|
||||
with open(path, "a", newline="") as file:
|
||||
writer = csv.DictWriter(file, fieldnames=field_names)
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
writer.writerow(dict_data)
|
||||
|
||||
|
||||
class SamplingMolecularMetrics(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_infos,
|
||||
train_smiles,
|
||||
reference_smiles,
|
||||
n_jobs=1,
|
||||
device="cpu",
|
||||
batch_size=512,
|
||||
):
|
||||
super().__init__()
|
||||
self.task_name = dataset_infos.task
|
||||
self.dataset_infos = dataset_infos
|
||||
self.active_atoms = dataset_infos.active_atoms
|
||||
self.train_smiles = train_smiles
|
||||
|
||||
if reference_smiles is not None:
|
||||
print(
|
||||
f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
|
||||
)
|
||||
start_time = time.time()
|
||||
self.stat_ref = compute_intermediate_statistics(
|
||||
reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
|
||||
)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print(
|
||||
f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
|
||||
)
|
||||
else:
|
||||
self.stat_ref = None
|
||||
|
||||
self.comput_config = {
|
||||
"n_jobs": n_jobs,
|
||||
"device": device,
|
||||
"batch_size": batch_size,
|
||||
}
|
||||
|
||||
self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None}
|
||||
for cur_task in dataset_infos.task.split("-")[:]:
|
||||
# print('loading evaluator for task', cur_task)
|
||||
model_path = os.path.join(
|
||||
dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib"
|
||||
)
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
evaluator = TaskModel(model_path, cur_task)
|
||||
self.task_evaluator[cur_task] = evaluator
|
||||
|
||||
def forward(self, molecules, targets, name, current_epoch, val_counter, test=False):
|
||||
if isinstance(targets, list):
|
||||
targets_cat = torch.cat(targets, dim=0)
|
||||
targets_np = targets_cat.detach().cpu().numpy()
|
||||
else:
|
||||
targets_np = targets.detach().cpu().numpy()
|
||||
|
||||
unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics(
|
||||
molecules,
|
||||
targets_np,
|
||||
self.train_smiles,
|
||||
self.stat_ref,
|
||||
self.dataset_infos,
|
||||
self.task_evaluator,
|
||||
self.comput_config,
|
||||
)
|
||||
|
||||
if test:
|
||||
file_name = "final_smiles.txt"
|
||||
with open(file_name, "w") as fp:
|
||||
all_tasks_name = list(self.task_evaluator.keys())
|
||||
all_tasks_name = all_tasks_name.copy()
|
||||
if 'meta_taskname' in all_tasks_name:
|
||||
all_tasks_name.remove('meta_taskname')
|
||||
if 'scs' in all_tasks_name:
|
||||
all_tasks_name.remove('scs')
|
||||
|
||||
all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name])
|
||||
fp.write(all_tasks_str + "\n")
|
||||
for i, smiles in enumerate(all_smiles):
|
||||
if targets_log is not None:
|
||||
all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name])
|
||||
fp.write(all_result_str + "\n")
|
||||
else:
|
||||
fp.write("%s\n" % smiles)
|
||||
print("All smiles saved")
|
||||
else:
|
||||
result_path = os.path.join(os.getcwd(), f"graphs/{name}")
|
||||
os.makedirs(result_path, exist_ok=True)
|
||||
text_path = os.path.join(
|
||||
result_path,
|
||||
f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt",
|
||||
)
|
||||
textfile = open(text_path, "w")
|
||||
for smiles in unique_smiles:
|
||||
textfile.write(smiles + "\n")
|
||||
textfile.close()
|
||||
|
||||
all_logs = all_metrics
|
||||
if test:
|
||||
all_logs["log_name"] = "test"
|
||||
else:
|
||||
all_logs["log_name"] = (
|
||||
"epoch" + str(current_epoch) + "_batch" + str(val_counter)
|
||||
)
|
||||
|
||||
result_to_csv("output.csv", all_logs)
|
||||
return all_smiles
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
126
graph_dit/metrics/molecular_metrics_train.py
Normal file
126
graph_dit/metrics/molecular_metrics_train.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
from torchmetrics import Metric, MetricCollection
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
class CEPerClass(Metric):
|
||||
full_state_update = False
|
||||
def __init__(self, class_id):
|
||||
super().__init__()
|
||||
self.class_id = class_id
|
||||
self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
self.binary_cross_entropy = torch.nn.BCELoss(reduction='sum')
|
||||
|
||||
def update(self, preds: Tensor, target: Tensor) -> None:
|
||||
"""Update state with predictions and targets.
|
||||
Args:
|
||||
preds: Predictions from model (bs, n, d) or (bs, n, n, d)
|
||||
target: Ground truth values (bs, n, d) or (bs, n, n, d)
|
||||
"""
|
||||
target = target.reshape(-1, target.shape[-1])
|
||||
mask = (target != 0.).any(dim=-1)
|
||||
|
||||
prob = self.softmax(preds)[..., self.class_id]
|
||||
prob = prob.flatten()[mask]
|
||||
|
||||
target = target[:, self.class_id]
|
||||
target = target[mask]
|
||||
|
||||
output = self.binary_cross_entropy(prob, target)
|
||||
|
||||
self.total_ce += output
|
||||
self.total_samples += prob.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.total_ce / self.total_samples
|
||||
|
||||
|
||||
class AtomCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
class NoBondCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class SingleCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class DoubleCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class TripleCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class AromaticCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class AtomMetricsCE(MetricCollection):
|
||||
def __init__(self, active_atoms):
|
||||
metrics_list = []
|
||||
|
||||
for i, atom_type in enumerate(active_atoms):
|
||||
metrics_list.append(type(f'{atom_type}_CE', (AtomCE,), {})(i))
|
||||
|
||||
super().__init__(metrics_list)
|
||||
|
||||
|
||||
class BondMetricsCE(MetricCollection):
|
||||
def __init__(self):
|
||||
ce_no_bond = NoBondCE(0)
|
||||
ce_SI = SingleCE(1)
|
||||
ce_DO = DoubleCE(2)
|
||||
ce_TR = TripleCE(3)
|
||||
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
|
||||
|
||||
|
||||
class TrainMolecularMetricsDiscrete(nn.Module):
|
||||
def __init__(self, dataset_infos):
|
||||
super().__init__()
|
||||
active_atoms = dataset_infos.active_atoms
|
||||
self.train_atom_metrics = AtomMetricsCE(active_atoms=active_atoms)
|
||||
self.train_bond_metrics = BondMetricsCE()
|
||||
|
||||
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
|
||||
self.train_atom_metrics(masked_pred_X, true_X)
|
||||
self.train_bond_metrics(masked_pred_E, true_E)
|
||||
if log:
|
||||
to_log = {}
|
||||
for key, val in self.train_atom_metrics.compute().items():
|
||||
to_log['train/' + key] = val.item()
|
||||
for key, val in self.train_bond_metrics.compute().items():
|
||||
to_log['train/' + key] = val.item()
|
||||
|
||||
def reset(self):
|
||||
for metric in [self.train_atom_metrics, self.train_bond_metrics]:
|
||||
metric.reset()
|
||||
|
||||
def log_epoch_metrics(self, current_epoch, log=True):
|
||||
epoch_atom_metrics = self.train_atom_metrics.compute()
|
||||
epoch_bond_metrics = self.train_bond_metrics.compute()
|
||||
|
||||
to_log = {}
|
||||
for key, val in epoch_atom_metrics.items():
|
||||
to_log['train_epoch/' + key] = val.item()
|
||||
for key, val in epoch_bond_metrics.items():
|
||||
to_log['train_epoch/' + key] = val.item()
|
||||
|
||||
for key, val in epoch_atom_metrics.items():
|
||||
epoch_atom_metrics[key] = round(val.item(),4)
|
||||
for key, val in epoch_bond_metrics.items():
|
||||
epoch_bond_metrics[key] = round(val.item(),4)
|
||||
|
||||
if log:
|
||||
print(f"Epoch {current_epoch}: {epoch_atom_metrics} -- {epoch_bond_metrics}")
|
||||
|
201
graph_dit/metrics/property_metric.py
Normal file
201
graph_dit/metrics/property_metric.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import math, os
|
||||
import pickle
|
||||
import os.path as op
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from joblib import dump, load
|
||||
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
||||
from sklearn.metrics import mean_absolute_error, roc_auc_score
|
||||
|
||||
|
||||
from rdkit import Chem
|
||||
from rdkit import rdBase
|
||||
from rdkit.Chem import AllChem
|
||||
from rdkit import DataStructs
|
||||
from rdkit.Chem import rdMolDescriptors
|
||||
rdBase.DisableLog('rdApp.error')
|
||||
|
||||
task_to_colname = {
|
||||
'hiv_b': 'HIV_active',
|
||||
'bace_b': 'Class',
|
||||
'bbbp_b': 'p_np',
|
||||
'O2': 'O2',
|
||||
'N2': 'N2',
|
||||
'CO2': 'CO2',
|
||||
}
|
||||
|
||||
tasktype_name = {
|
||||
'hiv_b': 'classification',
|
||||
'bace_b': 'classification',
|
||||
'bbbp_b': 'classification',
|
||||
'O2': 'regression',
|
||||
'N2': 'regression',
|
||||
'CO2': 'regression',
|
||||
}
|
||||
|
||||
class TaskModel():
|
||||
"""Scores based on an ECFP classifier."""
|
||||
def __init__(self, model_path, task_name):
|
||||
task_type = tasktype_name[task_name]
|
||||
self.task_name = task_name
|
||||
self.task_type = task_type
|
||||
self.model_path = model_path
|
||||
self.metric_func = roc_auc_score if 'classification' in self.task_type else mean_absolute_error
|
||||
|
||||
try:
|
||||
self.model = load(model_path)
|
||||
print(self.task_name, ' evaluator loaded')
|
||||
except:
|
||||
print(self.task_name, ' evaluator not found, training new one...')
|
||||
if 'classification' in task_type:
|
||||
self.model = RandomForestClassifier(random_state=0)
|
||||
elif 'regression' in task_type:
|
||||
self.model = RandomForestRegressor(random_state=0)
|
||||
perfermance = self.train()
|
||||
dump(self.model, model_path)
|
||||
print('Oracle peformance: ', perfermance)
|
||||
|
||||
def train(self):
|
||||
data_path = os.path.dirname(self.model_path)
|
||||
data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz')
|
||||
df = pd.read_csv(data_path)
|
||||
col_name = task_to_colname[self.task_name]
|
||||
y = df[col_name].to_numpy()
|
||||
x_smiles = df['smiles'].to_numpy()
|
||||
mask = ~np.isnan(y)
|
||||
y = y[mask]
|
||||
|
||||
if 'classification' in self.task_type:
|
||||
y = y.astype(int)
|
||||
|
||||
x_smiles = x_smiles[mask]
|
||||
x_fps = []
|
||||
mask = []
|
||||
for i,smiles in enumerate(x_smiles):
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
mask.append( int(mol is not None) )
|
||||
fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048))
|
||||
x_fps.append(fp)
|
||||
x_fps = np.concatenate(x_fps, axis=0)
|
||||
self.model.fit(x_fps, y)
|
||||
y_pred = self.model.predict(x_fps)
|
||||
perf = self.metric_func(y, y_pred)
|
||||
print(f'{self.task_name} performance: {perf}')
|
||||
return perf
|
||||
|
||||
def __call__(self, smiles_list):
|
||||
fps = []
|
||||
mask = []
|
||||
for i,smiles in enumerate(smiles_list):
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
mask.append( int(mol is not None) )
|
||||
fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048))
|
||||
fps.append(fp)
|
||||
|
||||
fps = np.concatenate(fps, axis=0)
|
||||
if 'classification' in self.task_type:
|
||||
scores = self.model.predict_proba(fps)[:, 1]
|
||||
else:
|
||||
scores = self.model.predict(fps)
|
||||
scores = scores * np.array(mask)
|
||||
return np.float32(scores)
|
||||
|
||||
@classmethod
|
||||
def fingerprints_from_mol(cls, mol): # use ECFP4
|
||||
features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
|
||||
features = np.zeros((1,))
|
||||
DataStructs.ConvertToNumpyArray(features_vec, features)
|
||||
return features.reshape(1, -1)
|
||||
|
||||
###### SAS Score ######
|
||||
_fscores = None
|
||||
|
||||
def readFragmentScores(name='fpscores'):
|
||||
import gzip
|
||||
global _fscores
|
||||
# generate the full path filename:
|
||||
if name == "fpscores":
|
||||
name = op.join(op.dirname(__file__), name)
|
||||
data = pickle.load(gzip.open('%s.pkl.gz' % name))
|
||||
outDict = {}
|
||||
for i in data:
|
||||
for j in range(1, len(i)):
|
||||
outDict[i[j]] = float(i[0])
|
||||
_fscores = outDict
|
||||
|
||||
def numBridgeheadsAndSpiro(mol, ri=None):
|
||||
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
|
||||
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
|
||||
return nBridgehead, nSpiro
|
||||
|
||||
def calculateSAS(smiles_list):
|
||||
scores = []
|
||||
for i, smiles in enumerate(smiles_list):
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
score = calculateScore(mol)
|
||||
scores.append(score)
|
||||
return np.float32(scores)
|
||||
|
||||
def calculateScore(m):
|
||||
if _fscores is None:
|
||||
readFragmentScores()
|
||||
|
||||
# fragment score
|
||||
fp = rdMolDescriptors.GetMorganFingerprint(m,
|
||||
2) # <- 2 is the *radius* of the circular fingerprint
|
||||
fps = fp.GetNonzeroElements()
|
||||
score1 = 0.
|
||||
nf = 0
|
||||
for bitId, v in fps.items():
|
||||
nf += v
|
||||
sfp = bitId
|
||||
score1 += _fscores.get(sfp, -4) * v
|
||||
score1 /= nf
|
||||
|
||||
# features score
|
||||
nAtoms = m.GetNumAtoms()
|
||||
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
|
||||
ri = m.GetRingInfo()
|
||||
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
|
||||
nMacrocycles = 0
|
||||
for x in ri.AtomRings():
|
||||
if len(x) > 8:
|
||||
nMacrocycles += 1
|
||||
|
||||
sizePenalty = nAtoms**1.005 - nAtoms
|
||||
stereoPenalty = math.log10(nChiralCenters + 1)
|
||||
spiroPenalty = math.log10(nSpiro + 1)
|
||||
bridgePenalty = math.log10(nBridgeheads + 1)
|
||||
macrocyclePenalty = 0.
|
||||
# ---------------------------------------
|
||||
# This differs from the paper, which defines:
|
||||
# macrocyclePenalty = math.log10(nMacrocycles+1)
|
||||
# This form generates better results when 2 or more macrocycles are present
|
||||
if nMacrocycles > 0:
|
||||
macrocyclePenalty = math.log10(2)
|
||||
|
||||
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
|
||||
|
||||
# correction for the fingerprint density
|
||||
# not in the original publication, added in version 1.1
|
||||
# to make highly symmetrical molecules easier to synthetise
|
||||
score3 = 0.
|
||||
if nAtoms > len(fps):
|
||||
score3 = math.log(float(nAtoms) / len(fps)) * .5
|
||||
|
||||
sascore = score1 + score2 + score3
|
||||
|
||||
# need to transform "raw" value into scale between 1 and 10
|
||||
min = -4.0
|
||||
max = 2.5
|
||||
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
|
||||
# smooth the 10-end
|
||||
if sascore > 8.:
|
||||
sascore = 8. + math.log(sascore + 1. - 9.)
|
||||
if sascore > 10.:
|
||||
sascore = 10.0
|
||||
elif sascore < 1.:
|
||||
sascore = 1.0
|
||||
|
||||
return sascore
|
94
graph_dit/metrics/train_loss.py
Normal file
94
graph_dit/metrics/train_loss.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from metrics.abstract_metrics import CrossEntropyMetric
|
||||
from torchmetrics import Metric, MeanSquaredError
|
||||
|
||||
# from 2:He to 119:*
|
||||
valencies_check = [0, 1, 2, 3, 4, 3, 2, 1, 0, 1, 2, 6, 6, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 6, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
||||
valencies_check = torch.tensor(valencies_check)
|
||||
|
||||
weight_check = [4.003, 6.941, 9.012, 10.812, 12.011, 14.007, 15.999, 18.998, 20.18, 22.99, 24.305, 26.982, 28.086, 30.974, 32.067, 35.453, 39.948, 39.098, 40.078, 44.956, 47.867, 50.942, 51.996, 54.938, 55.845, 58.933, 58.693, 63.546, 65.39, 69.723, 72.61, 74.922, 78.96, 79.904, 83.8, 85.468, 87.62, 88.906, 91.224, 92.906, 95.94, 98.0, 101.07, 102.906, 106.42, 107.868, 112.412, 114.818, 118.711, 121.76, 127.6, 126.904, 131.29, 132.905, 137.328, 138.906, 140.116, 140.908, 144.24, 145.0, 150.36, 151.964, 157.25, 158.925, 162.5, 164.93, 167.26, 168.934, 173.04, 174.967, 178.49, 180.948, 183.84, 186.207, 190.23, 192.217, 195.078, 196.967, 200.59, 204.383, 207.2, 208.98, 209.0, 210.0, 222.0, 223.0, 226.0, 227.0, 232.038, 231.036, 238.029, 237.0, 244.0, 243.0, 247.0, 247.0, 251.0, 252.0, 257.0, 258.0, 259.0, 262.0, 267.0, 268.0, 269.0, 270.0, 269.0, 278.0, 281.0, 281.0, 285.0, 284.0, 289.0, 288.0, 293.0, 292.0, 294.0, 294.0]
|
||||
weight_check = torch.tensor(weight_check)
|
||||
|
||||
class AtomWeightMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_loss', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
global weight_check
|
||||
self.weight_check = weight_check
|
||||
|
||||
def update(self, X, Y):
|
||||
atom_pred_num = X.argmax(dim=-1)
|
||||
atom_real_num = Y.argmax(dim=-1)
|
||||
self.weight_check = self.weight_check.type_as(X)
|
||||
|
||||
pred_weight = self.weight_check[atom_pred_num]
|
||||
real_weight = self.weight_check[atom_real_num]
|
||||
|
||||
lss = 0
|
||||
lss += torch.abs(pred_weight.sum(dim=-1) - real_weight.sum(dim=-1)).sum()
|
||||
self.total_loss += lss
|
||||
self.total_samples += X.size(0)
|
||||
|
||||
def compute(self):
|
||||
return self.total_loss / self.total_samples
|
||||
|
||||
|
||||
class TrainLossDiscrete(nn.Module):
|
||||
""" Train with Cross entropy"""
|
||||
def __init__(self, lambda_train, weight_node=None, weight_edge=None):
|
||||
super().__init__()
|
||||
self.node_loss = CrossEntropyMetric()
|
||||
self.edge_loss = CrossEntropyMetric()
|
||||
self.weight_loss = AtomWeightMetric()
|
||||
|
||||
self.y_loss = MeanSquaredError()
|
||||
self.lambda_train = lambda_train
|
||||
|
||||
def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, node_mask, log: bool):
|
||||
""" Compute train metrics
|
||||
masked_pred_X : tensor -- (bs, n, dx)
|
||||
masked_pred_E : tensor -- (bs, n, n, de)
|
||||
pred_y : tensor -- (bs, )
|
||||
true_X : tensor -- (bs, n, dx)
|
||||
true_E : tensor -- (bs, n, n, de)
|
||||
true_y : tensor -- (bs, )
|
||||
log : boolean. """
|
||||
|
||||
loss_weight = self.weight_loss(masked_pred_X, true_X)
|
||||
|
||||
true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx)
|
||||
true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de)
|
||||
masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1))) # (bs * n, dx)
|
||||
masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1))) # (bs * n * n, de)
|
||||
|
||||
# Remove masked rows
|
||||
mask_X = (true_X != 0.).any(dim=-1)
|
||||
mask_E = (true_E != 0.).any(dim=-1)
|
||||
|
||||
flat_true_X = true_X[mask_X, :]
|
||||
flat_pred_X = masked_pred_X[mask_X, :]
|
||||
|
||||
flat_true_E = true_E[mask_E, :]
|
||||
flat_pred_E = masked_pred_E[mask_E, :]
|
||||
|
||||
loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0
|
||||
loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0
|
||||
|
||||
return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + loss_weight
|
||||
|
||||
def reset(self):
|
||||
for metric in [self.node_loss, self.edge_loss, self.y_loss]:
|
||||
metric.reset()
|
||||
|
||||
def log_epoch_metrics(self, current_epoch, start_epoch_time, log=True):
|
||||
epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1
|
||||
epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1
|
||||
epoch_weight_loss = self.weight_loss.compute() if self.weight_loss.total_samples > 0 else -1
|
||||
|
||||
if log:
|
||||
print(f"Epoch {current_epoch} finished: X_CE: {epoch_node_loss :.4f} -- E_CE: {epoch_edge_loss :.4f} "
|
||||
f"Weight: {epoch_weight_loss :.4f} "
|
||||
f"-- Time taken {time.time() - start_epoch_time:.1f}s ")
|
Reference in New Issue
Block a user