add somecomments
This commit is contained in:
@@ -76,12 +76,16 @@ class Graph_DiT(pl.LightningModule):
|
||||
timesteps=cfg.model.diffusion_steps)
|
||||
|
||||
|
||||
print("__init__")
|
||||
print("dataset_info.node_types", self.dataset_info.node_types)
|
||||
# dataset_info.node_types tensor([7.4826e-01, 2.6870e-02, 9.3930e-02, 4.4959e-02, 5.2982e-03, 7.5689e-04, 5.3739e-03, 1.5138e-03, 7.5689e-05, 4.3143e-03, 6.8650e-02])
|
||||
x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float())
|
||||
|
||||
e_marginals = self.dataset_info.edge_types.float() / torch.sum(self.dataset_info.edge_types.float())
|
||||
x_marginals = x_marginals / (x_marginals ).sum()
|
||||
e_marginals = e_marginals / (e_marginals ).sum()
|
||||
|
||||
# transition e is the probability of transitioning from x1 to x2 with e
|
||||
xe_conditions = self.dataset_info.transition_E.float()
|
||||
xe_conditions = xe_conditions[self.active_index][:, self.active_index]
|
||||
|
||||
|
Reference in New Issue
Block a user