Skip to content

Commit 73e0ce0

Browse files
split into two processes is easier
1 parent a562322 commit 73e0ce0

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/usr/bin/env python3
2+
"""This file creates toy samples of ellipticities and saves them to .hdf5 file."""
3+
4+
from pathlib import Path
5+
6+
import jax
7+
import jax.numpy as jnp
8+
import typer
9+
10+
from bpd import DATA_DIR
11+
from bpd.io import load_dataset
12+
from bpd.pipelines.shear_inference import pipeline_shear_inference
13+
14+
15+
def _extract_seed(fpath: str) -> int:
16+
name = Path(fpath).name
17+
first = name.find("_")
18+
second = name.find("_", first + 1)
19+
third = name.find(".")
20+
return int(name[second + 1 : third])
21+
22+
23+
def main(
24+
seed: int,
25+
tag: str,
26+
interim_samples_fname: str,
27+
sigma_e_int: float = 3e-2,
28+
initial_step_size: float = 1e-3,
29+
n_samples: int = 3000,
30+
trim: int = 1,
31+
overwrite: bool = False,
32+
):
33+
# directory structure
34+
dirpath = DATA_DIR / "cache_chains" / tag
35+
assert dirpath.exists()
36+
interim_samples_fpath = DATA_DIR / "cache_chains" / tag / interim_samples_fname
37+
assert interim_samples_fpath.exists(), "ellipticity samples file does not exist"
38+
old_seed = _extract_seed(interim_samples_fpath)
39+
fpath = DATA_DIR / "cache_chains" / tag / f"g_samples_{old_seed}_{seed}.npy"
40+
41+
if fpath.exists() and not overwrite:
42+
raise IOError("overwriting...")
43+
44+
samples_dataset = load_dataset(interim_samples_fpath)
45+
e_post = samples_dataset["e_post"][:, ::trim, :]
46+
true_g = samples_dataset["true_g"]
47+
sigma_e = samples_dataset["sigma_e"]
48+
49+
rng_key = jax.random.key(seed)
50+
g_samples = pipeline_shear_inference(
51+
rng_key,
52+
e_post,
53+
true_g=true_g,
54+
sigma_e=sigma_e,
55+
sigma_e_int=sigma_e_int,
56+
initial_step_size=initial_step_size,
57+
n_samples=n_samples,
58+
)
59+
60+
jnp.save(fpath, g_samples)
61+
62+
63+
if __name__ == "__main__":
64+
typer.run(main)

0 commit comments

Comments
 (0)