Simple library that provides performant implementations of standard econometrics routines in the JAX ecosystem.
jaxarrays everywherelineaxfor solving linear systemsjaxoptandoptaxfor numerical optimization (Levenberg–Marquardt for NNLS-type problems and SGD for larger problems)
- Linear Regression with multiple solver backends (lineax, JAX, numpy)
- Fixed Effects Regression with JAX-accelerated alternating projections
- GMM and IV Estimation
- Causal Inference (IPW, AIPW, Entropy Balancing)
- Maximum Likelihood Estimation (Logistic, Poisson)
jaxonometrics supports high-performance fixed effects regression with multiple FE variables:
from jaxonometrics import LinearRegression
import jax.numpy as jnp
# Your data
X = jnp.asarray(data) # (n_obs, n_features)
y = jnp.asarray(target) # (n_obs,)
firm_ids = jnp.asarray(firm_identifiers, dtype=jnp.int32)
year_ids = jnp.asarray(year_identifiers, dtype=jnp.int32)
# Two-way fixed effects
model = LinearRegression(solver="lineax")
model.fit(X, y, fe=[firm_ids, year_ids])
coefficients = model.params["coef"]uv pip install git+https://github.com/py-econometrics/jaxonometricsor clone the repository and install in editable mode.
Run the full test suite:
pytest tests/ -vRun only fixed effects tests:
pytest tests/ -m fe -vRun tests excluding slow ones:
pytest tests/ -m "not slow" -v