Skip to content

Commit 6fe4a6f

Browse files
authored
Improve docstrings and type hints in scheduling_ddim.py (#12622)
* Improve docstrings and type hints in scheduling_ddim.py - Add complete type hints for all function parameters - Enhance docstrings to follow project conventions - Add missing parameter descriptions Fixes #9567 * Enhance docstrings and type hints in scheduling_ddim.py - Update parameter types and descriptions for clarity - Improve explanations in method docstrings to align with project standards - Add optional annotations for parameters where applicable * Refine type hints and docstrings in scheduling_ddim.py - Update parameter types to use Literal for specific string options - Enhance docstring descriptions for clarity and consistency - Ensure all parameters have appropriate type annotations and defaults * Apply review feedback on scheduling_ddim.py - Replace "prevent singularities" with "avoid numerical instability" for better clarity - Add backticks around `alpha_bar` variable name for consistent formatting - Convert Imagen Video paper URLs to Hugging Face papers references * Propagate changes using 'make fix-copies' * Add missing Literal
1 parent 40de88a commit 6fe4a6f

11 files changed

+77
-40
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import math
1919
from dataclasses import dataclass
20-
from typing import List, Optional, Tuple, Union
20+
from typing import List, Literal, Optional, Tuple, Union
2121

2222
import numpy as np
2323
import torch
@@ -92,11 +92,10 @@ def alpha_bar_fn(t):
9292
return torch.tensor(betas, dtype=torch.float32)
9393

9494

95-
def rescale_zero_terminal_snr(betas):
95+
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
9696
"""
9797
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
9898
99-
10099
Args:
101100
betas (`torch.Tensor`):
102101
the betas that the scheduler is being initialized with.
@@ -143,9 +142,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
143142
The starting `beta` value of inference.
144143
beta_end (`float`, defaults to 0.02):
145144
The final `beta` value.
146-
beta_schedule (`str`, defaults to `"linear"`):
147-
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
148-
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
145+
beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`):
146+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Must be one
147+
of `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`.
149148
trained_betas (`np.ndarray`, *optional*):
150149
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
151150
clip_sample (`bool`, defaults to `True`):
@@ -158,20 +157,21 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
158157
otherwise it uses the alpha value at step 0.
159158
steps_offset (`int`, defaults to 0):
160159
An offset added to the inference steps, as required by some model families.
161-
prediction_type (`str`, defaults to `epsilon`, *optional*):
162-
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
163-
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
164-
Video](https://imagen.research.google/video/paper.pdf) paper).
160+
prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`):
161+
Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion
162+
process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen
163+
Video](https://huggingface.co/papers/2210.02303) paper).
165164
thresholding (`bool`, defaults to `False`):
166165
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
167166
as Stable Diffusion.
168167
dynamic_thresholding_ratio (`float`, defaults to 0.995):
169168
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
170169
sample_max_value (`float`, defaults to 1.0):
171170
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
172-
timestep_spacing (`str`, defaults to `"leading"`):
173-
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
174-
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
171+
timestep_spacing (`Literal["leading", "trailing", "linspace"]`, defaults to `"leading"`):
172+
The way the timesteps should be scaled. Must be one of `"leading"`, `"trailing"`, or `"linspace"`. Refer to
173+
Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are
174+
Flawed](https://huggingface.co/papers/2305.08891) for more information.
175175
rescale_betas_zero_snr (`bool`, defaults to `False`):
176176
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
177177
dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -187,17 +187,17 @@ def __init__(
187187
num_train_timesteps: int = 1000,
188188
beta_start: float = 0.0001,
189189
beta_end: float = 0.02,
190-
beta_schedule: str = "linear",
190+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
191191
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
192192
clip_sample: bool = True,
193193
set_alpha_to_one: bool = True,
194194
steps_offset: int = 0,
195-
prediction_type: str = "epsilon",
195+
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
196196
thresholding: bool = False,
197197
dynamic_thresholding_ratio: float = 0.995,
198198
clip_sample_range: float = 1.0,
199199
sample_max_value: float = 1.0,
200-
timestep_spacing: str = "leading",
200+
timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
201201
rescale_betas_zero_snr: bool = False,
202202
):
203203
if trained_betas is not None:
@@ -250,7 +250,25 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
250250
"""
251251
return sample
252252

253-
def _get_variance(self, timestep, prev_timestep):
253+
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
254+
"""
255+
Computes the variance of the noise added at a given diffusion step.
256+
257+
For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
258+
literature:
259+
var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
260+
where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
261+
262+
Args:
263+
timestep (`int`):
264+
The current timestep in the diffusion process.
265+
prev_timestep (`int`):
266+
The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
267+
268+
Returns:
269+
`torch.Tensor`:
270+
The variance for the current timestep.
271+
"""
254272
alpha_prod_t = self.alphas_cumprod[timestep]
255273
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
256274
beta_prod_t = 1 - alpha_prod_t
@@ -294,13 +312,18 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
294312

295313
return sample
296314

297-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
315+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
298316
"""
299317
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
300318
301319
Args:
302320
num_inference_steps (`int`):
303321
The number of diffusion steps used when generating samples with a pre-trained model.
322+
device (`Union[str, torch.device]`, *optional*):
323+
The device to use for the timesteps.
324+
325+
Raises:
326+
ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
304327
"""
305328

306329
if num_inference_steps > self.config.num_train_timesteps:
@@ -346,7 +369,7 @@ def step(
346369
sample: torch.Tensor,
347370
eta: float = 0.0,
348371
use_clipped_model_output: bool = False,
349-
generator=None,
372+
generator: Optional[torch.Generator] = None,
350373
variance_noise: Optional[torch.Tensor] = None,
351374
return_dict: bool = True,
352375
) -> Union[DDIMSchedulerOutput, Tuple]:
@@ -357,20 +380,21 @@ def step(
357380
Args:
358381
model_output (`torch.Tensor`):
359382
The direct output from learned diffusion model.
360-
timestep (`float`):
383+
timestep (`int`):
361384
The current discrete timestep in the diffusion chain.
362385
sample (`torch.Tensor`):
363386
A current instance of a sample created by the diffusion process.
364-
eta (`float`):
365-
The weight of noise for added noise in diffusion step.
366-
use_clipped_model_output (`bool`, defaults to `False`):
387+
eta (`float`, *optional*, defaults to 0.0):
388+
The weight of noise for added noise in diffusion step. A value of 0 corresponds to DDIM (deterministic)
389+
and 1 corresponds to DDPM (fully stochastic).
390+
use_clipped_model_output (`bool`, *optional*, defaults to `False`):
367391
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
368392
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
369393
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
370394
`use_clipped_model_output` has no effect.
371395
generator (`torch.Generator`, *optional*):
372-
A random number generator.
373-
variance_noise (`torch.Tensor`):
396+
A random number generator for reproducible sampling.
397+
variance_noise (`torch.Tensor`, *optional*):
374398
Alternative to generating noise with `generator` by directly providing the noise for the variance
375399
itself. Useful for methods such as [`CycleDiffusion`].
376400
return_dict (`bool`, *optional*, defaults to `True`):
@@ -517,5 +541,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
517541
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
518542
return velocity
519543

520-
def __len__(self):
544+
def __len__(self) -> int:
521545
return self.config.num_train_timesteps

src/diffusers/schedulers/scheduling_ddim_inverse.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def rescale_zero_terminal_snr(betas):
9595
"""
9696
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
9797
98-
9998
Args:
10099
betas (`torch.Tensor`):
101100
the betas that the scheduler is being initialized with.

src/diffusers/schedulers/scheduling_ddim_parallel.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import math
1919
from dataclasses import dataclass
20-
from typing import List, Optional, Tuple, Union
20+
from typing import List, Literal, Optional, Tuple, Union
2121

2222
import numpy as np
2323
import torch
@@ -97,7 +97,6 @@ def rescale_zero_terminal_snr(betas):
9797
"""
9898
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
9999
100-
101100
Args:
102101
betas (`torch.Tensor`):
103102
the betas that the scheduler is being initialized with.
@@ -194,17 +193,17 @@ def __init__(
194193
num_train_timesteps: int = 1000,
195194
beta_start: float = 0.0001,
196195
beta_end: float = 0.02,
197-
beta_schedule: str = "linear",
196+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
198197
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
199198
clip_sample: bool = True,
200199
set_alpha_to_one: bool = True,
201200
steps_offset: int = 0,
202-
prediction_type: str = "epsilon",
201+
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
203202
thresholding: bool = False,
204203
dynamic_thresholding_ratio: float = 0.995,
205204
clip_sample_range: float = 1.0,
206205
sample_max_value: float = 1.0,
207-
timestep_spacing: str = "leading",
206+
timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
208207
rescale_betas_zero_snr: bool = False,
209208
):
210209
if trained_betas is not None:
@@ -324,6 +323,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
324323
Args:
325324
num_inference_steps (`int`):
326325
The number of diffusion steps used when generating samples with a pre-trained model.
326+
device (`Union[str, torch.device]`, *optional*):
327+
The device to use for the timesteps.
328+
329+
Raises:
330+
ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
327331
"""
328332

329333
if num_inference_steps > self.config.num_train_timesteps:

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def rescale_zero_terminal_snr(betas):
9494
"""
9595
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
9696
97-
9897
Args:
9998
betas (`torch.Tensor`):
10099
the betas that the scheduler is being initialized with.

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def rescale_zero_terminal_snr(betas):
9696
"""
9797
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
9898
99-
10099
Args:
101100
betas (`torch.Tensor`):
102101
the betas that the scheduler is being initialized with.

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def rescale_zero_terminal_snr(betas):
8080
"""
8181
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
8282
83-
8483
Args:
8584
betas (`torch.Tensor`):
8685
the betas that the scheduler is being initialized with.

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def rescale_zero_terminal_snr(betas):
9797
"""
9898
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
9999
100-
101100
Args:
102101
betas (`torch.Tensor`):
103102
the betas that the scheduler is being initialized with.

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def rescale_zero_terminal_snr(betas):
100100
"""
101101
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
102102
103-
104103
Args:
105104
betas (`torch.Tensor`):
106105
the betas that the scheduler is being initialized with.

src/diffusers/schedulers/scheduling_lcm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
9999
"""
100100
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
101101
102-
103102
Args:
104103
betas (`torch.Tensor`):
105104
the betas that the scheduler is being initialized with.

src/diffusers/schedulers/scheduling_tcd.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
9898
"""
9999
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
100100
101-
102101
Args:
103102
betas (`torch.Tensor`):
104103
the betas that the scheduler is being initialized with.
@@ -316,6 +315,24 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
316315

317316
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance
318317
def _get_variance(self, timestep, prev_timestep):
318+
"""
319+
Computes the variance of the noise added at a given diffusion step.
320+
321+
For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
322+
literature:
323+
var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
324+
where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
325+
326+
Args:
327+
timestep (`int`):
328+
The current timestep in the diffusion process.
329+
prev_timestep (`int`):
330+
The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
331+
332+
Returns:
333+
`torch.Tensor`:
334+
The variance for the current timestep.
335+
"""
319336
alpha_prod_t = self.alphas_cumprod[timestep]
320337
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
321338
beta_prod_t = 1 - alpha_prod_t

0 commit comments

Comments
 (0)