Skip to content

Commit

Permalink
small changes to make code more readable (#42)
Browse files Browse the repository at this point in the history
* use norm if possible

* rename
  • Loading branch information
ismael-mendoza authored Nov 15, 2024
1 parent 9ebc63b commit 0409d02
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 24 deletions.
32 changes: 12 additions & 20 deletions bpd/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import jax.numpy as jnp
import jax.scipy as jsp
from jax import grad, vmap
from jax.numpy.linalg import norm
from jax.typing import ArrayLike

from bpd.prior import inv_shear_func1, inv_shear_func2, inv_shear_transformation

_grad_fnc1 = vmap(vmap(grad(inv_shear_func1), in_axes=(0, None)), in_axes=(0, None))
_grad_fnc2 = vmap(vmap(grad(inv_shear_func2), in_axes=(0, None)), in_axes=(0, None))


def shear_loglikelihood_unreduced(
g: tuple[float, float], e_post, prior: Callable, interim_prior: Callable
Expand All @@ -17,36 +21,24 @@ def shear_loglikelihood_unreduced(
# the priors are callables for now on only ellipticities
# the interim_prior should have been used when obtaining e_obs from the chain (i.e. for now same sigma)
# normalization in priors can be ignored for now as alpha is fixed.
_, K, _ = e_post.shape # (N, K, 2)
_, _, _ = e_post.shape # (N, K, 2)

e_post_mag = jnp.sqrt(e_post[..., 0] ** 2 + e_post[..., 1] ** 2)
e_post_mag = norm(e_post, axis=-1)
denom = interim_prior(e_post_mag) # (N, K), can ignore angle in prior as uniform

# for num, use trick
# p(w_n' | g, alpha ) = p(w_n' \cross^{-1} g | alpha ) = p(w_n | alpha) * |jac(w_n / w_n')|

# shape = (N, K, 2)
grad1 = vmap(
vmap(grad(inv_shear_func1, argnums=0), in_axes=(0, None)),
in_axes=(0, None),
)(e_post, g)

grad2 = vmap(
vmap(grad(inv_shear_func2, argnums=0), in_axes=(0, None)),
in_axes=(0, None),
)(e_post, g)

grad1 = _grad_fnc1(e_post, g)
grad2 = _grad_fnc2(e_post, g)
absjacdet = jnp.abs(grad1[..., 0] * grad2[..., 1] - grad1[..., 1] * grad2[..., 0])

e_post_unsheared = inv_shear_transformation(e_post, g)
e_obs_unsheared_mag = jnp.sqrt(
e_post_unsheared[..., 0] ** 2 + e_post_unsheared[..., 1] ** 2
)
num = prior(e_obs_unsheared_mag) * absjacdet # (N, K)

ratio = jnp.log((1 / K)) + jsp.special.logsumexp(
jnp.log(num) - jnp.log(denom), axis=-1
)
e_post_unsheared_mag = norm(e_post_unsheared, axis=-1)
num = prior(e_post_unsheared_mag) * absjacdet # (N, K)

ratio = jsp.special.logsumexp(jnp.log(num) - jnp.log(denom), axis=-1)
return ratio


Expand Down
5 changes: 2 additions & 3 deletions bpd/pipelines/toy_ellips.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jax import Array, random, vmap
from jax import jit as jjit
from jax._src.prng import PRNGKeyArray
from jax.numpy.linalg import norm

from bpd.chains import run_inference_nuts
from bpd.prior import ellip_mag_prior, sample_synthetic_sheared_ellips_unclipped
Expand All @@ -23,9 +24,7 @@ def logtarget(

# ignore angle prior assumed uniform
# prior enforces magnitude < 1.0 for posterior samples
e_sheared_mag = jnp.sqrt(e_sheared[0] ** 2 + e_sheared[1] ** 2)
prior = jnp.log(interim_prior(e_sheared_mag))

prior = jnp.log(interim_prior(norm(e_sheared)))
likelihood = jnp.sum(jsp.stats.norm.logpdf(e_obs, loc=e_sheared, scale=sigma_m))
return prior + likelihood

Expand Down
3 changes: 2 additions & 1 deletion bpd/prior.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax.numpy as jnp
from jax import Array, random
from jax.numpy.linalg import norm


def ellip_mag_prior(e, sigma: float):
Expand Down Expand Up @@ -136,7 +137,7 @@ def sample_synthetic_sheared_ellips_clipped(
# clip magnitude to < 1
# preserve angle after noise added when clipping
beta = jnp.arctan2(e_obs[:, :, 1], e_obs[:, :, 0]) / 2
e_obs_mag = jnp.sqrt(e_obs[:, :, 0] ** 2 + e_obs[:, :, 1] ** 2)
e_obs_mag = norm(e_obs, axis=-1)
e_obs_mag = jnp.clip(e_obs_mag, 0, e_tol) # otherwise likelihood explodes

final_eobs1 = e_obs_mag * jnp.cos(2 * beta)
Expand Down

0 comments on commit 0409d02

Please sign in to comment.