diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index d24e69d6542..9a6037cf3c4 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -520,6 +520,7 @@ Index.stack Index.unstack Index.create_variables + Index.should_add_coord_to_array Index.to_pandas_index Index.isel Index.sel diff --git a/doc/api.rst b/doc/api.rst index 74c0831f26b..be64e3eac3a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1645,6 +1645,7 @@ Exceptions :toctree: generated/ AlignmentError + CoordinateValidationError MergeError SerializationWarning diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 76fb5d42aa9..0f7d957e10a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,10 @@ New Features - Improved compatibility with OPeNDAP DAP4 data model for backend engine ``pydap``. This includes ``datatree`` support, and removing slashes from dimension names. By `Miguel Jimenez-Urias `_. +- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray` by overriding + :py:meth:`Index.should_add_coord_to_array`. For example, this enables support for CF boundaries coordinate (e.g., + ``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10137`). + By `Benoit Bovy `_. - Improved support pandas categorical extension as indices (i.e., :py:class:`pandas.IntervalIndex`). (:issue:`9661`, :pull:`9671`) By `Ilan Gold `_. - Improved checks and errors raised when trying to align objects with conflicting indexes. diff --git a/xarray/__init__.py b/xarray/__init__.py index b08729f7478..d1001b4470a 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -28,7 +28,7 @@ ) from xarray.conventions import SerializationWarning, decode_cf from xarray.core.common import ALL_DIMS, full_like, ones_like, zeros_like -from xarray.core.coordinates import Coordinates +from xarray.core.coordinates import Coordinates, CoordinateValidationError from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree @@ -129,6 +129,7 @@ "Variable", # Exceptions "AlignmentError", + "CoordinateValidationError", "InvalidTreeError", "MergeError", "NotFoundInTreeError", diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 0972b04f1fc..13fe0a791bb 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool: return self.to_dataset().identical(other.to_dataset()) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: # redirect to DatasetCoordinates._update_coords self._data.coords._update_coords(coords, indexes) @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset: return self._data._copy_listed(names) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: variables = self._data._variables.copy() variables.update(coords) @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset: return self._data.dataset._copy_listed(self._names) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: from xarray.core.datatree import check_alignment @@ -964,22 +964,14 @@ def __getitem__(self, key: Hashable) -> T_DataArray: return self._data._getitem_coord(key) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: - coords_plus_data = coords.copy() - coords_plus_data[_THIS_ARRAY] = self._data.variable - dims = calculate_dimensions(coords_plus_data) - if not set(dims) <= set(self.dims): - raise ValueError( - "cannot add coordinates with new dimensions to a DataArray" - ) - self._data._coords = coords + validate_dataarray_coords( + self._data.shape, Coordinates._construct_direct(coords, indexes), self.dims + ) - # TODO(shoyer): once ._indexes is always populated by a dict, modify - # it to update inplace instead. - original_indexes = dict(self._data.xindexes) - original_indexes.update(indexes) - self._data._indexes = original_indexes + self._data._coords = coords + self._data._indexes = indexes def _drop_coords(self, coord_names): # should drop indexed coordinates only @@ -1154,9 +1146,58 @@ def create_coords_with_default_indexes( return new_coords -def _coordinates_from_variable(variable: Variable) -> Coordinates: - from xarray.core.indexes import create_default_index_implicit +class CoordinateValidationError(ValueError): + """Error class for Xarray coordinate validation failures.""" + + +def validate_dataarray_coords( + shape: tuple[int, ...], + coords: Coordinates | Mapping[Hashable, Variable], + dim: tuple[Hashable, ...], +): + """Validate coordinates ``coords`` to include in a DataArray defined by + ``shape`` and dimensions ``dim``. + + If a coordinate is associated with an index, the validation is performed by + the index. By default the coordinate dimensions must match (a subset of) the + array dimensions (in any order) to conform to the DataArray model. The index + may override this behavior with other validation rules, though. + + Non-index coordinates must all conform to the DataArray model. Scalar + coordinates are always valid. + """ + sizes = dict(zip(dim, shape, strict=True)) + dim_set = set(dim) + + indexes: Mapping[Hashable, Index] + if isinstance(coords, Coordinates): + indexes = coords.xindexes + else: + indexes = {} + + for k, v in coords.items(): + if k in indexes: + invalid = not indexes[k].should_add_coord_to_array(k, v, dim_set) + else: + invalid = any(d not in dim for d in v.dims) + + if invalid: + raise CoordinateValidationError( + f"coordinate {k} has dimensions {v.dims}, but these " + "are not a subset of the DataArray " + f"dimensions {dim}" + ) + + for d, s in v.sizes.items(): + if d in sizes and s != sizes[d]: + raise CoordinateValidationError( + f"conflicting sizes for dimension {d!r}: " + f"length {sizes[d]} on the data but length {s} on " + f"coordinate {k!r}" + ) + +def coordinates_from_variable(variable: Variable) -> Coordinates: (name,) = variable.dims new_index, index_vars = create_default_index_implicit(variable) indexes = dict.fromkeys(index_vars, new_index) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f49b8b8cb48..5a578128d1a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -33,6 +33,7 @@ DataArrayCoordinates, assert_coordinate_consistent, create_coords_with_default_indexes, + validate_dataarray_coords, ) from xarray.core.dataset import Dataset from xarray.core.extension_array import PandasExtensionArray @@ -124,25 +125,6 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset) -def _check_coords_dims(shape, coords, dim): - sizes = dict(zip(dim, shape, strict=True)) - for k, v in coords.items(): - if any(d not in dim for d in v.dims): - raise ValueError( - f"coordinate {k} has dimensions {v.dims}, but these " - "are not a subset of the DataArray " - f"dimensions {dim}" - ) - - for d, s in v.sizes.items(): - if s != sizes[d]: - raise ValueError( - f"conflicting sizes for dimension {d!r}: " - f"length {sizes[d]} on the data but length {s} on " - f"coordinate {k!r}" - ) - - def _infer_coords_and_dims( shape: tuple[int, ...], coords: ( @@ -206,7 +188,7 @@ def _infer_coords_and_dims( var.dims = (dim,) new_coords[dim] = var.to_index_variable() - _check_coords_dims(shape, new_coords, dims_tuple) + validate_dataarray_coords(shape, new_coords, dims_tuple) return new_coords, dims_tuple diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4e7cfc8c49b..24cc7ec4008 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1159,7 +1159,15 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: coords: dict[Hashable, Variable] = {} # preserve ordering for k in self._variables: - if k in self._coord_names and set(self._variables[k].dims) <= needed_dims: + if k in self._indexes: + add_coord = self._indexes[k].should_add_coord_to_array( + k, self._variables[k], needed_dims + ) + else: + var_dims = set(self._variables[k].dims) + add_coord = k in self._coord_names and var_dims <= needed_dims + + if add_coord: coords[k] = self._variables[k] indexes = filter_indexes_from_coords(self._indexes, set(coords)) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 0621b170e46..05a19819735 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -23,7 +23,7 @@ DatasetGroupByAggregations, ) from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce -from xarray.core.coordinates import Coordinates, _coordinates_from_variable +from xarray.core.coordinates import Coordinates, coordinates_from_variable from xarray.core.duck_array_ops import where from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( @@ -1147,7 +1147,7 @@ def _flox_reduce( new_coords.append( # Using IndexVariable here ensures we reconstruct PandasMultiIndex with # all associated levels properly. - _coordinates_from_variable( + coordinates_from_variable( IndexVariable( dims=grouper.name, data=output_index, diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index bc934132f1c..8babb885a5e 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -196,6 +196,49 @@ def create_variables( else: return {} + def should_add_coord_to_array( + self, + name: Hashable, + var: Variable, + dims: set[Hashable], + ) -> bool: + """Define whether or not an index coordinate variable should be added to + a new DataArray. + + This method is called repeatedly for each Variable associated with this + index when creating a new DataArray (via its constructor or from a + Dataset) or updating an existing one. The variables associated with this + index are the ones passed to :py:meth:`Index.from_variables` and/or + returned by :py:meth:`Index.create_variables`. + + By default returns ``True`` if the dimensions of the coordinate variable + are a subset of the array dimensions and ``False`` otherwise (DataArray + model). This default behavior may be overridden in Index subclasses to + bypass strict conformance with the DataArray model. This is useful for + example to include the (n+1)-dimensional cell boundary coordinate + associated with an interval index. + + Returning ``False`` will either: + + - raise a :py:class:`CoordinateValidationError` when passing the + coordinate directly to a new or an existing DataArray, e.g., via + ``DataArray.__init__()`` or ``DataArray.assign_coords()`` + + - drop the coordinate (and therefore drop the index) when a new + DataArray is constructed by indexing a Dataset + + Parameters + ---------- + name : Hashable + Name of a coordinate variable associated to this index. + var : Variable + Coordinate variable object. + dims: tuple + Dimensions of the new DataArray object being created. + + """ + return all(d in dims for d in var.dims) + def to_pandas_index(self) -> pd.Index: """Cast this xarray index to a pandas.Index object or raise a ``TypeError`` if this is not supported. diff --git a/xarray/groupers.py b/xarray/groupers.py index 96c1b0f55d1..fb7ff9311da 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -18,7 +18,7 @@ from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.computation.apply_ufunc import apply_ufunc -from xarray.core.coordinates import Coordinates, _coordinates_from_variable +from xarray.core.coordinates import Coordinates, coordinates_from_variable from xarray.core.dataarray import DataArray from xarray.core.duck_array_ops import array_all, isnull from xarray.core.groupby import T_Group, _DummyGroup @@ -115,7 +115,7 @@ def __init__( if coords is None: assert not isinstance(self.unique_coord, _DummyGroup) - self.coords = _coordinates_from_variable(self.unique_coord) + self.coords = coordinates_from_variable(self.unique_coord) else: self.coords = coords @@ -252,7 +252,7 @@ def _factorize_unique(self) -> EncodedGroups: codes=codes, full_index=full_index, unique_coord=unique_coord, - coords=_coordinates_from_variable(unique_coord), + coords=coordinates_from_variable(unique_coord), ) def _factorize_dummy(self) -> EncodedGroups: @@ -280,7 +280,7 @@ def _factorize_dummy(self) -> EncodedGroups: else: if TYPE_CHECKING: assert isinstance(unique_coord, Variable) - coords = _coordinates_from_variable(unique_coord) + coords = coordinates_from_variable(unique_coord) return EncodedGroups( codes=codes, @@ -417,7 +417,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: codes=codes, full_index=full_index, unique_coord=unique_coord, - coords=_coordinates_from_variable(unique_coord), + coords=coordinates_from_variable(unique_coord), ) @@ -551,7 +551,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: group_indices=group_indices, full_index=full_index, unique_coord=unique_coord, - coords=_coordinates_from_variable(unique_coord), + coords=coordinates_from_variable(unique_coord), ) diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index b911bbfb6e6..e524603c9a5 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -401,12 +401,12 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool): assert isinstance(da._coords, dict), da._coords assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords - assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), ( - da.dims, - {k: v.dims for k, v in da._coords.items()}, - ) if check_default_indexes: + assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), ( + da.dims, + {k: v.dims for k, v in da._coords.items()}, + ) assert all( isinstance(v, IndexVariable) for (k, v) in da._coords.items() diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8d0d5011026..e7acdcdd4f3 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -33,8 +33,12 @@ from xarray.coders import CFDatetimeCoder from xarray.core import dtypes from xarray.core.common import full_like -from xarray.core.coordinates import Coordinates -from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords +from xarray.core.coordinates import Coordinates, CoordinateValidationError +from xarray.core.indexes import ( + Index, + PandasIndex, + filter_indexes_from_coords, +) from xarray.core.types import QueryEngineOptions, QueryParserOptions from xarray.core.utils import is_scalar from xarray.testing import _assert_internal_invariants @@ -418,9 +422,13 @@ def test_constructor_invalid(self) -> None: with pytest.raises(TypeError, match=r"is not hashable"): DataArray(data, dims=["x", []]) # type: ignore[list-item] - with pytest.raises(ValueError, match=r"conflicting sizes for dim"): + with pytest.raises( + CoordinateValidationError, match=r"conflicting sizes for dim" + ): DataArray([1, 2, 3], coords=[("x", [0, 1])]) - with pytest.raises(ValueError, match=r"conflicting sizes for dim"): + with pytest.raises( + CoordinateValidationError, match=r"conflicting sizes for dim" + ): DataArray([1, 2], coords={"x": [0, 1], "y": ("x", [1])}, dims="x") with pytest.raises(ValueError, match=r"conflicting MultiIndex"): @@ -529,6 +537,25 @@ class CustomIndex(Index): ... # test coordinate variables copied assert da.coords["x"] is not coords.variables["x"] + def test_constructor_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + def should_add_coord_to_array(self, name, var, dims): + return True + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + actual = DataArray([1.0, 2.0], coords=coords, dims="x") + + assert_identical(actual.coords, coords, check_default_indexes=False) + assert "x_bnds" not in actual.dims + def test_equals_and_identical(self) -> None: orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") @@ -1602,11 +1629,11 @@ def test_assign_coords(self) -> None: # GH: 2112 da = xr.DataArray([0, 1, 2], dims="x") - with pytest.raises(ValueError): + with pytest.raises(CoordinateValidationError): da["x"] = [0, 1, 2, 3] # size conflict - with pytest.raises(ValueError): + with pytest.raises(CoordinateValidationError): da.coords["x"] = [0, 1, 2, 3] # size conflict - with pytest.raises(ValueError): + with pytest.raises(CoordinateValidationError): da.coords["x"] = ("y", [1, 2, 3]) # no new dimension to a DataArray def test_assign_coords_existing_multiindex(self) -> None: @@ -1634,6 +1661,27 @@ def test_assign_coords_no_default_index(self) -> None: assert_identical(actual.coords, coords, check_default_indexes=False) assert "y" not in actual.xindexes + def test_assign_coords_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + def should_add_coord_to_array(self, name, var, dims): + return True + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + da = DataArray([1.0, 2.0], dims="x") + actual = da.assign_coords(coords) + expected = DataArray([1.0, 2.0], coords=coords, dims="x") + + assert_identical(actual, expected, check_default_indexes=False) + assert "x_bnds" not in actual.dims + def test_coords_alignment(self) -> None: lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 52c60a77066..ac186a7d351 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4262,6 +4262,26 @@ def test_getitem_multiple_dtype(self) -> None: dataset = Dataset({key: ("dim0", range(1)) for key in keys}) assert_identical(dataset, dataset[keys]) + def test_getitem_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + def should_add_coord_to_array(self, name, var, dims): + return True + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + ds = Dataset({"foo": (("x"), [1.0, 2.0])}, coords=coords) + actual = ds["foo"] + + assert_identical(actual.coords, coords, check_default_indexes=False) + assert "x_bnds" not in actual.dims + def test_virtual_variables_default_coords(self) -> None: dataset = Dataset({"foo": ("x", range(10))}) expected1 = DataArray(range(10), dims="x", name="x")