add somecomments

This commit is contained in:
Hanzhang Ma
2024-06-08 21:35:35 +02:00
parent 2c00828630
commit 4f8945ca07
6 changed files with 72 additions and 3 deletions

View File

@@ -76,12 +76,16 @@ class Graph_DiT(pl.LightningModule):
timesteps=cfg.model.diffusion_steps)
print("__init__")
print("dataset_info.node_types", self.dataset_info.node_types)
# dataset_info.node_types tensor([7.4826e-01, 2.6870e-02, 9.3930e-02, 4.4959e-02, 5.2982e-03, 7.5689e-04, 5.3739e-03, 1.5138e-03, 7.5689e-05, 4.3143e-03, 6.8650e-02])
x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float())
e_marginals = self.dataset_info.edge_types.float() / torch.sum(self.dataset_info.edge_types.float())
x_marginals = x_marginals / (x_marginals ).sum()
e_marginals = e_marginals / (e_marginals ).sum()
# transition e is the probability of transitioning from x1 to x2 with e
xe_conditions = self.dataset_info.transition_E.float()
xe_conditions = xe_conditions[self.active_index][:, self.active_index]