Skip to content

Adds Pair plot #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

The-Broken-Keyboard
Copy link
Contributor

@The-Broken-Keyboard The-Broken-Keyboard commented Jun 13, 2025

Adds Pair Plot

Key Change required in PlotMatrix:
When we send mask through map_triangle ( for divergence ), then it doesn't perform subsetting on mask data. Which causes dimension conflict. It was directly adding **kwargs to fun_kwargs. Now it performs subsetting before adding few keys ( dataArrays or datasets ) of kwargs into fun_kwargs.

for the following code:

from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

data = load_arviz_data("centered_eight")
visuals = {"divergence": True}
pc = azp.plot_pair(
    data,
    backend="bokeh" , 
    var_names=["mu","tau", "theta"],
    coords={"school": ["Choate", "Deerfield"]},
    visuals=visuals,
)
pc.show()

we get the following plot:

image


📚 Documentation preview 📚: https://arviz-plots--287.org.readthedocs.build/en/287/

@The-Broken-Keyboard The-Broken-Keyboard marked this pull request as draft June 13, 2025 10:45
@codecov-commenter
Copy link

codecov-commenter commented Jun 13, 2025

Codecov Report

Attention: Patch coverage is 90.94828% with 21 lines in your changes missing coverage. Please review.

Project coverage is 88.02%. Comparing base (734c61b) to head (127c32b).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/arviz_plots/plots/pair_plot.py 88.73% 16 Missing ⚠️
src/arviz_plots/visuals/__init__.py 90.00% 2 Missing ⚠️
src/arviz_plots/backend/bokeh/__init__.py 90.90% 1 Missing ⚠️
src/arviz_plots/backend/none/__init__.py 83.33% 1 Missing ⚠️
src/arviz_plots/plot_matrix.py 97.14% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #287      +/-   ##
==========================================
+ Coverage   87.82%   88.02%   +0.19%     
==========================================
  Files          44       45       +1     
  Lines        5102     5303     +201     
==========================================
+ Hits         4481     4668     +187     
- Misses        621      635      +14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@The-Broken-Keyboard
Copy link
Contributor Author

@aloctavodia @OriolAbril please review this. I will add documentation and tests once we are done with final version of function

Copy link
Contributor

