|
| 1 | +# Copyright 2024 The TensorFlow Probability Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +r"""Diffusion/score matching utilities. |
| 15 | +
|
| 16 | +In this file, we are generally dealing with the following joint model: |
| 17 | +
|
| 18 | +Q(x) Q(z_0 | x) \prod_{t=1}^T Q(z_t | z_{t - 1}) |
| 19 | +
|
| 20 | +T could be infinite depending on the model used. The goal is to learn a model |
| 21 | +for Q(x) which we only have access to via samples from it. We do this by |
| 22 | +learning P(z_{t - 1} | z_t; t), which we parameterize using a trainable |
| 23 | +`denoise_fn(z_t, f(t))` with some function `f` (often log signal-to-noise |
| 24 | +ratio). |
| 25 | +""" |
| 26 | + |
| 27 | +import enum |
| 28 | +from typing import Any, Callable |
| 29 | + |
| 30 | +from flax import struct |
| 31 | +import jax |
| 32 | +import jax.numpy as jnp |
| 33 | +from discussion.robust_inverse_graphics import saving |
| 34 | +from fun_mc import using_jax as fun_mc |
| 35 | + |
| 36 | + |
| 37 | +__all__ = [ |
| 38 | + "linear_log_snr", |
| 39 | + "variance_preserving_forward_process", |
| 40 | + "vdm_diffusion_loss", |
| 41 | + "vdm_sample", |
| 42 | + "VDMDiffusionLossExtra", |
| 43 | + "VDMSampleExtra", |
| 44 | + "DenoiseOutputType", |
| 45 | +] |
| 46 | + |
| 47 | + |
| 48 | +Extra = Any |
| 49 | +LogSnrFn = Callable[[jnp.ndarray], jnp.ndarray] |
| 50 | +DenoiseFn = Callable[[jnp.ndarray, jnp.ndarray], tuple[jnp.ndarray, Extra]] |
| 51 | + |
| 52 | + |
| 53 | +class DenoiseOutputType(enum.Enum): |
| 54 | + """How to interpret the output of `denoise_fn`. |
| 55 | +
|
| 56 | + NOISE: The output is the predicted noise. |
| 57 | + ANGULAR_VELOCITY: The output is the angular velocity, defined as: |
| 58 | + `alpha_t * noise - sigma_t * x`. |
| 59 | + DIRECT: The output is the denoised value. |
| 60 | + """ |
| 61 | + |
| 62 | + NOISE = "noise" |
| 63 | + ANGULAR_VELOCITY = "angular_velocity" |
| 64 | + DIRECT = "direct" |
| 65 | + |
| 66 | + |
| 67 | +def variance_preserving_forward_process( |
| 68 | + z_0: jnp.ndarray, noise: jnp.ndarray, log_snr_t: jnp.ndarray |
| 69 | +) -> jnp.ndarray: |
| 70 | + """Variance preserving forward process. |
| 71 | +
|
| 72 | + This produces a sample from Q(z_t | z_0) given the desired level of noise and |
| 73 | + randomness. |
| 74 | +
|
| 75 | + Args: |
| 76 | + z_0: Un-noised inputs. |
| 77 | + noise: Noise. |
| 78 | + log_snr_t: Log signal-to-noise ratio at time t. |
| 79 | +
|
| 80 | + Returns: |
| 81 | + Value of z_t. |
| 82 | + """ |
| 83 | + var_t = jax.nn.sigmoid(-log_snr_t) |
| 84 | + alpha_t = jnp.sqrt(jax.nn.sigmoid(log_snr_t)) # sqrt(1 - var_t) |
| 85 | + return alpha_t * z_0 + jnp.sqrt(var_t) * noise |
| 86 | + |
| 87 | + |
| 88 | +@saving.register |
| 89 | +@struct.dataclass |
| 90 | +class VDMDiffusionLossExtra: |
| 91 | + """Extra outputs from `vdm_diffusion_loss`. |
| 92 | +
|
| 93 | + Attributes: |
| 94 | + noise: The added noise. |
| 95 | + recon_noise: The reconstructed noise (only set if `denoise_output` is NOISE. |
| 96 | + target: Target value for the loss to reconstruct. |
| 97 | + recon: Output of `denoise_fn`. |
| 98 | + extra: Extra outputs from `denoise_fn`. |
| 99 | + """ |
| 100 | + |
| 101 | + noise: jnp.ndarray |
| 102 | + recon_noise: jnp.ndarray | None |
| 103 | + target: jnp.ndarray |
| 104 | + recon: jnp.ndarray |
| 105 | + extra: Extra |
| 106 | + |
| 107 | + |
| 108 | +def vdm_diffusion_loss( |
| 109 | + t: jnp.ndarray, |
| 110 | + num_steps: int | None, |
| 111 | + x: jnp.ndarray, |
| 112 | + log_snr_fn: LogSnrFn, |
| 113 | + denoise_fn: DenoiseFn, |
| 114 | + seed: jax.Array, |
| 115 | + denoise_output: DenoiseOutputType = DenoiseOutputType.NOISE, |
| 116 | +) -> tuple[jnp.ndarray, VDMDiffusionLossExtra]: |
| 117 | + r"""The diffusion loss of the variational diffusion model (VDM). |
| 118 | +
|
| 119 | + This uses the parameterization from [1]. The typical procedure minimizes the |
| 120 | + expectation of this function, averaging across examples (`z_0`) and times |
| 121 | + (sampled uniformly from [0, 1]). |
| 122 | +
|
| 123 | + When `denoise_output` is NOISE, and when the loss is minimized, |
| 124 | + `denoise_fn(z_t, log_snr_t) \propto -grad log Q(z_t; log_snr_t)` where `z_t` |
| 125 | + is sampled from the forward process (`variance_preserving_forward_process`) |
| 126 | + and `Q(.; log_snr_t)` is the marginal density of `z_t`. |
| 127 | +
|
| 128 | + Args: |
| 129 | + t: Time in [0, 1] |
| 130 | + num_steps: If None, use continuous time parameterization. Otherwise, |
| 131 | + discretize `t` to this many bins. |
| 132 | + x: Un-noised inputs. |
| 133 | + log_snr_fn: Takes in time in [0, 1] and returns the log signal-to-noise |
| 134 | + ratio. |
| 135 | + denoise_fn: Function that denoises `z_t` given the `log_snr_t`. Its output |
| 136 | + is interpreted based on the value of `denoise_output`. |
| 137 | + seed: Random seed. |
| 138 | + denoise_output: How to interpret the output of `denoise_fn`. |
| 139 | +
|
| 140 | + Returns: |
| 141 | + A tuple of the loss and `VDMDiffusionLossExtra` extra outputs. |
| 142 | +
|
| 143 | + #### References |
| 144 | +
|
| 145 | + [1] Kingma, D. P., Salimans, T., Poole, B., & Ho, J. (2021). Variational |
| 146 | + Diffusion Models. In arXiv [cs.LG]. arXiv. http://arxiv.org/abs/2107.00630 |
| 147 | + """ |
| 148 | + |
| 149 | + if num_steps is not None: |
| 150 | + t = jnp.ceil(t * num_steps) / num_steps |
| 151 | + |
| 152 | + log_snr_t = log_snr_fn(t) |
| 153 | + noise = jax.random.normal(seed, x.shape) |
| 154 | + z_t = variance_preserving_forward_process(x, noise, log_snr_t) |
| 155 | + |
| 156 | + recon, extra = denoise_fn(z_t, log_snr_t) |
| 157 | + |
| 158 | + match denoise_output: |
| 159 | + case DenoiseOutputType.NOISE: |
| 160 | + target = noise |
| 161 | + recon_noise = recon |
| 162 | + sq_error = 0.5 * jnp.square(target - recon).sum() |
| 163 | + if num_steps is None: |
| 164 | + log_snr_t_grad = jax.grad(log_snr_fn)(t) |
| 165 | + loss = -log_snr_t_grad * sq_error |
| 166 | + else: |
| 167 | + s = t - (1.0 / num_steps) |
| 168 | + log_snr_s = log_snr_fn(s) |
| 169 | + loss = num_steps * jnp.expm1(log_snr_s - log_snr_t) * sq_error |
| 170 | + case DenoiseOutputType.ANGULAR_VELOCITY: |
| 171 | + # Plug in x_hat = alpha_t * z_t - sigma_t * v into equation (13) or (15) |
| 172 | + # and simplify to get the loss being (SNR(s) - SNR(t)) sigma_t**2 * MSE |
| 173 | + # for discrete case and SNR'(t) * sigma_t**2 * MSE for the continuous |
| 174 | + # case. |
| 175 | + recon_noise = None |
| 176 | + var_t = jax.nn.sigmoid(-log_snr_t) |
| 177 | + sigma_t = jnp.sqrt(var_t) |
| 178 | + alpha_t_2 = jax.nn.sigmoid(log_snr_t) |
| 179 | + alpha_t = jnp.sqrt(alpha_t_2) |
| 180 | + v = alpha_t * noise - sigma_t * x |
| 181 | + |
| 182 | + target = v |
| 183 | + sq_error = 0.5 * jnp.square(target - recon).sum() |
| 184 | + if num_steps is None: |
| 185 | + log_snr_t_grad = jax.grad(log_snr_fn)(t) |
| 186 | + loss = -alpha_t_2 * log_snr_t_grad * sq_error |
| 187 | + else: |
| 188 | + s = t - (1.0 / num_steps) |
| 189 | + log_snr_s = log_snr_fn(s) |
| 190 | + loss = ( |
| 191 | + num_steps * jnp.expm1(log_snr_s - log_snr_t) * alpha_t_2 * sq_error |
| 192 | + ) |
| 193 | + case DenoiseOutputType.DIRECT: |
| 194 | + recon_noise = None |
| 195 | + target = x |
| 196 | + sq_error = 0.5 * jnp.square(target - recon).sum() |
| 197 | + if num_steps is None: |
| 198 | + snr_t_grad = jax.grad(lambda t: jnp.exp(log_snr_fn(t)))(t) |
| 199 | + loss = -snr_t_grad * sq_error |
| 200 | + else: |
| 201 | + s = t - (1.0 / num_steps) |
| 202 | + snr_t = jnp.exp(log_snr_t) |
| 203 | + # TODO(siege): Not sure this is more stable than doing snr_s - snr_t |
| 204 | + # directly. |
| 205 | + log_snr_s = log_snr_fn(s) |
| 206 | + loss = num_steps * snr_t * jnp.expm1(log_snr_s - log_snr_t) * sq_error |
| 207 | + case _: |
| 208 | + raise ValueError(f"Unknown denoise_output: {denoise_output}") |
| 209 | + |
| 210 | + return loss, VDMDiffusionLossExtra( |
| 211 | + noise=noise, |
| 212 | + recon_noise=recon_noise, |
| 213 | + target=target, |
| 214 | + recon=recon, |
| 215 | + extra=extra, |
| 216 | + ) |
| 217 | + |
| 218 | + |
| 219 | +def _vdm_sample_step( |
| 220 | + z_t: jnp.ndarray, |
| 221 | + step: jnp.ndarray, |
| 222 | + num_steps: int, |
| 223 | + log_snr_fn: LogSnrFn, |
| 224 | + denoise_fn: DenoiseFn, |
| 225 | + seed: jax.Array, |
| 226 | + denoise_output: DenoiseOutputType, |
| 227 | + t_start: jnp.ndarray, |
| 228 | +) -> tuple[jnp.ndarray, Extra]: |
| 229 | + """One step of the sampling process.""" |
| 230 | + t = t_start * (step / num_steps) |
| 231 | + s = t_start * ((step - 1) / num_steps) |
| 232 | + |
| 233 | + log_snr_t = log_snr_fn(t) |
| 234 | + log_snr_s = log_snr_fn(s) |
| 235 | + recon, extra = denoise_fn(z_t, log_snr_t) |
| 236 | + |
| 237 | + zeta = jax.random.normal(seed, z_t.shape) |
| 238 | + |
| 239 | + alpha_s_2 = jax.nn.sigmoid(log_snr_s) |
| 240 | + alpha_s = jnp.sqrt(alpha_s_2) |
| 241 | + alpha_t_2 = jax.nn.sigmoid(log_snr_t) |
| 242 | + alpha_t = jnp.sqrt(alpha_t_2) |
| 243 | + var_t_s_div_var_t = -jnp.expm1(log_snr_t - log_snr_s) |
| 244 | + var_s = jax.nn.sigmoid(-log_snr_s) |
| 245 | + var_t = jax.nn.sigmoid(-log_snr_t) |
| 246 | + sigma_t = jnp.sqrt(var_t) |
| 247 | + |
| 248 | + match denoise_output: |
| 249 | + case DenoiseOutputType.NOISE: |
| 250 | + recon_noise = recon |
| 251 | + mu = jnp.sqrt(alpha_s_2 / alpha_t_2) * ( |
| 252 | + z_t - sigma_t * var_t_s_div_var_t * recon_noise |
| 253 | + ) |
| 254 | + case DenoiseOutputType.ANGULAR_VELOCITY: |
| 255 | + # We use the expression for q(z_s | z_t, x) directly with x_hat |
| 256 | + # substituted for x. |
| 257 | + # TODO(siege): Try simplifying this further for better numerics. |
| 258 | + x_hat = alpha_t * z_t - sigma_t * recon |
| 259 | + alpha_t_s = alpha_t / alpha_s |
| 260 | + |
| 261 | + mu = alpha_t_s * var_s / var_t * z_t + alpha_s * var_t_s_div_var_t * x_hat |
| 262 | + case DenoiseOutputType.DIRECT: |
| 263 | + x_hat = recon |
| 264 | + alpha_t_s = alpha_t / alpha_s |
| 265 | + |
| 266 | + mu = alpha_t_s * var_s / var_t * z_t + alpha_s * var_t_s_div_var_t * x_hat |
| 267 | + case _: |
| 268 | + raise ValueError(f"Unknown denoise_output: {denoise_output}") |
| 269 | + sigma = jnp.sqrt(var_t_s_div_var_t * var_s) |
| 270 | + z_s = mu + sigma * zeta |
| 271 | + return z_s, extra |
| 272 | + |
| 273 | + |
| 274 | +@saving.register |
| 275 | +@struct.dataclass |
| 276 | +class VDMSampleExtra: |
| 277 | + """Extra outputs from `vdm_sample`. |
| 278 | +
|
| 279 | + Attributes: |
| 280 | + z_s: A trace of samples. |
| 281 | + """ |
| 282 | + |
| 283 | + z_s: jnp.ndarray | None |
| 284 | + |
| 285 | + |
| 286 | +def vdm_sample( |
| 287 | + z_t: jnp.ndarray, |
| 288 | + num_steps: int, |
| 289 | + log_snr_fn: LogSnrFn, |
| 290 | + denoise_fn: DenoiseFn, |
| 291 | + seed: jax.Array, |
| 292 | + trace_z_s: bool = False, |
| 293 | + denoise_output: DenoiseOutputType = DenoiseOutputType.NOISE, |
| 294 | + t_start: jnp.ndarray | float = 1.0, |
| 295 | +) -> tuple[jnp.ndarray, VDMSampleExtra]: |
| 296 | + """Generates a sample from the variational diffusion model (VDM). |
| 297 | +
|
| 298 | + This uses the sampler from [1]. See `vdm_diffusion_loss` for the requirements |
| 299 | + on `denoise_fn`. |
| 300 | +
|
| 301 | + Args: |
| 302 | + z_t: The initial noised sample. Should have the same distribution as |
| 303 | + `variance_preserving_forward_process(x, noise, log_snr_fn(t_start))`. |
| 304 | + num_steps: Number of steps to take. The more steps taken, then more accurate |
| 305 | + the sample. 1000 is a common value. |
| 306 | + log_snr_fn: Takes in time in [0, 1] and returns the log signal-to-noise |
| 307 | + ratio. |
| 308 | + denoise_fn: Function that denoises `z_t` given the `log_snr_t`. Its output |
| 309 | + is interpreted based on the value of `denoise_output`. |
| 310 | + seed: Random seed. |
| 311 | + trace_z_s: Whether to trace intermediate samples. |
| 312 | + denoise_output: How to interpret the output of `denoise_fn`. |
| 313 | + t_start: The value of t in z_t. Typically this is 1, signifying that z_t is |
| 314 | + a sample from a standard normal. |
| 315 | +
|
| 316 | + Returns: |
| 317 | + A tuple of the sample and `VDMSampleExtra` extra outputs. |
| 318 | +
|
| 319 | +
|
| 320 | + #### References |
| 321 | +
|
| 322 | + [1] Kingma, D. P., Salimans, T., Poole, B., & Ho, J. (2021). Variational |
| 323 | + Diffusion Models. In arXiv [cs.LG]. arXiv. http://arxiv.org/abs/2107.00630 |
| 324 | + """ |
| 325 | + |
| 326 | + def body(z_t, step, seed): |
| 327 | + sample_seed, seed = jax.random.split(seed) |
| 328 | + z_s, _ = _vdm_sample_step( |
| 329 | + z_t=z_t, |
| 330 | + step=step, |
| 331 | + num_steps=num_steps, |
| 332 | + log_snr_fn=log_snr_fn, |
| 333 | + denoise_fn=denoise_fn, |
| 334 | + seed=sample_seed, |
| 335 | + denoise_output=denoise_output, |
| 336 | + t_start=t_start, |
| 337 | + ) |
| 338 | + if trace_z_s: |
| 339 | + trace = {"z_s": z_s} |
| 340 | + else: |
| 341 | + trace = {} |
| 342 | + return (z_s, step - 1, seed), trace |
| 343 | + |
| 344 | + (z_0, _, _), trace = fun_mc.trace((z_t, num_steps, seed), body, num_steps) |
| 345 | + |
| 346 | + return z_0, VDMSampleExtra(z_s=trace.get("z_s")) |
| 347 | + |
| 348 | + |
| 349 | +def linear_log_snr( |
| 350 | + t: jnp.ndarray, |
| 351 | + log_snr_start: jax.typing.ArrayLike = 6.0, |
| 352 | + log_snr_end: jax.typing.ArrayLike = -6.0, |
| 353 | +) -> jnp.ndarray: |
| 354 | + """Linear log signal-to-noise ratio function.""" |
| 355 | + return log_snr_start + (log_snr_end - log_snr_start) * t # pytype: disable=bad-return-type # numpy-scalars |
0 commit comments