add somecomments
This commit is contained in:
@@ -123,4 +123,8 @@ class AbstractDatasetInfos:
|
||||
'y': example_batch['y'].size(1)}
|
||||
self.output_dims = {'X': example_batch_x.size(1),
|
||||
'E': example_batch_edge_attr.size(1),
|
||||
'y': example_batch['y'].size(1)}
|
||||
'y': example_batch['y'].size(1)}
|
||||
print('input dims')
|
||||
print(self.input_dims)
|
||||
print('output dims')
|
||||
print(self.output_dims)
|
@@ -28,19 +28,38 @@ class DataModule(AbstractDataModule):
|
||||
def __init__(self, cfg):
|
||||
self.datadir = cfg.dataset.datadir
|
||||
self.task = cfg.dataset.task_name
|
||||
print("DataModule")
|
||||
print("task", self.task)
|
||||
print("datadir`",self.datadir)
|
||||
super().__init__(cfg)
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
target = getattr(self.cfg.dataset, 'guidance_target', None)
|
||||
print("target", target)
|
||||
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
||||
root_path = os.path.join(base_path, self.datadir)
|
||||
self.root_path = root_path
|
||||
|
||||
batch_size = self.cfg.train.batch_size
|
||||
|
||||
num_workers = self.cfg.train.num_workers
|
||||
pin_memory = self.cfg.dataset.pin_memory
|
||||
|
||||
# Load the dataset to the memory
|
||||
# Dataset has target property, root path, and transform
|
||||
dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)
|
||||
print("len dataset", len(dataset))
|
||||
def print_data(dataset):
|
||||
print("dataset", dataset)
|
||||
print("dataset keys", dataset.keys)
|
||||
print("dataset x", dataset.x)
|
||||
print("dataset edge_index", dataset.edge_index)
|
||||
print("dataset edge_attr", dataset.edge_attr)
|
||||
print("dataset y", dataset.y)
|
||||
print("")
|
||||
print_data(dataset=dataset[0])
|
||||
print_data(dataset=dataset[1])
|
||||
|
||||
|
||||
if len(self.task.split('-')) == 2:
|
||||
train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
|
||||
@@ -53,8 +72,12 @@ class DataModule(AbstractDataModule):
|
||||
train_index = torch.cat([train_index, unlabeled_index], dim=0)
|
||||
|
||||
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
|
||||
self.train_dataset = train_dataset
|
||||
self.train_dataset = train_dataset
|
||||
print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
||||
print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
||||
print('dataset len', len(dataset) , 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
||||
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)
|
||||
|
||||
self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
|
||||
self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
|
||||
|
||||
@@ -253,6 +276,9 @@ class DataInfos(AbstractDatasetInfos):
|
||||
|
||||
|
||||
def compute_meta(root, source_name, train_index, test_index):
|
||||
# initialize the periodic table
|
||||
# 118 elements + 1 for *
|
||||
# Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types.
|
||||
pt = Chem.GetPeriodicTable()
|
||||
atom_name_list = []
|
||||
atom_count_list = []
|
||||
@@ -267,11 +293,13 @@ def compute_meta(root, source_name, train_index, test_index):
|
||||
valencies = [0] * 500
|
||||
tansition_E = np.zeros((118, 118, 5))
|
||||
|
||||
# Load the data from the source file
|
||||
filename = f'{source_name}.csv.gz'
|
||||
df = pd.read_csv(f'{root}/{filename}')
|
||||
all_index = list(range(len(df)))
|
||||
non_test_index = list(set(all_index) - set(test_index))
|
||||
df = df.iloc[non_test_index]
|
||||
# extract the smiles from the dataframe
|
||||
tot_smiles = df['smiles'].tolist()
|
||||
|
||||
n_atom_list = []
|
||||
@@ -323,6 +351,11 @@ def compute_meta(root, source_name, train_index, test_index):
|
||||
bond_index = bond_type_to_index[bond_type]
|
||||
bond_count_list[bond_index] += 2
|
||||
|
||||
# Update the transition matrix
|
||||
# The transition matrix is symmetric, so we update both directions
|
||||
# We also update the temporary transition matrix to check for errors
|
||||
# in the atom count
|
||||
|
||||
tansition_E[start_index, end_index, bond_index] += 2
|
||||
tansition_E[end_index, start_index, bond_index] += 2
|
||||
tansition_E_temp[start_index, end_index, bond_index] += 2
|
||||
|
Reference in New Issue
Block a user