update_name

This commit is contained in:
gang liu
2024-05-25 15:32:36 -04:00
parent a6bd0117d4
commit 2c00828630
28 changed files with 178 additions and 19 deletions

View File

View File

@@ -0,0 +1,411 @@
from rdkit import Chem, RDLogger
RDLogger.DisableLog('rdApp.*')
from fcd_torch import FCD as FCDMetric
from mini_moses.metrics.metrics import FragMetric, internal_diversity
from mini_moses.metrics.utils import get_mol, mapper
import re
import time
import random
random.seed(0)
import numpy as np
from multiprocessing import Pool
import torch
from metrics.property_metric import calculateSAS
bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC]
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}
bd_dict_x = {'O2-N2': [5.00E+04, 1.00E-03]}
bd_dict_y = {'O2-N2': [5.00E+04/2.78E+04, 1.00E-03/2.43E-05]}
selectivity = ['O2-N2']
a_dict = {}
b_dict = {}
for prop_name in selectivity:
x1, x2 = np.log10(bd_dict_x[prop_name][0]), np.log10(bd_dict_x[prop_name][1])
y1, y2 = np.log10(bd_dict_y[prop_name][0]), np.log10(bd_dict_y[prop_name][1])
a = (y1-y2)/(x1-x2)
b = y1-a*x1
a_dict[prop_name] = a
b_dict[prop_name] = b
def selectivity_evaluation(gas1, gas2, prop_name):
x = np.log10(np.array(gas1))
y = np.log10(np.array(gas1) / np.array(gas2))
upper = (y - (a_dict[prop_name] * x + b_dict[prop_name])) > 0
return upper
class BasicMolecularMetrics(object):
def __init__(self, atom_decoder, train_smiles=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512):
self.dataset_smiles_list = train_smiles
self.atom_decoder = atom_decoder
self.n_jobs = n_jobs
self.device = device
self.batch_size = batch_size
self.stat_ref = stat_ref
self.task_evaluator = task_evaluator
def compute_relaxed_validity(self, generated, ensure_connected):
valid = []
num_components = []
all_smiles = []
valid_mols = []
covered_atoms = set()
direct_valid_count = 0
for graph in generated:
atom_types, edge_types = graph
mol = build_molecule_with_partial_charges(atom_types, edge_types, self.atom_decoder)
direct_valid_flag = True if check_mol(mol, largest_connected_comp=True) is not None else False
if direct_valid_flag:
direct_valid_count += 1
if not ensure_connected:
mol_conn, _ = correct_mol(mol, connection=True)
mol = mol_conn if mol_conn is not None else correct_mol(mol, connection=False)[0]
else: # ensure fully connected
mol, _ = correct_mol(mol, connection=True)
smiles = mol2smiles(mol)
mol = get_mol(smiles)
try:
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
num_components.append(len(mol_frags))
largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
smiles = mol2smiles(largest_mol)
if smiles is not None and largest_mol is not None and len(smiles) > 1 and Chem.MolFromSmiles(smiles) is not None:
valid_mols.append(largest_mol)
valid.append(smiles)
for atom in largest_mol.GetAtoms():
covered_atoms.add(atom.GetSymbol())
all_smiles.append(smiles)
else:
all_smiles.append(None)
except Exception as e:
# print(f"An error occurred: {e}")
all_smiles.append(None)
return valid, len(valid) / len(generated), direct_valid_count / len(generated), np.array(num_components), all_smiles, covered_atoms
def evaluate(self, generated, targets, ensure_connected, active_atoms=None):
""" generated: list of pairs (positions: n x 3, atom_types: n [int])
the positions and atom types should already be masked. """
valid, validity, nc_validity, num_components, all_smiles, covered_atoms = self.compute_relaxed_validity(generated, ensure_connected=ensure_connected)
nc_mu = num_components.mean() if len(num_components) > 0 else 0
nc_min = num_components.min() if len(num_components) > 0 else 0
nc_max = num_components.max() if len(num_components) > 0 else 0
len_active = len(active_atoms) if active_atoms is not None else 1
cover_str = f"Cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}"
print(f"Validity over {len(generated)} molecules: {validity * 100 :.2f}% (w/o correction: {nc_validity * 100 :.2f}%), cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}")
print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}")
if validity > 0:
dist_metrics = {'cover_str': cover_str ,'validity': validity, 'validity_nc': nc_validity}
unique = list(set(valid))
close_pool = False
if self.n_jobs != 1:
pool = Pool(self.n_jobs)
close_pool = True
else:
pool = 1
valid_mols = mapper(pool)(get_mol, valid)
dist_metrics['interval_diversity'] = internal_diversity(valid_mols, pool, device=self.device)
start_time = time.time()
if self.stat_ref is not None:
kwargs = {'n_jobs': pool, 'device': self.device, 'batch_size': self.batch_size}
kwargs_fcd = {'n_jobs': self.n_jobs, 'device': self.device, 'batch_size': self.batch_size}
try:
dist_metrics['sim/Frag'] = FragMetric(**kwargs)(gen=valid_mols, pref=self.stat_ref['Frag'])
except:
print('error: ', 'pool', pool)
print('valid_mols: ', valid_mols)
dist_metrics['dist/FCD'] = FCDMetric(**kwargs_fcd)(gen=valid, pref=self.stat_ref['FCD'])
if self.task_evaluator is not None:
evaluation_list = list(self.task_evaluator.keys())
evaluation_list = evaluation_list.copy()
assert 'meta_taskname' in evaluation_list
meta_taskname = self.task_evaluator['meta_taskname']
evaluation_list.remove('meta_taskname')
meta_split = meta_taskname.split('-')
valid_index = np.array([True if smiles else False for smiles in all_smiles])
targets_log = {}
for i, name in enumerate(evaluation_list):
targets_log[f'input_{name}'] = np.array([float('nan')] * len(valid_index))
targets_log[f'input_{name}'] = targets[:, i]
targets = targets[valid_index]
if len(meta_split) == 2:
cached_perm = {meta_split[0]: None, meta_split[1]: None}
for i, name in enumerate(evaluation_list):
if name == 'scs':
continue
elif name == 'sas':
scores = calculateSAS(valid)
else:
scores = self.task_evaluator[name](valid)
targets_log[f'output_{name}'] = np.array([float('nan')] * len(valid_index))
targets_log[f'output_{name}'][valid_index] = scores
if name in ['O2', 'N2', 'CO2']:
if len(meta_split) == 2:
cached_perm[name] = scores
scores, cur_targets = np.log10(scores), np.log10(targets[:, i])
dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - cur_targets))
elif name == 'sas':
dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - targets[:, i]))
else:
true_y = targets[:, i]
predicted_labels = (scores >= 0.5).astype(int)
acc = (predicted_labels == true_y).sum() / len(true_y)
dist_metrics[f'{name}/acc'] = acc
if len(meta_split) == 2:
if cached_perm[meta_split[0]] is not None and cached_perm[meta_split[1]] is not None:
task_name = self.task_evaluator['meta_taskname']
upper = selectivity_evaluation(cached_perm[meta_split[0]], cached_perm[meta_split[1]], task_name)
dist_metrics[f'selectivity/{task_name}'] = np.sum(upper)
end_time = time.time()
elapsed_time = end_time - start_time
max_key_length = max(len(key) for key in dist_metrics)
print(f'Details over {len(valid)} ({len(generated)}) valid (total) molecules, calculating metrics using {elapsed_time:.2f} s:')
strs = ''
for i, (key, value) in enumerate(dist_metrics.items()):
if isinstance(value, (int, float, np.floating, np.integer)):
strs = strs + f'{key:>{max_key_length}}:{value:<7.4f}\t'
if i % 4 == 3:
strs = strs + '\n'
print(strs)
if close_pool:
pool.close()
pool.join()
else:
unique = []
dist_metrics = {}
targets_log = None
return unique, dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_smiles, dist_metrics, targets_log
def mol2smiles(mol):
if mol is None:
return None
try:
Chem.SanitizeMol(mol)
except ValueError:
return None
return Chem.MolToSmiles(mol)
def build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder, verbose=False):
if verbose:
print("\nbuilding new molecule")
mol = Chem.RWMol()
for atom in atom_types:
a = Chem.Atom(atom_decoder[atom.item()])
mol.AddAtom(a)
if verbose:
print("Atom added: ", atom.item(), atom_decoder[atom.item()])
edge_types = torch.triu(edge_types)
all_bonds = torch.nonzero(edge_types)
for i, bond in enumerate(all_bonds):
if bond[0].item() != bond[1].item():
mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()])
if verbose:
print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(),
bond_dict[edge_types[bond[0], bond[1]].item()])
# add formal charge to atom: e.g. [O+], [N+], [S+]
# not support [O-], [N-], [S-], [NH+] etc.
flag, atomid_valence = check_valency(mol)
if verbose:
print("flag, valence", flag, atomid_valence)
if flag:
continue
else:
if len(atomid_valence) == 2:
idx = atomid_valence[0]
v = atomid_valence[1]
an = mol.GetAtomWithIdx(idx).GetAtomicNum()
if verbose:
print("atomic num of atom with a large valence", an)
if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
mol.GetAtomWithIdx(idx).SetFormalCharge(1)
# print("Formal charge added")
else:
continue
return mol
def check_valency(mol):
try:
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
return True, None
except ValueError as e:
e = str(e)
p = e.find('#')
e_sub = e[p:]
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
return False, atomid_valence
def correct_mol(mol, connection=False):
#####
no_correct = False
flag, _ = check_valency(mol)
if flag:
no_correct = True
while True:
if connection:
mol_conn = connect_fragments(mol)
# if mol_conn is not None:
mol = mol_conn
if mol is None:
return None, no_correct
flag, atomid_valence = check_valency(mol)
if flag:
break
else:
try:
assert len(atomid_valence) == 2
idx = atomid_valence[0]
v = atomid_valence[1]
queue = []
check_idx = 0
for b in mol.GetAtomWithIdx(idx).GetBonds():
type = int(b.GetBondType())
queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
if type == 12:
check_idx += 1
queue.sort(key=lambda tup: tup[1], reverse=True)
if queue[-1][1] == 12:
return None, no_correct
elif len(queue) > 0:
start = queue[check_idx][2]
end = queue[check_idx][3]
t = queue[check_idx][1] - 1
mol.RemoveBond(start, end)
if t >= 1:
mol.AddBond(start, end, bond_dict[t])
except Exception as e:
# print(f"An error occurred in correction: {e}")
return None, no_correct
return mol, no_correct
def check_mol(m, largest_connected_comp=True):
if m is None:
return None
sm = Chem.MolToSmiles(m, isomericSmiles=True)
if largest_connected_comp and '.' in sm:
vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.')
vsm.sort(key=lambda tup: tup[1], reverse=True)
mol = Chem.MolFromSmiles(vsm[0][0])
else:
mol = Chem.MolFromSmiles(sm)
return mol
##### connect fragements
def select_atom_with_available_valency(frag):
atoms = list(frag.GetAtoms())
random.shuffle(atoms)
for atom in atoms:
if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0:
return atom
return None
def select_atoms_with_available_valency(frag):
return [atom for atom in frag.GetAtoms() if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0]
def try_to_connect_fragments(combined_mol, frag, atom1, atom2):
# Make copies of the molecules to try the connection
trial_combined_mol = Chem.RWMol(combined_mol)
trial_frag = Chem.RWMol(frag)
# Add the new fragment to the combined molecule with new indices
new_indices = {atom.GetIdx(): trial_combined_mol.AddAtom(atom) for atom in trial_frag.GetAtoms()}
# Add the bond between the suitable atoms from each fragment
trial_combined_mol.AddBond(atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE)
# Adjust the hydrogen count of the connected atoms
for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]:
atom = trial_combined_mol.GetAtomWithIdx(atom_idx)
num_h = atom.GetTotalNumHs()
atom.SetNumExplicitHs(max(0, num_h - 1))
# Add bonds for the new fragment
for bond in trial_frag.GetBonds():
trial_combined_mol.AddBond(new_indices[bond.GetBeginAtomIdx()], new_indices[bond.GetEndAtomIdx()], bond.GetBondType())
# Convert to a Mol object and try to sanitize it
new_mol = Chem.Mol(trial_combined_mol)
try:
Chem.SanitizeMol(new_mol)
return new_mol # Return the new valid molecule
except Chem.MolSanitizeException:
return None # If the molecule is not valid, return None
def connect_fragments(mol):
# Get the separate fragments
frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
if len(frags) < 2:
return mol
combined_mol = Chem.RWMol(frags[0])
for frag in frags[1:]:
# Select all atoms with available valency from both molecules
atoms1 = select_atoms_with_available_valency(combined_mol)
atoms2 = select_atoms_with_available_valency(frag)
# Try to connect using all combinations of available valency atoms
for atom1 in atoms1:
for atom2 in atoms2:
new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2)
if new_mol is not None:
# If a valid connection is made, update the combined molecule and break
combined_mol = new_mol
break
else:
# Continue if the inner loop didn't break (no valid connection found for atom1)
continue
# Break if the inner loop did break (valid connection found)
break
else:
# If no valid connections could be made with any of the atoms, return None
return None
return combined_mol
#### connect fragements
def compute_molecular_metrics(molecule_list, targets, train_smiles, stat_ref, dataset_info, task_evaluator, comput_config):
""" molecule_list: (dict) """
atom_decoder = dataset_info.atom_decoder
active_atoms = dataset_info.active_atoms
ensure_connected = dataset_info.ensure_connected
metrics = BasicMolecularMetrics(atom_decoder, train_smiles, stat_ref, task_evaluator, **comput_config)
evaluated_res = metrics.evaluate(molecule_list, targets, ensure_connected, active_atoms)
all_smiles = evaluated_res[-3]
all_metrics = evaluated_res[-2]
targets_log = evaluated_res[-1]
unique_smiles = evaluated_res[0]
return unique_smiles, all_smiles, all_metrics, targets_log
if __name__ == '__main__':
smiles_mol = 'C1CCC1'
print("Smiles mol %s" % smiles_mol)
chem_mol = Chem.MolFromSmiles(smiles_mol)
print(block_mol)

View File

@@ -0,0 +1,222 @@
import os
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Geometry import Point3D
from rdkit import RDLogger
import imageio
import networkx as nx
import numpy as np
import rdkit.Chem
import matplotlib.pyplot as plt
class MolecularVisualization:
def __init__(self, dataset_infos):
self.dataset_infos = dataset_infos
def mol_from_graphs(self, node_list, adjacency_matrix):
"""
Convert graphs to rdkit molecules
node_list: the nodes of a batch of nodes (bs x n)
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
"""
# dictionary to map integer value to the char of atom
atom_decoder = self.dataset_infos.atom_decoder
# [list(self.dataset_infos.atom_decoder.keys())[0]]
# create empty editable mol object
mol = Chem.RWMol()
# add atoms to mol and keep track of index
node_to_idx = {}
for i in range(len(node_list)):
if node_list[i] == -1:
continue
a = Chem.Atom(atom_decoder[int(node_list[i])])
molIdx = mol.AddAtom(a)
node_to_idx[i] = molIdx
for ix, row in enumerate(adjacency_matrix):
for iy, bond in enumerate(row):
# only traverse half the symmetric matrix
if iy <= ix:
continue
if bond == 1:
bond_type = Chem.rdchem.BondType.SINGLE
elif bond == 2:
bond_type = Chem.rdchem.BondType.DOUBLE
elif bond == 3:
bond_type = Chem.rdchem.BondType.TRIPLE
elif bond == 4:
bond_type = Chem.rdchem.BondType.AROMATIC
else:
continue
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
try:
mol = mol.GetMol()
except rdkit.Chem.KekulizeException:
print("Can't kekulize molecule")
mol = None
return mol
def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph'):
# define path to save figures
if not os.path.exists(path):
os.makedirs(path)
# visualize the final molecules
print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}")
if num_molecules_to_visualize > len(molecules):
print(f"Shortening to {len(molecules)}")
num_molecules_to_visualize = len(molecules)
for i in range(num_molecules_to_visualize):
file_path = os.path.join(path, 'molecule_{}.png'.format(i))
mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
try:
Draw.MolToFile(mol, file_path)
except rdkit.Chem.KekulizeException:
print("Can't kekulize molecule")
def visualize_chain(self, path, nodes_list, adjacency_matrix, trainer=None):
RDLogger.DisableLog('rdApp.*')
# convert graphs to the rdkit molecules
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
# find the coordinates of atoms in the final molecule
final_molecule = mols[-1]
AllChem.Compute2DCoords(final_molecule)
coords = []
for i, atom in enumerate(final_molecule.GetAtoms()):
positions = final_molecule.GetConformer().GetAtomPosition(i)
coords.append((positions.x, positions.y, positions.z))
# align all the molecules
for i, mol in enumerate(mols):
AllChem.Compute2DCoords(mol)
conf = mol.GetConformer()
for j, atom in enumerate(mol.GetAtoms()):
x, y, z = coords[j]
conf.SetAtomPosition(j, Point3D(x, y, z))
# draw gif
save_paths = []
num_frams = nodes_list.shape[0]
for frame in range(num_frams):
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
Draw.MolToFile(mols[frame], file_name, size=(300, 300), legend=f"Frame {frame}")
save_paths.append(file_name)
imgs = [imageio.imread(fn) for fn in save_paths]
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
imgs.extend([imgs[-1]] * 10)
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
# draw grid image
try:
img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200))
img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1])))
except Chem.rdchem.KekulizeException:
print("Can't kekulize molecule")
return mols
def visualize_by_smiles(self, path: str, smiles_list: list, num_to_visualize: int):
os.makedirs(path, exist_ok=True)
print(f"Visualizing corrected {num_to_visualize} of {len(smiles_list)}")
if num_to_visualize > len(smiles_list):
print(f"Shortening to {len(smiles_list)}")
num_to_visualize = len(smiles_list)
for i in range(num_to_visualize):
file_path = os.path.join(path, 'molecule_corrected_{}.png'.format(i))
if smiles_list[i] is None:
continue
mol = Chem.MolFromSmiles(smiles_list[i])
try:
Draw.MolToFile(mol, file_path)
except rdkit.Chem.KekulizeException:
print("Can't kekulize molecule")
class NonMolecularVisualization:
def to_networkx(self, node_list, adjacency_matrix):
"""
Convert graphs to networkx graphs
node_list: the nodes of a batch of nodes (bs x n)
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
"""
graph = nx.Graph()
for i in range(len(node_list)):
if node_list[i] == -1:
continue
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
rows, cols = np.where(adjacency_matrix >= 1)
edges = zip(rows.tolist(), cols.tolist())
for edge in edges:
edge_type = adjacency_matrix[edge[0]][edge[1]]
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
return graph
def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=100, largest_component=False):
if largest_component:
CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
graph = CGs[0]
# Plot the graph structure with colors
if pos is None:
pos = nx.spring_layout(graph, iterations=iterations)
# Set node colors based on the eigenvectors
w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray())
vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1])
m = max(np.abs(vmin), vmax)
vmin, vmax = -m, m
plt.figure()
nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1],
cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey')
plt.tight_layout()
plt.savefig(path)
plt.close("all")
def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'):
# define path to save figures
if not os.path.exists(path):
os.makedirs(path)
# visualize the final molecules
for i in range(num_graphs_to_visualize):
file_path = os.path.join(path, 'graph_{}.png'.format(i))
graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
im = plt.imread(file_path)
def visualize_chain(self, path, nodes_list, adjacency_matrix):
# convert graphs to networkx
graphs = [self.to_networkx(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
# find the coordinates of atoms in the final molecule
final_graph = graphs[-1]
final_pos = nx.spring_layout(final_graph, seed=0)
# draw gif
save_paths = []
num_frams = nodes_list.shape[0]
for frame in range(num_frams):
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
self.visualize_non_molecule(graph=graphs[frame], pos=final_pos, path=file_name)
save_paths.append(file_name)
imgs = [imageio.imread(fn) for fn in save_paths]
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
imgs.extend([imgs[-1]] * 10)
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)