Skip to content

Commit 9e231c9

Browse files
committed
Change to use jax.Array and jax.typing
1 parent e830c2f commit 9e231c9

File tree

13 files changed

+96
-103
lines changed

13 files changed

+96
-103
lines changed

pmwd/boltzmann.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ def transfer_fit(k, cosmo, conf):
3434
3535
Parameters
3636
----------
37-
k : array_like
37+
k : ArrayLike
3838
Wavenumbers in [1/L].
3939
cosmo : Cosmology
4040
conf : Configuration
4141
4242
Returns
4343
-------
44-
T : jax.numpy.ndarray of (k * 1.).dtype
44+
T : jax.Array of (k * 1.).dtype
4545
Matter transfer function.
4646
4747
.. _Transfer Function:
@@ -129,14 +129,14 @@ def transfer(k, cosmo, conf):
129129
130130
Parameters
131131
----------
132-
k : array_like
132+
k : ArrayLike
133133
Wavenumbers in [1/L].
134134
cosmo : Cosmology
135135
conf : Configuration
136136
137137
Returns
138138
-------
139-
T : jax.numpy.ndarray of (k * 1.).dtype
139+
T : jax.Array of (k * 1.).dtype
140140
Matter transfer function.
141141
142142
Raises
@@ -238,7 +238,7 @@ def growth(a, cosmo, conf, order=1, deriv=0):
238238
239239
Parameters
240240
----------
241-
a : array_like
241+
a : ArrayLike
242242
Scale factors.
243243
cosmo : Cosmology
244244
conf : Configuration
@@ -249,7 +249,7 @@ def growth(a, cosmo, conf, order=1, deriv=0):
249249
250250
Returns
251251
-------
252-
D : jax.numpy.ndarray of (a * 1.).dtype
252+
D : jax.Array of (a * 1.).dtype
253253
Growth functions or derivatives.
254254
255255
Raises
@@ -297,16 +297,16 @@ def varlin(R, a, cosmo, conf):
297297
298298
Parameters
299299
----------
300-
R : array_like
300+
R : ArrayLike
301301
Scales in [L].
302-
a : array_like or None
302+
a : ArrayLike or None
303303
Scale factors. If None, output is not scaled by growth.
304304
cosmo : Cosmology
305305
conf : Configuration
306306
307307
Returns
308308
-------
309-
sigma2 : jax.numpy.ndarray of (k * a * 1.).dtype
309+
sigma2 : jax.Array of (k * a * 1.).dtype
310310
Linear matter overdensity variance.
311311
312312
Raises
@@ -401,16 +401,16 @@ def linear_power(k, a, cosmo, conf):
401401
402402
Parameters
403403
----------
404-
k : array_like
404+
k : ArrayLike
405405
Wavenumbers in [1/L].
406-
a : array_like or None
406+
a : ArrayLike or None
407407
Scale factors. If None, output is not scaled by growth.
408408
cosmo : Cosmology
409409
conf : Configuration
410410
411411
Returns
412412
-------
413-
Plin : jax.numpy.ndarray of (k * a * 1.).dtype
413+
Plin : jax.Array of (k * a * 1.).dtype
414414
Linear matter power spectrum in [L^3].
415415
416416
Raises

pmwd/configuration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import math
33
from typing import ClassVar, Optional, Tuple, Union
44

5-
from numpy.typing import DTypeLike
65
import jax
76
from jax import ensure_compile_time_eval
7+
from jax.typing import DTypeLike
88
import jax.numpy as jnp
99
from jax.tree_util import tree_map
1010
from mcfit import TophatVar

pmwd/cosmology.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
from dataclasses import field
22
from functools import partial
33
from operator import add, sub
4-
from typing import ClassVar, Optional, Union
4+
from typing import ClassVar, Optional
55

6-
import numpy as np
7-
from jax import value_and_grad
6+
from jax import Array, value_and_grad
7+
from jax.typing import ArrayLike
88
import jax.numpy as jnp
99
from jax.tree_util import tree_map
1010

1111
from pmwd.tree_util import pytree_dataclass
1212
from pmwd.configuration import Configuration
1313

1414

15-
FloatParam = Union[float, jnp.ndarray]
16-
17-
1815
@partial(pytree_dataclass, aux_fields="conf", frozen=True)
1916
class Cosmology:
2017
"""Cosmological and configuration parameters, "immutable" as a frozen dataclass.
@@ -34,45 +31,45 @@ class Cosmology:
3431
----------
3532
conf : Configuration
3633
Configuration parameters.
37-
A_s_1e9 : float or jax.numpy.ndarray
34+
A_s_1e9 : float ArrayLike
3835
Primordial scalar power spectrum amplitude, multiplied by 1e9.
39-
n_s : float or jax.numpy.ndarray
36+
n_s : float ArrayLike
4037
Primordial scalar power spectrum spectral index.
41-
Omega_m : float or jax.numpy.ndarray
38+
Omega_m : float ArrayLike
4239
Total matter density parameter today.
43-
Omega_b : float or jax.numpy.ndarray
40+
Omega_b : float ArrayLike
4441
Baryonic matter density parameter today.
45-
Omega_k_ : None, float, or jax.numpy.ndarray, optional
42+
Omega_k_ : None or float ArrayLike, optional
4643
Spatial curvature density parameter today. Default is None.
47-
w_0_ : None, float, or jax.numpy.ndarray, optional
44+
w_0_ : None or float ArrayLike, optional
4845
Dark energy equation of state constant parameter. Default is None.
49-
w_a_ : None, float, or jax.numpy.ndarray, optional
46+
w_a_ : None or float ArrayLike, optional
5047
Dark energy equation of state linear parameter. Default is None.
51-
h : float or jax.numpy.ndarray
48+
h : float ArrayLike
5249
Hubble constant in unit of 100 [km/s/Mpc].
5350
5451
"""
5552

