update analysis code from diffsize branch
This commit is contained in:
@@ -10,7 +10,41 @@ import numpy as np
|
||||
import rdkit.Chem
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
class GraphVisualization:
|
||||
def __init__(self, dataset_infos):
|
||||
self.dataset_infos = dataset_infos
|
||||
def graph_from_graphs(self, node_list, adjency_matrix):
|
||||
"""
|
||||
Convert graphs to networkx graphs
|
||||
node_list: the nodes of a batch of nodes (bs x n)
|
||||
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
||||
"""
|
||||
graph = nx.Graph()
|
||||
|
||||
for i in range(len(node_list)):
|
||||
if node_list[i] == -1:
|
||||
continue
|
||||
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
|
||||
|
||||
rows, cols = np.where(adjency_matrix >= 1)
|
||||
edges = zip(rows.tolist(), cols.tolist())
|
||||
for edge in edges:
|
||||
edge_type = adjency_matrix[edge[0]][edge[1]]
|
||||
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
|
||||
|
||||
return graph
|
||||
|
||||
def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'):
|
||||
# define path to save figures
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
# visualize the final molecules
|
||||
for i in range(num_graphs_to_visualize):
|
||||
file_path = os.path.join(path, 'graph_{}.png'.format(i))
|
||||
graph = self.graph_from_graphs(graphs[i][0].numpy(), graphs[i][1].numpy())
|
||||
self.visualize_graph(graph=graph, pos=None, path=file_path)
|
||||
im = plt.imread(file_path)
|
||||
class MolecularVisualization:
|
||||
def __init__(self, dataset_infos):
|
||||
self.dataset_infos = dataset_infos
|
||||
|
Reference in New Issue
Block a user