@aloctavodia aloctavodia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 25 to 37
sample_dims=None,
plot_matrix=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we want that as arguments or if we want multiple plot_pair variants (or maybe a bit of both). Here are some things that could be arguments specific to the plot:

  • which triangle to populate plots into. Legacy arviz defaults to lower triangle only. We can change the default but we should definitely allow control over that.
  • Adding marginals to the diagonal. Similar to legacy version (I think it was also the default). This one would also need combining with x/y labels on the bottom/leftmost plots

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep consistency with the other functions, marginal, marginal_kind and triangle should go between sample_dims and plot_matrix (as it is the equivalent to plot_collection: https://arviz-plots.readthedocs.io/en/latest/contributing/new_plot.html

Comment on lines +130 to +175
pairs = tuple(
xarray_sel_iter(
distribution, skip_dims={dim for dim in distribution.dims if dim in sample_dims}
)
)
n_pairs = len(pairs)
pc_kwargs = set_grid_layout(
pc_kwargs, plot_bknd, distribution, num_rows=n_pairs, num_cols=n_pairs
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me more convinced we should do #284

@The-Broken-Keyboard
Copy link
Contributor Author

The-Broken-Keyboard commented Jun 20, 2025

@OriolAbril @aloctavodia

I have incorporated all the changes suggested above

MAJOR CHANGES

  • Added a map_row_col method which takes a variable, selection dict and orientation (row or col) and then maps the given function to whole row or col of matrix. Since we have this method in PlotMatrix, so we can also include similar kind of method in PlotCollection and the work around implemented in fixes shared x property for appropriate plots ( like rank plot, trace plot etc ) #208 will not be needed now, because we can use this function in plotcollection to handle conflicting xlabels issue too. So, i removed the work around of fixes shared x property for appropriate plots ( like rank plot, trace plot etc ) #208 for now ( to incorporate this new function ).
  • Added parameters triangle and remove_axis ( in visuals ) .
  • Added marginals ( supports kind="kde","hist" and "ecdf") on diagonal.
  • Also added diagonal xlabels if triangle="upper" is given
  • Diagonal labels will be used automatically if marginal=False
  • x labels for bottom plots and y labels for left plots will be used only if marginal=True, because otherwise diagonal labels will be there to label the rows and cols(automatically, as mentioned above).
  • **Enforcing remove_axis=False if marginal=True. Because if we have marginals we definitely need xlabels and ylabels to label rows and cols ( since diagonal labels won't be there and removing axis will remove axis labels too).
  • sharex and sharey is set to True. If marginal=True then we can't have sharey=True, so that is forced to be False.
  • added a new remove_matrix_axis function, which is same as remove_axis in visuals, but since remove_axis wasn't compatible with map_triangle of PlotMatrix, so I needed one which is compatible to map_triangle.

We are getting warning in Bokeh whenever we use upper or lower triangle, it says renderer wasn't found for some plots

Examples:

Code 1

from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
    data,
    backend="bokeh" , 
    var_names=["mu","tau","theta"],
    coords={"school": ["Choate", "Deerfield"]},
    marginal=True,
    marginal_kind="kde",
    triangle="both",
    visuals={"divergence": True},
)
pc.show()

result 1

image

code 2

from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
    data,
    backend="bokeh" , 
    var_names=["mu","tau","theta"],
    coords={"school": ["Choate", "Deerfield"]},
    marginal=True,
    marginal_kind="hist",
    triangle="both",
    visuals={"divergence": True},
)
pc.show()

result 2

image

code 3

from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
    data,
    backend="bokeh" , 
    var_names=["mu","tau","theta"],
    coords={"school": ["Choate", "Deerfield"]},
    marginal=True,
    marginal_kind="ecdf",
    triangle="both",
    visuals={"divergence": True},
)
pc.show()

result 3

image

code 4

from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
    data,
    backend="bokeh" , 
    var_names=["mu","tau","theta"],
    coords={"school": ["Choate", "Deerfield"]},
    marginal=True,
    marginal_kind="kde",
    triangle="upper",
    visuals={"divergence": True},
)
pc.show()

image

code 5

from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
    data,
    backend="bokeh" , 
    var_names=["mu","tau","theta"],
    coords={"school": ["Choate", "Deerfield"]},
    marginal=True,
    marginal_kind="kde",
    triangle="lower",
    visuals={"divergence": True},
)
pc.show()

result 5

image

code 6

from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
    data,
    backend="bokeh" , 
    var_names=["mu","tau","theta"],
    coords={"school": ["Choate", "Deerfield"]},
    marginal=False,
    triangle="both",
    visuals={"divergence": True, "remove_axis":{"axis": "both"}},
)
pc.show()

result 6

image

Comment on lines 178 to 181
pc_kwargs["figure_kwargs"].setdefault("sharey", True)
pc_kwargs["figure_kwargs"].setdefault("sharex", True)
if marginal:
pc_kwargs["figure_kwargs"]["sharey"] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized we don't want True, I complained about the axis limits for tau and then said to use True which is the reason the axis for tau include negative values. We want sharex="col" and sharey="row" (when present).

Re the if marginal case, I think it is OK but I hesitate a bit between having this hardcoded or still as setdefault (that is, only call setdefault for sharey when "not marginal"). It is a design question about whether we want to allow users to make terrible dataviz choices when they change defaults vs enforcing all kwargs the users provide.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This required a little bit logic change in bokeh backend's create_plotting_grid function. Before the logic was such that we couldn't have sharex="col" and sharey="row" simultaneously. Since they were mutually exclusive so it was throwing error . Fixed that in backend.

