Initial commit
This commit is contained in:
BIN
ddpo_pytorch/__pycache__/prompts.cpython-310.pyc
Normal file
BIN
ddpo_pytorch/__pycache__/prompts.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ddpo_pytorch/__pycache__/rewards.cpython-310.pyc
Normal file
BIN
ddpo_pytorch/__pycache__/rewards.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc
Normal file
BIN
ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc
Normal file
Binary file not shown.
1000
ddpo_pytorch/assets/imagenet_classes.txt
Normal file
1000
ddpo_pytorch/assets/imagenet_classes.txt
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
143
ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
Normal file
143
ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py
|
||||
# with the following modifications:
|
||||
# -
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
from diffusers.utils import randn_tensor
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
|
||||
|
||||
|
||||
def ddim_step_with_logprob(
|
||||
self: DDIMScheduler,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
prev_sample: Optional[torch.FloatTensor] = None,
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
|
||||
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
|
||||
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
|
||||
generator: random number generator.
|
||||
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
|
||||
can directly provide the noise for the variance itself. This is useful for methods such as
|
||||
CycleDiffusion. (https://arxiv.org/abs/2210.05559)
|
||||
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
assert isinstance(self, DDIMScheduler)
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_sample_direction -> "direction pointing to x_t"
|
||||
# - pred_prev_sample -> "x_t-1"
|
||||
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
# 2. compute alphas, betas
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(timestep.device)
|
||||
self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device)
|
||||
alpha_prod_t = self.alphas_cumprod.gather(0, timestep)
|
||||
alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod)
|
||||
print(timestep)
|
||||
print(alpha_prod_t)
|
||||
print(alpha_prod_t_prev)
|
||||
print(prev_timestep)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
pred_epsilon = model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction`"
|
||||
)
|
||||
|
||||
# 4. Clip or threshold "predicted x_0"
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
elif self.config.clip_sample:
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = self._get_variance(timestep, prev_timestep)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
if use_clipped_model_output:
|
||||
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
||||
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
|
||||
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
|
||||
if prev_sample is not None and generator is not None:
|
||||
raise ValueError(
|
||||
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
||||
" `prev_sample` stays `None`."
|
||||
)
|
||||
|
||||
if prev_sample is None:
|
||||
variance_noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
)
|
||||
prev_sample = prev_sample_mean + std_dev_t * variance_noise
|
||||
|
||||
# log prob of prev_sample given prev_sample_mean and std_dev_t
|
||||
log_prob = (
|
||||
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
|
||||
- torch.log(std_dev_t)
|
||||
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
||||
)
|
||||
# mean along all but batch dimension
|
||||
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
||||
|
||||
return prev_sample, log_prob
|
225
ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py
Normal file
225
ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
||||
# with the following modifications:
|
||||
# -
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
rescale_noise_cfg,
|
||||
)
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from .ddim_with_logprob import ddim_step_with_logprob
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def pipeline_with_logprob(
|
||||
self: StableDiffusionPipeline,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
all_latents = [latents]
|
||||
all_log_probs = []
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents, log_prob = ddim_step_with_logprob(self.scheduler, noise_pred, t, latents, **extra_step_kwargs)
|
||||
|
||||
all_latents.append(latents)
|
||||
all_log_probs.append(log_prob)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
return image, has_nsfw_concept, all_latents, all_log_probs
|
54
ddpo_pytorch/prompts.py
Normal file
54
ddpo_pytorch/prompts.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from importlib import resources
|
||||
import functools
|
||||
import random
|
||||
import inflect
|
||||
|
||||
IE = inflect.engine()
|
||||
ASSETS_PATH = resources.files("ddpo_pytorch.assets")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def load_lines(name):
|
||||
with ASSETS_PATH.joinpath(name).open() as f:
|
||||
return [line.strip() for line in f.readlines()]
|
||||
|
||||
|
||||
def imagenet(low, high):
|
||||
return random.choice(load_lines("imagenet_classes.txt")[low:high]), {}
|
||||
|
||||
|
||||
def imagenet_all():
|
||||
return imagenet(0, 1000)
|
||||
|
||||
|
||||
def imagenet_animals():
|
||||
return imagenet(0, 398)
|
||||
|
||||
|
||||
def imagenet_dogs():
|
||||
return imagenet(151, 269)
|
||||
|
||||
|
||||
def nouns_activities(nouns_file, activities_file):
|
||||
nouns = load_lines(nouns_file)
|
||||
activities = load_lines(activities_file)
|
||||
return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {}
|
||||
|
||||
|
||||
def counting(nouns_file, low, high):
|
||||
nouns = load_lines(nouns_file)
|
||||
number = IE.number_to_words(random.randint(low, high))
|
||||
noun = random.choice(nouns)
|
||||
plural_noun = IE.plural(noun)
|
||||
prompt = f"{number} {plural_noun}"
|
||||
metadata = {
|
||||
"questions": [
|
||||
f"How many {plural_noun} are there in this image?",
|
||||
f"What animal is in this image?",
|
||||
],
|
||||
"answers": [
|
||||
number,
|
||||
noun,
|
||||
],
|
||||
}
|
||||
return prompt, metadata
|
29
ddpo_pytorch/rewards.py
Normal file
29
ddpo_pytorch/rewards.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def jpeg_incompressibility():
|
||||
def _fn(images, prompts, metadata):
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
||||
images = [Image.fromarray(image) for image in images]
|
||||
buffers = [io.BytesIO() for _ in images]
|
||||
for image, buffer in zip(images, buffers):
|
||||
image.save(buffer, format="JPEG", quality=95)
|
||||
sizes = [buffer.tell() / 1000 for buffer in buffers]
|
||||
return np.array(sizes), {}
|
||||
|
||||
return _fn
|
||||
|
||||
|
||||
def jpeg_compressibility():
|
||||
jpeg_fn = jpeg_incompressibility()
|
||||
|
||||
def _fn(images, prompts, metadata):
|
||||
rew, meta = jpeg_fn(images, prompts, metadata)
|
||||
return -rew, meta
|
||||
|
||||
return _fn
|
34
ddpo_pytorch/stat_tracking.py
Normal file
34
ddpo_pytorch/stat_tracking.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
|
||||
class PerPromptStatTracker:
|
||||
def __init__(self, buffer_size, min_count):
|
||||
self.buffer_size = buffer_size
|
||||
self.min_count = min_count
|
||||
self.stats = {}
|
||||
|
||||
def update(self, prompts, rewards):
|
||||
unique = np.unique(prompts)
|
||||
advantages = np.empty_like(rewards)
|
||||
for prompt in unique:
|
||||
prompt_rewards = rewards[prompts == prompt]
|
||||
if prompt not in self.stats:
|
||||
self.stats[prompt] = deque(maxlen=self.buffer_size)
|
||||
self.stats[prompt].extend(prompt_rewards)
|
||||
|
||||
if len(self.stats[prompt]) < self.min_count:
|
||||
mean = np.mean(rewards)
|
||||
std = np.std(rewards) + 1e-6
|
||||
else:
|
||||
mean = np.mean(self.stats[prompt])
|
||||
std = np.std(self.stats[prompt]) + 1e-6
|
||||
advantages[prompts == prompt] = (prompt_rewards - mean) / std
|
||||
|
||||
return advantages
|
||||
|
||||
def get_stats(self):
|
||||
return {
|
||||
k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)}
|
||||
for k, v in self.stats.items()
|
||||
}
|
Reference in New Issue
Block a user