Skip to content

Commit

Permalink
Fix DataArray.to_dataframe when the array has MultiIndex (pydata#4442)
Browse files Browse the repository at this point in the history
Co-authored-by: Keewis <[email protected]>
  • Loading branch information
ghislainp and keewis authored Feb 20, 2021
1 parent c4ad6f1 commit eb7e112
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 2 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ Bug fixes
a float64 array (:issue:`4898`, :pull:`4911`). By `Blair Bonnett <https://github.com/bcbnz>`_.
- Fix decoding of vlen strings using h5py versions greater than 3.0.0 with h5netcdf backend (:issue:`4570`, :pull:`4893`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Allow converting :py:class:`Dataset` or :py:class:`DataArray` objects with a ``MultiIndex``
and at least one other dimension to a ``pandas`` object (:issue:`3008`, :pull:`4442`).
By `ghislainp <https://github.com/ghislainp>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
46 changes: 44 additions & 2 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
cast,
)

import numpy as np
import pandas as pd

from . import formatting, indexing
Expand Down Expand Up @@ -107,8 +108,49 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
return self._data.get_index(dim) # type: ignore
else:
indexes = [self._data.get_index(k) for k in ordered_dims] # type: ignore
names = list(ordered_dims)
return pd.MultiIndex.from_product(indexes, names=names)

# compute the sizes of the repeat and tile for the cartesian product
# (taken from pandas.core.reshape.util)
index_lengths = np.fromiter(
(len(index) for index in indexes), dtype=np.intp
)
cumprod_lengths = np.cumproduct(index_lengths)

if cumprod_lengths[-1] != 0:
# sizes of the repeats
repeat_counts = cumprod_lengths[-1] / cumprod_lengths
else:
# if any factor is empty, the cartesian product is empty
repeat_counts = np.zeros_like(cumprod_lengths)

# sizes of the tiles
tile_counts = np.roll(cumprod_lengths, 1)
tile_counts[0] = 1

# loop over the indexes
# for each MultiIndex or Index compute the cartesian product of the codes

code_list = []
level_list = []
names = []

for i, index in enumerate(indexes):
if isinstance(index, pd.MultiIndex):
codes, levels = index.codes, index.levels
else:
code, level = pd.factorize(index)
codes = [code]
levels = [level]

# compute the cartesian product
code_list += [
np.tile(np.repeat(code, repeat_counts[i]), tile_counts[i])
for code in codes
]
level_list += levels
names += index.names

return pd.MultiIndex(level_list, code_list, names=names)

def update(self, other: Mapping[Hashable, Any]) -> None:
other_vars = getattr(other, "variables", other)
Expand Down
27 changes: 27 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3635,6 +3635,33 @@ def test_to_dataframe(self):
with raises_regex(ValueError, "unnamed"):
arr.to_dataframe()

def test_to_dataframe_multiindex(self):
# regression test for #3008
arr_np = np.random.randn(4, 3)

mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"])

arr = DataArray(arr_np, [("MI", mindex), ("C", [5, 6, 7])], name="foo")

actual = arr.to_dataframe()
assert_array_equal(actual["foo"].values, arr_np.flatten())
assert_array_equal(actual.index.names, list("ABC"))
assert_array_equal(actual.index.levels[0], [1, 2])
assert_array_equal(actual.index.levels[1], ["a", "b"])
assert_array_equal(actual.index.levels[2], [5, 6, 7])

def test_to_dataframe_0length(self):
# regression test for #3008
arr_np = np.random.randn(4, 0)

mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"])

arr = DataArray(arr_np, [("MI", mindex), ("C", [])], name="foo")

actual = arr.to_dataframe()
assert len(actual) == 0
assert_array_equal(actual.index.names, list("ABC"))

def test_to_pandas_name_matches_coordinate(self):
# coordinate with same name as array
arr = DataArray([1, 2, 3], dims="x", name="x")
Expand Down

0 comments on commit eb7e112

Please sign in to comment.