Skip to content

Commit f843748

Browse files
Add ruff (#33)
* add ruff * ruff all files * ruff fix * add ruff check to tests * more options for ruff, no notebooks for now * ruff * ruff * ruff * ruffs * ruff * ruff * we sometimes use {} without f * some ruff * space * mistake * fix imports * options * ignore tensorflow warnings
1 parent b049873 commit f843748

22 files changed

+114
-93
lines changed

.github/workflows/tests.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
runs-on: ubuntu-latest
1414
strategy:
1515
matrix:
16-
python-version: ["3.10", "3.11", "3.12"]
16+
python-version: ["3.10", "3.11", "3.12", "3.13"]
1717

1818
steps:
1919
- name: Checkout github repo
@@ -26,12 +26,15 @@ jobs:
2626
cache: 'pip'
2727

2828
- name: Install dependencies
29-
run: |
29+
run: |
3030
python -m pip install -U pip
3131
python -m pip install .
3232
python -m pip install .[dev]
3333
python -m pip install git+https://github.com/GalSim-developers/JAX-GalSim.git
3434
35+
- name: Run Ruff
36+
run: ruff check --output-format=github .
37+
3538
- name: Run Tests
3639
run: |
3740
pytest --durations=0

bpd/diagnostics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ def get_contour_plot(
1616
kde: bool | int | float = False,
1717
) -> Figure:
1818
c = ChainConsumer()
19-
for name, samples in zip(names, samples_list):
19+
for name, samples in zip(names, samples_list, strict=False):
2020
df = pd.DataFrame.from_dict(samples)
2121
c.add_chain(Chain(samples=df, name=name, kde=kde))
2222
c.add_truth(Truth(location=truth))
2323
return c.plotter.plot(figsize=figsize)
2424

2525

2626
def get_gauss_pc_fig(
27-
ax: Axes, samples: np.ndarray, truth: float, param_name: str = None
27+
ax: Axes, samples: np.ndarray, truth: float, param_name: str | None = None
2828
) -> None:
2929
"""Get a marginal pc figure assuming Gaussian distribution of samples."""
3030
assert samples.ndim == 2 # (n_chains, n_samples)
@@ -49,7 +49,7 @@ def get_gauss_pc_fig(
4949

5050

5151
def get_pc_fig(
52-
ax: Axes, samples: np.ndarray, truth: float, param_name: str = None
52+
ax: Axes, samples: np.ndarray, truth: float, param_name: str | None = None
5353
) -> None:
5454
"""Get a marginal probability calibration figure using `hpdi` from `numpyro`."""
5555
assert samples.ndim == 2 # (n_chains, n_samples)

bpd/initialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def init_with_ball(
2626
"""Sample ball given offset of each parameter."""
2727
new = {}
2828
keys = random.split(rng_key, len(true_params.keys()))
29-
rng_key_dict = {p: k for p, k in zip(true_params, keys)}
29+
rng_key_dict = {p: k for p, k in zip(true_params, keys, strict=False)}
3030

3131
for p, centr in true_params.items():
3232
offset = offset_dict[p]

bpd/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
def save_dataset(
1212
ds: dict[str, Array], fpath: str | Path, overwrite: bool = False
1313
) -> None:
14-
1514
if Path(fpath).exists() and not overwrite:
1615
raise IOError("overwriting existing ds")
1716
assert Path(fpath).suffix == ".npz"
@@ -22,6 +21,7 @@ def save_dataset(
2221
def load_dataset(fpath: str) -> dict[str, Array]:
2322
assert Path(fpath).exists(), "file path does not exists"
2423
assert Path(fpath).suffix == ".npz"
24+
2525
ds = {}
2626

2727
npzfile = jnp.load(fpath)

bpd/likelihood.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def shear_loglikelihood_unreduced(
1717
# the priors are callables for now on only ellipticities
1818
# the interim_prior should have been used when obtaining e_obs from the chain (i.e. for now same sigma)
1919
# normalizatoin in priors can be ignored for now as alpha is fixed.
20-
N, K, _ = e_post.shape
20+
_, K, _ = e_post.shape # (N, K, 2)
2121

2222
e_post_mag = jnp.sqrt(e_post[..., 0] ** 2 + e_post[..., 1] ** 2)
2323
denom = interim_prior(e_post_mag) # (N, K), can ignore angle in prior as uniform

bpd/pipelines/image_ellips.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@
33

44
import blackjax
55
import jax.numpy as jnp
6-
from jax import Array
6+
from jax import Array, random
77
from jax import jit as jjit
8-
from jax import random
98
from jax._src.prng import PRNGKeyArray
109
from jax.scipy import stats
1110

1211
from bpd.chains import inference_loop
1312
from bpd.draw import draw_gaussian, draw_gaussian_galsim
1413
from bpd.noise import add_noise
15-
from bpd.pipelines.toy_ellips import do_inference
1614
from bpd.prior import ellip_mag_prior, sample_ellip_prior
1715

1816

@@ -122,7 +120,6 @@ def do_inference(
122120
target_acceptance_rate: float = 0.80,
123121
n_samples: int = 100,
124122
):
125-
126123
key1, key2 = random.split(rng_key)
127124

128125
_logdensity = partial(logtarget_fnc, data=data)
@@ -164,7 +161,6 @@ def pipeline_image_interim_samples(
164161
background: float = 1.0,
165162
fft_size: int = 256,
166163
):
167-
168164
k1, k2 = random.split(rng_key)
169165

170166
init_position = initialization_fnc(k1, true_params=true_params, data=target_image)

bpd/pipelines/shear_inference.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
from typing import Callable
33

44
import blackjax
5-
from jax import Array
5+
from jax import Array, random
66
from jax import jit as jjit
7-
from jax import random
87
from jax._src.prng import PRNGKeyArray
98
from jax.scipy import stats
109

bpd/pipelines/toy_ellips.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
import blackjax
55
import jax.numpy as jnp
66
import jax.scipy as jsp
7-
from jax import Array
7+
from jax import Array, random, vmap
88
from jax import jit as jjit
9-
from jax import random, vmap
109
from jax._src.prng import PRNGKeyArray
1110

1211
from bpd.chains import inference_loop
@@ -73,7 +72,6 @@ def pipeline_toy_ellips_samples(
7372
k: int,
7473
n_warmup_steps: int = 500,
7574
):
76-
7775
k1, k2 = random.split(key)
7876

7977
true_g = jnp.array([g1, g2])

pyproject.toml

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,89 @@
22
requires = ["setuptools>=45"]
33
build-backend = "setuptools.build_meta"
44

5+
56
[project]
67
name = "BPD"
78
authors = [{ name = "Ismael Mendoza" }]
89
description = "Bayesian Pixel Domain method for shear inference."
910
version = "0.0.1"
1011
license = { file = "LICENSE" }
1112
readme = "README.md"
12-
dependencies = [
13-
"numpy >=1.18.0",
14-
"galsim >=2.3.0",
15-
"jax >=0.4.30",
16-
"jaxlib",
17-
"blackjax >=1.2.0",
18-
"numpyro >=0.15.0",
19-
]
13+
dependencies = ["numpy >=1.18.0", "galsim >=2.3.0", "jax >=0.4.30", "jaxlib", "blackjax >=1.2.0"]
14+
2015

2116
[project.optional-dependencies]
22-
dev = ["pytest", "click"]
17+
dev = ["pytest", "click", "ruff", "ChainConsumer"]
18+
2319

2420
[project.urls]
2521
home = "https://github.com/LSSTDESC/BPD"
2622

23+
2724
[tool.setuptools.packages.find]
28-
include = ["bpd*", "scripts*"]
25+
include = ["bpd*"]
26+
27+
[tool.ruff]
28+
exclude = [
29+
".bzr",
30+
".direnv",
31+
".eggs",
32+
".git",
33+
".git-rewrite",
34+
".hg",
35+
".ipynb_checkpoints",
36+
".mypy_cache",
37+
".nox",
38+
".pants.d",
39+
".pyenv",
40+
".pytest_cache",
41+
".pytype",
42+
".ruff_cache",
43+
".svn",
44+
".tox",
45+
".venv",
46+
".vscode",
47+
"__pypackages__",
48+
"_build",
49+
"buck-out",
50+
"build",
51+
"dist",
52+
"node_modules",
53+
"site-packages",
54+
"venv",
55+
]
56+
57+
# Same as Black.
58+
line-length = 88
59+
indent-width = 4
60+
61+
# Assume Python 3.8
62+
target-version = "py310"
63+
64+
[tool.ruff.format]
65+
# Like Black, use double quotes for strings.
66+
quote-style = "double"
67+
68+
# Like Black, indent with spaces, rather than tabs.
69+
indent-style = "space"
70+
71+
# Like Black, respect magic trailing commas.
72+
skip-magic-trailing-comma = false
73+
74+
# Like Black, automatically detect the appropriate line ending.
75+
line-ending = "auto"
76+
77+
exclude = ["*.ipynb"]
2978

3079

31-
[tool.flake8]
32-
max-line-length = 88
33-
ignore = ["C901", "E203", "W503"]
34-
per-file-ignores = ["__init__.py:F401"]
80+
[tool.ruff.lint]
81+
select = ["E", "F", "I", "W", "B", "SIM", "PLE", "PLC", "PLW", "RUF"]
82+
ignore = ["C901", "E203", "E501", "E731", "PLC0206", "RUF027"]
83+
preview = true
84+
exclude = ["*.ipynb", "scripts/image_shear_gpu_one_galaxy.py", "scripts/benchmarks/*.py"]
3585

3686

37-
[tool.isort]
38-
profile = "black"
39-
multi_line_output = 3
40-
include_trailing_comma = true
41-
use_parentheses = true
42-
ensure_newline_before_comments = true
43-
line_length = 88
87+
[tool.pytest.ini_options]
88+
minversion = "6.0"
89+
addopts = "-rav"
90+
filterwarnings = ["ignore::DeprecationWarning:tensorflow.*"]

scripts/benchmarks/benchmark1.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
def sample_ball(rng_key, center_params: dict):
8383
new = {}
8484
keys = random.split(rng_key, len(center_params.keys()))
85-
rng_key_dict = {p: k for p, k in zip(center_params, keys)}
85+
rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)}
8686
for p in center_params:
8787
centr = center_params[p]
8888
if p == "f":
@@ -128,7 +128,6 @@ def draw_gal(f, hlr, g1, g2, x, y):
128128

129129

130130
def _logprob_fn(params, data):
131-
132131
# prior
133132
prior = jnp.array(0.0, device=GPU)
134133
for p in ("f", "hlr", "g1", "g2"): # uniform priors
@@ -148,7 +147,7 @@ def _logprob_fn(params, data):
148147

149148

150149
def _log_setup(snr: float):
151-
with open(LOG_FILE, "a") as f:
150+
with open(LOG_FILE, "a", encoding="utf-8") as f:
152151
print(file=f)
153152
print(
154153
f"""Running benchmark 1 with configuration as follows. Variable number of chains.
@@ -186,7 +185,6 @@ def _log_setup(snr: float):
186185

187186
# vmap only rng_key
188187
def do_warmup(rng_key, init_position: dict, data):
189-
190188
_logdensity = partial(_logprob_fn, data=data)
191189

192190
warmup = blackjax.window_adaptation(
@@ -258,7 +256,6 @@ def main():
258256
_init_positions = {p: q[:n_chains] for p, q in all_init_positions.items()}
259257

260258
if ii == 0:
261-
262259
# compilation times
263260
t1 = time.time()
264261
(_sts, _tp), _ = jax.block_until_ready(
@@ -310,7 +307,7 @@ def main():
310307
jnp.save(filepath, results)
311308

312309
_log_setup(snr)
313-
with open(LOG_FILE, "a") as f:
310+
with open(LOG_FILE, "a", encoding="utf-8") as f:
314311
print(file=f)
315312
print(f"results were saved to {filepath}", file=f)
316313

0 commit comments

Comments
 (0)