comment some output statements and record dimension infos

This commit is contained in:
mhz
2024-07-01 10:05:45 +02:00
parent 7147679c42
commit 4d1dea1179
2 changed files with 20 additions and 7 deletions

View File

@@ -65,10 +65,11 @@ def reverse_tensor(x):
def sample_discrete_features(probX, probE, node_mask, step=None, add_nose=True):
''' Sample features from multinomial distribution with given probabilities (probX, probE, proby)
:param probX: bs, n, dx_out node features
:param probE: bs, n, n, de_out edge features
:param proby: bs, dy_out global features.
:param probX: bs, n, dx_out node features 1200 8 7
:param probE: bs, n, n, de_out edge features 1200 8 8 2
:param proby: bs, dy_out global features. 1200 8
'''
# print(f"sample_discrete_features in: probX: {probX.shape}, probE: {probE.shape}, node_mask: {node_mask.shape}")
bs, n, _ = probX.shape
# Noise X
@@ -97,8 +98,11 @@ def sample_discrete_features(probX, probE, node_mask, step=None, add_nose=True):
# Sample E
E_t = probE.multinomial(1).reshape(bs, n, n) # (bs, n, n)
# print(f"sample_discrete_features out: X_t: {X_t.shape}, E_t: {E_t.shape}")
E_t = torch.triu(E_t, diagonal=1)
# print(f"sample_discrete_features out: X_t: {X_t.shape}, E_t: {E_t.shape}")
E_t = (E_t + torch.transpose(E_t, 1, 2))
# print(f"sample_discrete_features out: X_t: {X_t.shape}, E_t: {E_t.shape}")
return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0).type_as(X_t))