Working on DGX
This commit is contained in:
55
ddpo_pytorch/config/base.py
Normal file
55
ddpo_pytorch/config/base.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import ml_collections
|
||||
|
||||
def get_config():
|
||||
|
||||
config = ml_collections.ConfigDict()
|
||||
|
||||
# misc
|
||||
config.seed = 42
|
||||
config.logdir = "logs"
|
||||
config.num_epochs = 100
|
||||
config.mixed_precision = "fp16"
|
||||
config.allow_tf32 = True
|
||||
|
||||
# pretrained model initialization
|
||||
config.pretrained = pretrained = ml_collections.ConfigDict()
|
||||
pretrained.model = "runwayml/stable-diffusion-v1-5"
|
||||
pretrained.revision = "main"
|
||||
|
||||
# training
|
||||
config.train = train = ml_collections.ConfigDict()
|
||||
train.batch_size = 1
|
||||
train.use_8bit_adam = False
|
||||
train.scale_lr = False
|
||||
train.learning_rate = 1e-4
|
||||
train.adam_beta1 = 0.9
|
||||
train.adam_beta2 = 0.999
|
||||
train.adam_weight_decay = 1e-4
|
||||
train.adam_epsilon = 1e-8
|
||||
train.gradient_accumulation_steps = 1
|
||||
train.max_grad_norm = 1.0
|
||||
train.num_inner_epochs = 1
|
||||
train.cfg = True
|
||||
train.adv_clip_max = 10
|
||||
train.clip_range = 1e-4
|
||||
|
||||
# sampling
|
||||
config.sample = sample = ml_collections.ConfigDict()
|
||||
sample.num_steps = 30
|
||||
sample.eta = 1.0
|
||||
sample.guidance_scale = 5.0
|
||||
sample.batch_size = 1
|
||||
sample.num_batches_per_epoch = 1
|
||||
|
||||
# prompting
|
||||
config.prompt_fn = "imagenet_animals"
|
||||
config.prompt_fn_kwargs = {}
|
||||
|
||||
# rewards
|
||||
config.reward_fn = "jpeg_compressibility"
|
||||
|
||||
config.per_prompt_stat_tracking = ml_collections.ConfigDict()
|
||||
config.per_prompt_stat_tracking.buffer_size = 64
|
||||
config.per_prompt_stat_tracking.min_count = 16
|
||||
|
||||
return config
|
20
ddpo_pytorch/config/dgx.py
Normal file
20
ddpo_pytorch/config/dgx.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import ml_collections
|
||||
from ddpo_pytorch.config import base
|
||||
|
||||
def get_config():
|
||||
config = base.get_config()
|
||||
|
||||
config.mixed_precision = "bf16"
|
||||
config.allow_tf32 = True
|
||||
|
||||
config.train.batch_size = 8
|
||||
config.train.gradient_accumulation_steps = 4
|
||||
|
||||
# sampling
|
||||
config.sample.num_steps = 50
|
||||
config.sample.batch_size = 8
|
||||
config.sample.num_batches_per_epoch = 4
|
||||
|
||||
config.per_prompt_stat_tracking = None
|
||||
|
||||
return config
|
@@ -14,6 +14,11 @@ from diffusers.utils import randn_tensor
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
|
||||
|
||||
|
||||
def _left_broadcast(t, shape):
|
||||
assert t.ndim <= len(shape)
|
||||
return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
|
||||
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
|
||||
alpha_prod_t_prev = torch.where(
|
||||
@@ -82,13 +87,16 @@ def ddim_step_with_logprob(
|
||||
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
# to prevent OOB on gather
|
||||
prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()).to(timestep.device)
|
||||
alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
|
||||
alpha_prod_t_prev = torch.where(
|
||||
prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod
|
||||
).to(timestep.device)
|
||||
)
|
||||
alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
|
||||
alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
@@ -121,6 +129,7 @@ def ddim_step_with_logprob(
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = _get_variance(self, timestep, prev_timestep)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
|
||||
|
||||
if use_clipped_model_output:
|
||||
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
||||
@@ -153,4 +162,4 @@ def ddim_step_with_logprob(
|
||||
# mean along all but batch dimension
|
||||
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
||||
|
||||
return prev_sample, log_prob
|
||||
return prev_sample.type(sample.dtype), log_prob
|
||||
|
Reference in New Issue
Block a user