Skip to content

Commit d1aadb8

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
RIG OSS (3/?): Open source ProbNerf, diffusion + models.py.
PiperOrigin-RevId: 652642011
1 parent 566f6d5 commit d1aadb8

File tree

7 files changed

+2296
-0
lines changed

7 files changed

+2296
-0
lines changed

Diff for: discussion/robust_inverse_graphics/diffusion.py

+355
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
# Copyright 2024 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
r"""Diffusion/score matching utilities.
15+
16+
In this file, we are generally dealing with the following joint model:
17+
18+
Q(x) Q(z_0 | x) \prod_{t=1}^T Q(z_t | z_{t - 1})
19+
20+
T could be infinite depending on the model used. The goal is to learn a model
21+
for Q(x) which we only have access to via samples from it. We do this by
22+
learning P(z_{t - 1} | z_t; t), which we parameterize using a trainable
23+
`denoise_fn(z_t, f(t))` with some function `f` (often log signal-to-noise
24+
ratio).
25+
"""
26+
27+
import enum
28+
from typing import Any, Callable
29+
30+
from flax import struct
31+
import jax
32+
import jax.numpy as jnp
33+
from discussion.robust_inverse_graphics import saving
34+
from fun_mc import using_jax as fun_mc
35+
36+
37+
__all__ = [
38+
"linear_log_snr",
39+
"variance_preserving_forward_process",
40+
"vdm_diffusion_loss",
41+
"vdm_sample",
42+
"VDMDiffusionLossExtra",
43+
"VDMSampleExtra",
44+
"DenoiseOutputType",
45+
]
46+
47+
48+
Extra = Any
49+
LogSnrFn = Callable[[jnp.ndarray], jnp.ndarray]
50+
DenoiseFn = Callable[[jnp.ndarray, jnp.ndarray], tuple[jnp.ndarray, Extra]]
51+
52+
53+
class DenoiseOutputType(enum.Enum):
54+
"""How to interpret the output of `denoise_fn`.
55+
56+
NOISE: The output is the predicted noise.
57+
ANGULAR_VELOCITY: The output is the angular velocity, defined as:
58+
`alpha_t * noise - sigma_t * x`.
59+
DIRECT: The output is the denoised value.
60+
"""
61+
62+
NOISE = "noise"
63+
ANGULAR_VELOCITY = "angular_velocity"
64+
DIRECT = "direct"
65+
66+
67+
def variance_preserving_forward_process(
68+
z_0: jnp.ndarray, noise: jnp.ndarray, log_snr_t: jnp.ndarray
69+
) -> jnp.ndarray:
70+
"""Variance preserving forward process.
71+
72+
This produces a sample from Q(z_t | z_0) given the desired level of noise and
73+
randomness.
74+
75+
Args:
76+
z_0: Un-noised inputs.
77+
noise: Noise.
78+
log_snr_t: Log signal-to-noise ratio at time t.
79+
80+
Returns:
81+
Value of z_t.
82+
"""
83+
var_t = jax.nn.sigmoid(-log_snr_t)
84+
alpha_t = jnp.sqrt(jax.nn.sigmoid(log_snr_t)) # sqrt(1 - var_t)
85+
return alpha_t * z_0 + jnp.sqrt(var_t) * noise
86+
87+
88+
@saving.register
89+
@struct.dataclass
90+
class VDMDiffusionLossExtra:
91+
"""Extra outputs from `vdm_diffusion_loss`.
92+
93+
Attributes:
94+
noise: The added noise.
95+
recon_noise: The reconstructed noise (only set if `denoise_output` is NOISE.
96+
target: Target value for the loss to reconstruct.
97+
recon: Output of `denoise_fn`.
98+
extra: Extra outputs from `denoise_fn`.
99+
"""
100+
101+
noise: jnp.ndarray
102+
recon_noise: jnp.ndarray | None
103+
target: jnp.ndarray
104+
recon: jnp.ndarray
105+
extra: Extra
106+
107+
108+
def vdm_diffusion_loss(
109+
t: jnp.ndarray,
110+
num_steps: int | None,
111+
x: jnp.ndarray,
112+
log_snr_fn: LogSnrFn,
113+
denoise_fn: DenoiseFn,
114+
seed: jax.Array,
115+
denoise_output: DenoiseOutputType = DenoiseOutputType.NOISE,
116+
) -> tuple[jnp.ndarray, VDMDiffusionLossExtra]:
117+
r"""The diffusion loss of the variational diffusion model (VDM).
118+
119+
This uses the parameterization from [1]. The typical procedure minimizes the
120+
expectation of this function, averaging across examples (`z_0`) and times
121+
(sampled uniformly from [0, 1]).
122+
123+
When `denoise_output` is NOISE, and when the loss is minimized,
124+
`denoise_fn(z_t, log_snr_t) \propto -grad log Q(z_t; log_snr_t)` where `z_t`
125+
is sampled from the forward process (`variance_preserving_forward_process`)
126+
and `Q(.; log_snr_t)` is the marginal density of `z_t`.
127+
128+
Args:
129+
t: Time in [0, 1]
130+
num_steps: If None, use continuous time parameterization. Otherwise,
131+
discretize `t` to this many bins.
132+
x: Un-noised inputs.
133+
log_snr_fn: Takes in time in [0, 1] and returns the log signal-to-noise
134+
ratio.
135+
denoise_fn: Function that denoises `z_t` given the `log_snr_t`. Its output
136+
is interpreted based on the value of `denoise_output`.
137+
seed: Random seed.
138+
denoise_output: How to interpret the output of `denoise_fn`.
139+
140+
Returns:
141+
A tuple of the loss and `VDMDiffusionLossExtra` extra outputs.
142+
143+
#### References
144+
145+
[1] Kingma, D. P., Salimans, T., Poole, B., & Ho, J. (2021). Variational
146+
Diffusion Models. In arXiv [cs.LG]. arXiv. http://arxiv.org/abs/2107.00630
147+
"""
148+
149+
if num_steps is not None:
150+
t = jnp.ceil(t * num_steps) / num_steps
151+
152+
log_snr_t = log_snr_fn(t)
153+
noise = jax.random.normal(seed, x.shape)
154+
z_t = variance_preserving_forward_process(x, noise, log_snr_t)
155+
156+
recon, extra = denoise_fn(z_t, log_snr_t)
157+
158+
match denoise_output:
159+
case DenoiseOutputType.NOISE:
160+
target = noise
161+
recon_noise = recon
162+
sq_error = 0.5 * jnp.square(target - recon).sum()
163+
if num_steps is None:
164+
log_snr_t_grad = jax.grad(log_snr_fn)(t)
165+
loss = -log_snr_t_grad * sq_error
166+
else:
167+
s = t - (1.0 / num_steps)
168+
log_snr_s = log_snr_fn(s)
169+
loss = num_steps * jnp.expm1(log_snr_s - log_snr_t) * sq_error
170+
case DenoiseOutputType.ANGULAR_VELOCITY:
171+
# Plug in x_hat = alpha_t * z_t - sigma_t * v into equation (13) or (15)
172+
# and simplify to get the loss being (SNR(s) - SNR(t)) sigma_t**2 * MSE
173+
# for discrete case and SNR'(t) * sigma_t**2 * MSE for the continuous
174+
# case.
175+
recon_noise = None
176+
var_t = jax.nn.sigmoid(-log_snr_t)
177+
sigma_t = jnp.sqrt(var_t)
178+
alpha_t_2 = jax.nn.sigmoid(log_snr_t)
179+
alpha_t = jnp.sqrt(alpha_t_2)
180+
v = alpha_t * noise - sigma_t * x
181+
182+
target = v
183+
sq_error = 0.5 * jnp.square(target - recon).sum()
184+
if num_steps is None:
185+
log_snr_t_grad = jax.grad(log_snr_fn)(t)
186+
loss = -alpha_t_2 * log_snr_t_grad * sq_error
187+
else:
188+
s = t - (1.0 / num_steps)
189+
log_snr_s = log_snr_fn(s)
190+
loss = (
191+
num_steps * jnp.expm1(log_snr_s - log_snr_t) * alpha_t_2 * sq_error
192+
)
193+
case DenoiseOutputType.DIRECT:
194+
recon_noise = None
195+
target = x
196+
sq_error = 0.5 * jnp.square(target - recon).sum()
197+
if num_steps is None:
198+
snr_t_grad = jax.grad(lambda t: jnp.exp(log_snr_fn(t)))(t)
199+
loss = -snr_t_grad * sq_error
200+
else:
201+
s = t - (1.0 / num_steps)
202+
snr_t = jnp.exp(log_snr_t)
203+
# TODO(siege): Not sure this is more stable than doing snr_s - snr_t
204+
# directly.
205+
log_snr_s = log_snr_fn(s)
206+
loss = num_steps * snr_t * jnp.expm1(log_snr_s - log_snr_t) * sq_error
207+
case _:
208+
raise ValueError(f"Unknown denoise_output: {denoise_output}")
209+
210+
return loss, VDMDiffusionLossExtra(
211+
noise=noise,
212+
recon_noise=recon_noise,
213+
target=target,
214+
recon=recon,
215+
extra=extra,
216+
)
217+
218+
219+
def _vdm_sample_step(
220+
z_t: jnp.ndarray,
221+
step: jnp.ndarray,
222+
num_steps: int,
223+
log_snr_fn: LogSnrFn,
224+
denoise_fn: DenoiseFn,
225+
seed: jax.Array,
226+
denoise_output: DenoiseOutputType,
227+
t_start: jnp.ndarray,
228+
) -> tuple[jnp.ndarray, Extra]:
229+
"""One step of the sampling process."""
230+
t = t_start * (step / num_steps)
231+
s = t_start * ((step - 1) / num_steps)
232+
233+
log_snr_t = log_snr_fn(t)
234+
log_snr_s = log_snr_fn(s)
235+
recon, extra = denoise_fn(z_t, log_snr_t)
236+
237+
zeta = jax.random.normal(seed, z_t.shape)
238+
239+
alpha_s_2 = jax.nn.sigmoid(log_snr_s)
240+
alpha_s = jnp.sqrt(alpha_s_2)
241+
alpha_t_2 = jax.nn.sigmoid(log_snr_t)
242+
alpha_t = jnp.sqrt(alpha_t_2)
243+
var_t_s_div_var_t = -jnp.expm1(log_snr_t - log_snr_s)
244+
var_s = jax.nn.sigmoid(-log_snr_s)
245+
var_t = jax.nn.sigmoid(-log_snr_t)
246+
sigma_t = jnp.sqrt(var_t)
247+
248+
match denoise_output:
249+
case DenoiseOutputType.NOISE:
250+
recon_noise = recon
251+
mu = jnp.sqrt(alpha_s_2 / alpha_t_2) * (
252+
z_t - sigma_t * var_t_s_div_var_t * recon_noise
253+
)
254+
case DenoiseOutputType.ANGULAR_VELOCITY:
255+
# We use the expression for q(z_s | z_t, x) directly with x_hat
256+
# substituted for x.
257+
# TODO(siege): Try simplifying this further for better numerics.
258+
x_hat = alpha_t * z_t - sigma_t * recon
259+
alpha_t_s = alpha_t / alpha_s
260+
261+
mu = alpha_t_s * var_s / var_t * z_t + alpha_s * var_t_s_div_var_t * x_hat
262+
case DenoiseOutputType.DIRECT:
263+
x_hat = recon
264+
alpha_t_s = alpha_t / alpha_s
265+
266+
mu = alpha_t_s * var_s / var_t * z_t + alpha_s * var_t_s_div_var_t * x_hat
267+
case _:
268+
raise ValueError(f"Unknown denoise_output: {denoise_output}")
269+
sigma = jnp.sqrt(var_t_s_div_var_t * var_s)
270+
z_s = mu + sigma * zeta
271+
return z_s, extra
272+
273+
274+
@saving.register
275+
@struct.dataclass
276+
class VDMSampleExtra:
277+
"""Extra outputs from `vdm_sample`.
278+
279+
Attributes:
280+
z_s: A trace of samples.
281+
"""
282+
283+
z_s: jnp.ndarray | None
284+
285+
286+
def vdm_sample(
287+
z_t: jnp.ndarray,
288+
num_steps: int,
289+
log_snr_fn: LogSnrFn,
290+
denoise_fn: DenoiseFn,
291+
seed: jax.Array,
292+
trace_z_s: bool = False,
293+
denoise_output: DenoiseOutputType = DenoiseOutputType.NOISE,
294+
t_start: jnp.ndarray | float = 1.0,
295+
) -> tuple[jnp.ndarray, VDMSampleExtra]:
296+
"""Generates a sample from the variational diffusion model (VDM).
297+
298+
This uses the sampler from [1]. See `vdm_diffusion_loss` for the requirements
299+
on `denoise_fn`.
300+
301+
Args:
302+
z_t: The initial noised sample. Should have the same distribution as
303+
`variance_preserving_forward_process(x, noise, log_snr_fn(t_start))`.
304+
num_steps: Number of steps to take. The more steps taken, then more accurate
305+
the sample. 1000 is a common value.
306+
log_snr_fn: Takes in time in [0, 1] and returns the log signal-to-noise
307+
ratio.
308+
denoise_fn: Function that denoises `z_t` given the `log_snr_t`. Its output
309+
is interpreted based on the value of `denoise_output`.
310+
seed: Random seed.
311+
trace_z_s: Whether to trace intermediate samples.
312+
denoise_output: How to interpret the output of `denoise_fn`.
313+
t_start: The value of t in z_t. Typically this is 1, signifying that z_t is
314+
a sample from a standard normal.
315+
316+
Returns:
317+
A tuple of the sample and `VDMSampleExtra` extra outputs.
318+
319+
320+
#### References
321+
322+
[1] Kingma, D. P., Salimans, T., Poole, B., & Ho, J. (2021). Variational
323+
Diffusion Models. In arXiv [cs.LG]. arXiv. http://arxiv.org/abs/2107.00630
324+
"""
325+
326+
def body(z_t, step, seed):
327+
sample_seed, seed = jax.random.split(seed)
328+
z_s, _ = _vdm_sample_step(
329+
z_t=z_t,
330+
step=step,
331+
num_steps=num_steps,
332+
log_snr_fn=log_snr_fn,
333+
denoise_fn=denoise_fn,
334+
seed=sample_seed,
335+
denoise_output=denoise_output,
336+
t_start=t_start,
337+
)
338+
if trace_z_s:
339+
trace = {"z_s": z_s}
340+
else:
341+
trace = {}
342+
return (z_s, step - 1, seed), trace
343+
344+
(z_0, _, _), trace = fun_mc.trace((z_t, num_steps, seed), body, num_steps)
345+
346+
return z_0, VDMSampleExtra(z_s=trace.get("z_s"))
347+
348+
349+
def linear_log_snr(
350+
t: jnp.ndarray,
351+
log_snr_start: jax.typing.ArrayLike = 6.0,
352+
log_snr_end: jax.typing.ArrayLike = -6.0,
353+
) -> jnp.ndarray:
354+
"""Linear log signal-to-noise ratio function."""
355+
return log_snr_start + (log_snr_end - log_snr_start) * t # pytype: disable=bad-return-type # numpy-scalars

0 commit comments

Comments
 (0)