Skip to content

Commit

Permalink
Add ruff (#33)
Browse files Browse the repository at this point in the history
* add ruff

* ruff all files

* ruff fix

* add ruff check to tests

* more options for ruff, no notebooks for now

* ruff

* ruff

* ruff

* ruffs

* ruff

* ruff

* we sometimes use {} without f

* some ruff

* space

* mistake

* fix imports

* options

* ignore tensorflow warnings
  • Loading branch information
ismael-mendoza authored Nov 3, 2024
1 parent b049873 commit f843748
Show file tree
Hide file tree
Showing 22 changed files with 114 additions and 93 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12", "3.13"]

steps:
- name: Checkout github repo
Expand All @@ -26,12 +26,15 @@ jobs:
cache: 'pip'

- name: Install dependencies
run: |
run: |
python -m pip install -U pip
python -m pip install .
python -m pip install .[dev]
python -m pip install git+https://github.com/GalSim-developers/JAX-GalSim.git
- name: Run Ruff
run: ruff check --output-format=github .

- name: Run Tests
run: |
pytest --durations=0
6 changes: 3 additions & 3 deletions bpd/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ def get_contour_plot(
kde: bool | int | float = False,
) -> Figure:
c = ChainConsumer()
for name, samples in zip(names, samples_list):
for name, samples in zip(names, samples_list, strict=False):
df = pd.DataFrame.from_dict(samples)
c.add_chain(Chain(samples=df, name=name, kde=kde))
c.add_truth(Truth(location=truth))
return c.plotter.plot(figsize=figsize)


def get_gauss_pc_fig(
ax: Axes, samples: np.ndarray, truth: float, param_name: str = None
ax: Axes, samples: np.ndarray, truth: float, param_name: str | None = None
) -> None:
"""Get a marginal pc figure assuming Gaussian distribution of samples."""
assert samples.ndim == 2 # (n_chains, n_samples)
Expand All @@ -49,7 +49,7 @@ def get_gauss_pc_fig(


def get_pc_fig(
ax: Axes, samples: np.ndarray, truth: float, param_name: str = None
ax: Axes, samples: np.ndarray, truth: float, param_name: str | None = None
) -> None:
"""Get a marginal probability calibration figure using `hpdi` from `numpyro`."""
assert samples.ndim == 2 # (n_chains, n_samples)
Expand Down
2 changes: 1 addition & 1 deletion bpd/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def init_with_ball(
"""Sample ball given offset of each parameter."""
new = {}
keys = random.split(rng_key, len(true_params.keys()))
rng_key_dict = {p: k for p, k in zip(true_params, keys)}
rng_key_dict = {p: k for p, k in zip(true_params, keys, strict=False)}

for p, centr in true_params.items():
offset = offset_dict[p]
Expand Down
2 changes: 1 addition & 1 deletion bpd/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
def save_dataset(
ds: dict[str, Array], fpath: str | Path, overwrite: bool = False
) -> None:

if Path(fpath).exists() and not overwrite:
raise IOError("overwriting existing ds")
assert Path(fpath).suffix == ".npz"
Expand All @@ -22,6 +21,7 @@ def save_dataset(
def load_dataset(fpath: str) -> dict[str, Array]:
assert Path(fpath).exists(), "file path does not exists"
assert Path(fpath).suffix == ".npz"

ds = {}

npzfile = jnp.load(fpath)
Expand Down
2 changes: 1 addition & 1 deletion bpd/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def shear_loglikelihood_unreduced(
# the priors are callables for now on only ellipticities
# the interim_prior should have been used when obtaining e_obs from the chain (i.e. for now same sigma)
# normalizatoin in priors can be ignored for now as alpha is fixed.
N, K, _ = e_post.shape
_, K, _ = e_post.shape # (N, K, 2)

e_post_mag = jnp.sqrt(e_post[..., 0] ** 2 + e_post[..., 1] ** 2)
denom = interim_prior(e_post_mag) # (N, K), can ignore angle in prior as uniform
Expand Down
6 changes: 1 addition & 5 deletions bpd/pipelines/image_ellips.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@

import blackjax
import jax.numpy as jnp
from jax import Array
from jax import Array, random
from jax import jit as jjit
from jax import random
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

from bpd.chains import inference_loop
from bpd.draw import draw_gaussian, draw_gaussian_galsim
from bpd.noise import add_noise
from bpd.pipelines.toy_ellips import do_inference
from bpd.prior import ellip_mag_prior, sample_ellip_prior


Expand Down Expand Up @@ -122,7 +120,6 @@ def do_inference(
target_acceptance_rate: float = 0.80,
n_samples: int = 100,
):

key1, key2 = random.split(rng_key)

_logdensity = partial(logtarget_fnc, data=data)
Expand Down Expand Up @@ -164,7 +161,6 @@ def pipeline_image_interim_samples(
background: float = 1.0,
fft_size: int = 256,
):

k1, k2 = random.split(rng_key)

init_position = initialization_fnc(k1, true_params=true_params, data=target_image)
Expand Down
3 changes: 1 addition & 2 deletions bpd/pipelines/shear_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from typing import Callable

import blackjax
from jax import Array
from jax import Array, random
from jax import jit as jjit
from jax import random
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

Expand Down
4 changes: 1 addition & 3 deletions bpd/pipelines/toy_ellips.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import blackjax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import Array
from jax import Array, random, vmap
from jax import jit as jjit
from jax import random, vmap
from jax._src.prng import PRNGKeyArray

from bpd.chains import inference_loop
Expand Down Expand Up @@ -73,7 +72,6 @@ def pipeline_toy_ellips_samples(
k: int,
n_warmup_steps: int = 500,
):

k1, k2 = random.split(key)

true_g = jnp.array([g1, g2])
Expand Down
89 changes: 68 additions & 21 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,89 @@
requires = ["setuptools>=45"]
build-backend = "setuptools.build_meta"


[project]
name = "BPD"
authors = [{ name = "Ismael Mendoza" }]
description = "Bayesian Pixel Domain method for shear inference."
version = "0.0.1"
license = { file = "LICENSE" }
readme = "README.md"
dependencies = [
"numpy >=1.18.0",
"galsim >=2.3.0",
"jax >=0.4.30",
"jaxlib",
"blackjax >=1.2.0",
"numpyro >=0.15.0",
]
dependencies = ["numpy >=1.18.0", "galsim >=2.3.0", "jax >=0.4.30", "jaxlib", "blackjax >=1.2.0"]


[project.optional-dependencies]
dev = ["pytest", "click"]
dev = ["pytest", "click", "ruff", "ChainConsumer"]


[project.urls]
home = "https://github.com/LSSTDESC/BPD"


[tool.setuptools.packages.find]
include = ["bpd*", "scripts*"]
include = ["bpd*"]

[tool.ruff]
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]

# Same as Black.
line-length = 88
indent-width = 4

# Assume Python 3.8
target-version = "py310"

[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"

# Like Black, indent with spaces, rather than tabs.
indent-style = "space"

# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false

# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"

exclude = ["*.ipynb"]


[tool.flake8]
max-line-length = 88
ignore = ["C901", "E203", "W503"]
per-file-ignores = ["__init__.py:F401"]
[tool.ruff.lint]
select = ["E", "F", "I", "W", "B", "SIM", "PLE", "PLC", "PLW", "RUF"]
ignore = ["C901", "E203", "E501", "E731", "PLC0206", "RUF027"]
preview = true
exclude = ["*.ipynb", "scripts/image_shear_gpu_one_galaxy.py", "scripts/benchmarks/*.py"]


[tool.isort]
profile = "black"
multi_line_output = 3
include_trailing_comma = true
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-rav"
filterwarnings = ["ignore::DeprecationWarning:tensorflow.*"]
9 changes: 3 additions & 6 deletions scripts/benchmarks/benchmark1.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
def sample_ball(rng_key, center_params: dict):
new = {}
keys = random.split(rng_key, len(center_params.keys()))
rng_key_dict = {p: k for p, k in zip(center_params, keys)}
rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)}
for p in center_params:
centr = center_params[p]
if p == "f":
Expand Down Expand Up @@ -128,7 +128,6 @@ def draw_gal(f, hlr, g1, g2, x, y):


def _logprob_fn(params, data):

# prior
prior = jnp.array(0.0, device=GPU)
for p in ("f", "hlr", "g1", "g2"): # uniform priors
Expand All @@ -148,7 +147,7 @@ def _logprob_fn(params, data):


def _log_setup(snr: float):
with open(LOG_FILE, "a") as f:
with open(LOG_FILE, "a", encoding="utf-8") as f:
print(file=f)
print(
f"""Running benchmark 1 with configuration as follows. Variable number of chains.
Expand Down Expand Up @@ -186,7 +185,6 @@ def _log_setup(snr: float):

# vmap only rng_key
def do_warmup(rng_key, init_position: dict, data):

_logdensity = partial(_logprob_fn, data=data)

warmup = blackjax.window_adaptation(
Expand Down Expand Up @@ -258,7 +256,6 @@ def main():
_init_positions = {p: q[:n_chains] for p, q in all_init_positions.items()}

if ii == 0:

# compilation times
t1 = time.time()
(_sts, _tp), _ = jax.block_until_ready(
Expand Down Expand Up @@ -310,7 +307,7 @@ def main():
jnp.save(filepath, results)

_log_setup(snr)
with open(LOG_FILE, "a") as f:
with open(LOG_FILE, "a", encoding="utf-8") as f:
print(file=f)
print(f"results were saved to {filepath}", file=f)

Expand Down
13 changes: 5 additions & 8 deletions scripts/benchmarks/benchmark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
def sample_ball(rng_key, center_params: dict):
new = {}
keys = random.split(rng_key, len(center_params.keys()))
rng_key_dict = {p: k for p, k in zip(center_params, keys)}
rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)}
for p in center_params:
centr = center_params[p]
if p == "f":
Expand Down Expand Up @@ -130,7 +130,6 @@ def draw_gal(f, hlr, g1, g2, x, y):


def _logprob_fn(params, data):

# prior
prior = jnp.array(0.0, device=GPU)
for p in ("f", "hlr", "g1", "g2"): # uniform priors
Expand All @@ -150,15 +149,15 @@ def _logprob_fn(params, data):


def _log_setup(snr: float):
with open(LOG_FILE, "a") as f:
with open(LOG_FILE, "a", encoding='utf-8') as f:
print(file=f)
print(
f"""Running benchmark 2 with configuration as follows. Variable number of chains.
The sampler used is NUTS with standard warmup.
TAG: {TAG}
SEED: {SEED}
SEED: {SEED}
Overall sampler configuration (fixed):
max doublings: {MAX_DOUBLINGS}
Expand Down Expand Up @@ -188,7 +187,6 @@ def _log_setup(snr: float):

# vmap only rng_key
def do_warmup(rng_key, init_position: dict, data):

_logdensity = partial(_logprob_fn, data=data)

warmup = blackjax.window_adaptation(
Expand Down Expand Up @@ -271,7 +269,6 @@ def main():
_data_ii = data_gpu[:n_obj]

if ii == 0:

# compilation times
t1 = time.time()
(_sts, _tp), _ = jax.block_until_ready(
Expand Down Expand Up @@ -323,7 +320,7 @@ def main():
jnp.save(filepath, results)

_log_setup(snr)
with open(LOG_FILE, "a") as f:
with open(LOG_FILE, "a", encoding='utf-8') as f:
print(file=f)
print(f"results were saved to {filepath}", file=f)

Expand Down
Loading

0 comments on commit f843748

Please sign in to comment.