Skip to content

Commit 83547ca

Browse files
Make imports compatible with ArviZ 1.0
Closes #127
1 parent 366ab2d commit 83547ca

File tree

6 files changed

+27
-17
lines changed

6 files changed

+27
-17
lines changed

CITATION.cff

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ abstract: >-
2525
only MCMC draws but also sampler statistics, and are
2626
compatible with sparse data, or varying dimensionality.
2727
MCMC chains stored with McBackend can be queried directly,
28-
or convert to the popular ArviZ InferenceData objects.
28+
or convert to the popular ArviZ inference data objects.
2929
keywords:
3030
- mcmc
3131
- arviz

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ The `mcbackend` package consists of three parts:
1515
### Part 1: A schema for MCMC run & chain metadata
1616
No matter which programming language your favorite PPL is written in, the [ProtocolBuffers](https://developers.google.com/protocol-buffers/) from McBackend can be used to generate code in languages like C++, C#, Python and many more to represent commonly used metadata about MCMC runs, chains and model variables.
1717

18-
The definitions in [`protobufs/meta.proto`](./protobufs/meta.proto) are designed to maximize compatibility with [`ArviZ`](https://github.com/arviz-devs/arviz) objects, making it easy to transform MCMC draws stored according to the McBackend schema to `InferenceData` objects for plotting & analysis.
18+
The definitions in [`protobufs/meta.proto`](./protobufs/meta.proto) are designed to maximize compatibility with [`ArviZ`](https://github.com/arviz-devs/arviz) objects, making it easy to transform MCMC draws stored according to the McBackend schema to `xarray.DataTree` objects for plotting & analysis.
1919

2020
### Part 2: A storage backend interface
2121
The `draws` and `stats` created by MCMC sampling algorithms at runtime need to be stored _somewhere_.
@@ -83,10 +83,10 @@ chain = run.get_chains()[0]
8383
chain.get_draws("my favorite variable")
8484
# >>> array([ ... ])
8585

86-
# Convert everything to `InferenceData`
86+
# Convert everything to an inference data structure
8787
idata = run.to_inferencedata()
8888
print(idata)
89-
# >>> Inference data with groups:
89+
# >>> DataTree:
9090
# >>> > posterior
9191
# >>> > sample_stats
9292
# >>> > observed_data
@@ -113,7 +113,7 @@ Getting rid of `MultiTrace` was a [long-term goal](https://github.com/pymc-devs/
113113
First clone the repository and set up a development environment containing the protobuf compiler.
114114

115115
```bash
116-
mamba create -n mcb python=3.11 grpcio-tools protobuf -y
116+
mamba create -n mcb python=3.13 grpcio-tools protobuf -y
117117
activate mcb
118118
pip install -r requirements-dev.txt
119119
pip install --pre "betterproto[compiler]"

mcbackend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
except ModuleNotFoundError:
1515
pass
1616

17-
__version__ = "0.5.3"
17+
__version__ = "0.5.4"
1818
__all__ = [
1919
"NumPyBackend",
2020
"NullBackend",

mcbackend/core.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@
1414
from .utils import as_array_from_ragged
1515

1616
try:
17-
from arviz import InferenceData, from_dict
17+
from arviz import from_dict
18+
19+
try:
20+
from arviz import InferenceData as IData
21+
except ImportError:
22+
from xarray import DataTree as IData
1823

1924
_HAS_ARVIZ = True
2025
except ModuleNotFoundError:
21-
InferenceData = TypeVar("InferenceData") # type: ignore
26+
IData = TypeVar("IData") # type: ignore
2227
_HAS_ARVIZ = False
2328

2429
Shape = Sequence[int]
@@ -192,8 +197,8 @@ def constant_data(self) -> Dict[str, numpy.ndarray]:
192197
def observed_data(self) -> Dict[str, numpy.ndarray]:
193198
return {dv.name: ndarray_to_numpy(dv.value) for dv in self.meta.data if dv.is_observed}
194199

195-
def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) -> InferenceData:
196-
"""Creates an ArviZ ``InferenceData`` object from this run.
200+
def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) -> IData:
201+
"""Creates an ArviZ inference data structure from this run.
197202
198203
Parameters
199204
----------
@@ -204,8 +209,9 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
204209
205210
Returns
206211
-------
207-
idata : arviz.InferenceData
212+
idata
208213
Samples and metadata of this inference run.
214+
``az.InferenceData`` (ArviZ <1) or ``xarray.DataTree`` (ArviZ >1).
209215
"""
210216
if not _HAS_ARVIZ:
211217
raise ModuleNotFoundError("ArviZ is not installed.")
@@ -216,7 +222,7 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
216222
nonrigid_vars = {var for var in variables if var.undefined_ndim or not is_rigid(var.shape)}
217223
if nonrigid_vars:
218224
raise NotImplementedError(
219-
"Creating InferenceData from runs with non-rigid variables is not supported."
225+
"Creating inference data from runs with non-rigid variables is not supported."
220226
f" The non-rigid variables are: {nonrigid_vars}."
221227
)
222228

@@ -226,11 +232,11 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
226232
if not equalize_chain_lengths:
227233
msg += (
228234
"\nArviZ does not properly support uneven chain lengths (see ArviZ issue #2094)."
229-
"\nWe'll try to give you an InferenceData, but best case the chain & draw dimensions"
235+
"\nWe'll try to give you a DataTree, but best case the chain & draw dimensions"
230236
" will be messed-up as {'chain': 1, 'draws': n_chains}."
231-
"\nYou won't be able to save this InferenceData to a file"
237+
"\nYou might not be able to save this DataTree to a file"
232238
" and you should expect many ArviZ functions to choke on it."
233-
"\nSpecify `to_inferencedata(equalize_chain_lengths=True)` to get regular InferenceData."
239+
"\nSpecify `to_inferencedata(equalize_chain_lengths=True)` to get regular inference data."
234240
)
235241
else:
236242
msg += "\nTruncating to the length of the shortest chain."

mcbackend/test_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy
99
import pandas
1010
import pytest
11+
import xarray
1112

1213
import mcbackend
1314
from mcbackend import utils as mutils
@@ -299,7 +300,10 @@ def test__to_inferencedata(self, tstatname, caplog):
299300
chain.append(d, s)
300301

301302
idata = run.to_inferencedata()
302-
assert isinstance(idata, arviz.InferenceData)
303+
try:
304+
assert isinstance(idata, arviz.InferenceData)
305+
except AttributeError:
306+
assert isinstance(idata, xarray.DataTree)
303307
assert idata.warmup_posterior.dims["chain"] == 1
304308
assert idata.posterior.dims["chain"] == 1
305309
if tstatname == "nottune":

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
arviz
1+
arviz>=1.0.0rc0
22
clickhouse-driver
33
flake8
44
pre-commit

0 commit comments

Comments
 (0)