Also since having just text on diagonal plots when marginal is False was little bit weird because we required to give x and y location of text which varied for different variables ( in case if we want them to be at centre). So now I added scatter plot for diagonal plots as well when marginal is False. Otherwise we will have marginals

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This required a little bit logic change in bokeh backend's create_plotting_grid function

Thanks! Bokeh has always been the one to generate more trouble for us so it ends up being less used and worked on. The refactor is making things much better on this end but still not fully there.

Also since having just text on diagonal plots when marginal is False was little bit weird because we required to give x and y location of text which varied for different variables ( in case if we want them to be at centre). So now I added scatter plot for diagonal plots as well when marginal is False. Otherwise we will have marginals

I don't think there is any case where we want the scatter against itself. It is ok as a subproduct in plot_pair_focus when removing it grealy increases complexity of the code, but I would not add it explicitly.

I see two options on this. One is doing some updates to the initialization of PlotMatrix in order to request a grid with one less column but same number of rows, and storing it like so in the "plot" dataarray (which would still have the same shape and coordinates) with None on the diagonal. Example for 3 variables, we generate a 3x2 grid:

None | 0, 0 | 0, 1
1, 0 | None | 1, 1
2, 0 | 2, 1 | None

This means there are no diagonal elements to be filled at all.

Alternatively, adapt the code for the labels along the diagonal. The backend text function takes vertical and horizontal alignment which we can set to center the text around the provided x and y. And as that would be for the case without marginals with axis sharing, the x and y would actually be the same value, and we can compute it as:

text_center = (distribution.max(dim=sample_dims) + distribution.min(dim=sample_dims)) / 2

which should get something that looks centered

