Skip to content

Add Index.validate_dataarray_coord #10137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0707a8b
typing fixes and tweaks
benbovy Mar 17, 2025
75086ef
add Index.validate_dataarray_coord()
benbovy Mar 17, 2025
8aaf2b8
Dataset._construct_dataarray: validate index coord
benbovy Mar 17, 2025
c9b4baa
DataArray init: validate index coord
benbovy Mar 17, 2025
a47523f
clean-up old TODO
benbovy Mar 17, 2025
551808a
refactor dataarray coord update
benbovy Mar 17, 2025
818b7f5
docstring tweaks
benbovy Mar 17, 2025
e8df9b5
add tests
benbovy Mar 13, 2025
678c013
assert invariants: skip check IndexVariable ...
benbovy Mar 14, 2025
0f822b5
update cherry-picked tests
benbovy Mar 17, 2025
43c44ea
update assert datarray invariants
benbovy Mar 17, 2025
3b33263
doc: add Index.validate_dataarray_coords to API
benbovy Mar 17, 2025
a8e6e20
typo
benbovy Mar 17, 2025
f1440c4
update whats new
benbovy Mar 17, 2025
5da014e
add CoordinateValidationError
benbovy Mar 18, 2025
6026656
docstrings tweaks
benbovy Mar 18, 2025
1eeec9c
nit refactor
benbovy Mar 18, 2025
426ddce
small refactor
benbovy Mar 18, 2025
5c0cc0f
Merge branch 'main' into index-validate-dataarray-coords
benbovy Mar 27, 2025
4399036
docstrings improvements
benbovy Mar 31, 2025
828a4cc
docstrings improvements
benbovy Mar 31, 2025
273d70c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2025
f49c83a
Merge branch 'main' into index-validate-dataarray-coords
benbovy Apr 23, 2025
3e55af0
refactor index check method
benbovy Apr 24, 2025
073c0a2
small refactor
benbovy Apr 24, 2025
8d43dcc
forgot updating API docs and whats new
benbovy Apr 24, 2025
4e7c70a
nit docstrings
benbovy Apr 24, 2025
b0f6782
Merge branch 'main' into index-validate-dataarray-coords
dcherian Apr 26, 2025
bf557f8
Merge branch 'main' into index-validate-dataarray-coords
dcherian Apr 26, 2025
df828b8
Merge branch 'main' into index-validate-dataarray-coords
dcherian Apr 28, 2025
fa574bc
rename method to Index.should_add_coord_to_array
benbovy May 5, 2025
524b7dc
Merge branch 'main' into index-validate-dataarray-coords
benbovy May 5, 2025
15e4159
review suggestion
benbovy May 6, 2025
67bf943
review suggestion 2
benbovy May 6, 2025
d3cbb3a
more docstrings tweaks
benbovy May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api-hidden.rst
Original file line number Diff line number Diff line change
@@ -520,6 +520,7 @@
Index.stack
Index.unstack
Index.create_variables
Index.should_add_coord_to_array
Index.to_pandas_index
Index.isel
Index.sel
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
@@ -1645,6 +1645,7 @@ Exceptions
:toctree: generated/

AlignmentError
CoordinateValidationError
MergeError
SerializationWarning

