Skip to content

Commit 92d44da

Browse files
committed
Enhance docstrings and type hints in PNDMScheduler class
- Updated parameter descriptions to include default values and specific types using Literal for better clarity. - Improved docstring formatting and consistency across methods, including detailed explanations for the `_get_prev_sample` method. - Added type hints for method return types to enhance code readability and maintainability.
1 parent 3579fda commit 92d44da

File tree

1 file changed

+48
-34
lines changed

1 file changed

+48
-34
lines changed

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
7979
methods the library implements for all schedulers such as loading and saving.
8080
8181
Args:
82-
num_train_timesteps (`int`, defaults to 1000):
82+
num_train_timesteps (`int`, defaults to `1000`):
8383
The number of diffusion steps to train the model.
84-
beta_start (`float`, defaults to 0.0001):
84+
beta_start (`float`, defaults to `0.0001`):
8585
The starting `beta` value of inference.
86-
beta_end (`float`, defaults to 0.02):
86+
beta_end (`float`, defaults to `0.02`):
8787
The final `beta` value.
88-
beta_schedule (`str`, defaults to `"linear"`):
89-
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
90-
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
88+
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
89+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
9190
trained_betas (`np.ndarray`, *optional*):
9291
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
9392
skip_prk_steps (`bool`, defaults to `False`):
@@ -97,14 +96,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
9796
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
9897
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
9998
otherwise it uses the alpha value at step 0.
100-
prediction_type (`str`, defaults to `epsilon`, *optional*):
99+
prediction_type (`"epsilon"` or `"v_prediction"`, defaults to `"epsilon"`):
101100
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
102-
or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
103-
paper).
104-
timestep_spacing (`str`, defaults to `"leading"`):
101+
or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
102+
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
105103
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
106104
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
107-
steps_offset (`int`, defaults to 0):
105+
steps_offset (`int`, defaults to `0`):
108106
An offset added to the inference steps, as required by some model families.
109107
"""
110108

@@ -117,12 +115,12 @@ def __init__(
117115
num_train_timesteps: int = 1000,
118116
beta_start: float = 0.0001,
119117
beta_end: float = 0.02,
120-
beta_schedule: str = "linear",
118+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
121119
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
122120
skip_prk_steps: bool = False,
123121
set_alpha_to_one: bool = False,
124-
prediction_type: str = "epsilon",
125-
timestep_spacing: str = "leading",
122+
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
123+
timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
126124
steps_offset: int = 0,
127125
):
128126
if trained_betas is not None:
@@ -164,7 +162,7 @@ def __init__(
164162
self.plms_timesteps = None
165163
self.timesteps = None
166164

167-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
165+
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
168166
"""
169167
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
170168
@@ -243,7 +241,7 @@ def step(
243241
The current discrete timestep in the diffusion chain.
244242
sample (`torch.Tensor`):
245243
A current instance of a sample created by the diffusion process.
246-
return_dict (`bool`):
244+
return_dict (`bool`, defaults to `True`):
247245
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
248246
249247
Returns:
@@ -276,14 +274,13 @@ def step_prk(
276274
The current discrete timestep in the diffusion chain.
277275
sample (`torch.Tensor`):
278276
A current instance of a sample created by the diffusion process.
279-
return_dict (`bool`):
277+
return_dict (`bool`, defaults to `True`):
280278
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
281279
282280
Returns:
283281
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
284282
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
285283
tuple is returned where the first element is the sample tensor.
286-
287284
"""
288285
if self.num_inference_steps is None:
289286
raise ValueError(
@@ -335,14 +332,13 @@ def step_plms(
335332
The current discrete timestep in the diffusion chain.
336333
sample (`torch.Tensor`):
337334
A current instance of a sample created by the diffusion process.
338-
return_dict (`bool`):
335+
return_dict (`bool`, defaults to `True`):
339336
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
340337
341338
Returns:
342339
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
343340
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
344341
tuple is returned where the first element is the sample tensor.
345-
346342
"""
347343
if self.num_inference_steps is None:
348344
raise ValueError(
@@ -403,19 +399,37 @@ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tens
403399
"""
404400
return sample
405401

406-
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
407-
# See formula (9) of PNDM paper https://huggingface.co/papers/2202.09778
408-
# this function computes x_(t−δ) using the formula of (9)
409-
# Note that x_t needs to be added to both sides of the equation
410-
411-
# Notation (<variable name> -> <name in paper>
412-
# alpha_prod_t -> α_t
413-
# alpha_prod_t_prev -> α_(t−δ)
414-
# beta_prod_t -> (1 - α_t)
415-
# beta_prod_t_prev -> (1 - α_(t−δ))
416-
# sample -> x_t
417-
# model_output -> e_θ(x_t, t)
418-
# prev_sample -> x_(t−δ)
402+
def _get_prev_sample(
403+
self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor
404+
) -> torch.Tensor:
405+
"""
406+
Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the PNDM paper.
407+
408+
See formula (9) of [PNDM paper](https://huggingface.co/papers/2202.09778)
409+
410+
Notation (<variable name> -> <name in paper>):
411+
- alpha_prod_t -> α_t
412+
- alpha_prod_t_prev -> α_(t−δ)
413+
- beta_prod_t -> (1 - α_t)
414+
- beta_prod_t_prev -> (1 - α_(t−δ))
415+
- sample -> x_t
416+
- model_output -> e_θ(x_t, t)
417+
- prev_sample -> x_(t−δ)
418+
419+
Args:
420+
sample (`torch.Tensor`):
421+
The current sample x_t.
422+
timestep (`int`):
423+
The current timestep t.
424+
prev_timestep (`int`):
425+
The previous timestep (t-δ).
426+
model_output (`torch.Tensor`):
427+
The model output e_θ(x_t, t).
428+
429+
Returns:
430+
`torch.Tensor`:
431+
The previous sample x_(t-δ).
432+
"""
419433
alpha_prod_t = self.alphas_cumprod[timestep]
420434
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
421435
beta_prod_t = 1 - alpha_prod_t
@@ -489,5 +503,5 @@ def add_noise(
489503
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
490504
return noisy_samples
491505

492-
def __len__(self):
506+
def __len__(self) -> int:
493507
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)