try to run the graph, commented sampling codes

This commit is contained in:
mhz
2024-06-25 00:09:27 +02:00
parent e04ad5fbe7
commit 82299e5213
5 changed files with 80 additions and 209 deletions

View File

@@ -116,7 +116,7 @@ class AbstractDatasetInfos:
def compute_input_output_dims(self, datamodule):
example_batch = datamodule.example_batch()
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float()
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float()
self.input_dims = {'X': example_batch_x.size(1),
'E': example_batch_edge_attr.size(1),