Skip to content

Commit

Permalink
Add convergence test interim samples from image (#36)
Browse files Browse the repository at this point in the history
* function to simplify later steps

* refactor, still not done until later pr

* test draft

* less stuff to run

* bug fix

* separate out slow and quick tests

* add slow marker

* flag I alwasy wante

* register mark
  • Loading branch information
ismael-mendoza authored Nov 5, 2024
1 parent 94a0c89 commit 38903ee
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 17 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ jobs:
- name: Run Ruff
run: ruff check --output-format=github .

- name: Run Tests
- name: Run fast tests
run: |
pytest --durations=0
pytest -m "not slow" --durations=0
- name: Run slow tests
run: |
pytest -m "slow" --durations=0
16 changes: 14 additions & 2 deletions bpd/pipelines/image_ellips.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from bpd.chains import run_inference_nuts
from bpd.draw import draw_gaussian, draw_gaussian_galsim
from bpd.noise import add_noise
from bpd.prior import ellip_mag_prior, sample_ellip_prior
from bpd.prior import ellip_mag_prior, sample_ellip_prior, scalar_shear_transformation


def get_target_galaxy_params_simple(
Expand All @@ -37,6 +37,18 @@ def get_target_galaxy_params_simple(
}


def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]):
true_params = {**galaxy_params}
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
g1, g2 = true_params.pop("g1"), true_params.pop("g2")

e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2))
true_params["e1"] = e1_prime
true_params["e2"] = e2_prime

return true_params # don't add g1,g2 back as we are not inferring those


def get_target_images_single(
rng_key: PRNGKeyArray,
n_samples: int,
Expand Down Expand Up @@ -112,7 +124,7 @@ def pipeline_image_interim_samples_one_galaxy(
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = False,
is_mass_matrix_diagonal: bool = True,
slen: int = 53,
fft_size: int = 256,
background: float = 1.0,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,6 @@ exclude = ["*.ipynb", "scripts/one_galaxy_shear.py", "scripts/benchmarks/*.py"]

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra"
addopts = "-ra -v --strict-markers"
filterwarnings = ["ignore::DeprecationWarning:tensorflow.*"]
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
12 changes: 2 additions & 10 deletions scripts/one_galaxy_shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from bpd.pipelines.image_ellips import (
get_target_galaxy_params_simple,
get_target_images_single,
get_true_params_from_galaxy_params,
pipeline_image_interim_samples,
)
from bpd.pipelines.shear_inference import pipeline_shear_inference
from bpd.prior import scalar_shear_transformation

init_fnc = init_with_truth

Expand Down Expand Up @@ -53,19 +53,11 @@ def main(
nkey,
n_samples=n_gals,
single_galaxy_params=galaxy_params,
psf_hlr=psf_hlr,
background=background,
slen=slen,
pixel_scale=pixel_scale,
)

true_params = {**galaxy_params}
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
g1, g2 = true_params.pop("g1"), true_params.pop("g2")

e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2))
true_params["e1"] = e1_prime
true_params["e2"] = e2_prime
true_params = get_true_params_from_galaxy_params(galaxy_params)

# prepare pipelines
pipe1 = partial(
Expand Down
71 changes: 69 additions & 2 deletions tests/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@
from jax import random, vmap

from bpd.chains import run_inference_nuts
from bpd.initialization import init_with_truth
from bpd.pipelines.image_ellips import (
get_target_galaxy_params_simple,
get_target_images_single,
get_true_params_from_galaxy_params,
pipeline_image_interim_samples_one_galaxy,
)
from bpd.pipelines.shear_inference import pipeline_shear_inference
from bpd.pipelines.toy_ellips import logtarget as logtarget_toy_ellips
from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples
from bpd.prior import ellip_mag_prior, sample_synthetic_sheared_ellips_unclipped


@pytest.mark.parametrize("seed", [1234, 4567])
def test_interim_ellipticity_posterior_convergence(seed):
def test_interim_toy_convergence(seed):
"""Check efficiency and convergence of chains for 100 galaxies."""
g1, g2 = 0.02, 0.0
sigma_m = 1e-4
Expand Down Expand Up @@ -74,7 +81,7 @@ def test_interim_ellipticity_posterior_convergence(seed):


@pytest.mark.parametrize("seed", [1234, 4567])
def test_shear_posterior_convergence(seed):
def test_toy_shear_convergence(seed):
g1, g2 = 0.02, 0.0
sigma_m = 1e-4
sigma_e = 1e-3
Expand Down Expand Up @@ -124,3 +131,63 @@ def test_shear_posterior_convergence(seed):

assert ess > 0.5 * 4000
assert jnp.abs(rhat - 1) < 0.01


@pytest.mark.slow
@pytest.mark.parametrize("seed", [1234, 4567])
def test_low_noise_single_galaxy_interim_samples(seed):
lf = 6.0
hlr = 1.0
g1, g2 = 0.02, 0.0
sigma_e = 1e-3
sigma_e_int = 3e-2
n_samples = 500
background = 1.0
slen = 53
fft_size = 256
init_fnc = init_with_truth

rng_key = random.key(seed)
pkey, nkey, gkey = random.split(rng_key, 3)

galaxy_params = get_target_galaxy_params_simple(
pkey, lf=lf, g1=g1, g2=g2, hlr=hlr, shape_noise=sigma_e
)

draw_params = {**galaxy_params}
draw_params["f"] = 10 ** draw_params.pop("lf")
target_image = get_target_images_single(
nkey,
n_samples=1,
single_galaxy_params=draw_params,
background=background,
slen=slen,
)[0]
true_params = get_true_params_from_galaxy_params(galaxy_params)

pipe1 = partial(
pipeline_image_interim_samples_one_galaxy,
initialization_fnc=init_fnc,
sigma_e_int=sigma_e_int,
n_samples=n_samples,
slen=slen,
fft_size=fft_size,
n_warmup_steps=300,
)
vpipe1 = vmap(jjit(pipe1), (0, 0, None))

# chain initialization
# one galaxy, test convergence, so 4 random seeds
keys = random.split(gkey, 4)
init_positions = vmap(init_fnc, (0, None))(keys, true_params)

samples = vpipe1(keys, init_positions, target_image)

# check each component
for _, v in samples.items():
assert v.shape == (4, n_samples)
ess = effective_sample_size(v)
rhat = potential_scale_reduction(v)

assert ess > 0.5 * n_samples
assert jnp.abs(rhat - 1) < 0.01

0 comments on commit 38903ee

Please sign in to comment.