Comment on lines 25 to 37
sample_dims=None,
plot_matrix=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep consistency with the other functions, marginal, marginal_kind and triangle should go between sample_dims and plot_matrix (as it is the equivalent to plot_collection: https://arviz-plots.readthedocs.io/en/latest/contributing/new_plot.html

@The-Broken-Keyboard The-Broken-Keyboard marked this pull request as ready for review June 23, 2025 12:17
Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next round of review. I think it is shaping up very well

Comment on lines 238 to 251
marginal_kwargs = copy(visuals.get("marginal", {}))
dist_plot_visuals = {}
dist_plot_aes_by_visuals = {}
dist_plot_stats = {}
dist_plot_visuals["dist"] = marginal_kwargs
dist_plot_visuals["title"] = False
dist_plot_visuals["point_estimate"] = False
dist_plot_visuals["point_estimate_text"] = False
dist_plot_visuals["credible_interval"] = False
if remove_axis_bool is False:
dist_plot_visuals["remove_axis"] = False
dist_plot_visuals["rug"] = False
dist_plot_aes_by_visuals["dist"] = aes_by_visuals.get("marginal", {})
dist_plot_stats["dist"] = stats.get("marginal", {})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should allow users to customize plot_dist on the diagonals to add the credible interval too or the point estimate.


# scatter

aes_by_visuals["scatter"] = {"overlay"}.union(aes_by_visuals.get("scatter", {}))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend defining all aes_by_visuals defaults in a block here, right after defining plot_collection. This makes it much easier to know what mappings are active by default in general. We need to properly think about these defaults too.

For scatter for example I think we want all defined aesthetic mappings by default, with the provided ones+overlay otherwise. That is. like trace in plot_trace_dist (not like divergence in plot_trace_dist). We also want the default to disable overlay for the marginal element.

Comment on lines 236 to 238
else:
var_name_x = var_name
selection_x = selection
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the behaviour of subset_matrix_da I think this means that the simple fact of var_name_y and selection_y being None will already behave like "diagonal" even if the orientation attribute is not None. I think it is actually fine and we only need to check orientation row or col, with no orientation meaning "diagonal" which I think is a good default.

@@ -355,6 +418,9 @@ def map_triangle(
artist_dims=artist_dims,
ignore_aes=ignore_aes,
)

self.orientation = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually raise if orientation is set here (see other comments too).

Comment on lines 163 to 203
@property
def orientation(self):
"""Orientation of targets we want to fetch through :meth:`~PlotMatrix.get_target`.

Orientation is always taken into account when using
:meth:`~PlotMatrix.get_target` to fetch target for plotting.
"""
return self._orientation

@orientation.setter
def orientation(self, value):
self._orientation = value

@property
def fixed_var_name(self):
"""Fixed variable's name of targets we fetch through :meth:`~PlotMatrix.get_target`.

fixed_var_name is taken into account when using :meth:`~PlotMatrix.get_target`
to fetch target for plotting and :attr:`~PlotMatrix.orientation` is
set to either `row` or `col`.

It tells the fixed variable name in a `row` or `col`
"""
return self._fixed_var_name

@fixed_var_name.setter
def fixed_var_name(self, value):
self._fixed_var_name = value

@property
def fixed_selection(self):
"""Fixed dictionary for subsetting of targets."""
return self._fixed_selection

@fixed_selection.setter
def fixed_selection(self, value):
self._fixed_selection = value

@fixed_var_name.setter
def fixed_var_name(self, value):
self._fixed_var_name = value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could have them as "simple" attributes set to None in __init__. Moreover, I don't think we expect/want users to interact with them directly but more so as an internal thing to reduce duplication between map and map_row_col

@@ -473,11 +632,11 @@ def map(
--------
arviz_plots.PlotMatrix.map_triangle
"""
self.set_fixed_var_attributes(index, orientation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here also revert the 3 attributes to None before returning.

coords=None,
sample_dims=None,
marginal=True,
marginal_kind="kde",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to rcparam or even none and pass to plot_dist so plot_dist changes none to the rcparam

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly test and doc polishing left and we can merge

Comment on lines 23 to 24
marginal=True,
marginal_kind="kde",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove these two arguments given they are the defaults

dt = load_arviz_data("centered_eight")
pc = azp.plot_pair(
dt,
var_names=["theta", "mu"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var_names=["theta", "mu"],
var_names=["theta", "tau"],

Tau is the interesting one to vizualize with divergences

Comment on lines 25 to 26
marginal=True,
marginal_kind="kde",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe change this so the plot looks more different than the one in the distributions section

Suggested change
marginal=True,
marginal_kind="kde",
marginal=False,

@@ -317,31 +317,33 @@ def create_plotting_grid(
subplot_kws_i["y_range"] = shared_yrange[col]
if width_ratios is not None:
subplot_kws["width"] = plot_widths[col]

if row * cols + (col + 1) > number:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add tests for this, but we don't yet have the setup for backend specific tests. I will ping you once I start working on this issue so we can both add tests for backend functions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Dictionary of {coordinate names : coordinate values} that should
be used to subset the aes, data and viz objects before any faceting
or aesthetics mapping is applied.
ignore_aes : set, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ignore_aes : set, optional
ignore_aes : str or set of str, default "all"

artist_dims : mapping of {hashable : int}, optional
Dictionary of sizes for proper allocation and storage when using
``map`` with functions that return an array of :term:`visual`.
**kwargs : mapping, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs : mapping, optional
**kwargs


See Also
--------
arviz_plots.PlotMatrix.map_triangle
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all 4: map, map_triangle, map_row and map_col should reference the other three in their see also sections, with the first one being the most relevant one, map_row<->map_col and map<->map_triangle.

@@ -332,6 +333,59 @@ def test_plot_mcse_models(self, datatree, datatree2, backend):
assert "model" in pc.aes["color"].dims
assert "/x" in pc.aes.groups

def test_plot_pair(self, datatree, backend):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either here or in follow up tests we should make sure that triangle argument is taken into account. That is, with triangle both scatter only has None on the diagonal, with lower on the diagonal and above... Same thing for marginal=False

@@ -0,0 +1,27 @@
"""
# Scatterplot all variable against each other
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Scatterplot all variable against each other
# Scatterplot all variables against each other

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants