Continue implementation

This commit is contained in:
Kevin Black
2023-06-23 21:08:32 -07:00
parent 6d848c3cdc
commit 92fc030123
3 changed files with 57 additions and 25 deletions

View File

@@ -1,6 +1,9 @@
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py
# with the following modifications:
# -
# - It computes and returns the log prob of `prev_sample` given the UNet prediction.
# - Instead of `variance_noise`, it takes `prev_sample` as an optional argument. If `prev_sample` is provided,
# it uses it to compute the log prob.
# - Timesteps can be a batched torch.Tensor.
from typing import Optional, Tuple, Union
@@ -11,6 +14,19 @@ from diffusers.utils import randn_tensor
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
alpha_prod_t_prev = torch.where(
prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod
).to(timestep.device)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def ddim_step_with_logprob(
self: DDIMScheduler,
model_output: torch.FloatTensor,
@@ -66,16 +82,13 @@ def ddim_step_with_logprob(
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
# 2. compute alphas, betas
self.alphas_cumprod = self.alphas_cumprod.to(timestep.device)
self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device)
alpha_prod_t = self.alphas_cumprod.gather(0, timestep)
alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod)
print(timestep)
print(alpha_prod_t)
print(alpha_prod_t_prev)
print(prev_timestep)
alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()).to(timestep.device)
alpha_prod_t_prev = torch.where(
prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod
).to(timestep.device)
beta_prod_t = 1 - alpha_prod_t
@@ -106,7 +119,7 @@ def ddim_step_with_logprob(
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = self._get_variance(timestep, prev_timestep)
variance = _get_variance(self, timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output: