Skip to content

Commit

Permalink
misc. fixes for Indexes with pd.Index objects (pydata#7003)
Browse files Browse the repository at this point in the history
  • Loading branch information
benbovy authored Sep 23, 2022
1 parent 1f4be33 commit 9d1499e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
22 changes: 19 additions & 3 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,12 +1092,13 @@ def get_unique(self) -> list[T_PandasOrXarrayIndex]:
"""Return a list of unique indexes, preserving order."""

unique_indexes: list[T_PandasOrXarrayIndex] = []
seen: set[T_PandasOrXarrayIndex] = set()
seen: set[int] = set()

for index in self._indexes.values():
if index not in seen:
index_id = id(index)
if index_id not in seen:
unique_indexes.append(index)
seen.add(index)
seen.add(index_id)

return unique_indexes

Expand Down Expand Up @@ -1201,9 +1202,24 @@ def copy_indexes(
"""
new_indexes = {}
new_index_vars = {}

for idx, coords in self.group_by_index():
if isinstance(idx, pd.Index):
convert_new_idx = True
dim = next(iter(coords.values())).dims[0]
if isinstance(idx, pd.MultiIndex):
idx = PandasMultiIndex(idx, dim)
else:
idx = PandasIndex(idx, dim)
else:
convert_new_idx = False

new_idx = idx.copy(deep=deep)
idx_vars = idx.create_variables(coords)

if convert_new_idx:
new_idx = cast(PandasIndex, new_idx).index

new_indexes.update({k: new_idx for k in coords})
new_index_vars.update(idx_vars)

Expand Down
31 changes: 25 additions & 6 deletions xarray/tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import xarray as xr
from xarray.core.indexes import (
Hashable,
Index,
Indexes,
PandasIndex,
Expand Down Expand Up @@ -535,18 +536,37 @@ def test_copy(self) -> None:

class TestIndexes:
@pytest.fixture
def unique_indexes(self) -> list[PandasIndex]:
def indexes_and_vars(self) -> tuple[list[PandasIndex], dict[Hashable, Variable]]:
x_idx = PandasIndex(pd.Index([1, 2, 3], name="x"), "x")
y_idx = PandasIndex(pd.Index([4, 5, 6], name="y"), "y")
z_pd_midx = pd.MultiIndex.from_product(
[["a", "b"], [1, 2]], names=["one", "two"]
)
z_midx = PandasMultiIndex(z_pd_midx, "z")

return [x_idx, y_idx, z_midx]
indexes = [x_idx, y_idx, z_midx]

variables = {}
for idx in indexes:
variables.update(idx.create_variables())

return indexes, variables

@pytest.fixture(params=["pd_index", "xr_index"])
def unique_indexes(
self, request, indexes_and_vars
) -> list[PandasIndex] | list[pd.Index]:
xr_indexes, _ = indexes_and_vars

if request.param == "pd_index":
return [idx.index for idx in xr_indexes]
else:
return xr_indexes

@pytest.fixture
def indexes(self, unique_indexes) -> Indexes[Index]:
def indexes(
self, unique_indexes, indexes_and_vars
) -> Indexes[Index] | Indexes[pd.Index]:
x_idx, y_idx, z_midx = unique_indexes
indexes: dict[Any, Index] = {
"x": x_idx,
Expand All @@ -555,9 +575,8 @@ def indexes(self, unique_indexes) -> Indexes[Index]:
"one": z_midx,
"two": z_midx,
}
variables: dict[Any, Variable] = {}
for idx in unique_indexes:
variables.update(idx.create_variables())

_, variables = indexes_and_vars

return Indexes(indexes, variables)

Expand Down

0 comments on commit 9d1499e

Please sign in to comment.