Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor defaults and cli #35

Merged
merged 28 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
92b576a
avoid repeated nuts code throughout codebase
ismael-mendoza Nov 4, 2024
55e1a0a
refactor and use new function
ismael-mendoza Nov 4, 2024
caeefeb
sigma seems off by sqrt(2)
ismael-mendoza Nov 4, 2024
4656c43
rename
ismael-mendoza Nov 4, 2024
1ab0594
being careful with defaults
ismael-mendoza Nov 4, 2024
29abcda
being careful with defaults and renaming
ismael-mendoza Nov 4, 2024
194432e
renaming
ismael-mendoza Nov 4, 2024
5d2a8f1
reorder
ismael-mendoza Nov 4, 2024
be03c71
avoid certain defaults for now
ismael-mendoza Nov 4, 2024
4475bec
we dont need these ones anymore
ismael-mendoza Nov 4, 2024
1db59ba
need to remove prior and move to target
ismael-mendoza Nov 4, 2024
534117e
fix test after refactoring
ismael-mendoza Nov 4, 2024
e0803fa
use typer
ismael-mendoza Nov 4, 2024
d9dcba6
typo
ismael-mendoza Nov 4, 2024
50016fd
typo
ismael-mendoza Nov 4, 2024
a47cae3
fix tests
ismael-mendoza Nov 4, 2024
12eb215
rename and use typer
ismael-mendoza Nov 4, 2024
5890174
rename
ismael-mendoza Nov 4, 2024
1ddfe9d
more judicious
ismael-mendoza Nov 4, 2024
0f62810
step size, no default
ismael-mendoza Nov 4, 2024
8886023
need more arguments now
ismael-mendoza Nov 4, 2024
1c10d0c
rtol change to be appropriate int est
ismael-mendoza Nov 4, 2024
4cb3d75
fix
ismael-mendoza Nov 4, 2024
d27a729
high snr is the default
ismael-mendoza Nov 4, 2024
9bdac2f
add typer, but finish in next PR
ismael-mendoza Nov 4, 2024
68965df
rename
ismael-mendoza Nov 4, 2024
87da98e
rename
ismael-mendoza Nov 4, 2024
da62ed4
fix corersponding slurm script
ismael-mendoza Nov 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions bpd/chains.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from functools import partial
from typing import Callable

import blackjax
import jax
from jax import random
from jax._src.prng import PRNGKeyArray
from jax.typing import ArrayLike


def inference_loop(rng_key, initial_state, kernel, n_samples: int):
Expand All @@ -12,3 +19,36 @@ def one_step(state, rng_key):
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

return (states, infos)


def run_inference_nuts(
rng_key: PRNGKeyArray,
init_positions: ArrayLike,
data: ArrayLike,
*,
logtarget: Callable,
n_samples: int,
initial_step_size: float,
max_num_doublings: int,
n_warmup_steps: int = 500,
target_acceptance_rate: float = 0.80,
is_mass_matrix_diagonal: bool = True,
):
key1, key2 = random.split(rng_key)

_logtarget = partial(logtarget, data=data)

warmup = blackjax.window_adaptation(
blackjax.nuts,
_logtarget,
progress_bar=False,
is_mass_matrix_diagonal=is_mass_matrix_diagonal,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
target_acceptance_rate=target_acceptance_rate,
)

(init_states, tuned_params), _ = warmup.run(key1, init_positions, n_warmup_steps)
kernel = blackjax.nuts(_logtarget, **tuned_params).step
states, _ = inference_loop(key2, init_states, kernel=kernel, n_samples=n_samples)
return states.position
12 changes: 7 additions & 5 deletions bpd/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ def draw_gaussian(
g2: float,
x: float,
y: float,
pixel_scale: float = 0.2,
slen: int = 53,
*,
slen: int,
fft_size: int, # rule of thumb: at least 4 times `slen`
psf_hlr: float = 0.7,
fft_size: int = 256, # rule of thumb, at least 4 times `slen`
pixel_scale: float = 0.2,
):
gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)

