update analysis code from diffsize branch

This commit is contained in:
mhz
2024-08-14 16:39:18 +02:00
parent b36ecd3ad0
commit 4959c6c176
2 changed files with 184 additions and 0 deletions

View File

@@ -10,7 +10,41 @@ import numpy as np
import rdkit.Chem
import matplotlib.pyplot as plt
class GraphVisualization:
def __init__(self, dataset_infos):
self.dataset_infos = dataset_infos
def graph_from_graphs(self, node_list, adjency_matrix):
"""
Convert graphs to networkx graphs
node_list: the nodes of a batch of nodes (bs x n)
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
"""
graph = nx.Graph()
for i in range(len(node_list)):
if node_list[i] == -1:
continue
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
rows, cols = np.where(adjency_matrix >= 1)
edges = zip(rows.tolist(), cols.tolist())
for edge in edges:
edge_type = adjency_matrix[edge[0]][edge[1]]
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
return graph
def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'):
# define path to save figures
if not os.path.exists(path):
os.makedirs(path)
# visualize the final molecules
for i in range(num_graphs_to_visualize):
file_path = os.path.join(path, 'graph_{}.png'.format(i))
graph = self.graph_from_graphs(graphs[i][0].numpy(), graphs[i][1].numpy())
self.visualize_graph(graph=graph, pos=None, path=file_path)
im = plt.imread(file_path)
class MolecularVisualization:
def __init__(self, dataset_infos):
self.dataset_infos = dataset_infos