update some score function

This commit is contained in:
mhz
2024-08-05 21:45:15 +02:00
parent f5d00be56e
commit 205f43291b
4 changed files with 37 additions and 17 deletions

View File

@@ -86,7 +86,7 @@ class Denoiser(nn.Module):
"""
def forward(self, x, e, node_mask, y, t, unconditioned):
print("Denoiser Forward")
# print("Denoiser Forward")
# print(x.shape, e.shape, y.shape, t.shape, unconditioned)
force_drop_id = torch.zeros_like(y.sum(-1))
# drop the nan values