Skip to content

Commit 3c1ca86

Browse files
authored
Improve docstrings and type hints in scheduling_ddpm.py (#12651)
* Enhance type hints and docstrings in scheduling_ddpm.py - Added type hints for function parameters and return types across the DDPMScheduler class and related functions. - Improved docstrings for clarity, including detailed descriptions of parameters and return values. - Updated the alpha_transform_type and beta_schedule parameters to use Literal types for better type safety. - Refined the _get_variance and previous_timestep methods with comprehensive documentation. * Refactor docstrings and type hints in scheduling_ddpm.py - Cleaned up whitespace in the rescale_zero_terminal_snr function. - Enhanced the variance_type parameter in the DDPMScheduler class with improved formatting for better readability. - Updated the docstring for the compute_variance method to maintain consistency and clarity in parameter descriptions and return values. * Apply `make fix-copies` * Refactor type hints across multiple scheduler files - Updated type hints to include `Literal` for improved type safety in various scheduling files. - Ensured consistency in type hinting for parameters and return types across the affected modules. - This change enhances code clarity and maintainability.
1 parent 6fe4a6f commit 3c1ca86

27 files changed

+887
-342
lines changed

src/diffusers/schedulers/scheduling_consistency_decoder.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from dataclasses import dataclass
3-
from typing import Optional, Tuple, Union
3+
from typing import Literal, Optional, Tuple, Union
44

55
import torch
66

@@ -12,27 +12,28 @@
1212

1313
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
1414
def betas_for_alpha_bar(
15-
num_diffusion_timesteps,
16-
max_beta=0.999,
17-
alpha_transform_type="cosine",
18-
):
15+
num_diffusion_timesteps: int,
16+
max_beta: float = 0.999,
17+
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
18+
) -> torch.Tensor:
1919
"""
2020
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
2121
(1-beta) over time from t = [0,1].
2222
2323
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
2424
to that part of the diffusion process.
2525
26-
2726
Args:
28-
num_diffusion_timesteps (`int`): the number of betas to produce.
29-
max_beta (`float`): the maximum beta to use; use values lower than 1 to
30-
prevent singularities.
31-
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
32-
Choose from `cosine` or `exp`
27+
num_diffusion_timesteps (`int`):
28+
The number of betas to produce.
29+
max_beta (`float`, defaults to `0.999`):
30+
The maximum beta to use; use values lower than 1 to avoid numerical instability.
31+
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
32+
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
3333
3434
Returns:
35-
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
35+
`torch.Tensor`:
36+
The betas used by the scheduler to step the model outputs.
3637
"""
3738
if alpha_transform_type == "cosine":
3839

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,27 +49,28 @@ class DDIMSchedulerOutput(BaseOutput):
4949

5050
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
5151
def betas_for_alpha_bar(
52-
num_diffusion_timesteps,
53-
max_beta=0.999,
54-
alpha_transform_type="cosine",
55-
):
52+
num_diffusion_timesteps: int,
53+
max_beta: float = 0.999,
54+
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
55+
) -> torch.Tensor:
5656
"""
5757
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
5858
(1-beta) over time from t = [0,1].
5959
6060
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
6161
to that part of the diffusion process.
6262
63-
6463
Args:
65-
num_diffusion_timesteps (`int`): the number of betas to produce.
66-
max_beta (`float`): the maximum beta to use; use values lower than 1 to
67-
prevent singularities.
68-
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
69-
Choose from `cosine` or `exp`
64+
num_diffusion_timesteps (`int`):
65+
The number of betas to produce.
66+
max_beta (`float`, defaults to `0.999`):
67+
The maximum beta to use; use values lower than 1 to avoid numerical instability.
68+
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
69+
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
7070
7171
Returns:
72-
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
72+
`torch.Tensor`:
73+
The betas used by the scheduler to step the model outputs.
7374
"""
7475
if alpha_transform_type == "cosine":
7576

@@ -281,13 +282,23 @@ def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
281282
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
282283
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
283284
"""
285+
Apply dynamic thresholding to the predicted sample.
286+
284287
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
285288
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
286289
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
287290
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
288291
photorealism as well as better image-text alignment, especially when using very large guidance weights."
289292
290293
https://huggingface.co/papers/2205.11487
294+
295+
Args:
296+
sample (`torch.Tensor`):
297+
The predicted sample to be thresholded.
298+
299+
Returns:
300+
`torch.Tensor`:
301+
The thresholded sample.
291302
"""
292303
dtype = sample.dtype
293304
batch_size, channels, *remaining_dims = sample.shape
@@ -501,6 +512,22 @@ def add_noise(
501512
noise: torch.Tensor,
502513
timesteps: torch.IntTensor,
503514
) -> torch.Tensor:
515+
"""
516+
Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
517+
diffusion process).
518+
519+
Args:
520+
original_samples (`torch.Tensor`):
521+
The original samples to which noise will be added.
522+
noise (`torch.Tensor`):
523+
The noise to add to the samples.
524+
timesteps (`torch.IntTensor`):
525+
The timesteps indicating the noise level for each sample.
526+
527+
Returns:
528+
`torch.Tensor`:
529+
The noisy samples.
530+
"""
504531
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
505532
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
506533
# for the subsequent add_noise calls
@@ -523,6 +550,21 @@ def add_noise(
523550

524551
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
525552
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
553+
"""
554+
Compute the velocity prediction from the sample and noise according to the velocity formula.
555+
556+
Args:
557+
sample (`torch.Tensor`):
558+
The input sample.
559+
noise (`torch.Tensor`):
560+
The noise tensor.
561+
timesteps (`torch.IntTensor`):
562+
The timesteps for velocity computation.
563+
564+
Returns:
565+
`torch.Tensor`:
566+
The computed velocity.
567+
"""
526568
# Make sure alphas_cumprod and timestep have same device and dtype as sample
527569
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
528570
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)