5653
conf: Configuration = field(repr=False)
5754

58-
A_s_1e9: FloatParam
59-
n_s: FloatParam
60-
Omega_m: FloatParam
61-
Omega_b: FloatParam
62-
h: FloatParam
55+
A_s_1e9: ArrayLike
56+
n_s: ArrayLike
57+
Omega_m: ArrayLike
58+
Omega_b: ArrayLike
59+
h: ArrayLike
6360

64-
Omega_k_: Optional[FloatParam] = None
61+
Omega_k_: Optional[ArrayLike] = None
6562
Omega_k_fixed: ClassVar[float] = 0
66-
w_0_: Optional[FloatParam] = None
63+
w_0_: Optional[ArrayLike] = None
6764
w_0_fixed: ClassVar[float] = -1
68-
w_a_: Optional[FloatParam] = None
65+
w_a_: Optional[ArrayLike] = None
6966
w_a_fixed: ClassVar[float] = 0
7067

71-
transfer: Optional[jnp.ndarray] = field(default=None, compare=False)
68+
transfer: Optional[Array] = field(default=None, compare=False)
7269

73-
growth: Optional[jnp.ndarray] = field(default=None, compare=False)
70+
growth: Optional[Array] = field(default=None, compare=False)
7471

75-
varlin: Optional[jnp.ndarray] = field(default=None, compare=False)
72+
varlin: Optional[Array] = field(default=None, compare=False)
7673

7774
def __post_init__(self):
7875
if self._is_transforming():
@@ -191,13 +188,13 @@ def E2(a, cosmo):
191188
192189
Parameters
193190
----------
194-
a : array_like
191+
a : ArrayLike
195192
Scale factors.
196193
cosmo : Cosmology
197194
198195
Returns
199196
-------
200-
E2 : jax.numpy.ndarray of cosmo.conf.cosmo_dtype
197+
E2 : jax.Array of cosmo.conf.cosmo_dtype
201198
Squared Hubble parameter time scaling factors.
202199
203200
Notes
@@ -229,13 +226,13 @@ def H_deriv(a, cosmo):
229226
230227
Parameters
231228
----------
232-
a : array_like
229+
a : ArrayLike
233230
Scale factors.
234231
cosmo : Cosmology
235232
236233
Returns
237234
-------
238-
dlnH_dlna : jax.numpy.ndarray of cosmo.conf.cosmo_dtype
235+
dlnH_dlna : jax.Array of cosmo.conf.cosmo_dtype
239236
Hubble parameter derivatives.
240237
241238
"""
@@ -250,13 +247,13 @@ def Omega_m_a(a, cosmo):
250247
251248
Parameters
252249
----------
253-
a : array_like
250+
a : ArrayLike
254251
Scale factors.
255252
cosmo : Cosmology
256253
257254
Returns
258255
-------
259-
Omega : jax.numpy.ndarray of cosmo.conf.cosmo_dtype
256+
Omega : jax.Array of cosmo.conf.cosmo_dtype
260257
Matter density parameters.
261258
262259
Notes

pmwd/gather.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@ def gather(ptcl, conf, mesh, val=0, offset=0, cell_size=None):
1212
----------
1313
ptcl : Particles
1414
conf : Configuration
15-
mesh : array_like
15+
mesh : ArrayLike
1616
Input mesh.
17-
val : array_like, optional
17+
val : ArrayLike, optional
1818
Input values, can be 0D.
19-
offset : array_like, optional
19+
offset : ArrayLike, optional
2020
Offset of mesh to particle grid. If 0D, the value is used in each dimension.
2121
cell_size : float, optional
2222
Mesh cell size in [L]. Default is ``conf.cell_size``.
2323
2424
Returns
2525
-------
26-
val : jax.numpy.ndarray
26+
val : jax.Array
2727
Output values.
2828
2929
"""

pmwd/lpt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def levi_civita(indices):
8181
8282
Parameters
8383
----------
84-
indices : array_like
84+
indices : ArrayLike
8585
8686
Returns
8787
-------
@@ -140,7 +140,7 @@ def lpt(modes, cosmo, conf):
140140
141141
Parameters
142142
----------
143-
modes : jax.numpy.ndarray
143+
modes : jax.Array
144144
Linear matter overdensity Fourier modes in [L^3].
145145
cosmo : Cosmology
146146
conf : Configuration

pmwd/modes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def white_noise(seed, conf, real=False, unit_abs=False, negate=False):
2727
2828
Returns
2929
-------
30-
modes : jax.numpy.ndarray of conf.float_dtype
30+
modes : jax.Array of conf.float_dtype
3131
White noise modes.
3232
3333
"""
@@ -75,7 +75,7 @@ def linear_modes(modes, cosmo, conf, a=None):
7575
7676
Parameters
7777
----------
78-
modes : jax.numpy.ndarray
78+
modes : jax.Array
7979
Fourier or real modes with white noise prior.
8080
cosmo : Cosmology
8181
conf : Configuration
@@ -84,7 +84,7 @@ def linear_modes(modes, cosmo, conf, a=None):
8484
8585
Returns
8686
-------
87-
modes : jax.numpy.ndarray of conf.float_dtype
87+
modes : jax.Array of conf.float_dtype
8888
Linear matter overdensity Fourier modes in [L^3].
8989
9090
Notes

0 commit comments

Comments
 (0)