@@ -79,15 +79,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
7979 methods the library implements for all schedulers such as loading and saving.
8080
8181 Args:
82- num_train_timesteps (`int`, defaults to 1000):
82+ num_train_timesteps (`int`, defaults to ` 1000` ):
8383 The number of diffusion steps to train the model.
84- beta_start (`float`, defaults to 0.0001):
84+ beta_start (`float`, defaults to ` 0.0001` ):
8585 The starting `beta` value of inference.
86- beta_end (`float`, defaults to 0.02):
86+ beta_end (`float`, defaults to ` 0.02` ):
8787 The final `beta` value.
88- beta_schedule (`str`, defaults to `"linear"`):
89- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
90- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
88+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
89+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
9190 trained_betas (`np.ndarray`, *optional*):
9291 Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
9392 skip_prk_steps (`bool`, defaults to `False`):
@@ -97,14 +96,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
9796 Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
9897 there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
9998 otherwise it uses the alpha value at step 0.
100- prediction_type (`str` , defaults to `epsilon`, *optional* ):
99+ prediction_type (`"epsilon"` or `"v_prediction"` , defaults to `" epsilon"` ):
101100 Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
102- or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
103- paper).
104- timestep_spacing (`str`, defaults to `"leading"`):
101+ or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
102+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
105103 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
106104 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
107- steps_offset (`int`, defaults to 0 ):
105+ steps_offset (`int`, defaults to `0` ):
108106 An offset added to the inference steps, as required by some model families.
109107 """
110108
@@ -117,12 +115,12 @@ def __init__(
117115 num_train_timesteps : int = 1000 ,
118116 beta_start : float = 0.0001 ,
119117 beta_end : float = 0.02 ,
120- beta_schedule : str = "linear" ,
118+ beta_schedule : Literal [ "linear" , "scaled_linear" , "squaredcos_cap_v2" ] = "linear" ,
121119 trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
122120 skip_prk_steps : bool = False ,
123121 set_alpha_to_one : bool = False ,
124- prediction_type : str = "epsilon" ,
125- timestep_spacing : str = "leading" ,
122+ prediction_type : Literal [ "epsilon" , "v_prediction" ] = "epsilon" ,
123+ timestep_spacing : Literal [ "linspace" , "leading" , "trailing" ] = "leading" ,
126124 steps_offset : int = 0 ,
127125 ):
128126 if trained_betas is not None :
@@ -164,7 +162,7 @@ def __init__(
164162 self .plms_timesteps = None
165163 self .timesteps = None
166164
167- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
165+ def set_timesteps (self , num_inference_steps : int , device : Optional [ Union [str , torch .device ]] = None ) -> None :
168166 """
169167 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
170168
@@ -243,7 +241,7 @@ def step(
243241 The current discrete timestep in the diffusion chain.
244242 sample (`torch.Tensor`):
245243 A current instance of a sample created by the diffusion process.
246- return_dict (`bool`):
244+ return_dict (`bool`, defaults to `True` ):
247245 Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
248246
249247 Returns:
@@ -276,14 +274,13 @@ def step_prk(
276274 The current discrete timestep in the diffusion chain.
277275 sample (`torch.Tensor`):
278276 A current instance of a sample created by the diffusion process.
279- return_dict (`bool`):
277+ return_dict (`bool`, defaults to `True` ):
280278 Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
281279
282280 Returns:
283281 [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
284282 If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
285283 tuple is returned where the first element is the sample tensor.
286-
287284 """
288285 if self .num_inference_steps is None :
289286 raise ValueError (
@@ -335,14 +332,13 @@ def step_plms(
335332 The current discrete timestep in the diffusion chain.
336333 sample (`torch.Tensor`):
337334 A current instance of a sample created by the diffusion process.
338- return_dict (`bool`):
335+ return_dict (`bool`, defaults to `True` ):
339336 Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
340337
341338 Returns:
342339 [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
343340 If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
344341 tuple is returned where the first element is the sample tensor.
345-
346342 """
347343 if self .num_inference_steps is None :
348344 raise ValueError (
@@ -403,19 +399,37 @@ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tens
403399 """
404400 return sample
405401
406- def _get_prev_sample (self , sample , timestep , prev_timestep , model_output ):
407- # See formula (9) of PNDM paper https://huggingface.co/papers/2202.09778
408- # this function computes x_(t−δ) using the formula of (9)
409- # Note that x_t needs to be added to both sides of the equation
410-
411- # Notation (<variable name> -> <name in paper>
412- # alpha_prod_t -> α_t
413- # alpha_prod_t_prev -> α_(t−δ)
414- # beta_prod_t -> (1 - α_t)
415- # beta_prod_t_prev -> (1 - α_(t−δ))
416- # sample -> x_t
417- # model_output -> e_θ(x_t, t)
418- # prev_sample -> x_(t−δ)
402+ def _get_prev_sample (
403+ self , sample : torch .Tensor , timestep : int , prev_timestep : int , model_output : torch .Tensor
404+ ) -> torch .Tensor :
405+ """
406+ Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the PNDM paper.
407+
408+ See formula (9) of [PNDM paper](https://huggingface.co/papers/2202.09778)
409+
410+ Notation (<variable name> -> <name in paper>):
411+ - alpha_prod_t -> α_t
412+ - alpha_prod_t_prev -> α_(t−δ)
413+ - beta_prod_t -> (1 - α_t)
414+ - beta_prod_t_prev -> (1 - α_(t−δ))
415+ - sample -> x_t
416+ - model_output -> e_θ(x_t, t)
417+ - prev_sample -> x_(t−δ)
418+
419+ Args:
420+ sample (`torch.Tensor`):
421+ The current sample x_t.
422+ timestep (`int`):
423+ The current timestep t.
424+ prev_timestep (`int`):
425+ The previous timestep (t-δ).
426+ model_output (`torch.Tensor`):
427+ The model output e_θ(x_t, t).
428+
429+ Returns:
430+ `torch.Tensor`:
431+ The previous sample x_(t-δ).
432+ """
419433 alpha_prod_t = self .alphas_cumprod [timestep ]
420434 alpha_prod_t_prev = self .alphas_cumprod [prev_timestep ] if prev_timestep >= 0 else self .final_alpha_cumprod
421435 beta_prod_t = 1 - alpha_prod_t
@@ -489,5 +503,5 @@ def add_noise(
489503 noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
490504 return noisy_samples
491505
492- def __len__ (self ):
506+ def __len__ (self ) -> int :
493507 return self .config .num_train_timesteps
0 commit comments