comment some output statements and record dimension infos
This commit is contained in:
@@ -65,10 +65,11 @@ def reverse_tensor(x):
|
||||
|
||||
def sample_discrete_features(probX, probE, node_mask, step=None, add_nose=True):
|
||||
''' Sample features from multinomial distribution with given probabilities (probX, probE, proby)
|
||||
:param probX: bs, n, dx_out node features
|
||||
:param probE: bs, n, n, de_out edge features
|
||||
:param proby: bs, dy_out global features.
|
||||
:param probX: bs, n, dx_out node features 1200 8 7
|
||||
:param probE: bs, n, n, de_out edge features 1200 8 8 2
|
||||
:param proby: bs, dy_out global features. 1200 8
|
||||
'''
|
||||
# print(f"sample_discrete_features in: probX: {probX.shape}, probE: {probE.shape}, node_mask: {node_mask.shape}")
|
||||
bs, n, _ = probX.shape
|
||||
|
||||
# Noise X
|
||||
@@ -97,8 +98,11 @@ def sample_discrete_features(probX, probE, node_mask, step=None, add_nose=True):
|
||||
|
||||
# Sample E
|
||||
E_t = probE.multinomial(1).reshape(bs, n, n) # (bs, n, n)
|
||||
# print(f"sample_discrete_features out: X_t: {X_t.shape}, E_t: {E_t.shape}")
|
||||
E_t = torch.triu(E_t, diagonal=1)
|
||||
# print(f"sample_discrete_features out: X_t: {X_t.shape}, E_t: {E_t.shape}")
|
||||
E_t = (E_t + torch.transpose(E_t, 1, 2))
|
||||
# print(f"sample_discrete_features out: X_t: {X_t.shape}, E_t: {E_t.shape}")
|
||||
|
||||
return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0).type_as(X_t))
|
||||
|
||||
|
Reference in New Issue
Block a user