From 4a595dff198edfc6163fd9bb6d2b3c095320ac2b Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 16:22:12 +0200 Subject: [PATCH 01/16] Add coordinate transform classes from prototype --- xarray/core/coordinate_transform.py | 74 ++++++++++++++ xarray/core/indexes.py | 111 +++++++++++++++++++++ xarray/core/indexing.py | 145 ++++++++++++++++++++++++++++ 3 files changed, 330 insertions(+) create mode 100644 xarray/core/coordinate_transform.py diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py new file mode 100644 index 00000000000..1d4db3e9b7e --- /dev/null +++ b/xarray/core/coordinate_transform.py @@ -0,0 +1,74 @@ +from typing import Any, Iterable, Hashable, Mapping + +import numpy as np + + +class CoordinateTransform: + """Abstract coordinate transform with dimension & coordinate names.""" + + coord_names: tuple[Hashable, ...] + dims: tuple[str, ...] + dim_size: dict[str, int] + dtype: Any + + def __init__( + self, + coord_names: Iterable[Hashable], + dim_size: Mapping[str, int], + dtype: Any = np.dtype(np.float64), + ): + self.coord_names = tuple(coord_names) + self.dims = tuple(dim_size) + self.dim_size = dict(dim_size) + self.dtype = dtype + + def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: + """Perform grid -> world coordinate transformation. + + Parameters + ---------- + dim_positions : dict + Grid location(s) along each dimension (axis). + + Returns + ------- + coord_labels : dict + World coordinate labels. + + """ + # TODO: cache the results in order to avoid re-computing + # all labels when accessing the values of each coordinate one at a time + raise NotImplementedError + + def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: + """Perform world -> grid coordinate reverse transformation. + + Parameters + ---------- + labels : dict + World coordinate labels. + + Returns + ------- + dim_positions : dict + Grid relative location(s) along each dimension (axis). + + """ + raise NotImplementedError + + def equals(self, other: "CoordinateTransform") -> bool: + """Check equality with another CoordinateTransform of the same kind.""" + raise NotImplementedError + + def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]: + """Returns all "world" coordinate labels.""" + if dims is None: + dims = self.dims + + positions = np.meshgrid( + *[np.arange(self.dim_size[d]) for d in dims], + indexing="ij", + ) + dim_positions = {dim: positions[i] for i, dim in enumerate(dims)} + + return self.forward(dim_positions) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 5abc2129e3e..8d90c955bfe 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -10,6 +10,7 @@ import pandas as pd from xarray.core import formatting, nputils, utils +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexing import ( IndexSelResult, PandasIndexingAdapter, @@ -1372,6 +1373,116 @@ def rename(self, name_dict, dims_dict): ) +class CoordinateTransformIndex(Index): + """Xarray index abstract class for transformation between "pixel" + and "world" coordinates. + + """ + + transform: CoordinateTransform + + def __init__( + self, + transform: CoordinateTransform, + ): + self.transform = transform + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + new_variables = {} + + for name in self.transform.coord_names: + # copy attributes, if any + attrs: Mapping[Hashable, Any] | None + + if variables is not None and name in variables: + var = variables[name] + attrs = var.attrs + else: + attrs = None + + data = CoordinateTransformIndexingAdapter(self.transform, name) + new_variables[name] = Variable(self.transform.dims, data, attrs=attrs) + + return new_variables + + def create_coordinates(self) -> Coordinates: + # TODO: move this in xarray.Index base class? + variables = self.create_variables() + indexes = {name: self for name in variables} + return xr.Coordinates(coords=variables, indexes=indexes) + + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Self | None: + # TODO: support returning a new index (e.g., possible to re-calculate the + # the transform or calculate another transform on a reduced dimension space) + return None + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + if method != "nearest": + raise ValueError( + "CoordinateTransformIndex only supports selection with method='nearest'" + ) + + labels_set = set(labels) + coord_names_set = set(self.transform.coord_names) + + missing_labels = coord_names_set - labels_set + if missing_labels: + raise ValueError( + f"missing labels for coordinate(s): {','.join(missing_labels)}." + ) + + label0_obj = next(iter(labels.values())) + dim_size0 = getattr(label0_obj, "sizes", None) + + is_xr_obj = [ + isinstance(label, (xr.DataArray, xr.Variable)) for label in labels.values() + ] + if not all(is_xr_obj): + raise TypeError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with either xarray.DataArray or xarray.Variable objects." + ) + dim_size = [getattr(label, "sizes", None) for label in labels.values()] + if any([ds != dim_size0 for ds in dim_size]): + raise ValueError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with xarray.DataArray or xarray.Variable objects of macthing dimensions." + ) + + coord_labels = { + name: labels[name].values for name in self.transform.coord_names + } + dim_positions = self.transform.reverse(coord_labels) + + results = {} + for dim, pos in dim_positions.items(): + if isinstance(label0_obj, Variable): + xr_pos = Variable(label.dims, idx) + else: + # dataarray + xr_pos = DataArray(idx, dims=label.dims) + results[dim] = idx + + return IndexSelResult(results) + + def equals(self, other: Self) -> bool: + return self.transform.equals(other.transform) + + def rename( + self, + name_dict: Mapping[Any, Hashable], + dims_dict: Mapping[Any, Hashable], + ) -> Self: + # TODO: maybe update self.transform coord_names, dim_size and dims attributes + return self + + def create_default_index_implicit( dim_variable: Variable, all_variables: Mapping | Iterable[Hashable] | None = None, diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 67912908a2b..35fd2597b85 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -15,6 +15,7 @@ import pandas as pd from xarray.core import duck_array_ops +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS from xarray.core.types import T_Xarray @@ -1303,6 +1304,42 @@ def _decompose_outer_indexer( return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) +def _posify_indices(indices: np.typing.ArrayLike, size: int) -> np.ndarray: + """Convert negative indices by their equivalent positive indices. + + Note: the resulting indices may still be out of bounds (< 0 or >= size). + + """ + return np.where(indices < 0, size + indices, indices) + + +def _check_bounds(indices, size): + """Check if the given indices are all within the array boundaries.""" + if np.any((indices < 0) | (indices >= size)): + raise IndexError("out of bounds index") + + +def _arrayize_outer_indexer(indexer: OuterIndexer, shape) -> OuterIndexer: + """Return a similar oindex with after replacing slices by arrays and + negative indices by their corresponding positive indices. + + Also check if array indices are within bounds. + + """ + new_key = [] + + for axis, value in enumerate(indexer.tuple): + size = shape[axis] + if isinstance(value, slice): + value = _expand_slice(value, size) + else: + value = _posify_indices(value, size) + _check_bounds(value, size) + new_key.append(value) + + return OuterIndexer(tuple(new_key)) + + def _arrayize_vectorized_indexer( indexer: VectorizedIndexer, shape: _Shape ) -> VectorizedIndexer: @@ -1921,3 +1958,111 @@ def copy(self, deep: bool = True) -> Self: # see PandasIndexingAdapter.copy array = self.array.copy(deep=True) if deep else self.array return type(self)(array, self._dtype, self.level) + + +class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap a CoordinateTransform to support explicit indexing and + lazy coordinate labels. + + """ + + _transform: CoordinateTransform + _coord_name: Hashable + _dims: tuple[str, ...] + + def __init__( + self, + transform: CoordinateTransform, + coord_name: Hashable, + dims: tuple[str] | None = None, + ): + self._transform = transform + self._coord_name = coord_name + self._dims = dims or transform.dims + + @property + def dtype(self) -> np.dtype: + return self._transform.dtype + + @property + def shape(self): + return tuple(self._transform.dim_size.values()) + + def get_duck_array(self) -> np.ndarray: + all_coords = self._transform.generate_coords(dims=self._dims) + return np.asarray(all_coords[self._coord_name]) + + def _oindex_get(self, indexer: OuterIndexer): + expanded_indexer_ = OuterIndexer(expanded_indexer(indexer.tuple, self.ndim)) + array_indexer = _arrayize_outer_indexer(expanded_indexer_, self.shape) + + positions = np.meshgrid(*array_indexer.tuple, indexing="ij") + dim_positions = { + dim: pos for dim, pos in zip(self._dims, positions, strict=False) + } + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]).squeeze() + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def _vindex_get(self, indexer: VectorizedIndexer): + expanded_indexer_ = VectorizedIndexer( + expanded_indexer(indexer.tuple, self.ndim) + ) + array_indexer = _arrayize_vectorized_indexer(expanded_indexer_, self.shape) + + dim_positions = {} + for i, (dim, pos) in enumerate( + zip(self._dims, array_indexer.tuple, strict=False) + ): + pos = _posify_indices(pos, self.shape[i]) + _check_bounds(pos, self.shape[i]) + dim_positions[dim] = pos + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def __getitem__(self, indexer: ExplicitIndexer): + # TODO: make it lazy (i.e., re-calculate and re-wrap the transform) when possible? + self._check_and_raise_if_non_basic_indexer(indexer) + + # also works with basic indexing + return self._oindex_get(indexer) + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def transpose(self, order): + new_dims = tuple([self._dims[i] for i in order]) + return type(self)(self._transform, self._coord_name, new_dims) + + def __repr__(self: Any) -> str: + return f"{type(self).__name__}(transform={self._transform!r})" + + def _get_array_subset(self) -> np.ndarray: + threshold = max(100, OPTIONS["display_values_threshold"] + 2) + if self.size > threshold: + pos = threshold // 2 + indices = np.concatenate([np.arange(0, pos), np.arange(-pos, 0)]) + subset = self.vindex[VectorizedIndexer((indices,) * self.ndim)] + else: + subset = self + + return np.asarray(subset) + + def _repr_inline_(self, max_width: int) -> str: + """Good to see some labels even for a lazy coordinate.""" + from xarray.core.formatting import format_array_flat + + return format_array_flat(self._get_array_subset(), max_width) From 0b545cf61cf192cd1037e2e1d312921a8ab5843c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 20:26:15 +0200 Subject: [PATCH 02/16] lint, public API and docstrings --- xarray/__init__.py | 2 ++ xarray/core/coordinate_transform.py | 10 ++++++--- xarray/core/indexes.py | 32 ++++++++++++++++++++--------- xarray/core/indexing.py | 7 ++++--- xarray/indexes/__init__.py | 9 ++++++-- 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/xarray/__init__.py b/xarray/__init__.py index e3b7ec469e9..b49ab1848b7 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -30,6 +30,7 @@ where, ) from xarray.core.concat import concat +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -109,6 +110,7 @@ "CFTimeIndex", "Context", "Coordinates", + "CoordinateTransform", "DataArray", "Dataset", "DataTree", diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py index 1d4db3e9b7e..40043da46bc 100644 --- a/xarray/core/coordinate_transform.py +++ b/xarray/core/coordinate_transform.py @@ -1,4 +1,5 @@ -from typing import Any, Iterable, Hashable, Mapping +from collections.abc import Hashable, Iterable, Mapping +from typing import Any import numpy as np @@ -15,11 +16,14 @@ def __init__( self, coord_names: Iterable[Hashable], dim_size: Mapping[str, int], - dtype: Any = np.dtype(np.float64), + dtype: Any = None, ): self.coord_names = tuple(coord_names) self.dims = tuple(dim_size) self.dim_size = dict(dim_size) + + if dtype is None: + dtype = np.dtype(np.float64) self.dtype = dtype def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: @@ -61,7 +65,7 @@ def equals(self, other: "CoordinateTransform") -> bool: raise NotImplementedError def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]: - """Returns all "world" coordinate labels.""" + """Compute all coordinate labels at once.""" if dims is None: dims = self.dims diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 8d90c955bfe..e154b727fc5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -12,6 +12,7 @@ from xarray.core import formatting, nputils, utils from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexing import ( + CoordinateTransformIndexingAdapter, IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter, @@ -25,6 +26,7 @@ ) if TYPE_CHECKING: + from xarray.core.coordinate import Coordinates from xarray.core.types import ErrorOptions, JoinOptions, Self from xarray.core.variable import Variable @@ -1374,8 +1376,13 @@ def rename(self, name_dict, dims_dict): class CoordinateTransformIndex(Index): - """Xarray index abstract class for transformation between "pixel" - and "world" coordinates. + """Helper class for creating Xarray indexes based on coordinate transforms. + + - wraps a :py:class:`CoordinateTransform` instance + - takes care of creating the index (lazy) coordinates + - supports point-wise label-based selection + - supports exact alignment only, by comparing indexes based on their transform + (not on their explicit coordinate labels) """ @@ -1409,9 +1416,11 @@ def create_variables( def create_coordinates(self) -> Coordinates: # TODO: move this in xarray.Index base class? + from xarray.core.coordinates import Coordinates + variables = self.create_variables() indexes = {name: self for name in variables} - return xr.Coordinates(coords=variables, indexes=indexes) + return Coordinates(coords=variables, indexes=indexes) def isel( self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] @@ -1423,6 +1432,9 @@ def isel( def sel( self, labels: dict[Any, Any], method=None, tolerance=None ) -> IndexSelResult: + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + if method != "nearest": raise ValueError( "CoordinateTransformIndex only supports selection with method='nearest'" @@ -1433,15 +1445,14 @@ def sel( missing_labels = coord_names_set - labels_set if missing_labels: - raise ValueError( - f"missing labels for coordinate(s): {','.join(missing_labels)}." - ) + missing_labels_str = ",".join([f"{name}" for name in missing_labels]) + raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") label0_obj = next(iter(labels.values())) dim_size0 = getattr(label0_obj, "sizes", None) is_xr_obj = [ - isinstance(label, (xr.DataArray, xr.Variable)) for label in labels.values() + isinstance(label, DataArray | Variable) for label in labels.values() ] if not all(is_xr_obj): raise TypeError( @@ -1461,13 +1472,14 @@ def sel( dim_positions = self.transform.reverse(coord_labels) results = {} + dims0 = tuple(dim_size0) for dim, pos in dim_positions.items(): if isinstance(label0_obj, Variable): - xr_pos = Variable(label.dims, idx) + xr_pos = Variable(dims0, pos) else: # dataarray - xr_pos = DataArray(idx, dims=label.dims) - results[dim] = idx + xr_pos = DataArray(pos, dims=dims0) + results[dim] = xr_pos return IndexSelResult(results) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 35fd2597b85..047e16c240d 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1961,8 +1961,9 @@ def copy(self, deep: bool = True) -> Self: class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): - """Wrap a CoordinateTransform to support explicit indexing and - lazy coordinate labels. + """Wrap a CoordinateTransform as a lazy coordinate array. + + Supports explicit indexing (both outer and vectorized). """ @@ -2036,7 +2037,7 @@ def __getitem__(self, indexer: ExplicitIndexer): self._check_and_raise_if_non_basic_indexer(indexer) # also works with basic indexing - return self._oindex_get(indexer) + return self._oindex_get(OuterIndexer(indexer.tuple)) def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: raise TypeError( diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index b1bf7a1af11..e2857b8602b 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -3,6 +3,11 @@ """ -from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex +from xarray.core.indexes import ( + CoordinateTransformIndex, + Index, + PandasIndex, + PandasMultiIndex, +) -__all__ = ["Index", "PandasIndex", "PandasMultiIndex"] +__all__ = ["CoordinateTransformIndex", "Index", "PandasIndex", "PandasMultiIndex"] From 8af6614086f8ca181ec070859fca1e019663c837 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 20:30:52 +0200 Subject: [PATCH 03/16] missing import --- xarray/core/indexes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index e154b727fc5..b56d1faf295 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1397,6 +1397,8 @@ def __init__( def create_variables( self, variables: Mapping[Any, Variable] | None = None ) -> IndexVars: + from xarray.core.variable import Variable + new_variables = {} for name in self.transform.coord_names: From e9a11ef6df072c4f61eea7ea7be00e12d7cee5da Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 20:48:25 +0200 Subject: [PATCH 04/16] sel: convert inverse transform results to ints --- xarray/core/indexes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b56d1faf295..ab725f86833 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1476,6 +1476,7 @@ def sel( results = {} dims0 = tuple(dim_size0) for dim, pos in dim_positions.items(): + pos = np.round(pos).astype("int") if isinstance(label0_obj, Variable): xr_pos = Variable(dims0, pos) else: From 0b3fd9ee751f64b9695609a601ae31b336c1e0a0 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 22:13:40 +0200 Subject: [PATCH 05/16] sel: add todo note about rounding decimal pos --- xarray/core/indexes.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index ab725f86833..987039e1f87 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1476,6 +1476,9 @@ def sel( results = {} dims0 = tuple(dim_size0) for dim, pos in dim_positions.items(): + # TODO: rounding the decimal positions is not always the behavior we expect + # (there are different ways to represent implicit intervals) + # we should probably make this customizable. pos = np.round(pos).astype("int") if isinstance(label0_obj, Variable): xr_pos = Variable(dims0, pos) From acf1c478c68fcadcc6bfbdd4414bc97b8667383f Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 26 Sep 2024 10:37:04 +0200 Subject: [PATCH 06/16] rename create_coordinates -> create_coords More consistent with the rest of Xarray API where `coords` is used everywhere. --- xarray/core/indexes.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 987039e1f87..4f2bba20844 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1417,6 +1417,11 @@ def create_variables( return new_variables def create_coordinates(self) -> Coordinates: + # TODO: remove this alias before merging https://github.com/pydata/xarray/pull/9543! + # (we keep it there so it doesn't break the code of those who are experimenting with this) + return self.create_coords() + + def create_coords(self) -> Coordinates: # TODO: move this in xarray.Index base class? from xarray.core.coordinates import Coordinates From e101585e9fb30a3b73d6d37a7bc0be1607f991b1 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 26 Sep 2024 10:46:13 +0200 Subject: [PATCH 07/16] add a Coordinates.from_transform convenient method --- xarray/core/coordinates.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index a6dec863aec..af622aaca8b 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -14,7 +14,9 @@ from xarray.core import formatting from xarray.core.alignment import Aligner +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexes import ( + CoordinateTransformIndex, Index, Indexes, PandasIndex, @@ -356,7 +358,7 @@ def _construct_direct( def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: """Wrap a pandas multi-index as Xarray coordinates (dimension + levels). - The returned coordinates can be directly assigned to a + The returned coordinate variables can be directly assigned to a :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the ``coords`` argument of their constructor. @@ -380,6 +382,28 @@ def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: return cls(coords=variables, indexes=indexes) + @classmethod + def from_transform(cls, transform: CoordinateTransform) -> Self: + """Wrap a coordinate transform as Xarray (lazy) coordinates. + + The returned coordinate variables can be directly assigned to a + :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the + ``coords`` argument of their constructor. + + Parameters + ---------- + transform : :py:class:`CoordinateTransform` + Xarray coordinate transform object. + + Returns + ------- + coords : Coordinates + A collection of Xarray indexed coordinates created from the transform. + + """ + index = CoordinateTransformIndex(transform) + return index.create_coords() + @property def _names(self) -> set[Hashable]: return self._data._coord_names From 09667c5da4e2de2f1db6896e3acce0205e3608e3 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Oct 2024 14:50:55 +0200 Subject: [PATCH 08/16] fix repr (extract subset values of any n-d array) --- xarray/core/indexing.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 047e16c240d..04677bb8d60 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -2055,8 +2055,12 @@ def _get_array_subset(self) -> np.ndarray: threshold = max(100, OPTIONS["display_values_threshold"] + 2) if self.size > threshold: pos = threshold // 2 - indices = np.concatenate([np.arange(0, pos), np.arange(-pos, 0)]) - subset = self.vindex[VectorizedIndexer((indices,) * self.ndim)] + flat_indices = np.concatenate( + [np.arange(0, pos), np.arange(self.size - pos, self.size)] + ) + subset = self.vindex[ + VectorizedIndexer(np.unravel_index(flat_indices, self.shape)) + ] else: subset = self From 4c7ce28884c0dd9af4caabb5297036a6d5644a9a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 12 Feb 2025 14:55:19 -0700 Subject: [PATCH 09/16] Apply suggestions from code review Co-authored-by: Max Jones <14077947+maxrjones@users.noreply.github.com> --- xarray/core/indexes.py | 2 +- xarray/core/indexing.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 53049f9e5a1..fb04b829737 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1472,7 +1472,7 @@ def sel( "with either xarray.DataArray or xarray.Variable objects." ) dim_size = [getattr(label, "sizes", None) for label in labels.values()] - if any([ds != dim_size0 for ds in dim_size]): + if any(ds != dim_size0 for ds in dim_size): raise ValueError( "CoordinateTransformIndex only supports advanced (point-wise) indexing " "with xarray.DataArray or xarray.Variable objects of macthing dimensions." diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 35d4fc52e8c..3b337612239 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -2058,9 +2058,7 @@ def _oindex_get(self, indexer: OuterIndexer): array_indexer = _arrayize_outer_indexer(expanded_indexer_, self.shape) positions = np.meshgrid(*array_indexer.tuple, indexing="ij") - dim_positions = { - dim: pos for dim, pos in zip(self._dims, positions, strict=False) - } + dim_positions = dict(zip(self._dims, positions, strict=False)) result = self._transform.forward(dim_positions) return np.asarray(result[self._coord_name]).squeeze() From 5cfb1afa9bdb45817e0527cde82764b12586f6d8 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 09:11:31 +0100 Subject: [PATCH 10/16] remove specific create coordinates methods In favor of the more generic `Coordinates.from_xindex()`. --- xarray/core/coordinates.py | 24 ------------------------ xarray/core/indexes.py | 14 -------------- 2 files changed, 38 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 1bca543ca20..47773ddfbb6 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -14,9 +14,7 @@ from xarray.core import formatting from xarray.core.alignment import Aligner -from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexes import ( - CoordinateTransformIndex, Index, Indexes, PandasIndex, @@ -418,28 +416,6 @@ def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: return cls(coords=variables, indexes=indexes) - @classmethod - def from_transform(cls, transform: CoordinateTransform) -> Self: - """Wrap a coordinate transform as Xarray (lazy) coordinates. - - The returned coordinate variables can be directly assigned to a - :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the - ``coords`` argument of their constructor. - - Parameters - ---------- - transform : :py:class:`CoordinateTransform` - Xarray coordinate transform object. - - Returns - ------- - coords : Coordinates - A collection of Xarray indexed coordinates created from the transform. - - """ - index = CoordinateTransformIndex(transform) - return index.create_coords() - @property def _names(self) -> set[Hashable]: return self._data._coord_names diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index fb04b829737..833ec8bb926 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -26,7 +26,6 @@ ) if TYPE_CHECKING: - from xarray.core.coordinate import Coordinates from xarray.core.types import ErrorOptions, JoinOptions, Self from xarray.core.variable import Variable @@ -1421,19 +1420,6 @@ def create_variables( return new_variables - def create_coordinates(self) -> Coordinates: - # TODO: remove this alias before merging https://github.com/pydata/xarray/pull/9543! - # (we keep it there so it doesn't break the code of those who are experimenting with this) - return self.create_coords() - - def create_coords(self) -> Coordinates: - # TODO: move this in xarray.Index base class? - from xarray.core.coordinates import Coordinates - - variables = self.create_variables() - indexes = {name: self for name in variables} - return Coordinates(coords=variables, indexes=indexes) - def isel( self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] ) -> Self | None: From 632c71b103d4e659216f37d3adf0f5b8e8aad091 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 09:54:09 +0100 Subject: [PATCH 11/16] fix more typing issues --- xarray/core/coordinate_transform.py | 4 +++- xarray/core/indexes.py | 11 +++++------ xarray/core/indexing.py | 6 +++--- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py index 40043da46bc..52e7f13ca14 100644 --- a/xarray/core/coordinate_transform.py +++ b/xarray/core/coordinate_transform.py @@ -64,7 +64,9 @@ def equals(self, other: "CoordinateTransform") -> bool: """Check equality with another CoordinateTransform of the same kind.""" raise NotImplementedError - def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]: + def generate_coords( + self, dims: tuple[str, ...] | None = None + ) -> dict[Hashable, Any]: """Compute all coordinate labels at once.""" if dims is None: dims = self.dims diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 833ec8bb926..240b4f178ec 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1447,7 +1447,7 @@ def sel( raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") label0_obj = next(iter(labels.values())) - dim_size0 = getattr(label0_obj, "sizes", None) + dim_size0 = getattr(label0_obj, "sizes", {}) is_xr_obj = [ isinstance(label, DataArray | Variable) for label in labels.values() @@ -1457,7 +1457,7 @@ def sel( "CoordinateTransformIndex only supports advanced (point-wise) indexing " "with either xarray.DataArray or xarray.Variable objects." ) - dim_size = [getattr(label, "sizes", None) for label in labels.values()] + dim_size = [getattr(label, "sizes", {}) for label in labels.values()] if any(ds != dim_size0 for ds in dim_size): raise ValueError( "CoordinateTransformIndex only supports advanced (point-wise) indexing " @@ -1469,7 +1469,7 @@ def sel( } dim_positions = self.transform.reverse(coord_labels) - results = {} + results: dict[str, Variable | DataArray] = {} dims0 = tuple(dim_size0) for dim, pos in dim_positions.items(): # TODO: rounding the decimal positions is not always the behavior we expect @@ -1477,11 +1477,10 @@ def sel( # we should probably make this customizable. pos = np.round(pos).astype("int") if isinstance(label0_obj, Variable): - xr_pos = Variable(dims0, pos) + results[dim] = Variable(dims0, pos) else: # dataarray - xr_pos = DataArray(pos, dims=dims0) - results[dim] = xr_pos + results[dim] = DataArray(pos, dims=dims0) return IndexSelResult(results) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 3b337612239..f379a932019 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1302,7 +1302,7 @@ def _decompose_outer_indexer( return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) -def _posify_indices(indices: np.typing.ArrayLike, size: int) -> np.ndarray: +def _posify_indices(indices: Any, size: int) -> np.ndarray: """Convert negative indices by their equivalent positive indices. Note: the resulting indices may still be out of bounds (< 0 or >= size). @@ -1311,7 +1311,7 @@ def _posify_indices(indices: np.typing.ArrayLike, size: int) -> np.ndarray: return np.where(indices < 0, size + indices, indices) -def _check_bounds(indices, size): +def _check_bounds(indices: Any, size: int): """Check if the given indices are all within the array boundaries.""" if np.any((indices < 0) | (indices >= size)): raise IndexError("out of bounds index") @@ -2046,7 +2046,7 @@ def dtype(self) -> np.dtype: return self._transform.dtype @property - def shape(self): + def shape(self) -> tuple[int, ...]: return tuple(self._transform.dim_size.values()) def get_duck_array(self) -> np.ndarray: From ae8b318c7a136c35a784e96fd0b63225011b9bf0 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 09:55:31 +0100 Subject: [PATCH 12/16] remove public imports: not ready yet for public use --- xarray/__init__.py | 2 -- xarray/indexes/__init__.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/xarray/__init__.py b/xarray/__init__.py index 05cfecb2b8b..8af936ed27a 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -30,7 +30,6 @@ where, ) from xarray.core.concat import concat -from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -118,7 +117,6 @@ "CFTimeIndex", "Context", "Coordinates", - "CoordinateTransform", "DataArray", "DataTree", "Dataset", diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index e2857b8602b..9073cbc2ed4 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -4,10 +4,9 @@ """ from xarray.core.indexes import ( - CoordinateTransformIndex, Index, PandasIndex, PandasMultiIndex, ) -__all__ = ["CoordinateTransformIndex", "Index", "PandasIndex", "PandasMultiIndex"] +__all__ = ["Index", "PandasIndex", "PandasMultiIndex"] From 1c425e30fda0541d5571250b87c9c0b3b3736dac Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 10:43:38 +0100 Subject: [PATCH 13/16] add experimental notice in docstrings --- xarray/core/coordinate_transform.py | 6 +++++- xarray/core/indexes.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py index 52e7f13ca14..d9e09cea173 100644 --- a/xarray/core/coordinate_transform.py +++ b/xarray/core/coordinate_transform.py @@ -5,7 +5,11 @@ class CoordinateTransform: - """Abstract coordinate transform with dimension & coordinate names.""" + """Abstract coordinate transform with dimension & coordinate names. + + EXPERIMENTAL (not ready for public use yet). + + """ coord_names: tuple[Hashable, ...] dims: tuple[str, ...] diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 240b4f178ec..43e231e84d4 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1382,6 +1382,8 @@ def rename(self, name_dict, dims_dict): class CoordinateTransformIndex(Index): """Helper class for creating Xarray indexes based on coordinate transforms. + EXPERIMENTAL (not ready for public use yet). + - wraps a :py:class:`CoordinateTransform` instance - takes care of creating the index (lazy) coordinates - supports point-wise label-based selection From 952faa78fb96b7a486b503465edda732342517e7 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 11:36:08 +0100 Subject: [PATCH 14/16] add coordinate transform tests --- xarray/tests/test_coordinate_transform.py | 218 ++++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 xarray/tests/test_coordinate_transform.py diff --git a/xarray/tests/test_coordinate_transform.py b/xarray/tests/test_coordinate_transform.py new file mode 100644 index 00000000000..26746657dbc --- /dev/null +++ b/xarray/tests/test_coordinate_transform.py @@ -0,0 +1,218 @@ +from collections.abc import Hashable +from typing import Any + +import numpy as np +import pytest + +import xarray as xr +from xarray.core.coordinate_transform import CoordinateTransform +from xarray.core.indexes import CoordinateTransformIndex +from xarray.tests import assert_equal + + +class SimpleCoordinateTransform(CoordinateTransform): + """Simple uniform scale transform in a 2D space (x/y coordinates).""" + + def __init__(self, shape: tuple[int, int], scale: float, dtype: Any = None): + super().__init__(("x", "y"), {"x": shape[1], "y": shape[0]}, dtype=dtype) + + self.scale = scale + + # array dimensions in reverse order (y = rows, x = cols) + self.xy_dims = tuple(self.dims) + self.dims = (self.dims[1], self.dims[0]) + + def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: + assert set(dim_positions) == set(self.dims) + return {dim: dim_positions[dim] * self.scale for dim in self.xy_dims} + + def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: + return {dim: coord_labels[dim] / self.scale for dim in self.xy_dims} + + def equals(self, other: "CoordinateTransform") -> bool: + if not isinstance(other, SimpleCoordinateTransform): + return False + return self.scale == other.scale + + def __repr__(self) -> str: + return f"Scale({self.scale})" + + +def test_abstract_coordinate_transform() -> None: + tr = CoordinateTransform(["x"], {"x": 5}) + + with pytest.raises(NotImplementedError): + tr.forward({"x": [1, 2]}) + + with pytest.raises(NotImplementedError): + tr.reverse({"x": [3.0, 4.0]}) + + with pytest.raises(NotImplementedError): + tr.equals(CoordinateTransform(["x"], {"x": 5})) + + +def test_coordinate_transform_init() -> None: + tr = SimpleCoordinateTransform((4, 4), 2.0) + + assert tr.coord_names == ("x", "y") + # array dimensions in reverse order (y = rows, x = cols) + assert tr.dims == ("y", "x") + assert tr.dim_size == {"x": 4, "y": 4} + assert tr.dtype == np.dtype(np.float64) + + tr2 = SimpleCoordinateTransform((4, 4), 2.0, dtype=np.int64) + assert tr2.dtype == np.dtype(np.int64) + + +@pytest.mark.parametrize("dims", [None, ("y", "x")]) +def test_coordinate_transform_generate_coords(dims) -> None: + tr = SimpleCoordinateTransform((2, 2), 2.0) + + actual = tr.generate_coords(dims) + expected = {"x": [[0.0, 2.0], [0.0, 2.0]], "y": [[0.0, 0.0], [2.0, 2.0]]} + assert set(actual) == set(expected) + np.testing.assert_array_equal(actual["x"], expected["x"]) + np.testing.assert_array_equal(actual["y"], expected["y"]) + + +def create_coords(scale: float, shape: tuple[int, int]) -> xr.Coordinates: + """Create x/y Xarray coordinate variables from a simple coordinate transform.""" + tr = SimpleCoordinateTransform(shape, scale) + index = CoordinateTransformIndex(tr) + return xr.Coordinates.from_xindex(index) + + +def test_coordinate_transform_variable() -> None: + coords = create_coords(scale=2.0, shape=(2, 2)) + + assert coords["x"].dtype == np.dtype(np.float64) + assert coords["y"].dtype == np.dtype(np.float64) + assert coords["x"].shape == (2, 2) + assert coords["y"].shape == (2, 2) + + np.testing.assert_array_equal(np.array(coords["x"]), [[0.0, 2.0], [0.0, 2.0]]) + np.testing.assert_array_equal(np.array(coords["y"]), [[0.0, 0.0], [2.0, 2.0]]) + + def assert_repr(var: xr.Variable): + assert ( + repr(var._data) + == "CoordinateTransformIndexingAdapter(transform=Scale(2.0))" + ) + + assert_repr(coords["x"].variable) + assert_repr(coords["y"].variable) + + +def test_coordinate_transform_variable_repr_inline() -> None: + var = create_coords(scale=2.0, shape=(2, 2))["x"].variable + + actual = var._data._repr_inline_(70) # type: ignore[union-attr] + assert actual == "0.0 2.0 0.0 2.0" + + # truncated inline repr + var2 = create_coords(scale=2.0, shape=(10, 10))["x"].variable + + actual2 = var2._data._repr_inline_(70) # type: ignore[union-attr] + assert ( + actual2 == "0.0 2.0 4.0 6.0 8.0 10.0 12.0 ... 6.0 8.0 10.0 12.0 14.0 16.0 18.0" + ) + + +def test_coordinate_transform_variable_basic_outer_indexing() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + assert var[0, 0] == 0.0 + assert var[0, 1] == 2.0 + assert var[0, -1] == 6.0 + np.testing.assert_array_equal(var[:, 0:2], [[0.0, 2.0]] * 4) + + with pytest.raises(IndexError, match="out of bounds index"): + var[5] + + with pytest.raises(IndexError, match="out of bounds index"): + var[-5] + + +def test_coordinate_transform_variable_vectorized_indexing() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + actual = var[{"x": xr.Variable("z", [0]), "y": xr.Variable("z", [0])}] + expected = xr.Variable("z", [0.0]) + assert_equal(actual, expected) + + with pytest.raises(IndexError, match="out of bounds index"): + var[{"x": xr.Variable("z", [5]), "y": xr.Variable("z", [5])}] + + +def test_coordinate_transform_setitem_error() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + # basic indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[0, 0] = 1.0 + + # outer indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[[0, 2], 0] = [1.0, 2.0] + + # vectorized indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[{"x": xr.Variable("z", [0]), "y": xr.Variable("z", [0])}] = 1.0 + + +def test_coordinate_transform_transpose() -> None: + coords = create_coords(scale=2.0, shape=(2, 2)) + + actual = coords["x"].transpose().values + expected = [[0.0, 0.0], [2.0, 2.0]] + np.testing.assert_array_equal(actual, expected) + + +def test_coordinate_transform_equals() -> None: + ds1 = create_coords(scale=2.0, shape=(2, 2)).to_dataset() + ds2 = create_coords(scale=2.0, shape=(2, 2)).to_dataset() + ds3 = create_coords(scale=4.0, shape=(2, 2)).to_dataset() + + # cannot use `assert_equal()` test utility function here yet + # (indexes invariant check are still based on IndexVariable, which + # doesn't work with coordinate transform index coordinate variables) + assert ds1.equals(ds2) + assert not ds1.equals(ds3) + + +def test_coordinate_transform_sel() -> None: + ds = create_coords(scale=2.0, shape=(4, 4)).to_dataset() + + data = [ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + ] + ds["data"] = (("y", "x"), data) + + actual = ds.sel( + x=xr.Variable("z", [0.5, 5.5]), y=xr.Variable("z", [0.0, 0.5]), method="nearest" + ) + expected = ds.isel(x=xr.Variable("z", [0, 3]), y=xr.Variable("z", [0, 0])) + + # cannot use `assert_equal()` test utility function here yet + # (indexes invariant check are still based on IndexVariable, which + # doesn't work with coordinate transform index coordinate variables) + assert actual.equals(expected) + + with pytest.raises(ValueError, match=".*only supports selection.*nearest"): + ds.sel(x=xr.Variable("z", [0.5, 5.5]), y=xr.Variable("z", [0.0, 0.5])) + + with pytest.raises(ValueError, match="missing labels for coordinate.*y"): + ds.sel(x=[0.5, 5.5], method="nearest") + + with pytest.raises(TypeError, match=".*only supports advanced.*indexing"): + ds.sel(x=[0.5, 5.5], y=[0.0, 0.5], method="nearest") + + with pytest.raises(ValueError, match=".*only supports advanced.*indexing"): + ds.sel( + x=xr.Variable("z", [0.5, 5.5]), + y=xr.Variable("z", [0.0, 0.5, 1.5]), + method="nearest", + ) From 03fdc90404a195dba8a39af5c549d93b0d2363cb Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 14:17:14 +0100 Subject: [PATCH 15/16] typing fixes --- xarray/core/indexing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index f379a932019..521abcdfddd 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -2035,7 +2035,7 @@ def __init__( self, transform: CoordinateTransform, coord_name: Hashable, - dims: tuple[str] | None = None, + dims: tuple[str, ...] | None = None, ): self._transform = transform self._coord_name = coord_name @@ -2102,7 +2102,7 @@ def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: "setting values is not supported on coordinate transform arrays." ) - def transpose(self, order): + def transpose(self, order: Iterable[int]) -> Self: new_dims = tuple([self._dims[i] for i in order]) return type(self)(self._transform, self._coord_name, new_dims) From 406b03b7067604438295ffdf34c8a608ace69669 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 14:18:51 +0100 Subject: [PATCH 16/16] update what's new --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 43edc5ee33e..d9d4998d983 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,6 +28,8 @@ New Features By `Kai Mühlbauer `_. - support python 3.13 (no free-threading) (:issue:`9664`, :pull:`9681`) By `Justus Magin `_. +- Added experimental support for coordinate transforms (not ready for public use yet!) (:pull:`9543`) + By `Benoit Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~