Expand All @@ -39,9 +40,10 @@ def draw_gaussian_galsim(
g2: float,
x: float, # pixels
y: float,
pixel_scale: float = 0.2,
slen: int = 53,
*,
slen: int,
psf_hlr: float = 0.7,
pixel_scale: float = 0.2,
):
gal = galsim.Gaussian(flux=f, half_light_radius=hlr)
gal = gal.shear(g1=e1, g2=e2)
Expand Down
79 changes: 14 additions & 65 deletions bpd/pipelines/image_ellips.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from functools import partial
from typing import Callable

import blackjax
import jax.numpy as jnp
from jax import Array, random
from jax import jit as jjit
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

from bpd.chains import inference_loop
from bpd.chains import run_inference_nuts
from bpd.draw import draw_gaussian, draw_gaussian_galsim
from bpd.noise import add_noise
from bpd.prior import ellip_mag_prior, sample_ellip_prior
Expand All @@ -17,7 +16,7 @@
def get_target_galaxy_params_simple(
rng_key: PRNGKeyArray,
shape_noise: float = 1e-3,
lf: float = 3.0,
lf: float = 6.0,
hlr: float = 1.0,
x: float = 0.0, # pixels
y: float = 0.0,
Expand All @@ -42,29 +41,24 @@ def get_target_images_single(
rng_key: PRNGKeyArray,
n_samples: int,
single_galaxy_params: dict[str, float],
psf_hlr: float = 0.7,
background: float = 1.0,
slen: int = 53,
pixel_scale: float = 0.2,
*,
background: float,
slen: int,
):
"""In this case, we sample multiple noise realizations of the same galaxy."""
assert "f" in single_galaxy_params and "lf" not in single_galaxy_params

noiseless = draw_gaussian_galsim(
**single_galaxy_params,
pixel_scale=pixel_scale,
psf_hlr=psf_hlr,
slen=slen,
)
noiseless = draw_gaussian_galsim(**single_galaxy_params, slen=slen)
return add_noise(rng_key, noiseless, bg=background, n=n_samples), noiseless


# interim prior
def logprior(
params: dict[str, Array],
*,
sigma_e: float,
flux_bds: tuple = (-1.0, 9.0),
hlr_bds: tuple = (0.01, 5.0),
sigma_e: float = 3e-2,
sigma_x: float = 1.0, # pixels
) -> Array:
prior = jnp.array(0.0)
Expand Down Expand Up @@ -107,71 +101,27 @@ def logtarget(
return logprior_fnc(params) + loglikelihood_fnc(params, data)


def do_inference(
rng_key: PRNGKeyArray,
init_positions: dict[str, Array],
data: Array,
*,
logtarget_fnc: Callable,
is_mass_matrix_diagonal: bool = False,
n_warmup_steps: int = 500,
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
target_acceptance_rate: float = 0.80,
n_samples: int = 100,
):
key1, key2 = random.split(rng_key)

_logdensity = partial(logtarget_fnc, data=data)

warmup = blackjax.window_adaptation(
blackjax.nuts,
_logdensity,
progress_bar=False,
is_mass_matrix_diagonal=is_mass_matrix_diagonal,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
target_acceptance_rate=target_acceptance_rate,
)

(init_states, tuned_params), _ = warmup.run(key1, init_positions, n_warmup_steps)

kernel = blackjax.nuts(_logdensity, **tuned_params).step
states, _ = inference_loop(key2, init_states, kernel=kernel, n_samples=n_samples)

return states.position


def pipeline_image_interim_samples(
def pipeline_image_interim_samples_one_galaxy(
rng_key: PRNGKeyArray,
true_params: dict[str, float],
target_image: Array,
*,
initialization_fnc: Callable,
sigma_e_int: float = 3e-2,
sigma_e_int: float,
n_samples: int = 100,
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
target_acceptance_rate: float = 0.80,
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = False,
slen: int = 53,
pixel_scale: float = 0.2,
psf_hlr: float = 0.7,
background: float = 1.0,
fft_size: int = 256,
background: float = 1.0,
):
k1, k2 = random.split(rng_key)

init_position = initialization_fnc(k1, true_params=true_params, data=target_image)

_draw_fnc = partial(
draw_gaussian,
pixel_scale=pixel_scale,
slen=slen,
psf_hlr=psf_hlr,
fft_size=fft_size,
)
_draw_fnc = partial(draw_gaussian, slen=slen, fft_size=fft_size)
_loglikelihood = partial(loglikelihood, draw_fnc=_draw_fnc, background=background)
_logprior = partial(logprior, sigma_e=sigma_e_int)

Expand All @@ -180,13 +130,12 @@ def pipeline_image_interim_samples(
)

_inference_fnc = partial(
do_inference,
logtarget_fnc=_logtarget,
run_inference_nuts,
logtarget=_logtarget,
is_mass_matrix_diagonal=is_mass_matrix_diagonal,
n_warmup_steps=n_warmup_steps,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
target_acceptance_rate=target_acceptance_rate,
n_samples=n_samples,
)
_run_inference = jjit(_inference_fnc)
Expand Down
43 changes: 12 additions & 31 deletions bpd/pipelines/shear_inference.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,34 @@
from functools import partial
from typing import Callable

import blackjax
from jax import Array, random
from jax import Array
from jax import jit as jjit
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

from bpd.chains import inference_loop
from bpd.chains import run_inference_nuts
from bpd.likelihood import shear_loglikelihood
from bpd.prior import ellip_mag_prior


def logtarget_density(g: Array, e_post: Array, loglikelihood: Callable):
def logtarget_density(g: Array, *, data: Array, loglikelihood: Callable):
e_post = data # comptability with `do_inference_nuts`
loglike = loglikelihood(g, e_post)
logprior = stats.uniform.logpdf(g, -0.1, 0.2).sum()
return logprior + loglike


def do_inference(
rng_key: PRNGKeyArray,
init_g: Array,
logtarget: Callable,
n_samples: int,
n_warmup_steps: int = 500,
):
key1, key2 = random.split(rng_key)

warmup = blackjax.window_adaptation(
blackjax.nuts,
logtarget,
progress_bar=False,
is_mass_matrix_diagonal=True,
max_num_doublings=2,
initial_step_size=1e-2,
target_acceptance_rate=0.80,
)

(init_states, tuned_params), _ = warmup.run(key1, init_g, n_warmup_steps)
kernel = blackjax.nuts(logtarget, **tuned_params).step
states, _ = inference_loop(key2, init_states, kernel=kernel, n_samples=n_samples)
return states.position


def pipeline_shear_inference(
rng_key: PRNGKeyArray,
e_post: Array,
*,
true_g: Array,
sigma_e: float,
sigma_e_int: float,
n_samples: int,
initial_step_size: float,
n_warmup_steps: int = 500,
max_num_doublings: int = 2,
):
prior = partial(ellip_mag_prior, sigma=sigma_e)
interim_prior = partial(ellip_mag_prior, sigma=sigma_e_int)
Expand All @@ -59,13 +37,16 @@ def pipeline_shear_inference(
_loglikelihood = jjit(
partial(shear_loglikelihood, prior=prior, interim_prior=interim_prior)
)
_logtarget = partial(logtarget_density, loglikelihood=_loglikelihood, e_post=e_post)
_logtarget = partial(logtarget_density, loglikelihood=_loglikelihood)

_do_inference = partial(
do_inference,
run_inference_nuts,
data=e_post,
logtarget=_logtarget,
n_samples=n_samples,
n_warmup_steps=n_warmup_steps,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
)

g_samples = _do_inference(rng_key, true_g)
Expand Down
Loading