-
-
Notifications
You must be signed in to change notification settings - Fork 11
Adds Pair plot #287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Adds Pair plot #287
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #287 +/- ##
==========================================
+ Coverage 87.82% 88.02% +0.19%
==========================================
Files 44 45 +1
Lines 5102 5303 +201
==========================================
+ Hits 4481 4668 +187
- Misses 621 635 +14 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@aloctavodia @OriolAbril please review this. I will add documentation and tests once we are done with final version of function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
9c7ab36
to
4361a25
Compare
src/arviz_plots/plots/pair_plot.py
Outdated
sample_dims=None, | ||
plot_matrix=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we want that as arguments or if we want multiple plot_pair
variants (or maybe a bit of both). Here are some things that could be arguments specific to the plot:
- which triangle to populate plots into. Legacy arviz defaults to lower triangle only. We can change the default but we should definitely allow control over that.
- Adding marginals to the diagonal. Similar to legacy version (I think it was also the default). This one would also need combining with x/y labels on the bottom/leftmost plots
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To keep consistency with the other functions, marginal
, marginal_kind
and triangle
should go between sample_dims
and plot_matrix
(as it is the equivalent to plot_collection
: https://arviz-plots.readthedocs.io/en/latest/contributing/new_plot.html
pairs = tuple( | ||
xarray_sel_iter( | ||
distribution, skip_dims={dim for dim in distribution.dims if dim in sample_dims} | ||
) | ||
) | ||
n_pairs = len(pairs) | ||
pc_kwargs = set_grid_layout( | ||
pc_kwargs, plot_bknd, distribution, num_rows=n_pairs, num_cols=n_pairs | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes me more convinced we should do #284
98c176c
to
ca67395
Compare
I have incorporated all the changes suggested above MAJOR CHANGES
We are getting warning in Bokeh whenever we use upper or lower triangle, it says renderer wasn't found for some plotsExamples: Code 1from arviz_base import load_arviz_data
import arviz_plots as azp
azp.style.use("arviz-variat")
data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
data,
backend="bokeh" ,
var_names=["mu","tau","theta"],
coords={"school": ["Choate", "Deerfield"]},
marginal=True,
marginal_kind="kde",
triangle="both",
visuals={"divergence": True},
)
pc.show() result 1code 2from arviz_base import load_arviz_data
import arviz_plots as azp
azp.style.use("arviz-variat")
data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
data,
backend="bokeh" ,
var_names=["mu","tau","theta"],
coords={"school": ["Choate", "Deerfield"]},
marginal=True,
marginal_kind="hist",
triangle="both",
visuals={"divergence": True},
)
pc.show() result 2code 3from arviz_base import load_arviz_data
import arviz_plots as azp
azp.style.use("arviz-variat")
data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
data,
backend="bokeh" ,
var_names=["mu","tau","theta"],
coords={"school": ["Choate", "Deerfield"]},
marginal=True,
marginal_kind="ecdf",
triangle="both",
visuals={"divergence": True},
)
pc.show() result 3code 4from arviz_base import load_arviz_data
import arviz_plots as azp
azp.style.use("arviz-variat")
data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
data,
backend="bokeh" ,
var_names=["mu","tau","theta"],
coords={"school": ["Choate", "Deerfield"]},
marginal=True,
marginal_kind="kde",
triangle="upper",
visuals={"divergence": True},
)
pc.show() code 5from arviz_base import load_arviz_data
import arviz_plots as azp
azp.style.use("arviz-variat")
data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
data,
backend="bokeh" ,
var_names=["mu","tau","theta"],
coords={"school": ["Choate", "Deerfield"]},
marginal=True,
marginal_kind="kde",
triangle="lower",
visuals={"divergence": True},
)
pc.show() result 5code 6from arviz_base import load_arviz_data
import arviz_plots as azp
azp.style.use("arviz-variat")
data = load_arviz_data("centered_eight")
pc = azp.plot_pair(
data,
backend="bokeh" ,
var_names=["mu","tau","theta"],
coords={"school": ["Choate", "Deerfield"]},
marginal=False,
triangle="both",
visuals={"divergence": True, "remove_axis":{"axis": "both"}},
)
pc.show() result 6 |
src/arviz_plots/plots/pair_plot.py
Outdated
pc_kwargs["figure_kwargs"].setdefault("sharey", True) | ||
pc_kwargs["figure_kwargs"].setdefault("sharex", True) | ||
if marginal: | ||
pc_kwargs["figure_kwargs"]["sharey"] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized we don't want True, I complained about the axis limits for tau and then said to use True which is the reason the axis for tau include negative values. We want sharex="col"
and sharey="row"
(when present).
Re the if marginal
case, I think it is OK but I hesitate a bit between having this hardcoded or still as setdefault (that is, only call setdefault for sharey when "not marginal"). It is a design question about whether we want to allow users to make terrible dataviz choices when they change defaults vs enforcing all kwargs the users provide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This required a little bit logic change in bokeh backend's create_plotting_grid function. Before the logic was such that we couldn't have sharex="col" and sharey="row" simultaneously. Since they were mutually exclusive so it was throwing error . Fixed that in backend.
Also since having just text on diagonal plots when marginal is False was little bit weird because we required to give x and y location of text which varied for different variables ( in case if we want them to be at centre). So now I added scatter plot for diagonal plots as well when marginal is False. Otherwise we will have marginals
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This required a little bit logic change in bokeh backend's create_plotting_grid function
Thanks! Bokeh has always been the one to generate more trouble for us so it ends up being less used and worked on. The refactor is making things much better on this end but still not fully there.
Also since having just text on diagonal plots when marginal is False was little bit weird because we required to give x and y location of text which varied for different variables ( in case if we want them to be at centre). So now I added scatter plot for diagonal plots as well when marginal is False. Otherwise we will have marginals
I don't think there is any case where we want the scatter against itself. It is ok as a subproduct in plot_pair_focus
when removing it grealy increases complexity of the code, but I would not add it explicitly.
I see two options on this. One is doing some updates to the initialization of PlotMatrix in order to request a grid with one less column but same number of rows, and storing it like so in the "plot" dataarray (which would still have the same shape and coordinates) with None
on the diagonal. Example for 3 variables, we generate a 3x2 grid:
None | 0, 0 | 0, 1
1, 0 | None | 1, 1
2, 0 | 2, 1 | None
This means there are no diagonal elements to be filled at all.
Alternatively, adapt the code for the labels along the diagonal. The backend text function takes vertical and horizontal alignment which we can set to center the text around the provided x
and y
. And as that would be for the case without marginals with axis sharing, the x and y would actually be the same value, and we can compute it as:
text_center = (distribution.max(dim=sample_dims) + distribution.min(dim=sample_dims)) / 2
which should get something that looks centered
src/arviz_plots/plots/pair_plot.py
Outdated
sample_dims=None, | ||
plot_matrix=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To keep consistency with the other functions, marginal
, marginal_kind
and triangle
should go between sample_dims
and plot_matrix
(as it is the equivalent to plot_collection
: https://arviz-plots.readthedocs.io/en/latest/contributing/new_plot.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
next round of review. I think it is shaping up very well
src/arviz_plots/plots/pair_plot.py
Outdated
marginal_kwargs = copy(visuals.get("marginal", {})) | ||
dist_plot_visuals = {} | ||
dist_plot_aes_by_visuals = {} | ||
dist_plot_stats = {} | ||
dist_plot_visuals["dist"] = marginal_kwargs | ||
dist_plot_visuals["title"] = False | ||
dist_plot_visuals["point_estimate"] = False | ||
dist_plot_visuals["point_estimate_text"] = False | ||
dist_plot_visuals["credible_interval"] = False | ||
if remove_axis_bool is False: | ||
dist_plot_visuals["remove_axis"] = False | ||
dist_plot_visuals["rug"] = False | ||
dist_plot_aes_by_visuals["dist"] = aes_by_visuals.get("marginal", {}) | ||
dist_plot_stats["dist"] = stats.get("marginal", {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should allow users to customize plot_dist
on the diagonals to add the credible interval too or the point estimate.
src/arviz_plots/plots/pair_plot.py
Outdated
|
||
# scatter | ||
|
||
aes_by_visuals["scatter"] = {"overlay"}.union(aes_by_visuals.get("scatter", {})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend defining all aes_by_visuals defaults in a block here, right after defining plot_collection. This makes it much easier to know what mappings are active by default in general. We need to properly think about these defaults too.
For scatter
for example I think we want all defined aesthetic mappings by default, with the provided ones+overlay otherwise. That is. like trace
in plot_trace_dist
(not like divergence
in plot_trace_dist
). We also want the default to disable overlay
for the marginal
element.
src/arviz_plots/plot_matrix.py
Outdated
else: | ||
var_name_x = var_name | ||
selection_x = selection |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the behaviour of subset_matrix_da
I think this means that the simple fact of var_name_y
and selection_y
being None will already behave like "diagonal" even if the orientation
attribute is not None. I think it is actually fine and we only need to check orientation row
or col
, with no orientation meaning "diagonal" which I think is a good default.
src/arviz_plots/plot_matrix.py
Outdated
@@ -355,6 +418,9 @@ def map_triangle( | |||
artist_dims=artist_dims, | |||
ignore_aes=ignore_aes, | |||
) | |||
|
|||
self.orientation = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would actually raise
if orientation
is set here (see other comments too).
src/arviz_plots/plot_matrix.py
Outdated
@property | ||
def orientation(self): | ||
"""Orientation of targets we want to fetch through :meth:`~PlotMatrix.get_target`. | ||
|
||
Orientation is always taken into account when using | ||
:meth:`~PlotMatrix.get_target` to fetch target for plotting. | ||
""" | ||
return self._orientation | ||
|
||
@orientation.setter | ||
def orientation(self, value): | ||
self._orientation = value | ||
|
||
@property | ||
def fixed_var_name(self): | ||
"""Fixed variable's name of targets we fetch through :meth:`~PlotMatrix.get_target`. | ||
|
||
fixed_var_name is taken into account when using :meth:`~PlotMatrix.get_target` | ||
to fetch target for plotting and :attr:`~PlotMatrix.orientation` is | ||
set to either `row` or `col`. | ||
|
||
It tells the fixed variable name in a `row` or `col` | ||
""" | ||
return self._fixed_var_name | ||
|
||
@fixed_var_name.setter | ||
def fixed_var_name(self, value): | ||
self._fixed_var_name = value | ||
|
||
@property | ||
def fixed_selection(self): | ||
"""Fixed dictionary for subsetting of targets.""" | ||
return self._fixed_selection | ||
|
||
@fixed_selection.setter | ||
def fixed_selection(self, value): | ||
self._fixed_selection = value | ||
|
||
@fixed_var_name.setter | ||
def fixed_var_name(self, value): | ||
self._fixed_var_name = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could have them as "simple" attributes set to None
in __init__
. Moreover, I don't think we expect/want users to interact with them directly but more so as an internal thing to reduce duplication between map
and map_row_col
src/arviz_plots/plot_matrix.py
Outdated
@@ -473,11 +632,11 @@ def map( | |||
-------- | |||
arviz_plots.PlotMatrix.map_triangle | |||
""" | |||
self.set_fixed_var_attributes(index, orientation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here also revert the 3 attributes to None
before returning.
src/arviz_plots/plots/pair_plot.py
Outdated
coords=None, | ||
sample_dims=None, | ||
marginal=True, | ||
marginal_kind="kde", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default to rcparam or even none and pass to plot_dist so plot_dist changes none to the rcparam
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly test and doc polishing left and we can merge
marginal=True, | ||
marginal_kind="kde", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would remove these two arguments given they are the defaults
dt = load_arviz_data("centered_eight") | ||
pc = azp.plot_pair( | ||
dt, | ||
var_names=["theta", "mu"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
var_names=["theta", "mu"], | |
var_names=["theta", "tau"], |
Tau is the interesting one to vizualize with divergences
marginal=True, | ||
marginal_kind="kde", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe change this so the plot looks more different than the one in the distributions section
marginal=True, | |
marginal_kind="kde", | |
marginal=False, |
@@ -317,31 +317,33 @@ def create_plotting_grid( | |||
subplot_kws_i["y_range"] = shared_yrange[col] | |||
if width_ratios is not None: | |||
subplot_kws["width"] = plot_widths[col] | |||
|
|||
if row * cols + (col + 1) > number: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should add tests for this, but we don't yet have the setup for backend specific tests. I will ping you once I start working on this issue so we can both add tests for backend functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
src/arviz_plots/plot_matrix.py
Outdated
Dictionary of {coordinate names : coordinate values} that should | ||
be used to subset the aes, data and viz objects before any faceting | ||
or aesthetics mapping is applied. | ||
ignore_aes : set, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ignore_aes : set, optional | |
ignore_aes : str or set of str, default "all" |
src/arviz_plots/plot_matrix.py
Outdated
artist_dims : mapping of {hashable : int}, optional | ||
Dictionary of sizes for proper allocation and storage when using | ||
``map`` with functions that return an array of :term:`visual`. | ||
**kwargs : mapping, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**kwargs : mapping, optional | |
**kwargs |
|
||
See Also | ||
-------- | ||
arviz_plots.PlotMatrix.map_triangle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all 4: map, map_triangle, map_row and map_col should reference the other three in their see also sections, with the first one being the most relevant one, map_row<->map_col and map<->map_triangle.
@@ -332,6 +333,59 @@ def test_plot_mcse_models(self, datatree, datatree2, backend): | |||
assert "model" in pc.aes["color"].dims | |||
assert "/x" in pc.aes.groups | |||
|
|||
def test_plot_pair(self, datatree, backend): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
either here or in follow up tests we should make sure that triangle
argument is taken into account. That is, with triangle both scatter
only has None
on the diagonal, with lower on the diagonal and above... Same thing for marginal=False
@@ -0,0 +1,27 @@ | |||
""" | |||
# Scatterplot all variable against each other |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Scatterplot all variable against each other | |
# Scatterplot all variables against each other |
Adds Pair Plot
Key Change required in PlotMatrix:
When we send mask through
map_triangle
( for divergence ), then it doesn't perform subsetting onmask
data. Which causes dimension conflict. It was directly adding**kwargs
tofun_kwargs
. Now it performs subsetting before adding few keys ( dataArrays or datasets ) ofkwargs
intofun_kwargs
.for the following code:
we get the following plot:
📚 Documentation preview 📚: https://arviz-plots--287.org.readthedocs.build/en/287/