Skip to content

Commit 519b8db

Browse files
Require equal chain lengths with ArviZ >=1
Due to API and behavior changes of `az.from_dict`, it is no longer possible to pass lists of arrays.
1 parent fce33fe commit 519b8db

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ disable =
99
R0912, # too many branches
1010
R0913, # too many arguments
1111
R0914, # too many local variables
12+
R0915, # too many statements
1213
R1711, # useless return is okay

mcbackend/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,11 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
241241
" and you should expect many ArviZ functions to choke on it."
242242
"\nSpecify `to_inferencedata(equalize_chain_lengths=True)` to get regular inference data."
243243
)
244+
if _ARVIZ_VERSION > 0:
245+
raise NotImplementedError(
246+
"ArviZ 1.0 no longer supports uneven chain lengths."
247+
" See discussion in https://github.com/pymc-devs/mcbackend/pull/128."
248+
)
244249
else:
245250
msg += "\nTruncating to the length of the shortest chain."
246251
_log.warning(msg)
@@ -288,7 +293,7 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
288293
w_ss = cast(Dict[str, Union[Sequence, numpy.ndarray]], warmup_sample_stats)
289294
pst = cast(Dict[str, Union[Sequence, numpy.ndarray]], posterior)
290295
ss = cast(Dict[str, Union[Sequence, numpy.ndarray]], sample_stats)
291-
if not equalize_chain_lengths:
296+
if not equalize_chain_lengths or _ARVIZ_VERSION > 0:
292297
# Convert ragged arrays to object-dtyped ndarray because NumPy >=1.24.0 no longer does that automatically
293298
w_pst = {k: as_array_from_ragged(v) for k, v in warmup_posterior.items()}
294299
w_ss = {k: as_array_from_ragged(v) for k, v in warmup_sample_stats.items()}

mcbackend/test_backend_clickhouse.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
column_spec_for,
1818
create_chain_table,
1919
)
20-
from mcbackend.core import Run, chain_id
20+
from mcbackend.core import _ARVIZ_VERSION, Run, chain_id
2121
from mcbackend.meta import ChainMeta, RunMeta, Variable
2222
from mcbackend.test_utils import CheckBehavior, CheckPerformance, make_runmeta
2323

@@ -378,6 +378,11 @@ def test_to_inferencedata_equalize_chain_lengths(self, caplog):
378378
assert "Truncating to" in caplog.records[0].message
379379
assert len(idata_even.posterior.draw) == 14
380380

381+
if _ARVIZ_VERSION > 0:
382+
with pytest.raises(NotImplementedError, match="ArviZ 1.0 no longer supports"):
383+
run.to_inferencedata(equalize_chain_lengths=False)
384+
return
385+
381386
# With equalize=False the "draw" dim has the length of the longest chain (here: 8-3 = 5)
382387
caplog.clear()
383388
with caplog.at_level(logging.WARNING):

0 commit comments

Comments
 (0)