update_name
This commit is contained in:
222
graph_dit/analysis/visualization.py
Normal file
222
graph_dit/analysis/visualization.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import os
|
||||
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import Draw, AllChem
|
||||
from rdkit.Geometry import Point3D
|
||||
from rdkit import RDLogger
|
||||
import imageio
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import rdkit.Chem
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class MolecularVisualization:
|
||||
def __init__(self, dataset_infos):
|
||||
self.dataset_infos = dataset_infos
|
||||
|
||||
def mol_from_graphs(self, node_list, adjacency_matrix):
|
||||
"""
|
||||
Convert graphs to rdkit molecules
|
||||
node_list: the nodes of a batch of nodes (bs x n)
|
||||
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
||||
"""
|
||||
# dictionary to map integer value to the char of atom
|
||||
atom_decoder = self.dataset_infos.atom_decoder
|
||||
# [list(self.dataset_infos.atom_decoder.keys())[0]]
|
||||
|
||||
# create empty editable mol object
|
||||
mol = Chem.RWMol()
|
||||
|
||||
# add atoms to mol and keep track of index
|
||||
node_to_idx = {}
|
||||
for i in range(len(node_list)):
|
||||
if node_list[i] == -1:
|
||||
continue
|
||||
a = Chem.Atom(atom_decoder[int(node_list[i])])
|
||||
molIdx = mol.AddAtom(a)
|
||||
node_to_idx[i] = molIdx
|
||||
|
||||
for ix, row in enumerate(adjacency_matrix):
|
||||
for iy, bond in enumerate(row):
|
||||
# only traverse half the symmetric matrix
|
||||
if iy <= ix:
|
||||
continue
|
||||
if bond == 1:
|
||||
bond_type = Chem.rdchem.BondType.SINGLE
|
||||
elif bond == 2:
|
||||
bond_type = Chem.rdchem.BondType.DOUBLE
|
||||
elif bond == 3:
|
||||
bond_type = Chem.rdchem.BondType.TRIPLE
|
||||
elif bond == 4:
|
||||
bond_type = Chem.rdchem.BondType.AROMATIC
|
||||
else:
|
||||
continue
|
||||
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
|
||||
|
||||
try:
|
||||
mol = mol.GetMol()
|
||||
except rdkit.Chem.KekulizeException:
|
||||
print("Can't kekulize molecule")
|
||||
mol = None
|
||||
return mol
|
||||
|
||||
def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph'):
|
||||
# define path to save figures
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
# visualize the final molecules
|
||||
print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}")
|
||||
if num_molecules_to_visualize > len(molecules):
|
||||
print(f"Shortening to {len(molecules)}")
|
||||
num_molecules_to_visualize = len(molecules)
|
||||
|
||||
for i in range(num_molecules_to_visualize):
|
||||
file_path = os.path.join(path, 'molecule_{}.png'.format(i))
|
||||
mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
|
||||
try:
|
||||
Draw.MolToFile(mol, file_path)
|
||||
except rdkit.Chem.KekulizeException:
|
||||
print("Can't kekulize molecule")
|
||||
|
||||
def visualize_chain(self, path, nodes_list, adjacency_matrix, trainer=None):
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
# convert graphs to the rdkit molecules
|
||||
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
|
||||
|
||||
# find the coordinates of atoms in the final molecule
|
||||
final_molecule = mols[-1]
|
||||
AllChem.Compute2DCoords(final_molecule)
|
||||
|
||||
coords = []
|
||||
for i, atom in enumerate(final_molecule.GetAtoms()):
|
||||
positions = final_molecule.GetConformer().GetAtomPosition(i)
|
||||
coords.append((positions.x, positions.y, positions.z))
|
||||
|
||||
# align all the molecules
|
||||
for i, mol in enumerate(mols):
|
||||
AllChem.Compute2DCoords(mol)
|
||||
conf = mol.GetConformer()
|
||||
for j, atom in enumerate(mol.GetAtoms()):
|
||||
x, y, z = coords[j]
|
||||
conf.SetAtomPosition(j, Point3D(x, y, z))
|
||||
|
||||
# draw gif
|
||||
save_paths = []
|
||||
num_frams = nodes_list.shape[0]
|
||||
|
||||
for frame in range(num_frams):
|
||||
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
|
||||
Draw.MolToFile(mols[frame], file_name, size=(300, 300), legend=f"Frame {frame}")
|
||||
save_paths.append(file_name)
|
||||
|
||||
imgs = [imageio.imread(fn) for fn in save_paths]
|
||||
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
||||
imgs.extend([imgs[-1]] * 10)
|
||||
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
|
||||
|
||||
# draw grid image
|
||||
try:
|
||||
img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200))
|
||||
img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1])))
|
||||
except Chem.rdchem.KekulizeException:
|
||||
print("Can't kekulize molecule")
|
||||
return mols
|
||||
|
||||
def visualize_by_smiles(self, path: str, smiles_list: list, num_to_visualize: int):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
print(f"Visualizing corrected {num_to_visualize} of {len(smiles_list)}")
|
||||
if num_to_visualize > len(smiles_list):
|
||||
print(f"Shortening to {len(smiles_list)}")
|
||||
num_to_visualize = len(smiles_list)
|
||||
|
||||
for i in range(num_to_visualize):
|
||||
file_path = os.path.join(path, 'molecule_corrected_{}.png'.format(i))
|
||||
if smiles_list[i] is None:
|
||||
continue
|
||||
mol = Chem.MolFromSmiles(smiles_list[i])
|
||||
try:
|
||||
Draw.MolToFile(mol, file_path)
|
||||
except rdkit.Chem.KekulizeException:
|
||||
print("Can't kekulize molecule")
|
||||
|
||||
class NonMolecularVisualization:
|
||||
def to_networkx(self, node_list, adjacency_matrix):
|
||||
"""
|
||||
Convert graphs to networkx graphs
|
||||
node_list: the nodes of a batch of nodes (bs x n)
|
||||
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
||||
"""
|
||||
graph = nx.Graph()
|
||||
|
||||
for i in range(len(node_list)):
|
||||
if node_list[i] == -1:
|
||||
continue
|
||||
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
|
||||
|
||||
rows, cols = np.where(adjacency_matrix >= 1)
|
||||
edges = zip(rows.tolist(), cols.tolist())
|
||||
for edge in edges:
|
||||
edge_type = adjacency_matrix[edge[0]][edge[1]]
|
||||
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
|
||||
|
||||
return graph
|
||||
|
||||
def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=100, largest_component=False):
|
||||
if largest_component:
|
||||
CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
|
||||
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
|
||||
graph = CGs[0]
|
||||
|
||||
# Plot the graph structure with colors
|
||||
if pos is None:
|
||||
pos = nx.spring_layout(graph, iterations=iterations)
|
||||
|
||||
# Set node colors based on the eigenvectors
|
||||
w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray())
|
||||
vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1])
|
||||
m = max(np.abs(vmin), vmax)
|
||||
vmin, vmax = -m, m
|
||||
|
||||
plt.figure()
|
||||
nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1],
|
||||
cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(path)
|
||||
plt.close("all")
|
||||
|
||||
def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'):
|
||||
# define path to save figures
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
# visualize the final molecules
|
||||
for i in range(num_graphs_to_visualize):
|
||||
file_path = os.path.join(path, 'graph_{}.png'.format(i))
|
||||
graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
|
||||
self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
|
||||
im = plt.imread(file_path)
|
||||
|
||||
def visualize_chain(self, path, nodes_list, adjacency_matrix):
|
||||
# convert graphs to networkx
|
||||
graphs = [self.to_networkx(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
|
||||
# find the coordinates of atoms in the final molecule
|
||||
final_graph = graphs[-1]
|
||||
final_pos = nx.spring_layout(final_graph, seed=0)
|
||||
|
||||
# draw gif
|
||||
save_paths = []
|
||||
num_frams = nodes_list.shape[0]
|
||||
|
||||
for frame in range(num_frams):
|
||||
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
|
||||
self.visualize_non_molecule(graph=graphs[frame], pos=final_pos, path=file_name)
|
||||
save_paths.append(file_name)
|
||||
|
||||
imgs = [imageio.imread(fn) for fn in save_paths]
|
||||
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
||||
imgs.extend([imgs[-1]] * 10)
|
||||
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
|
Reference in New Issue
Block a user