Skip to content

WIP: Add kwargs to InferenceData.to_netcdf() #2410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

cowirihy
Copy link

@cowirihy cowirihy commented Jan 17, 2025

Relates to #2298, with solution broadly along the lines of that sketched out in the issue.

Added **kwargs to InferenceData.to_netcdf() method, to allow any of the parameters that can be passed to xarray.Dataset.to_netcdf() to get passed through.

E.g. for my usage case I define an encoding={'var_A' : {"dtype": "int16", "scale_factor" : 0.1}} dict, so that var_A samples get stored via 16-bit integers and to 1 decimal place precision, to economise on file size but with an inconsequential loss of precision. Note this would be done for var_A in any group in which it appears, e.g. both posterior and prior groups if present.

I've put in a placeholder for where a new unittest could be added, but am not so confident in defining this. What I envisage, which I've tested via a seperate script my end, is the following:

  • Load data to define an InferenceData instance, reading from netcdf file as I can see other unittests do already
  • Define some customisation e.g. encoding settings for a couple of the RVs in the model to which the data relates
  • Write to a new netcdf file but passing encoding (and/or other params that would alter the behaviour of Dataset.to_netcdf)
  • Read back in from the 2nd file and verify that approximately the same data is recovered but with the expected loss of precision

Help welcome in setting up the latter! It would also be worthwhile verifying via tests that the handling code I've included (populating the kwargs dict based on compress and engine parameters per previous) is working as intended and in a backwards compatible manner; it should! Perhaps existing tests are adequate to prove this though?

Checklist

  • Follows official PR format
  • [n/a] Includes a sample plot to visually illustrate the changes (only for plot-related functions)
  • [n/a] New features are properly documented (with an example if appropriate)?
  • Includes new or updated tests to cover the new feature
  • Code style correct (follows pylint and black guidelines)
  • Changes are listed in changelog

📚 Documentation preview 📚: https://arviz--2410.org.readthedocs.build/en/2410/

@amaloney
Copy link
Member

Thanks @cowirihy I'll try my best to review this week.

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Sorry about the delayed review. I hope it is helpful.

Comment on lines +486 to +488
Other keyword arguments will be passed to `xarray.Dataset.to_netcdf()`. If
provided these will serve to override dict items that relate to `compress` and
`engine` parameters described above.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be the parameter description according to numpydoc, so indented under a **kwargs parameter with no type: https://numpydoc.readthedocs.io/en/latest/format.html#parameters (the last paragraph of this section).

Also, if you use :meth:`xarray.Dataset.to_netcdf` or even `xarray.Dataset.to_netcdf` (given our sphinx configuration) it will be rendered as a link to the respective docs in the xarray website. You can check the rendered docstring preview from your PR at https://arviz--2410.org.readthedocs.build/en/2410/api/generated/arviz.InferenceData.to_netcdf.html

Comment on lines +513 to +516
try:
encoding_kw2 = kwargs["encoding"]
except KeyError:
encoding_kw2 = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we generally use .get for these kind of operations: encoding_kw2 = kwargs.get("encoding", {})

Comment on lines +540 to +545
for var_name, kw1 in encoding_kw1.items():
try:
kw2 = encoding_kw2[var_name]
except KeyError:
kw2 = {}
encoding_kw_merged[var_name] = {**kw1,**kw2}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic here would not work as expected when there are non compressible types. My line of thought/duck debugging:

  1. encoding_kw2 is full and has elements for all variables
  2. encoding_kw_merged is empty
  3. We loop only over the variable names in encoding_kw1 which will only contain compressible variables. Then for each of these variables only:
    • We merge the respective variable specifics encoding_kw1 and encoding_kw2
  4. encoding_kw_merged has the same keys as encoding_kw1 and the merged dicts as values.
    • If there were no compressible variables, encoding_kw_merged would be empty even with encoding_kw2 being full

Potential proposal:

Suggested change
for var_name, kw1 in encoding_kw1.items():
try:
kw2 = encoding_kw2[var_name]
except KeyError:
kw2 = {}
encoding_kw_merged[var_name] = {**kw1,**kw2}
for var_name in data.data_vars:
kw1 = encoding_kw1.get(var_name, {})
kw2 = encoding_kw2.get(var_name, {})
encoding_kw_merged[var_name] = kw1 | kw2

Comment on lines +1490 to +1493
# 1) define an InferenceData object (e.g. from file)
# 2) define different sets of `**kwargs` to pass
# 3) use inference_data.to_netcdf(filepath,**kwargs)
# 4) test these make it through to `data.to_netcdf()` as intended - TODO how?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what you propose is about right. Pseudocode idea:

idata = load...
# store with encoding kwargs that mean small but non-neglibible loss of precision
# and as previous test, check requested filename exists
idata_encoded = load...
for group in idata.groups:
    # use https://docs.xarray.dev/en/stable/generated/xarray.testing.assert_allclose.html#xarray.testing.assert_allclose
    # once as
    with pytest.raises(AssertionError):
        `assert_allclose(... tol=low/default)
    # then again as
    assert_allclose(..., tol=high)
# clean up files

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants