diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2cdc6cf..cc8ee22 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - name: Checkout github repo @@ -26,12 +26,15 @@ jobs: cache: 'pip' - name: Install dependencies - run: | + run: | python -m pip install -U pip python -m pip install . python -m pip install .[dev] python -m pip install git+https://github.com/GalSim-developers/JAX-GalSim.git + - name: Run Ruff + run: ruff check --output-format=github . + - name: Run Tests run: | pytest --durations=0 diff --git a/bpd/diagnostics.py b/bpd/diagnostics.py index ef5b418..8231c4b 100644 --- a/bpd/diagnostics.py +++ b/bpd/diagnostics.py @@ -16,7 +16,7 @@ def get_contour_plot( kde: bool | int | float = False, ) -> Figure: c = ChainConsumer() - for name, samples in zip(names, samples_list): + for name, samples in zip(names, samples_list, strict=False): df = pd.DataFrame.from_dict(samples) c.add_chain(Chain(samples=df, name=name, kde=kde)) c.add_truth(Truth(location=truth)) @@ -24,7 +24,7 @@ def get_contour_plot( def get_gauss_pc_fig( - ax: Axes, samples: np.ndarray, truth: float, param_name: str = None + ax: Axes, samples: np.ndarray, truth: float, param_name: str | None = None ) -> None: """Get a marginal pc figure assuming Gaussian distribution of samples.""" assert samples.ndim == 2 # (n_chains, n_samples) @@ -49,7 +49,7 @@ def get_gauss_pc_fig( def get_pc_fig( - ax: Axes, samples: np.ndarray, truth: float, param_name: str = None + ax: Axes, samples: np.ndarray, truth: float, param_name: str | None = None ) -> None: """Get a marginal probability calibration figure using `hpdi` from `numpyro`.""" assert samples.ndim == 2 # (n_chains, n_samples) diff --git a/bpd/initialization.py b/bpd/initialization.py index 1f406bd..e15a050 100644 --- a/bpd/initialization.py +++ b/bpd/initialization.py @@ -26,7 +26,7 @@ def init_with_ball( """Sample ball given offset of each parameter.""" new = {} keys = random.split(rng_key, len(true_params.keys())) - rng_key_dict = {p: k for p, k in zip(true_params, keys)} + rng_key_dict = {p: k for p, k in zip(true_params, keys, strict=False)} for p, centr in true_params.items(): offset = offset_dict[p] diff --git a/bpd/io.py b/bpd/io.py index 3d59190..aeca783 100644 --- a/bpd/io.py +++ b/bpd/io.py @@ -11,7 +11,6 @@ def save_dataset( ds: dict[str, Array], fpath: str | Path, overwrite: bool = False ) -> None: - if Path(fpath).exists() and not overwrite: raise IOError("overwriting existing ds") assert Path(fpath).suffix == ".npz" @@ -22,6 +21,7 @@ def save_dataset( def load_dataset(fpath: str) -> dict[str, Array]: assert Path(fpath).exists(), "file path does not exists" assert Path(fpath).suffix == ".npz" + ds = {} npzfile = jnp.load(fpath) diff --git a/bpd/likelihood.py b/bpd/likelihood.py index b778449..c1f55f6 100644 --- a/bpd/likelihood.py +++ b/bpd/likelihood.py @@ -17,7 +17,7 @@ def shear_loglikelihood_unreduced( # the priors are callables for now on only ellipticities # the interim_prior should have been used when obtaining e_obs from the chain (i.e. for now same sigma) # normalizatoin in priors can be ignored for now as alpha is fixed. - N, K, _ = e_post.shape + _, K, _ = e_post.shape # (N, K, 2) e_post_mag = jnp.sqrt(e_post[..., 0] ** 2 + e_post[..., 1] ** 2) denom = interim_prior(e_post_mag) # (N, K), can ignore angle in prior as uniform diff --git a/bpd/pipelines/image_ellips.py b/bpd/pipelines/image_ellips.py index 7868c70..8492195 100644 --- a/bpd/pipelines/image_ellips.py +++ b/bpd/pipelines/image_ellips.py @@ -3,16 +3,14 @@ import blackjax import jax.numpy as jnp -from jax import Array +from jax import Array, random from jax import jit as jjit -from jax import random from jax._src.prng import PRNGKeyArray from jax.scipy import stats from bpd.chains import inference_loop from bpd.draw import draw_gaussian, draw_gaussian_galsim from bpd.noise import add_noise -from bpd.pipelines.toy_ellips import do_inference from bpd.prior import ellip_mag_prior, sample_ellip_prior @@ -122,7 +120,6 @@ def do_inference( target_acceptance_rate: float = 0.80, n_samples: int = 100, ): - key1, key2 = random.split(rng_key) _logdensity = partial(logtarget_fnc, data=data) @@ -164,7 +161,6 @@ def pipeline_image_interim_samples( background: float = 1.0, fft_size: int = 256, ): - k1, k2 = random.split(rng_key) init_position = initialization_fnc(k1, true_params=true_params, data=target_image) diff --git a/bpd/pipelines/shear_inference.py b/bpd/pipelines/shear_inference.py index 3a5104c..c4b39bd 100644 --- a/bpd/pipelines/shear_inference.py +++ b/bpd/pipelines/shear_inference.py @@ -2,9 +2,8 @@ from typing import Callable import blackjax -from jax import Array +from jax import Array, random from jax import jit as jjit -from jax import random from jax._src.prng import PRNGKeyArray from jax.scipy import stats diff --git a/bpd/pipelines/toy_ellips.py b/bpd/pipelines/toy_ellips.py index 826cc5b..cbf8fb0 100644 --- a/bpd/pipelines/toy_ellips.py +++ b/bpd/pipelines/toy_ellips.py @@ -4,9 +4,8 @@ import blackjax import jax.numpy as jnp import jax.scipy as jsp -from jax import Array +from jax import Array, random, vmap from jax import jit as jjit -from jax import random, vmap from jax._src.prng import PRNGKeyArray from bpd.chains import inference_loop @@ -73,7 +72,6 @@ def pipeline_toy_ellips_samples( k: int, n_warmup_steps: int = 500, ): - k1, k2 = random.split(key) true_g = jnp.array([g1, g2]) diff --git a/pyproject.toml b/pyproject.toml index 231b227..79cc4d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,7 @@ requires = ["setuptools>=45"] build-backend = "setuptools.build_meta" + [project] name = "BPD" authors = [{ name = "Ismael Mendoza" }] @@ -9,35 +10,81 @@ description = "Bayesian Pixel Domain method for shear inference." version = "0.0.1" license = { file = "LICENSE" } readme = "README.md" -dependencies = [ - "numpy >=1.18.0", - "galsim >=2.3.0", - "jax >=0.4.30", - "jaxlib", - "blackjax >=1.2.0", - "numpyro >=0.15.0", -] +dependencies = ["numpy >=1.18.0", "galsim >=2.3.0", "jax >=0.4.30", "jaxlib", "blackjax >=1.2.0"] + [project.optional-dependencies] -dev = ["pytest", "click"] +dev = ["pytest", "click", "ruff", "ChainConsumer"] + [project.urls] home = "https://github.com/LSSTDESC/BPD" + [tool.setuptools.packages.find] -include = ["bpd*", "scripts*"] +include = ["bpd*"] + +[tool.ruff] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 88 +indent-width = 4 + +# Assume Python 3.8 +target-version = "py310" + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" + +exclude = ["*.ipynb"] -[tool.flake8] -max-line-length = 88 -ignore = ["C901", "E203", "W503"] -per-file-ignores = ["__init__.py:F401"] +[tool.ruff.lint] +select = ["E", "F", "I", "W", "B", "SIM", "PLE", "PLC", "PLW", "RUF"] +ignore = ["C901", "E203", "E501", "E731", "PLC0206", "RUF027"] +preview = true +exclude = ["*.ipynb", "scripts/image_shear_gpu_one_galaxy.py", "scripts/benchmarks/*.py"] -[tool.isort] -profile = "black" -multi_line_output = 3 -include_trailing_comma = true -use_parentheses = true -ensure_newline_before_comments = true -line_length = 88 +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-rav" +filterwarnings = ["ignore::DeprecationWarning:tensorflow.*"] diff --git a/scripts/benchmarks/benchmark1.py b/scripts/benchmarks/benchmark1.py index 805c05b..08c0f4d 100755 --- a/scripts/benchmarks/benchmark1.py +++ b/scripts/benchmarks/benchmark1.py @@ -82,7 +82,7 @@ def sample_ball(rng_key, center_params: dict): new = {} keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys)} + rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} for p in center_params: centr = center_params[p] if p == "f": @@ -128,7 +128,6 @@ def draw_gal(f, hlr, g1, g2, x, y): def _logprob_fn(params, data): - # prior prior = jnp.array(0.0, device=GPU) for p in ("f", "hlr", "g1", "g2"): # uniform priors @@ -148,7 +147,7 @@ def _logprob_fn(params, data): def _log_setup(snr: float): - with open(LOG_FILE, "a") as f: + with open(LOG_FILE, "a", encoding="utf-8") as f: print(file=f) print( f"""Running benchmark 1 with configuration as follows. Variable number of chains. @@ -186,7 +185,6 @@ def _log_setup(snr: float): # vmap only rng_key def do_warmup(rng_key, init_position: dict, data): - _logdensity = partial(_logprob_fn, data=data) warmup = blackjax.window_adaptation( @@ -258,7 +256,6 @@ def main(): _init_positions = {p: q[:n_chains] for p, q in all_init_positions.items()} if ii == 0: - # compilation times t1 = time.time() (_sts, _tp), _ = jax.block_until_ready( @@ -310,7 +307,7 @@ def main(): jnp.save(filepath, results) _log_setup(snr) - with open(LOG_FILE, "a") as f: + with open(LOG_FILE, "a", encoding="utf-8") as f: print(file=f) print(f"results were saved to {filepath}", file=f) diff --git a/scripts/benchmarks/benchmark2.py b/scripts/benchmarks/benchmark2.py index 88e2a2d..2b7b761 100755 --- a/scripts/benchmarks/benchmark2.py +++ b/scripts/benchmarks/benchmark2.py @@ -84,7 +84,7 @@ def sample_ball(rng_key, center_params: dict): new = {} keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys)} + rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} for p in center_params: centr = center_params[p] if p == "f": @@ -130,7 +130,6 @@ def draw_gal(f, hlr, g1, g2, x, y): def _logprob_fn(params, data): - # prior prior = jnp.array(0.0, device=GPU) for p in ("f", "hlr", "g1", "g2"): # uniform priors @@ -150,15 +149,15 @@ def _logprob_fn(params, data): def _log_setup(snr: float): - with open(LOG_FILE, "a") as f: + with open(LOG_FILE, "a", encoding='utf-8') as f: print(file=f) print( f"""Running benchmark 2 with configuration as follows. Variable number of chains. - + The sampler used is NUTS with standard warmup. TAG: {TAG} - SEED: {SEED} + SEED: {SEED} Overall sampler configuration (fixed): max doublings: {MAX_DOUBLINGS} @@ -188,7 +187,6 @@ def _log_setup(snr: float): # vmap only rng_key def do_warmup(rng_key, init_position: dict, data): - _logdensity = partial(_logprob_fn, data=data) warmup = blackjax.window_adaptation( @@ -271,7 +269,6 @@ def main(): _data_ii = data_gpu[:n_obj] if ii == 0: - # compilation times t1 = time.time() (_sts, _tp), _ = jax.block_until_ready( @@ -323,7 +320,7 @@ def main(): jnp.save(filepath, results) _log_setup(snr) - with open(LOG_FILE, "a") as f: + with open(LOG_FILE, "a", encoding='utf-8') as f: print(file=f) print(f"results were saved to {filepath}", file=f) diff --git a/scripts/benchmarks/benchmark2_7.py b/scripts/benchmarks/benchmark2_7.py index d4b3854..893aa68 100755 --- a/scripts/benchmarks/benchmark2_7.py +++ b/scripts/benchmarks/benchmark2_7.py @@ -79,7 +79,7 @@ def sample_ball(rng_key, center_params: dict): new = {} keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys)} + rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} for p in center_params: centr = center_params[p] if p == "f": @@ -125,7 +125,6 @@ def draw_gal(f, hlr, g1, g2, x, y): def _logprob_fn(params, data): - # prior prior = jnp.array(0.0, device=GPU) for p in ("f", "hlr", "g1", "g2"): # uniform priors @@ -147,7 +146,7 @@ def _logprob_fn(params, data): def _log_setup(snr: float): - with open(LOG_FILE, "a") as f: + with open(LOG_FILE, "a", encoding="utf-8") as f: print(file=f) print( f"""Running benchmark 2.7 with configuration as follows @@ -185,7 +184,6 @@ def _log_setup(snr: float): # vmap only rng_key def do_warmup(rng_key, init_position: dict, data): - _logdensity = partial(_logprob_fn, data=data) warmup = blackjax.window_adaptation( @@ -304,7 +302,7 @@ def main(): filepath = SCRATCH_DIR.joinpath(filename) jnp.save(filepath, results) - with open(LOG_FILE, "a") as f: + with open(LOG_FILE, "a", encoding="utf-8") as f: print(file=f) print(f"results were saved to {filepath}", file=f) diff --git a/scripts/benchmarks/benchmark2_72.py b/scripts/benchmarks/benchmark2_72.py index 8dded5f..367e8bc 100755 --- a/scripts/benchmarks/benchmark2_72.py +++ b/scripts/benchmarks/benchmark2_72.py @@ -73,7 +73,7 @@ def sample_ball(rng_key, center_params: dict): new = {} keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys)} + rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} for p in center_params: centr = center_params[p] if p == "f": @@ -119,7 +119,6 @@ def draw_gal(f, hlr, g1, g2, x, y): def _logprob_fn(params, data): - # prior prior = jnp.array(0.0, device=GPU) for p in ("f", "hlr", "g1", "g2"): # uniform priors @@ -147,7 +146,6 @@ def _logprob_fn(params, data): # vmap only rng_key def do_warmup(rng_key, init_position: dict, data): - _logdensity = partial(_logprob_fn, data=data) warmup = blackjax.window_adaptation( diff --git a/scripts/benchmarks/benchmark2_8.py b/scripts/benchmarks/benchmark2_8.py index ca4d725..2a1c5eb 100755 --- a/scripts/benchmarks/benchmark2_8.py +++ b/scripts/benchmarks/benchmark2_8.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -"""Here we run multiple chains each on one galaxy, same noise realization. +"""Here we run multiple chains each on one galaxy, same noise realization. To get a sense of efficiency. @@ -73,7 +73,7 @@ def sample_ball(rng_key, center_params: dict): new = {} keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys)} + rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} for p in center_params: centr = center_params[p] if p == "f": @@ -119,7 +119,6 @@ def draw_gal(f, hlr, g1, g2, x, y): def _logprob_fn(params, data): - # prior prior = jnp.array(0.0, device=GPU) for p in ("f", "hlr", "g1", "g2"): # uniform priors @@ -150,7 +149,7 @@ def _logprob_fn(params, data): def _log_setup(snr: float): - with open(LOG_FILE, "a") as f: + with open(LOG_FILE, "a", encoding="utf-8") as f: print(file=f) print( f"""Running benchmark 2.8 with configuration as follows @@ -184,7 +183,6 @@ def _log_setup(snr: float): # vmap only rng_key def do_warmup(rng_key, init_position: dict, data): - _logdensity = partial(_logprob_fn, data=data) warmup = blackjax.window_adaptation( @@ -301,7 +299,7 @@ def main(): filepath = SCRATCH_DIR.joinpath(filename) jnp.save(filepath, results) - with open(LOG_FILE, "a") as f: + with open(LOG_FILE, "a", encoding="utf-8") as f: print(file=f) print(f"results were saved to {filepath}", file=f) diff --git a/scripts/benchmarks/benchmark_chees1.py b/scripts/benchmarks/benchmark_chees1.py index 61546a3..601fddd 100755 --- a/scripts/benchmarks/benchmark_chees1.py +++ b/scripts/benchmarks/benchmark_chees1.py @@ -19,7 +19,6 @@ import jax_galsim as xgalsim import numpy as np import optax -from jax import jit as jjit from jax import random, vmap from jax.scipy import stats @@ -87,7 +86,7 @@ def sample_ball(rng_key, center_params: dict): new = {} keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys)} + rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} for p in center_params: centr = center_params[p] if p == "f": @@ -133,7 +132,6 @@ def draw_gal(f, hlr, g1, g2, x, y): def _logprob_fn(params, data): - # prior prior = jnp.array(0.0, device=GPU) for p in ("f", "hlr", "g1", "g2"): # uniform priors @@ -153,15 +151,15 @@ def _logprob_fn(params, data): def _log_setup(snr: float): - with open(LOG_FILE, "a") as f: + with open(LOG_FILE, "a", encoding="utf-8") as f: print(file=f) print( f"""Running benchmark chees 1 with configuration as follows. Variable number of chains. - + The sampler used is NUTS with standard warmup. TAG: {TAG} - SEED: {SEED} + SEED: {SEED} Overall sampler configuration (fixed): n_samples: {N_SAMPLES} @@ -183,7 +181,7 @@ def _log_setup(snr: float): other parameters: slen: {SLEN} psf_hlr: {PSF_HLR} - background: {BACKGROUND} + background: {BACKGROUND} snr: {snr} """, file=f, @@ -292,7 +290,7 @@ def main(): jnp.save(filepath, results) _log_setup(snr) - with open(LOG_FILE, "a") as f: + with open(LOG_FILE, "a", encoding="utf-8") as f: print(file=f) print(f"results were saved to {filepath}", file=f) diff --git a/scripts/get_shear_from_post_ellips.py b/scripts/get_shear_from_post_ellips.py index 27c384d..5d5dcd2 100755 --- a/scripts/get_shear_from_post_ellips.py +++ b/scripts/get_shear_from_post_ellips.py @@ -35,7 +35,6 @@ def main( trim: int, overwrite: bool, ): - # directory structure dirpath = DATA_DIR / "cache_chains" / tag assert dirpath.exists() @@ -44,9 +43,8 @@ def main( old_seed = _extract_seed(e_samples_fname) fpath = DATA_DIR / "cache_chains" / tag / f"g_samples_{old_seed}_{seed}.npy" - if fpath.exists(): - if not overwrite: - raise IOError("overwriting...") + if fpath.exists() and not overwrite: + raise IOError("overwriting...") samples_dataset = load_dataset(e_samples_fpath) e_post = samples_dataset["e_post"][:, ::trim, :] diff --git a/scripts/image_shear_gpu_one_galaxy.py b/scripts/image_shear_gpu_one_galaxy.py index 2376acd..0444337 100644 --- a/scripts/image_shear_gpu_one_galaxy.py +++ b/scripts/image_shear_gpu_one_galaxy.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 from functools import partial -from math import ceil import click import jax.numpy as jnp @@ -103,7 +102,6 @@ def main( slen=slen, pixel_scale=pixel_scale, fft_size=fft_size, - background=background, ) vpipe1 = vmap(jjit(pipe1), (0, None, 0)) @@ -125,7 +123,6 @@ def main( def main(): - pipe1 = partial( pipeline_toy_ellips_samples, g1=g1, diff --git a/scripts/mp_toy_shear_cpu.py b/scripts/mp_toy_shear_cpu.py index 101fffb..05b3392 100755 --- a/scripts/mp_toy_shear_cpu.py +++ b/scripts/mp_toy_shear_cpu.py @@ -66,7 +66,7 @@ def main( sigma_e=shape_noise, n_samples=n_samples_shear, ) - g_samples_list = pool.starmap(task2, zip(seeds, e_post_list)) + g_samples_list = pool.starmap(task2, zip(seeds, e_post_list, strict=False)) print("INFO: Shear samples obtained") print("INFO: Saving shear samples to disk...") diff --git a/scripts/slurm/slurm_job.py b/scripts/slurm/slurm_job.py index 967d693..c35b7ba 100755 --- a/scripts/slurm/slurm_job.py +++ b/scripts/slurm/slurm_job.py @@ -27,7 +27,7 @@ def setup_sbatch_job_gpu( job_dir = Path(JOB_DIR) jobfile = job_dir.joinpath(jobfile_name) - with open(jobfile, "w") as f: + with open(jobfile, "w", encoding="utf-8") as f: f.writelines( "#!/bin/bash\n" f"#SBATCH --job-name={jobname}\n" diff --git a/scripts/slurm/slurm_toy_shear_gpu.py b/scripts/slurm/slurm_toy_shear_gpu.py index c5148a9..3a386a4 100755 --- a/scripts/slurm/slurm_toy_shear_gpu.py +++ b/scripts/slurm/slurm_toy_shear_gpu.py @@ -44,7 +44,6 @@ def main( add_extra: bool, qos: str, ): - tagpath = DATA_DIR / "cache_chains" / jobname if not add_extra: assert not tagpath.exists() @@ -66,7 +65,7 @@ def main( ) # append to jobfile the commands. - with open(jobfile, "a") as f: + with open(jobfile, "a", encoding="utf-8") as f: f.write("\n") for ii in range(4): @@ -74,13 +73,13 @@ def main( cmd = base_cmd.format(seed=cmd_seed) srun_cmd = f"srun --exact -u -n 1 -c 1 --gpus-per-task 1 --mem-per-gpu={mem_per_gpu} {cmd} &\n" - with open(jobfile, "a") as f: + with open(jobfile, "a", encoding="utf-8") as f: f.write(srun_cmd) - with open(jobfile, "a") as f: + with open(jobfile, "a", encoding="utf-8") as f: f.write("\nwait") - subprocess.run(f"sbatch {jobfile.as_posix()}", shell=True) + subprocess.run(f"sbatch {jobfile.as_posix()}", shell=True, check=False) if __name__ == "__main__": diff --git a/tests/test_shear_inference.py b/tests/test_shear_inference.py index bb5f497..1c9786f 100644 --- a/tests/test_shear_inference.py +++ b/tests/test_shear_inference.py @@ -4,13 +4,12 @@ import pytest from jax import random -from scripts.get_shear_from_post_ellips import pipeline_shear_inference -from scripts.get_toy_ellip_samples import pipeline_toy_ellips_samples +from bpd.pipelines.shear_inference import pipeline_shear_inference +from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples @pytest.mark.parametrize("seed", [1234, 4567]) def test_shear_inference_toy_ellipticities(seed): - key = random.key(seed) k1, k2 = random.split(key) diff --git a/tests/test_shear_trans.py b/tests/test_shear_trans.py index b0641fe..93726f3 100644 --- a/tests/test_shear_trans.py +++ b/tests/test_shear_trans.py @@ -14,7 +14,6 @@ def test_scalar_inverse(): - # scalar version ellips = (0.0, 0.1, 0.2, -0.1, -0.2) shears = (0.0, -0.01, 0.01, -0.02, 0.02)