11from dataclasses import field
22from functools import partial
33from operator import add , sub
4- from typing import ClassVar , Optional , Union
4+ from typing import ClassVar , Optional
55
6- import numpy as np
7- from jax import value_and_grad
6+ from jax import Array , value_and_grad
7+ from jax . typing import ArrayLike
88import jax .numpy as jnp
99from jax .tree_util import tree_map
1010
1111from pmwd .tree_util import pytree_dataclass
1212from pmwd .configuration import Configuration
1313
1414
15- FloatParam = Union [float , jnp .ndarray ]
16-
17-
1815@partial (pytree_dataclass , aux_fields = "conf" , frozen = True )
1916class Cosmology :
2017 """Cosmological and configuration parameters, "immutable" as a frozen dataclass.
@@ -34,45 +31,45 @@ class Cosmology:
3431 ----------
3532 conf : Configuration
3633 Configuration parameters.
37- A_s_1e9 : float or jax.numpy.ndarray
34+ A_s_1e9 : float ArrayLike
3835 Primordial scalar power spectrum amplitude, multiplied by 1e9.
39- n_s : float or jax.numpy.ndarray
36+ n_s : float ArrayLike
4037 Primordial scalar power spectrum spectral index.
41- Omega_m : float or jax.numpy.ndarray
38+ Omega_m : float ArrayLike
4239 Total matter density parameter today.
43- Omega_b : float or jax.numpy.ndarray
40+ Omega_b : float ArrayLike
4441 Baryonic matter density parameter today.
45- Omega_k_ : None, float, or jax.numpy.ndarray , optional
42+ Omega_k_ : None or float ArrayLike , optional
4643 Spatial curvature density parameter today. Default is None.
47- w_0_ : None, float, or jax.numpy.ndarray , optional
44+ w_0_ : None or float ArrayLike , optional
4845 Dark energy equation of state constant parameter. Default is None.
49- w_a_ : None, float, or jax.numpy.ndarray , optional
46+ w_a_ : None or float ArrayLike , optional
5047 Dark energy equation of state linear parameter. Default is None.
51- h : float or jax.numpy.ndarray
48+ h : float ArrayLike
5249 Hubble constant in unit of 100 [km/s/Mpc].
5350
5451 """
5552
5653 conf : Configuration = field (repr = False )
5754
58- A_s_1e9 : FloatParam
59- n_s : FloatParam
60- Omega_m : FloatParam
61- Omega_b : FloatParam
62- h : FloatParam
55+ A_s_1e9 : ArrayLike
56+ n_s : ArrayLike
57+ Omega_m : ArrayLike
58+ Omega_b : ArrayLike
59+ h : ArrayLike
6360
64- Omega_k_ : Optional [FloatParam ] = None
61+ Omega_k_ : Optional [ArrayLike ] = None
6562 Omega_k_fixed : ClassVar [float ] = 0
66- w_0_ : Optional [FloatParam ] = None
63+ w_0_ : Optional [ArrayLike ] = None
6764 w_0_fixed : ClassVar [float ] = - 1
68- w_a_ : Optional [FloatParam ] = None
65+ w_a_ : Optional [ArrayLike ] = None
6966 w_a_fixed : ClassVar [float ] = 0
7067
71- transfer : Optional [jnp . ndarray ] = field (default = None , compare = False )
68+ transfer : Optional [Array ] = field (default = None , compare = False )
7269
73- growth : Optional [jnp . ndarray ] = field (default = None , compare = False )
70+ growth : Optional [Array ] = field (default = None , compare = False )
7471
75- varlin : Optional [jnp . ndarray ] = field (default = None , compare = False )
72+ varlin : Optional [Array ] = field (default = None , compare = False )
7673
7774 def __post_init__ (self ):
7875 if self ._is_transforming ():
@@ -191,13 +188,13 @@ def E2(a, cosmo):
191188
192189 Parameters
193190 ----------
194- a : array_like
191+ a : ArrayLike
195192 Scale factors.
196193 cosmo : Cosmology
197194
198195 Returns
199196 -------
200- E2 : jax.numpy.ndarray of cosmo.conf.cosmo_dtype
197+ E2 : jax.Array of cosmo.conf.cosmo_dtype
201198 Squared Hubble parameter time scaling factors.
202199
203200 Notes
@@ -229,13 +226,13 @@ def H_deriv(a, cosmo):
229226
230227 Parameters
231228 ----------
232- a : array_like
229+ a : ArrayLike
233230 Scale factors.
234231 cosmo : Cosmology
235232
236233 Returns
237234 -------
238- dlnH_dlna : jax.numpy.ndarray of cosmo.conf.cosmo_dtype
235+ dlnH_dlna : jax.Array of cosmo.conf.cosmo_dtype
239236 Hubble parameter derivatives.
240237
241238 """
@@ -250,13 +247,13 @@ def Omega_m_a(a, cosmo):
250247
251248 Parameters
252249 ----------
253- a : array_like
250+ a : ArrayLike
254251 Scale factors.
255252 cosmo : Cosmology
256253
257254 Returns
258255 -------
259- Omega : jax.numpy.ndarray of cosmo.conf.cosmo_dtype
256+ Omega : jax.Array of cosmo.conf.cosmo_dtype
260257 Matter density parameters.
261258
262259 Notes
0 commit comments