comment some output statements and record dimension infos
This commit is contained in:
@@ -103,16 +103,25 @@ class MarginalTransition:
|
||||
self.e_marginals = e_marginals # Dx, De
|
||||
self.xe_conditions = xe_conditions
|
||||
|
||||
self.u_x = x_marginals.unsqueeze(0).expand(self.X_classes, -1).unsqueeze(0) # 1, Dx, Dx
|
||||
self.u_e = e_marginals.unsqueeze(0).expand(self.E_classes, -1).unsqueeze(0) # 1, De, De
|
||||
self.u_xe = xe_conditions.unsqueeze(0) # 1, Dx, De
|
||||
self.u_ex = ex_conditions.unsqueeze(0) # 1, De, Dx
|
||||
self.u_x = x_marginals.unsqueeze(0).expand(self.X_classes, -1).unsqueeze(0) # 1, Dx, Dx 1 7 7
|
||||
self.u_e = e_marginals.unsqueeze(0).expand(self.E_classes, -1).unsqueeze(0) # 1, De, De 1 2 2
|
||||
self.u_xe = xe_conditions.unsqueeze(0) # 1, Dx, De 1 7 2
|
||||
self.u_ex = ex_conditions.unsqueeze(0) # 1, De, Dx 1 2 7
|
||||
self.u = self.get_union_transition(self.u_x, self.u_e, self.u_xe, self.u_ex, n_nodes) # 1, Dx + n*De, Dx + n*De
|
||||
# print(f"Shape of u_x: {self.u_x.shape}")
|
||||
# print(f"Shape of u_e: {self.u_e.shape}")
|
||||
# print(f"Shape of u_xe: {self.u_xe.shape}")
|
||||
# print(f"Shape of u_ex: {self.u_ex.shape}")
|
||||
# print(f"Shape of u: {self.u.shape}")
|
||||
|
||||
def get_union_transition(self, u_x, u_e, u_xe, u_ex, n_nodes):
|
||||
# print(f"before processing Shape of u_e: {u_e.shape}")
|
||||
# print(f"before processing Shape of u_ex: {u_ex.shape}")
|
||||
u_e = u_e.repeat(1, n_nodes, n_nodes) # (1, n*de, n*de)
|
||||
u_xe = u_xe.repeat(1, 1, n_nodes) # (1, dx, n*de)
|
||||
u_ex = u_ex.repeat(1, n_nodes, 1) # (1, n*de, dx)
|
||||
# print(f"After processing Shape of u_ex: {u_ex.shape}")
|
||||
# print(f"After processing Shape of u_e: {u_e.shape}")
|
||||
u0 = torch.cat([u_x, u_xe], dim=2) # (1, dx, dx + n*de)
|
||||
u1 = torch.cat([u_ex, u_e], dim=2) # (1, n*de, dx + n*de)
|
||||
u = torch.cat([u0, u1], dim=1) # (1, dx + n*de, dx + n*de)
|
||||
|
Reference in New Issue
Block a user