diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 43edc5ee33e..d9d4998d983 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,6 +28,8 @@ New Features By `Kai Mühlbauer `_. - support python 3.13 (no free-threading) (:issue:`9664`, :pull:`9681`) By `Justus Magin `_. +- Added experimental support for coordinate transforms (not ready for public use yet!) (:pull:`9543`) + By `Benoit Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py new file mode 100644 index 00000000000..d9e09cea173 --- /dev/null +++ b/xarray/core/coordinate_transform.py @@ -0,0 +1,84 @@ +from collections.abc import Hashable, Iterable, Mapping +from typing import Any + +import numpy as np + + +class CoordinateTransform: + """Abstract coordinate transform with dimension & coordinate names. + + EXPERIMENTAL (not ready for public use yet). + + """ + + coord_names: tuple[Hashable, ...] + dims: tuple[str, ...] + dim_size: dict[str, int] + dtype: Any + + def __init__( + self, + coord_names: Iterable[Hashable], + dim_size: Mapping[str, int], + dtype: Any = None, + ): + self.coord_names = tuple(coord_names) + self.dims = tuple(dim_size) + self.dim_size = dict(dim_size) + + if dtype is None: + dtype = np.dtype(np.float64) + self.dtype = dtype + + def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: + """Perform grid -> world coordinate transformation. + + Parameters + ---------- + dim_positions : dict + Grid location(s) along each dimension (axis). + + Returns + ------- + coord_labels : dict + World coordinate labels. + + """ + # TODO: cache the results in order to avoid re-computing + # all labels when accessing the values of each coordinate one at a time + raise NotImplementedError + + def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: + """Perform world -> grid coordinate reverse transformation. + + Parameters + ---------- + labels : dict + World coordinate labels. + + Returns + ------- + dim_positions : dict + Grid relative location(s) along each dimension (axis). + + """ + raise NotImplementedError + + def equals(self, other: "CoordinateTransform") -> bool: + """Check equality with another CoordinateTransform of the same kind.""" + raise NotImplementedError + + def generate_coords( + self, dims: tuple[str, ...] | None = None + ) -> dict[Hashable, Any]: + """Compute all coordinate labels at once.""" + if dims is None: + dims = self.dims + + positions = np.meshgrid( + *[np.arange(self.dim_size[d]) for d in dims], + indexing="ij", + ) + dim_positions = {dim: positions[i] for i, dim in enumerate(dims)} + + return self.forward(dim_positions) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index a9ceeb08b96..47773ddfbb6 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -392,7 +392,7 @@ def from_xindex(cls, index: Index) -> Self: def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: """Wrap a pandas multi-index as Xarray coordinates (dimension + levels). - The returned coordinates can be directly assigned to a + The returned coordinate variables can be directly assigned to a :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the ``coords`` argument of their constructor. diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index fbaef9729e3..43e231e84d4 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -10,7 +10,9 @@ import pandas as pd from xarray.core import formatting, nputils, utils +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexing import ( + CoordinateTransformIndexingAdapter, IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter, @@ -1377,6 +1379,125 @@ def rename(self, name_dict, dims_dict): ) +class CoordinateTransformIndex(Index): + """Helper class for creating Xarray indexes based on coordinate transforms. + + EXPERIMENTAL (not ready for public use yet). + + - wraps a :py:class:`CoordinateTransform` instance + - takes care of creating the index (lazy) coordinates + - supports point-wise label-based selection + - supports exact alignment only, by comparing indexes based on their transform + (not on their explicit coordinate labels) + + """ + + transform: CoordinateTransform + + def __init__( + self, + transform: CoordinateTransform, + ): + self.transform = transform + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + from xarray.core.variable import Variable + + new_variables = {} + + for name in self.transform.coord_names: + # copy attributes, if any + attrs: Mapping[Hashable, Any] | None + + if variables is not None and name in variables: + var = variables[name] + attrs = var.attrs + else: + attrs = None + + data = CoordinateTransformIndexingAdapter(self.transform, name) + new_variables[name] = Variable(self.transform.dims, data, attrs=attrs) + + return new_variables + + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Self | None: + # TODO: support returning a new index (e.g., possible to re-calculate the + # the transform or calculate another transform on a reduced dimension space) + return None + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + + if method != "nearest": + raise ValueError( + "CoordinateTransformIndex only supports selection with method='nearest'" + ) + + labels_set = set(labels) + coord_names_set = set(self.transform.coord_names) + + missing_labels = coord_names_set - labels_set + if missing_labels: + missing_labels_str = ",".join([f"{name}" for name in missing_labels]) + raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") + + label0_obj = next(iter(labels.values())) + dim_size0 = getattr(label0_obj, "sizes", {}) + + is_xr_obj = [ + isinstance(label, DataArray | Variable) for label in labels.values() + ] + if not all(is_xr_obj): + raise TypeError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with either xarray.DataArray or xarray.Variable objects." + ) + dim_size = [getattr(label, "sizes", {}) for label in labels.values()] + if any(ds != dim_size0 for ds in dim_size): + raise ValueError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with xarray.DataArray or xarray.Variable objects of macthing dimensions." + ) + + coord_labels = { + name: labels[name].values for name in self.transform.coord_names + } + dim_positions = self.transform.reverse(coord_labels) + + results: dict[str, Variable | DataArray] = {} + dims0 = tuple(dim_size0) + for dim, pos in dim_positions.items(): + # TODO: rounding the decimal positions is not always the behavior we expect + # (there are different ways to represent implicit intervals) + # we should probably make this customizable. + pos = np.round(pos).astype("int") + if isinstance(label0_obj, Variable): + results[dim] = Variable(dims0, pos) + else: + # dataarray + results[dim] = DataArray(pos, dims=dims0) + + return IndexSelResult(results) + + def equals(self, other: Self) -> bool: + return self.transform.equals(other.transform) + + def rename( + self, + name_dict: Mapping[Any, Hashable], + dims_dict: Mapping[Any, Hashable], + ) -> Self: + # TODO: maybe update self.transform coord_names, dim_size and dims attributes + return self + + def create_default_index_implicit( dim_variable: Variable, all_variables: Mapping | Iterable[Hashable] | None = None, diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index cf9d3885f08..521abcdfddd 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -17,6 +17,7 @@ from packaging.version import Version from xarray.core import duck_array_ops +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS from xarray.core.types import T_Xarray @@ -1301,6 +1302,42 @@ def _decompose_outer_indexer( return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) +def _posify_indices(indices: Any, size: int) -> np.ndarray: + """Convert negative indices by their equivalent positive indices. + + Note: the resulting indices may still be out of bounds (< 0 or >= size). + + """ + return np.where(indices < 0, size + indices, indices) + + +def _check_bounds(indices: Any, size: int): + """Check if the given indices are all within the array boundaries.""" + if np.any((indices < 0) | (indices >= size)): + raise IndexError("out of bounds index") + + +def _arrayize_outer_indexer(indexer: OuterIndexer, shape) -> OuterIndexer: + """Return a similar oindex with after replacing slices by arrays and + negative indices by their corresponding positive indices. + + Also check if array indices are within bounds. + + """ + new_key = [] + + for axis, value in enumerate(indexer.tuple): + size = shape[axis] + if isinstance(value, slice): + value = _expand_slice(value, size) + else: + value = _posify_indices(value, size) + _check_bounds(value, size) + new_key.append(value) + + return OuterIndexer(tuple(new_key)) + + def _arrayize_vectorized_indexer( indexer: VectorizedIndexer, shape: _Shape ) -> VectorizedIndexer: @@ -1981,3 +2018,114 @@ def copy(self, deep: bool = True) -> Self: # see PandasIndexingAdapter.copy array = self.array.copy(deep=True) if deep else self.array return type(self)(array, self._dtype, self.level) + + +class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap a CoordinateTransform as a lazy coordinate array. + + Supports explicit indexing (both outer and vectorized). + + """ + + _transform: CoordinateTransform + _coord_name: Hashable + _dims: tuple[str, ...] + + def __init__( + self, + transform: CoordinateTransform, + coord_name: Hashable, + dims: tuple[str, ...] | None = None, + ): + self._transform = transform + self._coord_name = coord_name + self._dims = dims or transform.dims + + @property + def dtype(self) -> np.dtype: + return self._transform.dtype + + @property + def shape(self) -> tuple[int, ...]: + return tuple(self._transform.dim_size.values()) + + def get_duck_array(self) -> np.ndarray: + all_coords = self._transform.generate_coords(dims=self._dims) + return np.asarray(all_coords[self._coord_name]) + + def _oindex_get(self, indexer: OuterIndexer): + expanded_indexer_ = OuterIndexer(expanded_indexer(indexer.tuple, self.ndim)) + array_indexer = _arrayize_outer_indexer(expanded_indexer_, self.shape) + + positions = np.meshgrid(*array_indexer.tuple, indexing="ij") + dim_positions = dict(zip(self._dims, positions, strict=False)) + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]).squeeze() + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def _vindex_get(self, indexer: VectorizedIndexer): + expanded_indexer_ = VectorizedIndexer( + expanded_indexer(indexer.tuple, self.ndim) + ) + array_indexer = _arrayize_vectorized_indexer(expanded_indexer_, self.shape) + + dim_positions = {} + for i, (dim, pos) in enumerate( + zip(self._dims, array_indexer.tuple, strict=False) + ): + pos = _posify_indices(pos, self.shape[i]) + _check_bounds(pos, self.shape[i]) + dim_positions[dim] = pos + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def __getitem__(self, indexer: ExplicitIndexer): + # TODO: make it lazy (i.e., re-calculate and re-wrap the transform) when possible? + self._check_and_raise_if_non_basic_indexer(indexer) + + # also works with basic indexing + return self._oindex_get(OuterIndexer(indexer.tuple)) + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def transpose(self, order: Iterable[int]) -> Self: + new_dims = tuple([self._dims[i] for i in order]) + return type(self)(self._transform, self._coord_name, new_dims) + + def __repr__(self: Any) -> str: + return f"{type(self).__name__}(transform={self._transform!r})" + + def _get_array_subset(self) -> np.ndarray: + threshold = max(100, OPTIONS["display_values_threshold"] + 2) + if self.size > threshold: + pos = threshold // 2 + flat_indices = np.concatenate( + [np.arange(0, pos), np.arange(self.size - pos, self.size)] + ) + subset = self.vindex[ + VectorizedIndexer(np.unravel_index(flat_indices, self.shape)) + ] + else: + subset = self + + return np.asarray(subset) + + def _repr_inline_(self, max_width: int) -> str: + """Good to see some labels even for a lazy coordinate.""" + from xarray.core.formatting import format_array_flat + + return format_array_flat(self._get_array_subset(), max_width) diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index b1bf7a1af11..9073cbc2ed4 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -3,6 +3,10 @@ """ -from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex +from xarray.core.indexes import ( + Index, + PandasIndex, + PandasMultiIndex, +) __all__ = ["Index", "PandasIndex", "PandasMultiIndex"] diff --git a/xarray/tests/test_coordinate_transform.py b/xarray/tests/test_coordinate_transform.py new file mode 100644 index 00000000000..26746657dbc --- /dev/null +++ b/xarray/tests/test_coordinate_transform.py @@ -0,0 +1,218 @@ +from collections.abc import Hashable +from typing import Any + +import numpy as np +import pytest + +import xarray as xr +from xarray.core.coordinate_transform import CoordinateTransform +from xarray.core.indexes import CoordinateTransformIndex +from xarray.tests import assert_equal + + +class SimpleCoordinateTransform(CoordinateTransform): + """Simple uniform scale transform in a 2D space (x/y coordinates).""" + + def __init__(self, shape: tuple[int, int], scale: float, dtype: Any = None): + super().__init__(("x", "y"), {"x": shape[1], "y": shape[0]}, dtype=dtype) + + self.scale = scale + + # array dimensions in reverse order (y = rows, x = cols) + self.xy_dims = tuple(self.dims) + self.dims = (self.dims[1], self.dims[0]) + + def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: + assert set(dim_positions) == set(self.dims) + return {dim: dim_positions[dim] * self.scale for dim in self.xy_dims} + + def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: + return {dim: coord_labels[dim] / self.scale for dim in self.xy_dims} + + def equals(self, other: "CoordinateTransform") -> bool: + if not isinstance(other, SimpleCoordinateTransform): + return False + return self.scale == other.scale + + def __repr__(self) -> str: + return f"Scale({self.scale})" + + +def test_abstract_coordinate_transform() -> None: + tr = CoordinateTransform(["x"], {"x": 5}) + + with pytest.raises(NotImplementedError): + tr.forward({"x": [1, 2]}) + + with pytest.raises(NotImplementedError): + tr.reverse({"x": [3.0, 4.0]}) + + with pytest.raises(NotImplementedError): + tr.equals(CoordinateTransform(["x"], {"x": 5})) + + +def test_coordinate_transform_init() -> None: + tr = SimpleCoordinateTransform((4, 4), 2.0) + + assert tr.coord_names == ("x", "y") + # array dimensions in reverse order (y = rows, x = cols) + assert tr.dims == ("y", "x") + assert tr.dim_size == {"x": 4, "y": 4} + assert tr.dtype == np.dtype(np.float64) + + tr2 = SimpleCoordinateTransform((4, 4), 2.0, dtype=np.int64) + assert tr2.dtype == np.dtype(np.int64) + + +@pytest.mark.parametrize("dims", [None, ("y", "x")]) +def test_coordinate_transform_generate_coords(dims) -> None: + tr = SimpleCoordinateTransform((2, 2), 2.0) + + actual = tr.generate_coords(dims) + expected = {"x": [[0.0, 2.0], [0.0, 2.0]], "y": [[0.0, 0.0], [2.0, 2.0]]} + assert set(actual) == set(expected) + np.testing.assert_array_equal(actual["x"], expected["x"]) + np.testing.assert_array_equal(actual["y"], expected["y"]) + + +def create_coords(scale: float, shape: tuple[int, int]) -> xr.Coordinates: + """Create x/y Xarray coordinate variables from a simple coordinate transform.""" + tr = SimpleCoordinateTransform(shape, scale) + index = CoordinateTransformIndex(tr) + return xr.Coordinates.from_xindex(index) + + +def test_coordinate_transform_variable() -> None: + coords = create_coords(scale=2.0, shape=(2, 2)) + + assert coords["x"].dtype == np.dtype(np.float64) + assert coords["y"].dtype == np.dtype(np.float64) + assert coords["x"].shape == (2, 2) + assert coords["y"].shape == (2, 2) + + np.testing.assert_array_equal(np.array(coords["x"]), [[0.0, 2.0], [0.0, 2.0]]) + np.testing.assert_array_equal(np.array(coords["y"]), [[0.0, 0.0], [2.0, 2.0]]) + + def assert_repr(var: xr.Variable): + assert ( + repr(var._data) + == "CoordinateTransformIndexingAdapter(transform=Scale(2.0))" + ) + + assert_repr(coords["x"].variable) + assert_repr(coords["y"].variable) + + +def test_coordinate_transform_variable_repr_inline() -> None: + var = create_coords(scale=2.0, shape=(2, 2))["x"].variable + + actual = var._data._repr_inline_(70) # type: ignore[union-attr] + assert actual == "0.0 2.0 0.0 2.0" + + # truncated inline repr + var2 = create_coords(scale=2.0, shape=(10, 10))["x"].variable + + actual2 = var2._data._repr_inline_(70) # type: ignore[union-attr] + assert ( + actual2 == "0.0 2.0 4.0 6.0 8.0 10.0 12.0 ... 6.0 8.0 10.0 12.0 14.0 16.0 18.0" + ) + + +def test_coordinate_transform_variable_basic_outer_indexing() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + assert var[0, 0] == 0.0 + assert var[0, 1] == 2.0 + assert var[0, -1] == 6.0 + np.testing.assert_array_equal(var[:, 0:2], [[0.0, 2.0]] * 4) + + with pytest.raises(IndexError, match="out of bounds index"): + var[5] + + with pytest.raises(IndexError, match="out of bounds index"): + var[-5] + + +def test_coordinate_transform_variable_vectorized_indexing() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + actual = var[{"x": xr.Variable("z", [0]), "y": xr.Variable("z", [0])}] + expected = xr.Variable("z", [0.0]) + assert_equal(actual, expected) + + with pytest.raises(IndexError, match="out of bounds index"): + var[{"x": xr.Variable("z", [5]), "y": xr.Variable("z", [5])}] + + +def test_coordinate_transform_setitem_error() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + # basic indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[0, 0] = 1.0 + + # outer indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[[0, 2], 0] = [1.0, 2.0] + + # vectorized indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[{"x": xr.Variable("z", [0]), "y": xr.Variable("z", [0])}] = 1.0 + + +def test_coordinate_transform_transpose() -> None: + coords = create_coords(scale=2.0, shape=(2, 2)) + + actual = coords["x"].transpose().values + expected = [[0.0, 0.0], [2.0, 2.0]] + np.testing.assert_array_equal(actual, expected) + + +def test_coordinate_transform_equals() -> None: + ds1 = create_coords(scale=2.0, shape=(2, 2)).to_dataset() + ds2 = create_coords(scale=2.0, shape=(2, 2)).to_dataset() + ds3 = create_coords(scale=4.0, shape=(2, 2)).to_dataset() + + # cannot use `assert_equal()` test utility function here yet + # (indexes invariant check are still based on IndexVariable, which + # doesn't work with coordinate transform index coordinate variables) + assert ds1.equals(ds2) + assert not ds1.equals(ds3) + + +def test_coordinate_transform_sel() -> None: + ds = create_coords(scale=2.0, shape=(4, 4)).to_dataset() + + data = [ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + ] + ds["data"] = (("y", "x"), data) + + actual = ds.sel( + x=xr.Variable("z", [0.5, 5.5]), y=xr.Variable("z", [0.0, 0.5]), method="nearest" + ) + expected = ds.isel(x=xr.Variable("z", [0, 3]), y=xr.Variable("z", [0, 0])) + + # cannot use `assert_equal()` test utility function here yet + # (indexes invariant check are still based on IndexVariable, which + # doesn't work with coordinate transform index coordinate variables) + assert actual.equals(expected) + + with pytest.raises(ValueError, match=".*only supports selection.*nearest"): + ds.sel(x=xr.Variable("z", [0.5, 5.5]), y=xr.Variable("z", [0.0, 0.5])) + + with pytest.raises(ValueError, match="missing labels for coordinate.*y"): + ds.sel(x=[0.5, 5.5], method="nearest") + + with pytest.raises(TypeError, match=".*only supports advanced.*indexing"): + ds.sel(x=[0.5, 5.5], y=[0.0, 0.5], method="nearest") + + with pytest.raises(ValueError, match=".*only supports advanced.*indexing"): + ds.sel( + x=xr.Variable("z", [0.5, 5.5]), + y=xr.Variable("z", [0.0, 0.5, 1.5]), + method="nearest", + )