Skip to content

Commit 52f5cf1

Browse files
authored
Automatically chunk other in GroupBy binary ops. (#7684)
* Automatically chunk `other` in GroupBy binary ops. Closes #7683 * Update xarray/core/groupby.py * Add test * Update xarray/core/groupby.py
1 parent db12b0d commit 52f5cf1

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

xarray/core/groupby.py

+15
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,21 @@ def _binary_op(self, other, f, reflexive=False):
888888
group = group.where(~mask, drop=True)
889889
codes = codes.where(~mask, drop=True).astype(int)
890890

891+
# if other is dask-backed, that's a hint that the
892+
# "expanded" dataset is too big to hold in memory.
893+
# this can be the case when `other` was read from disk
894+
# and contains our lazy indexing classes
895+
# We need to check for dask-backed Datasets
896+
# so utils.is_duck_dask_array does not work for this check
897+
if obj.chunks and not other.chunks:
898+
# TODO: What about datasets with some dask vars, and others not?
899+
# This handles dims other than `name``
900+
chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims}
901+
# a chunk size of 1 seems reasonable since we expect individual elements of
902+
# other to be repeated multiple times across the reduced dimension(s)
903+
chunks[name] = 1
904+
other = other.chunk(chunks)
905+
891906
# codes are defined for coord, so we align `other` with `coord`
892907
# before indexing
893908
other, _ = align(other, coord, join="right", copy=False)

xarray/tests/test_groupby.py

+15
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from xarray import DataArray, Dataset, Variable
1313
from xarray.core.groupby import _consolidate_slices
1414
from xarray.tests import (
15+
InaccessibleArray,
1516
assert_allclose,
1617
assert_array_equal,
1718
assert_equal,
@@ -2392,3 +2393,17 @@ def test_min_count_error(use_flox: bool) -> None:
23922393
with xr.set_options(use_flox=use_flox):
23932394
with pytest.raises(TypeError):
23942395
da.groupby("labels").mean(min_count=1)
2396+
2397+
2398+
@requires_dask
2399+
def test_groupby_math_auto_chunk():
2400+
da = xr.DataArray(
2401+
[[1, 2, 3], [1, 2, 3], [1, 2, 3]],
2402+
dims=("y", "x"),
2403+
coords={"label": ("x", [2, 2, 1])},
2404+
)
2405+
sub = xr.DataArray(
2406+
InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]}
2407+
)
2408+
actual = da.chunk(x=1, y=2).groupby("label") - sub
2409+
assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)}

0 commit comments

Comments
 (0)