Continue implementation
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py
|
||||
# with the following modifications:
|
||||
# -
|
||||
# - It computes and returns the log prob of `prev_sample` given the UNet prediction.
|
||||
# - Instead of `variance_noise`, it takes `prev_sample` as an optional argument. If `prev_sample` is provided,
|
||||
# it uses it to compute the log prob.
|
||||
# - Timesteps can be a batched torch.Tensor.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -11,6 +14,19 @@ from diffusers.utils import randn_tensor
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
|
||||
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
|
||||
alpha_prod_t_prev = torch.where(
|
||||
prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod
|
||||
).to(timestep.device)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
|
||||
return variance
|
||||
|
||||
|
||||
def ddim_step_with_logprob(
|
||||
self: DDIMScheduler,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -66,16 +82,13 @@ def ddim_step_with_logprob(
|
||||
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(timestep.device)
|
||||
self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device)
|
||||
alpha_prod_t = self.alphas_cumprod.gather(0, timestep)
|
||||
alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod)
|
||||
print(timestep)
|
||||
print(alpha_prod_t)
|
||||
print(alpha_prod_t_prev)
|
||||
print(prev_timestep)
|
||||
alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()).to(timestep.device)
|
||||
alpha_prod_t_prev = torch.where(
|
||||
prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod
|
||||
).to(timestep.device)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
@@ -106,7 +119,7 @@ def ddim_step_with_logprob(
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = self._get_variance(timestep, prev_timestep)
|
||||
variance = _get_variance(self, timestep, prev_timestep)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
if use_clipped_model_output:
|
||||
|
Reference in New Issue
Block a user