src/diffusers/schedulers/scheduling_ddim_cogvideox.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import math
2020
from dataclasses import dataclass
21-
from typing import List, Optional, Tuple, Union
21+
from typing import List, Literal, Optional, Tuple, Union
2222

2323
import numpy as np
2424
import torch
@@ -49,27 +49,28 @@ class DDIMSchedulerOutput(BaseOutput):
4949

5050
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
5151
def betas_for_alpha_bar(
52-
num_diffusion_timesteps,
53-
max_beta=0.999,
54-
alpha_transform_type="cosine",
55-
):
52+
num_diffusion_timesteps: int,
53+
max_beta: float = 0.999,
54+
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
55+
) -> torch.Tensor:
5656
"""
5757
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
5858
(1-beta) over time from t = [0,1].
5959
6060
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
6161
to that part of the diffusion process.
6262
63-
6463
Args:
65-
num_diffusion_timesteps (`int`): the number of betas to produce.
66-
max_beta (`float`): the maximum beta to use; use values lower than 1 to
67-
prevent singularities.
68-
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
69-
Choose from `cosine` or `exp`
64+
num_diffusion_timesteps (`int`):
65+
The number of betas to produce.
66+
max_beta (`float`, defaults to `0.999`):
67+
The maximum beta to use; use values lower than 1 to avoid numerical instability.
68+
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
69+
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
7070
7171
Returns:
72-
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
72+
`torch.Tensor`:
73+
The betas used by the scheduler to step the model outputs.
7374
"""
7475
if alpha_transform_type == "cosine":
7576

@@ -408,6 +409,22 @@ def add_noise(
408409
noise: torch.Tensor,
409410
timesteps: torch.IntTensor,
410411
) -> torch.Tensor:
412+
"""
413+
Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
414+
diffusion process).
415+
416+
Args:
417+
original_samples (`torch.Tensor`):
418+
The original samples to which noise will be added.
419+
noise (`torch.Tensor`):
420+
The noise to add to the samples.
421+
timesteps (`torch.IntTensor`):
422+
The timesteps indicating the noise level for each sample.
423+
424+
Returns:
425+
`torch.Tensor`:
426+
The noisy samples.
427+
"""
411428
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
412429
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
413430
# for the subsequent add_noise calls
@@ -430,6 +447,21 @@ def add_noise(
430447

431448
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
432449
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
450+
"""
451+
Compute the velocity prediction from the sample and noise according to the velocity formula.
452+
453+
Args:
454+
sample (`torch.Tensor`):
455+
The input sample.
456+
noise (`torch.Tensor`):
457+
The noise tensor.
458+
timesteps (`torch.IntTensor`):
459+
The timesteps for velocity computation.
460+
461+
Returns:
462+
`torch.Tensor`:
463+
The computed velocity.
464+
"""
433465
# Make sure alphas_cumprod and timestep have same device and dtype as sample
434466
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
435467
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)

src/diffusers/schedulers/scheduling_ddim_inverse.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# and https://github.com/hojonathanho/diffusion
1717
import math
1818
from dataclasses import dataclass
19-
from typing import List, Optional, Tuple, Union
19+
from typing import List, Literal, Optional, Tuple, Union
2020

2121
import numpy as np
2222
import torch
@@ -47,27 +47,28 @@ class DDIMSchedulerOutput(BaseOutput):
4747

4848
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
4949
def betas_for_alpha_bar(
50-
num_diffusion_timesteps,
51-
max_beta=0.999,
52-
alpha_transform_type="cosine",
53-
):
50+
num_diffusion_timesteps: int,
51+
max_beta: float = 0.999,
52+
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
53+
) -> torch.Tensor:
5454
"""
5555
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
5656
(1-beta) over time from t = [0,1].
5757
5858
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
5959
to that part of the diffusion process.
6060
61-
6261
Args:
63-
num_diffusion_timesteps (`int`): the number of betas to produce.
64-
max_beta (`float`): the maximum beta to use; use values lower than 1 to
65-
prevent singularities.
66-
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
67-
Choose from `cosine` or `exp`
62+
num_diffusion_timesteps (`int`):
63+
The number of betas to produce.
64+
max_beta (`float`, defaults to `0.999`):
65+
The maximum beta to use; use values lower than 1 to avoid numerical instability.
66+
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
67+
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
6868
6969
Returns:
70-
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
70+
`torch.Tensor`:
71+
The betas used by the scheduler to step the model outputs.
7172
"""
7273
if alpha_transform_type == "cosine":
7374

0 commit comments

Comments
 (0)