4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
@@ -34,6 +34,10 @@ New Features
- Improved compatibility with OPeNDAP DAP4 data model for backend engine ``pydap``. This
includes ``datatree`` support, and removing slashes from dimension names. By
`Miguel Jimenez-Urias <https://github.com/Mikejmnez>`_.
- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray` by overriding
:py:meth:`Index.should_add_coord_to_array`. For example, this enables support for CF boundaries coordinate (e.g.,
``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10137`).
By `Benoit Bovy <https://github.com/benbovy>`_.
- Improved support pandas categorical extension as indices (i.e., :py:class:`pandas.IntervalIndex`). (:issue:`9661`, :pull:`9671`)
By `Ilan Gold <https://github.com/ilan-gold>`_.
- Improved checks and errors raised when trying to align objects with conflicting indexes.
3 changes: 2 additions & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@
)
from xarray.conventions import SerializationWarning, decode_cf
from xarray.core.common import ALL_DIMS, full_like, ones_like, zeros_like
from xarray.core.coordinates import Coordinates
from xarray.core.coordinates import Coordinates, CoordinateValidationError
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
@@ -129,6 +129,7 @@
"Variable",
# Exceptions
"AlignmentError",
"CoordinateValidationError",
"InvalidTreeError",
"MergeError",
"NotFoundInTreeError",
79 changes: 60 additions & 19 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
@@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool:
return self.to_dataset().identical(other.to_dataset())

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
# redirect to DatasetCoordinates._update_coords
self._data.coords._update_coords(coords, indexes)
@@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset:
return self._data._copy_listed(names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
variables = self._data._variables.copy()
variables.update(coords)
@@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset:
return self._data.dataset._copy_listed(self._names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
from xarray.core.datatree import check_alignment

@@ -964,22 +964,14 @@ def __getitem__(self, key: Hashable) -> T_DataArray:
return self._data._getitem_coord(key)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
coords_plus_data = coords.copy()
coords_plus_data[_THIS_ARRAY] = self._data.variable
dims = calculate_dimensions(coords_plus_data)
if not set(dims) <= set(self.dims):
raise ValueError(
"cannot add coordinates with new dimensions to a DataArray"
)
self._data._coords = coords
validate_dataarray_coords(
self._data.shape, Coordinates._construct_direct(coords, indexes), self.dims
)

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = dict(self._data.xindexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes
self._data._coords = coords
self._data._indexes = indexes

def _drop_coords(self, coord_names):
# should drop indexed coordinates only
@@ -1154,9 +1146,58 @@ def create_coords_with_default_indexes(
return new_coords


def _coordinates_from_variable(variable: Variable) -> Coordinates:
from xarray.core.indexes import create_default_index_implicit
class CoordinateValidationError(ValueError):
"""Error class for Xarray coordinate validation failures."""


def validate_dataarray_coords(
shape: tuple[int, ...],
coords: Coordinates | Mapping[Hashable, Variable],
dim: tuple[Hashable, ...],
):
"""Validate coordinates ``coords`` to include in a DataArray defined by
``shape`` and dimensions ``dim``.

If a coordinate is associated with an index, the validation is performed by
the index. By default the coordinate dimensions must match (a subset of) the
array dimensions (in any order) to conform to the DataArray model. The index
may override this behavior with other validation rules, though.

Non-index coordinates must all conform to the DataArray model. Scalar
coordinates are always valid.
"""
sizes = dict(zip(dim, shape, strict=True))
dim_set = set(dim)

indexes: Mapping[Hashable, Index]
if isinstance(coords, Coordinates):
indexes = coords.xindexes
else:
indexes = {}

for k, v in coords.items():
if k in indexes:
invalid = not indexes[k].should_add_coord_to_array(k, v, dim_set)
else:
invalid = any(d not in dim for d in v.dims)

if invalid:
raise CoordinateValidationError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dim}"
)

for d, s in v.sizes.items():
if d in sizes and s != sizes[d]:
raise CoordinateValidationError(
f"conflicting sizes for dimension {d!r}: "
f"length {sizes[d]} on the data but length {s} on "
f"coordinate {k!r}"
)


def coordinates_from_variable(variable: Variable) -> Coordinates:
(name,) = variable.dims
new_index, index_vars = create_default_index_implicit(variable)
indexes = dict.fromkeys(index_vars, new_index)
22 changes: 2 additions & 20 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@
DataArrayCoordinates,
assert_coordinate_consistent,
create_coords_with_default_indexes,
validate_dataarray_coords,
)
from xarray.core.dataset import Dataset
from xarray.core.extension_array import PandasExtensionArray
@@ -124,25 +125,6 @@
T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset)


def _check_coords_dims(shape, coords, dim):
sizes = dict(zip(dim, shape, strict=True))
for k, v in coords.items():
if any(d not in dim for d in v.dims):
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dim}"
)

for d, s in v.sizes.items():
if s != sizes[d]:
raise ValueError(
f"conflicting sizes for dimension {d!r}: "
f"length {sizes[d]} on the data but length {s} on "
f"coordinate {k!r}"
)


