From 16f4c8e21f6861cf2d3d4d5e946edae887dd9eb3 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 12:58:00 -0800 Subject: [PATCH] draft ready for testing --- scripts/one_galaxy_image_interim_samples.py | 22 ++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/scripts/one_galaxy_image_interim_samples.py b/scripts/one_galaxy_image_interim_samples.py index 6b82a5e..21a4536 100755 --- a/scripts/one_galaxy_image_interim_samples.py +++ b/scripts/one_galaxy_image_interim_samples.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 from functools import partial +from math import ceil import jax.numpy as jnp import typer @@ -44,8 +45,6 @@ def main( if not dirpath.exists(): dirpath.mkdir(exist_ok=True) - fpath = dirpath / f"e_post_{seed}.npy" - # get images galaxy_params = get_target_galaxy_params_simple( # default hlr, x, y pkey, lf=lf, g1=g1, g2=g2, shape_noise=shape_noise @@ -65,7 +64,7 @@ def main( true_params = get_true_params_from_galaxy_params(galaxy_params) # prepare pipelines - pipe1 = partial( + pipe = partial( pipeline_image_interim_samples_one_galaxy, initialization_fnc=init_fnc, sigma_e_int=sigma_e_int, @@ -75,7 +74,7 @@ def main( fft_size=fft_size, background=background, ) - vpipe1 = vmap(jjit(pipe1), (0, None, 0)) + vpipe = vmap(jjit(pipe), (0, None, 0)) # initialization gkey1, gkey2 = random.split(gkey, 2) @@ -83,11 +82,20 @@ def main( gkeys2 = random.split(gkey2, n_gals) init_positions = vmap(init_fnc, (0, None))(gkeys1, true_params) - galaxy_samples = vpipe1(gkeys2, true_params, target_images) + n_batch = ceil(len(n_gals) / n_vec) + + for ii in range(n_batch): + # slice + start, stop = ii * n_vec, (ii + 1) * n_vec + b_ipositions = {k: v[start:stop] for k, v in init_positions.items()} + bimages = target_images[start:stop] + _keys = gkeys2[start:stop] + _samples = vpipe(_keys, b_ipositions, bimages) - e_post = jnp.stack([galaxy_samples["e1"], galaxy_samples["e2"]], axis=-1) + e_post = jnp.stack([_samples["e1"], _samples["e2"]], axis=-1) + fpath = dirpath / f"e_post_{seed}_{ii}.npy" - jnp.save() + jnp.save(fpath, e_post) if __name__ == "__main__":