Compare commits

..

No commits in common. "7e83bf1401379626b0c53b4b4f075a33f76d131c" and "deadd50d243324fe94cb8c20a591ffbda036b35e" have entirely different histories.

2 changed files with 3 additions and 3 deletions

View File

@ -692,8 +692,8 @@ class Dataset(InMemoryDataset):
if adj[start][end] == 1: if adj[start][end] == 1:
edges_list.append((start, end)) edges_list.append((start, end))
edge_type.append(1) edge_type.append(1)
# edges_list.append((end, start)) edges_list.append((end, start))
# edge_type.append(1) edge_type.append(1)
edge_index = torch.tensor(edges_list, dtype=torch.long).t() edge_index = torch.tensor(edges_list, dtype=torch.long).t()
edge_type = torch.tensor(edge_type, dtype=torch.long) edge_type = torch.tensor(edge_type, dtype=torch.long)

View File

@ -139,7 +139,7 @@ class PlaceHolder:
self.E = self.E * e_mask1 * e_mask2 self.E = self.E * e_mask1 * e_mask2
# print(f"X: {self.X.shape}, E: {self.E.shape}") # print(f"X: {self.X.shape}, E: {self.E.shape}")
# print(f"X: {self.X}, E: {self.E}") # print(f"X: {self.X}, E: {self.E}")
# assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
return self return self