update_name
This commit is contained in:
0
graph_dit/datasets/__init__.py
Normal file
0
graph_dit/datasets/__init__.py
Normal file
126
graph_dit/datasets/abstract_dataset.py
Normal file
126
graph_dit/datasets/abstract_dataset.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from diffusion.distributions import DistributionNodes
|
||||
import utils as utils
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from torch_geometric.loader import DataLoader
|
||||
|
||||
|
||||
class AbstractDataModule(pl.LightningDataModule):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.dataloaders = None
|
||||
self.input_dims = None
|
||||
self.output_dims = None
|
||||
|
||||
def prepare_data(self, datasets) -> None:
|
||||
batch_size = self.cfg.train.batch_size
|
||||
num_workers = self.cfg.train.num_workers
|
||||
self.dataloaders = {split: DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
|
||||
shuffle='debug' not in self.cfg.general.name)
|
||||
for split, dataset in datasets.items()}
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.dataloaders["train"]
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.dataloaders["val"]
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.dataloaders["test"]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataloaders['train'][idx]
|
||||
|
||||
def node_counts(self, max_nodes_possible=300):
|
||||
all_counts = torch.zeros(max_nodes_possible)
|
||||
for split in ['train', 'val', 'test']:
|
||||
for i, data in enumerate(self.dataloaders[split]):
|
||||
unique, counts = torch.unique(data.batch, return_counts=True)
|
||||
for count in counts:
|
||||
all_counts[count] += 1
|
||||
max_index = max(all_counts.nonzero())
|
||||
all_counts = all_counts[:max_index + 1]
|
||||
all_counts = all_counts / all_counts.sum()
|
||||
return all_counts
|
||||
|
||||
def node_types(self):
|
||||
num_classes = None
|
||||
for data in self.dataloaders['train']:
|
||||
num_classes = data.x.shape[1]
|
||||
break
|
||||
|
||||
counts = torch.zeros(num_classes)
|
||||
|
||||
for split in ['train', 'val', 'test']:
|
||||
for i, data in enumerate(self.dataloaders[split]):
|
||||
counts += data.x.sum(dim=0)
|
||||
|
||||
counts = counts / counts.sum()
|
||||
return counts
|
||||
|
||||
def edge_counts(self):
|
||||
num_classes = None
|
||||
for data in self.dataloaders['train']:
|
||||
num_classes = 5
|
||||
break
|
||||
|
||||
d = torch.Tensor(num_classes)
|
||||
|
||||
for split in ['train', 'val', 'test']:
|
||||
for i, data in enumerate(self.dataloaders[split]):
|
||||
unique, counts = torch.unique(data.batch, return_counts=True)
|
||||
|
||||
all_pairs = 0
|
||||
for count in counts:
|
||||
all_pairs += count * (count - 1)
|
||||
|
||||
num_edges = data.edge_index.shape[1]
|
||||
num_non_edges = all_pairs - num_edges
|
||||
|
||||
data_edge_attr = torch.nn.functional.one_hot(data.edge_attr, num_classes=5).float()
|
||||
edge_types = data_edge_attr.sum(dim=0)
|
||||
assert num_non_edges >= 0
|
||||
d[0] += num_non_edges
|
||||
d[1:] += edge_types[1:]
|
||||
|
||||
d = d / d.sum()
|
||||
return d
|
||||
|
||||
|
||||
class MolecularDataModule(AbstractDataModule):
|
||||
def valency_count(self, max_n_nodes):
|
||||
valencies = torch.zeros(3 * max_n_nodes - 2) # Max valency possible if everything is connected
|
||||
multiplier = torch.Tensor([0, 1, 2, 3, 1.5])
|
||||
for split in ['train', 'val', 'test']:
|
||||
for i, data in enumerate(self.dataloaders[split]):
|
||||
n = data.x.shape[0]
|
||||
for atom in range(n):
|
||||
data_edge_attr = torch.nn.functional.one_hot(data.edge_attr, num_classes=5).float()
|
||||
edges = data_edge_attr[data.edge_index[0] == atom]
|
||||
edges_total = edges.sum(dim=0)
|
||||
valency = (edges_total * multiplier).sum()
|
||||
valencies[valency.long().item()] += 1
|
||||
valencies = valencies / valencies.sum()
|
||||
return valencies
|
||||
|
||||
|
||||
class AbstractDatasetInfos:
|
||||
def complete_infos(self, n_nodes, node_types):
|
||||
self.input_dims = None
|
||||
self.output_dims = None
|
||||
self.num_classes = len(node_types)
|
||||
self.max_n_nodes = len(n_nodes) - 1
|
||||
self.nodes_dist = DistributionNodes(n_nodes)
|
||||
|
||||
def compute_input_output_dims(self, datamodule):
|
||||
example_batch = datamodule.example_batch()
|
||||
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
|
||||
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float()
|
||||
|
||||
self.input_dims = {'X': example_batch_x.size(1),
|
||||
'E': example_batch_edge_attr.size(1),
|
||||
'y': example_batch['y'].size(1)}
|
||||
self.output_dims = {'X': example_batch_x.size(1),
|
||||
'E': example_batch_edge_attr.size(1),
|
||||
'y': example_batch['y'].size(1)}
|
381
graph_dit/datasets/dataset.py
Normal file
381
graph_dit/datasets/dataset.py
Normal file
@@ -0,0 +1,381 @@
|
||||
|
||||
import sys
|
||||
sys.path.append('../')
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import pathlib
|
||||
import json
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from rdkit import Chem, RDLogger
|
||||
from rdkit.Chem.rdchem import BondType as BT
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from torch_geometric.data import Data, InMemoryDataset
|
||||
from torch_geometric.loader import DataLoader
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import utils as utils
|
||||
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
||||
from diffusion.distributions import DistributionNodes
|
||||
|
||||
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
||||
|
||||
class DataModule(AbstractDataModule):
|
||||
def __init__(self, cfg):
|
||||
self.datadir = cfg.dataset.datadir
|
||||
self.task = cfg.dataset.task_name
|
||||
super().__init__(cfg)
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
target = getattr(self.cfg.dataset, 'guidance_target', None)
|
||||
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
||||
root_path = os.path.join(base_path, self.datadir)
|
||||
self.root_path = root_path
|
||||
|
||||
batch_size = self.cfg.train.batch_size
|
||||
num_workers = self.cfg.train.num_workers
|
||||
pin_memory = self.cfg.dataset.pin_memory
|
||||
|
||||
dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)
|
||||
|
||||
if len(self.task.split('-')) == 2:
|
||||
train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
|
||||
else:
|
||||
train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)
|
||||
|
||||
self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index
|
||||
train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
|
||||
if len(unlabeled_index) > 0:
|
||||
train_index = torch.cat([train_index, unlabeled_index], dim=0)
|
||||
|
||||
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
|
||||
self.train_dataset = train_dataset
|
||||
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)
|
||||
self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
|
||||
self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
|
||||
|
||||
training_iterations = len(train_dataset) // batch_size
|
||||
self.training_iterations = training_iterations
|
||||
|
||||
def random_data_split(self, dataset):
|
||||
nan_count = torch.isnan(dataset.y[:, 0]).sum().item()
|
||||
labeled_len = len(dataset) - nan_count
|
||||
full_idx = list(range(labeled_len))
|
||||
train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
|
||||
train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
|
||||
train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
|
||||
unlabeled_index = list(range(labeled_len, len(dataset)))
|
||||
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))
|
||||
return train_index, val_index, test_index, unlabeled_index
|
||||
|
||||
def fixed_split(self, dataset):
|
||||
if self.task == 'O2-N2':
|
||||
test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]
|
||||
else:
|
||||
raise ValueError('Invalid task name: {}'.format(self.task))
|
||||
full_idx = list(range(len(dataset)))
|
||||
full_idx = list(set(full_idx) - set(test_index))
|
||||
train_ratio = 0.8
|
||||
train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)
|
||||
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
||||
return train_index, val_index, test_index, []
|
||||
|
||||
def get_train_smiles(self):
|
||||
filename = f'{self.task}.csv.gz'
|
||||
df = pd.read_csv(f'{self.root_path}/raw/{filename}')
|
||||
df_test = df.iloc[self.test_index]
|
||||
df = df.iloc[self.train_index]
|
||||
smiles_list = df['smiles'].tolist()
|
||||
smiles_list_test = df_test['smiles'].tolist()
|
||||
smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list]
|
||||
smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test]
|
||||
return smiles_list, smiles_list_test
|
||||
|
||||
def get_data_split(self):
|
||||
filename = f'{self.task}.csv.gz'
|
||||
df = pd.read_csv(f'{self.root_path}/raw/{filename}')
|
||||
df_val = df.iloc[self.val_index]
|
||||
df_test = df.iloc[self.test_index]
|
||||
df_train = df.iloc[self.train_index]
|
||||
return df_train, df_val, df_test
|
||||
|
||||
def example_batch(self):
|
||||
return next(iter(self.val_loader))
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.train_loader
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.val_loader
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.test_loader
|
||||
|
||||
|
||||
class Dataset(InMemoryDataset):
|
||||
def __init__(self, source, root, target_prop=None,
|
||||
transform=None, pre_transform=None, pre_filter=None):
|
||||
self.target_prop = target_prop
|
||||
self.source = source
|
||||
super().__init__(root, transform, pre_transform, pre_filter)
|
||||
self.data, self.slices = torch.load(self.processed_paths[0])
|
||||
|
||||
@property
|
||||
def raw_file_names(self):
|
||||
return [f'{self.source}.csv.gz']
|
||||
|
||||
@property
|
||||
def processed_file_names(self):
|
||||
return [f'{self.source}.pt']
|
||||
|
||||
def process(self):
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
data_path = osp.join(self.raw_dir, self.raw_file_names[0])
|
||||
data_df = pd.read_csv(data_path)
|
||||
|
||||
def mol_to_graph(mol, sa, sc, target, target2=None, target3=None, valid_atoms=None):
|
||||
type_idx = []
|
||||
heavy_atom_indices, active_atoms = [], []
|
||||
for atom in mol.GetAtoms():
|
||||
if atom.GetAtomicNum() != 1:
|
||||
type_idx.append(119-2) if atom.GetSymbol() == '*' else type_idx.append(atom.GetAtomicNum()-2)
|
||||
heavy_atom_indices.append(atom.GetIdx())
|
||||
active_atoms.append(atom.GetSymbol())
|
||||
if valid_atoms is not None:
|
||||
if not atom.GetSymbol() in valid_atoms:
|
||||
return None, None
|
||||
x = torch.LongTensor(type_idx)
|
||||
|
||||
edges_list = []
|
||||
edge_type = []
|
||||
for bond in mol.GetBonds():
|
||||
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
||||
if start in heavy_atom_indices and end in heavy_atom_indices:
|
||||
start_new, end_new = heavy_atom_indices.index(start), heavy_atom_indices.index(end)
|
||||
edges_list.append((start_new, end_new))
|
||||
edge_type.append(bonds[bond.GetBondType()])
|
||||
edges_list.append((end_new, start_new))
|
||||
edge_type.append(bonds[bond.GetBondType()])
|
||||
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
|
||||
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
||||
edge_attr = edge_type
|
||||
|
||||
if target3 is not None:
|
||||
y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1,-1)
|
||||
elif target2 is not None:
|
||||
y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1,-1)
|
||||
else:
|
||||
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1,-1)
|
||||
|
||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
|
||||
if self.pre_transform is not None:
|
||||
data = self.pre_transform(data)
|
||||
return data, active_atoms
|
||||
|
||||
# Loop through every row in the DataFrame and apply the function
|
||||
data_list = []
|
||||
len_data = len(data_df)
|
||||
with tqdm(total=len_data) as pbar:
|
||||
# --- data processing start ---
|
||||
active_atoms = set()
|
||||
for i, (sms, df_row) in enumerate(data_df.iterrows()):
|
||||
if i == sms:
|
||||
sms = df_row['smiles']
|
||||
mol = Chem.MolFromSmiles(sms, sanitize=False)
|
||||
if len(self.target_prop.split('-')) == 2:
|
||||
target1, target2 = self.target_prop.split('-')
|
||||
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2])
|
||||
elif len(self.target_prop.split('-')) == 3:
|
||||
target1, target2, target3 = self.target_prop.split('-')
|
||||
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3])
|
||||
else:
|
||||
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop])
|
||||
active_atoms.update(cur_active_atoms)
|
||||
data_list.append(data)
|
||||
pbar.update(1)
|
||||
|
||||
torch.save(self.collate(data_list), self.processed_paths[0])
|
||||
|
||||
|
||||
class DataInfos(AbstractDatasetInfos):
|
||||
def __init__(self, datamodule, cfg):
|
||||
tasktype_dict = {
|
||||
'hiv_b': 'classification',
|
||||
'bace_b': 'classification',
|
||||
'bbbp_b': 'classification',
|
||||
'O2': 'regression',
|
||||
'N2': 'regression',
|
||||
'CO2': 'regression',
|
||||
}
|
||||
task_name = cfg.dataset.task_name
|
||||
self.task = task_name
|
||||
self.task_type = tasktype_dict.get(task_name, "regression")
|
||||
self.ensure_connected = cfg.model.ensure_connected
|
||||
|
||||
datadir = cfg.dataset.datadir
|
||||
|
||||
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
||||
meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json')
|
||||
data_root = os.path.join(base_path, datadir, 'raw')
|
||||
if os.path.exists(meta_filename):
|
||||
with open(meta_filename, 'r') as f:
|
||||
meta_dict = json.load(f)
|
||||
else:
|
||||
meta_dict = compute_meta(data_root, task_name, datamodule.train_index, datamodule.test_index)
|
||||
|
||||
self.base_path = base_path
|
||||
self.active_atoms = meta_dict['active_atoms']
|
||||
self.max_n_nodes = meta_dict['max_node']
|
||||
self.original_max_n_nodes = meta_dict['max_node']
|
||||
self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist'])
|
||||
self.edge_types = torch.Tensor(meta_dict['bond_type_dist'])
|
||||
self.transition_E = torch.Tensor(meta_dict['transition_E'])
|
||||
|
||||
self.atom_decoder = meta_dict['active_atoms']
|
||||
node_types = torch.Tensor(meta_dict['atom_type_dist'])
|
||||
active_index = (node_types > 0).nonzero().squeeze()
|
||||
self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index]
|
||||
self.nodes_dist = DistributionNodes(self.n_nodes)
|
||||
self.active_index = active_index
|
||||
|
||||
val_len = 3 * self.original_max_n_nodes - 2
|
||||
meta_val = torch.Tensor(meta_dict['valencies'])
|
||||
self.valency_distribution = torch.zeros(val_len)
|
||||
val_len = min(val_len, len(meta_val))
|
||||
self.valency_distribution[:val_len] = meta_val[:val_len]
|
||||
self.y_prior = None
|
||||
self.train_ymin = []
|
||||
self.train_ymax = []
|
||||
|
||||
|
||||
def compute_meta(root, source_name, train_index, test_index):
|
||||
pt = Chem.GetPeriodicTable()
|
||||
atom_name_list = []
|
||||
atom_count_list = []
|
||||
for i in range(2, 119):
|
||||
atom_name_list.append(pt.GetElementSymbol(i))
|
||||
atom_count_list.append(0)
|
||||
atom_name_list.append('*')
|
||||
atom_count_list.append(0)
|
||||
n_atoms_per_mol = [0] * 500
|
||||
bond_count_list = [0, 0, 0, 0, 0]
|
||||
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
||||
valencies = [0] * 500
|
||||
tansition_E = np.zeros((118, 118, 5))
|
||||
|
||||
filename = f'{source_name}.csv.gz'
|
||||
df = pd.read_csv(f'{root}/{filename}')
|
||||
all_index = list(range(len(df)))
|
||||
non_test_index = list(set(all_index) - set(test_index))
|
||||
df = df.iloc[non_test_index]
|
||||
tot_smiles = df['smiles'].tolist()
|
||||
|
||||
n_atom_list = []
|
||||
n_bond_list = []
|
||||
for i, sms in enumerate(tot_smiles):
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(sms)
|
||||
except:
|
||||
continue
|
||||
|
||||
n_atom = mol.GetNumHeavyAtoms()
|
||||
n_bond = mol.GetNumBonds()
|
||||
n_atom_list.append(n_atom)
|
||||
n_bond_list.append(n_bond)
|
||||
|
||||
n_atoms_per_mol[n_atom] += 1
|
||||
cur_atom_count_arr = np.zeros(118)
|
||||
for atom in mol.GetAtoms():
|
||||
symbol = atom.GetSymbol()
|
||||
if symbol == 'H':
|
||||
continue
|
||||
elif symbol == '*':
|
||||
atom_count_list[-1] += 1
|
||||
cur_atom_count_arr[-1] += 1
|
||||
else:
|
||||
atom_count_list[atom.GetAtomicNum()-2] += 1
|
||||
cur_atom_count_arr[atom.GetAtomicNum()-2] += 1
|
||||
try:
|
||||
valencies[int(atom.GetExplicitValence())] += 1
|
||||
except:
|
||||
print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence()))
|
||||
|
||||
tansition_E_temp = np.zeros((118, 118, 5))
|
||||
for bond in mol.GetBonds():
|
||||
start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom()
|
||||
if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H':
|
||||
continue
|
||||
|
||||
if start_atom.GetSymbol() == '*':
|
||||
start_index = 117
|
||||
else:
|
||||
start_index = start_atom.GetAtomicNum() - 2
|
||||
if end_atom.GetSymbol() == '*':
|
||||
end_index = 117
|
||||
else:
|
||||
end_index = end_atom.GetAtomicNum() - 2
|
||||
|
||||
bond_type = bond.GetBondType()
|
||||
bond_index = bond_type_to_index[bond_type]
|
||||
bond_count_list[bond_index] += 2
|
||||
|
||||
tansition_E[start_index, end_index, bond_index] += 2
|
||||
tansition_E[end_index, start_index, bond_index] += 2
|
||||
tansition_E_temp[start_index, end_index, bond_index] += 2
|
||||
tansition_E_temp[end_index, start_index, bond_index] += 2
|
||||
|
||||
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
|
||||
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118
|
||||
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118
|
||||
tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1)
|
||||
assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
|
||||
|
||||
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
|
||||
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
|
||||
|
||||
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
|
||||
print('processed meta info: ------', filename, '------')
|
||||
print('len atom_count_list', len(atom_count_list))
|
||||
print('len atom_name_list', len(atom_name_list))
|
||||
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
|
||||
active_atoms = active_atoms.tolist()
|
||||
atom_count_list = atom_count_list.tolist()
|
||||
|
||||
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
|
||||
bond_count_list = bond_count_list.tolist()
|
||||
valencies = np.array(valencies) / np.sum(valencies)
|
||||
valencies = valencies.tolist()
|
||||
|
||||
no_edge = np.sum(tansition_E, axis=-1) == 0
|
||||
first_elt = tansition_E[:, :, 0]
|
||||
first_elt[no_edge] = 1
|
||||
tansition_E[:, :, 0] = first_elt
|
||||
|
||||
tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True)
|
||||
|
||||
meta_dict = {
|
||||
'source': source_name,
|
||||
'num_graph': len(n_atom_list),
|
||||
'n_atoms_per_mol_dist': n_atoms_per_mol,
|
||||
'max_node': max(n_atom_list),
|
||||
'max_bond': max(n_bond_list),
|
||||
'atom_type_dist': atom_count_list,
|
||||
'bond_type_dist': bond_count_list,
|
||||
'valencies': valencies,
|
||||
'active_atoms': active_atoms,
|
||||
'num_atom_type': len(active_atoms),
|
||||
'transition_E': tansition_E.tolist(),
|
||||
}
|
||||
|
||||
with open(f'{root}/{source_name}.meta.json', "w") as f:
|
||||
json.dump(meta_dict, f)
|
||||
|
||||
return meta_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
Reference in New Issue
Block a user