def _infer_coords_and_dims(
shape: tuple[int, ...],
coords: (
@@ -206,7 +188,7 @@ def _infer_coords_and_dims(
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

_check_coords_dims(shape, new_coords, dims_tuple)
validate_dataarray_coords(shape, new_coords, dims_tuple)

return new_coords, dims_tuple

10 changes: 9 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1159,7 +1159,15 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
coords: dict[Hashable, Variable] = {}
# preserve ordering
for k in self._variables:
if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
if k in self._indexes:
add_coord = self._indexes[k].should_add_coord_to_array(
k, self._variables[k], needed_dims
)
else:
var_dims = set(self._variables[k].dims)
add_coord = k in self._coord_names and var_dims <= needed_dims

if add_coord:
coords[k] = self._variables[k]

indexes = filter_indexes_from_coords(self._indexes, set(coords))
4 changes: 2 additions & 2 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
DatasetGroupByAggregations,
)
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.coordinates import Coordinates, coordinates_from_variable
from xarray.core.duck_array_ops import where
from xarray.core.formatting import format_array_flat
from xarray.core.indexes import (
@@ -1147,7 +1147,7 @@ def _flox_reduce(
new_coords.append(
# Using IndexVariable here ensures we reconstruct PandasMultiIndex with
# all associated levels properly.
_coordinates_from_variable(
coordinates_from_variable(
IndexVariable(
dims=grouper.name,
data=output_index,
43 changes: 43 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
@@ -196,6 +196,49 @@ def create_variables(
else:
return {}

def should_add_coord_to_array(
self,
name: Hashable,
var: Variable,
dims: set[Hashable],
) -> bool:
"""Define whether or not an index coordinate variable should be added to
a new DataArray.

This method is called repeatedly for each Variable associated with this
index when creating a new DataArray (via its constructor or from a
Dataset) or updating an existing one. The variables associated with this
index are the ones passed to :py:meth:`Index.from_variables` and/or
returned by :py:meth:`Index.create_variables`.

By default returns ``True`` if the dimensions of the coordinate variable
are a subset of the array dimensions and ``False`` otherwise (DataArray
model). This default behavior may be overridden in Index subclasses to
bypass strict conformance with the DataArray model. This is useful for
example to include the (n+1)-dimensional cell boundary coordinate
associated with an interval index.

Returning ``False`` will either:

- raise a :py:class:`CoordinateValidationError` when passing the
coordinate directly to a new or an existing DataArray, e.g., via
``DataArray.__init__()`` or ``DataArray.assign_coords()``

- drop the coordinate (and therefore drop the index) when a new
DataArray is constructed by indexing a Dataset

Parameters
----------
name : Hashable
Name of a coordinate variable associated to this index.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: could drop this, assume var.name will be set.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think we were discussing at some point about getting rid of IndexVariable.name? Also Variable has no name property but it is also the type of some variables associated with custom indexes (IndexVariable is still tightly related to a 1D variable wrapping a pandas.Index).

var : Variable
Coordinate variable object.
dims: tuple
Dimensions of the new DataArray object being created.

"""
return all(d in dims for d in var.dims)

def to_pandas_index(self) -> pd.Index:
"""Cast this xarray index to a pandas.Index object or raise a
``TypeError`` if this is not supported.
12 changes: 6 additions & 6 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@

from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
from xarray.computation.apply_ufunc import apply_ufunc
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.coordinates import Coordinates, coordinates_from_variable
from xarray.core.dataarray import DataArray
from xarray.core.duck_array_ops import array_all, isnull
from xarray.core.groupby import T_Group, _DummyGroup
@@ -115,7 +115,7 @@ def __init__(

if coords is None:
assert not isinstance(self.unique_coord, _DummyGroup)
self.coords = _coordinates_from_variable(self.unique_coord)
self.coords = coordinates_from_variable(self.unique_coord)
else:
self.coords = coords

@@ -252,7 +252,7 @@ def _factorize_unique(self) -> EncodedGroups:
codes=codes,
full_index=full_index,
unique_coord=unique_coord,
coords=_coordinates_from_variable(unique_coord),
coords=coordinates_from_variable(unique_coord),
)

def _factorize_dummy(self) -> EncodedGroups:
@@ -280,7 +280,7 @@ def _factorize_dummy(self) -> EncodedGroups:
else:
if TYPE_CHECKING:
assert isinstance(unique_coord, Variable)
coords = _coordinates_from_variable(unique_coord)
coords = coordinates_from_variable(unique_coord)

return EncodedGroups(
codes=codes,
@@ -417,7 +417,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
codes=codes,
full_index=full_index,
unique_coord=unique_coord,
coords=_coordinates_from_variable(unique_coord),
coords=coordinates_from_variable(unique_coord),
)


@@ -551,7 +551,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
group_indices=group_indices,
full_index=full_index,
unique_coord=unique_coord,
coords=_coordinates_from_variable(unique_coord),
coords=coordinates_from_variable(unique_coord),
)


8 changes: 4 additions & 4 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
@@ -401,12 +401,12 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool):

assert isinstance(da._coords, dict), da._coords
assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
da.dims,
{k: v.dims for k, v in da._coords.items()},
)

if check_default_indexes:
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
da.dims,
{k: v.dims for k, v in da._coords.items()},
)
assert all(
isinstance(v, IndexVariable)
for (k, v) in da._coords.items()
Loading
Loading