Skip to content

Commit 27b2052

Browse files
plot_hdi raise exception when x is string (#2412) (#2413)
* plot_hdi - add exception if x is string type * add test that plot_hdi raises when x is string * run black * update changelog * Change TypeError to NotImplementedError, modify error message * Fix typo in error message * remove backticks from error message * escape special characters in pytest.raise * pylint and black * run isort
1 parent ebd6c6d commit 27b2052

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- Only emit a warning for custom groups in `InferenceData` when explicitly requested ([2401](https://github.com/arviz-devs/arviz/pull/2401))
1010
- Splits Bayes Factor computation out from `az.plot_bf` into `az.bayes_factor` ([2402](https://github.com/arviz-devs/arviz/issues/2402))
1111
- Update `method="sd"` of `mcse` to not use normality assumption ([2167](https://github.com/arviz-devs/arviz/pull/2167))
12+
- Add exception in `az.plot_hdi` for `x` of type `str` ([2413](https://github.com/arviz-devs/arviz/pull/2413))
1213

1314
### Documentation
1415
- Add example of ECDF comparison plot to gallery ([2178](https://github.com/arviz-devs/arviz/pull/2178))

arviz/plots/hdiplot.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ def plot_hdi(
136136
x = np.asarray(x)
137137
x_shape = x.shape
138138

139+
if isinstance(x[0], str):
140+
raise NotImplementedError(
141+
"The `arviz.plot_hdi()` function does not support categorical data. "
142+
"Consider using `arviz.plot_forest()`."
143+
)
139144
if y is None and hdi_data is None:
140145
raise ValueError("One of {y, hdi_data} is required")
141146
if hdi_data is not None and y is not None:

arviz/tests/base_tests/test_plots_matplotlib.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,22 @@
22

33
# pylint: disable=redefined-outer-name,too-many-lines
44
import os
5+
import re
56
from copy import deepcopy
67

78
import matplotlib.pyplot as plt
89
import numpy as np
910
import pytest
11+
import xarray as xr
1012
from matplotlib import animation
1113
from pandas import DataFrame
1214
from scipy.stats import gaussian_kde, norm
13-
import xarray as xr
1415

1516
from ...data import from_dict, load_arviz_data
1617
from ...plots import (
1718
plot_autocorr,
18-
plot_bpv,
1919
plot_bf,
20+
plot_bpv,
2021
plot_compare,
2122
plot_density,
2223
plot_dist,
@@ -43,20 +44,20 @@
4344
plot_ts,
4445
plot_violin,
4546
)
47+
from ...plots.dotplot import wilkinson_algorithm
48+
from ...plots.plot_utils import plot_point_interval
4649
from ...rcparams import rc_context, rcParams
4750
from ...stats import compare, hdi, loo, waic
4851
from ...stats.density_utils import kde as _kde
49-
from ...utils import _cov, BehaviourChangeWarning
50-
from ...plots.plot_utils import plot_point_interval
51-
from ...plots.dotplot import wilkinson_algorithm
52+
from ...utils import BehaviourChangeWarning, _cov
5253
from ..helpers import ( # pylint: disable=unused-import
54+
RandomVariableTestClass,
5355
create_model,
5456
create_multidimensional_model,
5557
does_not_warn,
5658
eight_schools_params,
5759
models,
5860
multidim_models,
59-
RandomVariableTestClass,
6061
)
6162

6263
rcParams["data.load"] = "eager"
@@ -1236,6 +1237,23 @@ def test_plot_hdi_dataset_error(models):
12361237
plot_hdi(np.arange(8), hdi_data=hdi_data)
12371238

12381239

1240+
def test_plot_hdi_string_error():
1241+
"""Check x as type string raises an error."""
1242+
x_data = ["a", "b", "c", "d"]
1243+
y_data = np.random.normal(0, 5, (1, 200, len(x_data)))
1244+
hdi_data = hdi(y_data)
1245+
with pytest.raises(
1246+
NotImplementedError,
1247+
match=re.escape(
1248+
(
1249+
"The `arviz.plot_hdi()` function does not support categorical data. "
1250+
"Consider using `arviz.plot_forest()`."
1251+
)
1252+
),
1253+
):
1254+
plot_hdi(x=x_data, y=y_data, hdi_data=hdi_data)
1255+
1256+
12391257
def test_plot_hdi_datetime_error():
12401258
"""Check x as datetime raises an error."""
12411259
x_data = np.arange(start="2022-01-01", stop="2022-03-01", dtype=np.datetime64)

0 commit comments

Comments
 (0)