add the idea of guidance
This commit is contained in:
@@ -134,7 +134,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
loss = self.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y,
|
||||
true_X=X, true_E=E, true_y=data.y, node_mask=node_mask,
|
||||
log=i % self.log_every_steps == 0)
|
||||
|
||||
# print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}')
|
||||
self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E,
|
||||
log=i % self.log_every_steps == 0)
|
||||
self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
|
||||
@@ -601,7 +601,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
# Normalize predictions
|
||||
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
|
||||
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
|
||||
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
|
||||
# gradient
|
||||
|
||||
# Retrieve transitions matrix
|
||||
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)
|
||||
@@ -629,25 +630,52 @@ class Graph_DiT(pl.LightningModule):
|
||||
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
|
||||
|
||||
return prob_X, prob_E
|
||||
|
||||
# diffusion nag: P_t(G_{t-1} |G_t, C) = P_t(G_{t-1} |G_t) + P_t(C | G_{t-1}, G_t)
|
||||
# with condition = P_t(G_{t-1} |G_t, C)
|
||||
# with condition = P_t(A_{t-1} |A_t, y)
|
||||
prob_X, prob_E = get_prob(noisy_data)
|
||||
|
||||
### Guidance
|
||||
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
|
||||
uncon_prob_X, uncon_prob_E = get_prob(noisy_data, unconditioned=True)
|
||||
prob_X = uncon_prob_X * (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale
|
||||
prob_X = uncon_prob_X * (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale
|
||||
prob_E = uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-10)) ** self.guide_scale
|
||||
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10)
|
||||
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-10)
|
||||
|
||||
|
||||
assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
|
||||
assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()
|
||||
|
||||
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
|
||||
# sample multiple times and get the best score arch...
|
||||
|
||||
sample_num = 100
|
||||
best_arch = None
|
||||
best_score = -1e8
|
||||
|
||||
for i in range(sample_num):
|
||||
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
score = get_score(sampled_s)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_arch = sampled_s
|
||||
|
||||
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
|
||||
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
|
||||
|
||||
# NASWOT score
|
||||
target_score = torch.tensor([3000.0])
|
||||
|
||||
# compute loss mse(cur_score - target_score)
|
||||
|
||||
# loss backward = gradient
|
||||
|
||||
# get prob.X, prob_E gradient
|
||||
|
||||
# update prob.X prob_E with using gradient
|
||||
|
||||
assert (E_s == torch.transpose(E_s, 1, 2)).all()
|
||||
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
|
||||
|
||||
|
Reference in New Issue
Block a user