Skip to content

Commit 76923c0

Browse files
committed
Enhance type hints and docstrings in LMSDiscreteScheduler class
Updated type hints for function parameters and return types to improve code clarity and maintainability. Enhanced docstrings for several methods, providing clearer descriptions of their functionality and expected arguments. Notable changes include specifying Literal types for certain parameters and ensuring consistent return type annotations across the class.
1 parent 3579fda commit 76923c0

File tree

1 file changed

+57
-25
lines changed

1 file changed

+57
-25
lines changed

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
9999
methods the library implements for all schedulers such as loading and saving.
100100
101101
Args:
102-
num_train_timesteps (`int`, defaults to 1000):
102+
num_train_timesteps (`int`, defaults to `1000`):
103103
The number of diffusion steps to train the model.
104-
beta_start (`float`, defaults to 0.0001):
104+
beta_start (`float`, defaults to `0.0001`):
105105
The starting `beta` value of inference.
106-
beta_end (`float`, defaults to 0.02):
106+
beta_end (`float`, defaults to `0.02`):
107107
The final `beta` value.
108-
beta_schedule (`str`, defaults to `"linear"`):
109-
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
110-
`linear` or `scaled_linear`.
108+
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
109+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
111110
trained_betas (`np.ndarray`, *optional*):
112111
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
113112
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
@@ -118,14 +117,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
118117
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
119118
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
120119
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
121-
prediction_type (`str`, defaults to `epsilon`, *optional*):
120+
prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`):
122121
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
123122
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
124123
Video](https://imagen.research.google/video/paper.pdf) paper).
125-
timestep_spacing (`str`, defaults to `"linspace"`):
124+
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
126125
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
127126
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
128-
steps_offset (`int`, defaults to 0):
127+
steps_offset (`int`, defaults to `0`):
129128
An offset added to the inference steps, as required by some model families.
130129
"""
131130

@@ -138,13 +137,13 @@ def __init__(
138137
num_train_timesteps: int = 1000,
139138
beta_start: float = 0.0001,
140139
beta_end: float = 0.02,
141-
beta_schedule: str = "linear",
140+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
142141
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
143142
use_karras_sigmas: Optional[bool] = False,
144143
use_exponential_sigmas: Optional[bool] = False,
145144
use_beta_sigmas: Optional[bool] = False,
146-
prediction_type: str = "epsilon",
147-
timestep_spacing: str = "linspace",
145+
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
146+
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
148147
steps_offset: int = 0,
149148
):
150149
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
@@ -183,29 +182,45 @@ def __init__(
183182
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
184183

185184
@property
186-
def init_noise_sigma(self):
185+
def init_noise_sigma(self) -> Union[float, torch.Tensor]:
186+
"""
187+
The standard deviation of the initial noise distribution.
188+
189+
Returns:
190+
`float` or `torch.Tensor`:
191+
The standard deviation of the initial noise distribution, computed based on the maximum sigma value and
192+
the timestep spacing configuration.
193+
"""
187194
# standard deviation of the initial noise distribution
188195
if self.config.timestep_spacing in ["linspace", "trailing"]:
189196
return self.sigmas.max()
190197

191198
return (self.sigmas.max() ** 2 + 1) ** 0.5
192199

193200
@property
194-
def step_index(self):
201+
def step_index(self) -> Optional[int]:
195202
"""
196-
The index counter for current timestep. It will increase 1 after each scheduler step.
203+
The index counter for current timestep. It will increase by 1 after each scheduler step.
204+
205+
Returns:
206+
`int` or `None`:
207+
The current step index, or `None` if not initialized.
197208
"""
198209
return self._step_index
199210

200211
@property
201-
def begin_index(self):
212+
def begin_index(self) -> Optional[int]:
202213
"""
203214
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
215+
216+
Returns:
217+
`int` or `None`:
218+
The begin index for the scheduler, or `None` if not set.
204219
"""
205220
return self._begin_index
206221

207222
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
208-
def set_begin_index(self, begin_index: int = 0):
223+
def set_begin_index(self, begin_index: int = 0) -> None:
209224
"""
210225
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
211226
@@ -239,14 +254,21 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
239254
self.is_scale_input_called = True
240255
return sample
241256

242-
def get_lms_coefficient(self, order, t, current_order):
257+
def get_lms_coefficient(self, order: int, t: int, current_order: int) -> float:
243258
"""
244259
Compute the linear multistep coefficient.
245260
246261
Args:
247-
order ():
248-
t ():
249-
current_order ():
262+
order (`int`):
263+
The order of the linear multistep method.
264+
t (`int`):
265+
The current timestep index.
266+
current_order (`int`):
267+
The current order for which to compute the coefficient.
268+
269+
Returns:
270+
`float`:
271+
The computed linear multistep coefficient.
250272
"""
251273

252274
def lms_derivative(tau):
@@ -261,7 +283,7 @@ def lms_derivative(tau):
261283

262284
return integrated_coeff
263285

264-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
286+
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
265287
"""
266288
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
267289
@@ -367,7 +389,7 @@ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
367389
self._step_index = self._begin_index
368390

369391
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
370-
def _sigma_to_t(self, sigma, log_sigmas):
392+
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
371393
"""
372394
Convert sigma values to corresponding timestep values through interpolation.
373395
@@ -405,7 +427,17 @@ def _sigma_to_t(self, sigma, log_sigmas):
405427

406428
# copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
407429
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
408-
"""Constructs the noise schedule of Karras et al. (2022)."""
430+
"""
431+
Construct the noise schedule of Karras et al. (2022).
432+
433+
Args:
434+
in_sigmas (`torch.Tensor`):
435+
The input sigma values to be converted.
436+
437+
Returns:
438+
`torch.Tensor`:
439+
The converted sigma values following the Karras noise schedule.
440+
"""
409441

410442
sigma_min: float = in_sigmas[-1].item()
411443
sigma_max: float = in_sigmas[0].item()
@@ -629,5 +661,5 @@ def add_noise(
629661
noisy_samples = original_samples + noise * sigma
630662
return noisy_samples
631663

632-
def __len__(self):
664+
def __len__(self) -> int:
633665
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)