Compare commits
2 Commits
deadd50d24
...
7e83bf1401
Author | SHA1 | Date | |
---|---|---|---|
7e83bf1401 | |||
9601a3c18d |
@ -692,8 +692,8 @@ class Dataset(InMemoryDataset):
|
|||||||
if adj[start][end] == 1:
|
if adj[start][end] == 1:
|
||||||
edges_list.append((start, end))
|
edges_list.append((start, end))
|
||||||
edge_type.append(1)
|
edge_type.append(1)
|
||||||
edges_list.append((end, start))
|
# edges_list.append((end, start))
|
||||||
edge_type.append(1)
|
# edge_type.append(1)
|
||||||
|
|
||||||
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
|
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
|
||||||
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
||||||
|
@ -139,7 +139,7 @@ class PlaceHolder:
|
|||||||
self.E = self.E * e_mask1 * e_mask2
|
self.E = self.E * e_mask1 * e_mask2
|
||||||
# print(f"X: {self.X.shape}, E: {self.E.shape}")
|
# print(f"X: {self.X.shape}, E: {self.E.shape}")
|
||||||
# print(f"X: {self.X}, E: {self.E}")
|
# print(f"X: {self.X}, E: {self.E}")
|
||||||
assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
|
# assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user