1717
1818import math
1919from dataclasses import dataclass
20- from typing import List , Optional , Tuple , Union
20+ from typing import List , Literal , Optional , Tuple , Union
2121
2222import numpy as np
2323import torch
@@ -92,11 +92,10 @@ def alpha_bar_fn(t):
9292 return torch .tensor (betas , dtype = torch .float32 )
9393
9494
95- def rescale_zero_terminal_snr (betas ) :
95+ def rescale_zero_terminal_snr (betas : torch . Tensor ) -> torch . Tensor :
9696 """
9797 Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
9898
99-
10099 Args:
101100 betas (`torch.Tensor`):
102101 the betas that the scheduler is being initialized with.
@@ -143,9 +142,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
143142 The starting `beta` value of inference.
144143 beta_end (`float`, defaults to 0.02):
145144 The final `beta` value.
146- beta_schedule (`str `, defaults to `"linear"`):
147- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
148- ` linear`, `scaled_linear`, or `squaredcos_cap_v2`.
145+ beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"] `, defaults to `"linear"`):
146+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Must be one
147+ of `" linear" `, `" scaled_linear" `, or `" squaredcos_cap_v2" `.
149148 trained_betas (`np.ndarray`, *optional*):
150149 Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
151150 clip_sample (`bool`, defaults to `True`):
@@ -158,20 +157,21 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
158157 otherwise it uses the alpha value at step 0.
159158 steps_offset (`int`, defaults to 0):
160159 An offset added to the inference steps, as required by some model families.
161- prediction_type (`str `, defaults to `epsilon`, *optional* ):
162- Prediction type of the scheduler function; can be ` epsilon` (predicts the noise of the diffusion process),
163- ` sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
164- Video](https://imagen.research.google/video/paper.pdf ) paper).
160+ prediction_type (`Literal["epsilon", "sample", "v_prediction"] `, defaults to `" epsilon"` ):
161+ Prediction type of the scheduler function. Must be one of `" epsilon" ` (predicts the noise of the diffusion
162+ process), `" sample" ` (directly predicts the noisy sample), or `" v_prediction" ` (see section 2.4 of [Imagen
163+ Video](https://huggingface.co/papers/2210.02303 ) paper).
165164 thresholding (`bool`, defaults to `False`):
166165 Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
167166 as Stable Diffusion.
168167 dynamic_thresholding_ratio (`float`, defaults to 0.995):
169168 The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
170169 sample_max_value (`float`, defaults to 1.0):
171170 The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
172- timestep_spacing (`str`, defaults to `"leading"`):
173- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
174- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
171+ timestep_spacing (`Literal["leading", "trailing", "linspace"]`, defaults to `"leading"`):
172+ The way the timesteps should be scaled. Must be one of `"leading"`, `"trailing"`, or `"linspace"`. Refer to
173+ Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are
174+ Flawed](https://huggingface.co/papers/2305.08891) for more information.
175175 rescale_betas_zero_snr (`bool`, defaults to `False`):
176176 Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
177177 dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -187,17 +187,17 @@ def __init__(
187187 num_train_timesteps : int = 1000 ,
188188 beta_start : float = 0.0001 ,
189189 beta_end : float = 0.02 ,
190- beta_schedule : str = "linear" ,
190+ beta_schedule : Literal [ "linear" , "scaled_linear" , "squaredcos_cap_v2" ] = "linear" ,
191191 trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
192192 clip_sample : bool = True ,
193193 set_alpha_to_one : bool = True ,
194194 steps_offset : int = 0 ,
195- prediction_type : str = "epsilon" ,
195+ prediction_type : Literal [ "epsilon" , "sample" , "v_prediction" ] = "epsilon" ,
196196 thresholding : bool = False ,
197197 dynamic_thresholding_ratio : float = 0.995 ,
198198 clip_sample_range : float = 1.0 ,
199199 sample_max_value : float = 1.0 ,
200- timestep_spacing : str = "leading" ,
200+ timestep_spacing : Literal [ "leading" , "trailing" , "linspace" ] = "leading" ,
201201 rescale_betas_zero_snr : bool = False ,
202202 ):
203203 if trained_betas is not None :
@@ -250,7 +250,25 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
250250 """
251251 return sample
252252
253- def _get_variance (self , timestep , prev_timestep ):
253+ def _get_variance (self , timestep : int , prev_timestep : int ) -> torch .Tensor :
254+ """
255+ Computes the variance of the noise added at a given diffusion step.
256+
257+ For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
258+ literature:
259+ var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
260+ where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
261+
262+ Args:
263+ timestep (`int`):
264+ The current timestep in the diffusion process.
265+ prev_timestep (`int`):
266+ The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
267+
268+ Returns:
269+ `torch.Tensor`:
270+ The variance for the current timestep.
271+ """
254272 alpha_prod_t = self .alphas_cumprod [timestep ]
255273 alpha_prod_t_prev = self .alphas_cumprod [prev_timestep ] if prev_timestep >= 0 else self .final_alpha_cumprod
256274 beta_prod_t = 1 - alpha_prod_t
@@ -294,13 +312,18 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
294312
295313 return sample
296314
297- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
315+ def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ) -> None :
298316 """
299317 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
300318
301319 Args:
302320 num_inference_steps (`int`):
303321 The number of diffusion steps used when generating samples with a pre-trained model.
322+ device (`Union[str, torch.device]`, *optional*):
323+ The device to use for the timesteps.
324+
325+ Raises:
326+ ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
304327 """
305328
306329 if num_inference_steps > self .config .num_train_timesteps :
@@ -346,7 +369,7 @@ def step(
346369 sample : torch .Tensor ,
347370 eta : float = 0.0 ,
348371 use_clipped_model_output : bool = False ,
349- generator = None ,
372+ generator : Optional [ torch . Generator ] = None ,
350373 variance_noise : Optional [torch .Tensor ] = None ,
351374 return_dict : bool = True ,
352375 ) -> Union [DDIMSchedulerOutput , Tuple ]:
@@ -357,20 +380,21 @@ def step(
357380 Args:
358381 model_output (`torch.Tensor`):
359382 The direct output from learned diffusion model.
360- timestep (`float `):
383+ timestep (`int `):
361384 The current discrete timestep in the diffusion chain.
362385 sample (`torch.Tensor`):
363386 A current instance of a sample created by the diffusion process.
364- eta (`float`):
365- The weight of noise for added noise in diffusion step.
366- use_clipped_model_output (`bool`, defaults to `False`):
387+ eta (`float`, *optional*, defaults to 0.0):
388+ The weight of noise for added noise in diffusion step. A value of 0 corresponds to DDIM (deterministic)
389+ and 1 corresponds to DDPM (fully stochastic).
390+ use_clipped_model_output (`bool`, *optional*, defaults to `False`):
367391 If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
368392 because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
369393 clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
370394 `use_clipped_model_output` has no effect.
371395 generator (`torch.Generator`, *optional*):
372- A random number generator.
373- variance_noise (`torch.Tensor`):
396+ A random number generator for reproducible sampling .
397+ variance_noise (`torch.Tensor`, *optional* ):
374398 Alternative to generating noise with `generator` by directly providing the noise for the variance
375399 itself. Useful for methods such as [`CycleDiffusion`].
376400 return_dict (`bool`, *optional*, defaults to `True`):
@@ -517,5 +541,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
517541 velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
518542 return velocity
519543
520- def __len__ (self ):
544+ def __len__ (self ) -> int :
521545 return self .config .num_train_timesteps
0 commit comments