Skip to content

Commit

Permalink
Allow multiple dims to be passed with min_count (pydata#4356)
Browse files Browse the repository at this point in the history
* Allow multiple dims to be passed with min_count

* Add whatsnew
  • Loading branch information
max-sixty authored Aug 20, 2020
1 parent efabe74 commit 43a2a4b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ New Features
- Support multiple outputs in :py:func:`xarray.apply_ufunc` when using ``dask='parallelized'``. (:issue:`1815`, :pull:`4060`)
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling`
now accept more than 1 dimension.(:pull:`4219`)
now accept more than 1 dimension. (:pull:`4219`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- ``min_count`` can be supplied to reductions such as ``.sum`` when specifying
multiple dimension to reduce over. (:pull:`4356`)
By `Maximilian Roos <https://github.com/max-sixty>`_.
- Build ``CFTimeIndex.__repr__`` explicitly as :py:class:`pandas.Index`. Add ``calendar`` as a new
property for :py:class:`CFTimeIndex` and show ``calendar`` and ``length`` in
``CFTimeIndex.__repr__`` (:issue:`2416`, :pull:`4092`)
Expand Down
6 changes: 1 addition & 5 deletions xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,9 @@ def _maybe_null_out(result, axis, mask, min_count=1):
"""
xarray version of pandas.core.nanops._maybe_null_out
"""
if hasattr(axis, "__len__"): # if tuple or list
raise ValueError(
"min_count is not available for reduction with more than one dimensions."
)

if axis is not None and getattr(result, "ndim", False):
null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0
null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0
if null_mask.any():
dtype, fill_value = dtypes.maybe_promote(result.dtype)
result = result.astype(dtype)
Expand Down
25 changes: 22 additions & 3 deletions xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,24 @@ def test_min_count(dim_num, dtype, dask, func, aggdim):
assert_dask_array(actual, dask)


@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_])
@pytest.mark.parametrize("dask", [False, True])
@pytest.mark.parametrize("func", ["sum", "prod"])
def test_min_count_nd(dtype, dask, func):
if dask and not has_dask:
pytest.skip("requires dask")

min_count = 3
dim_num = 3
da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask)
actual = getattr(da, func)(dim=["x", "y", "z"], skipna=True, min_count=min_count)
# Supplying all dims is equivalent to supplying `...` or `None`
expected = getattr(da, func)(dim=..., skipna=True, min_count=min_count)

assert_allclose(actual, expected)
assert_dask_array(actual, dask)


@pytest.mark.parametrize("func", ["sum", "prod"])
def test_min_count_dataset(func):
da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False)
Expand All @@ -606,14 +624,15 @@ def test_min_count_dataset(func):

@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_])
@pytest.mark.parametrize("dask", [False, True])
@pytest.mark.parametrize("skipna", [False, True])
@pytest.mark.parametrize("func", ["sum", "prod"])
def test_multiple_dims(dtype, dask, func):
def test_multiple_dims(dtype, dask, skipna, func):
if dask and not has_dask:
pytest.skip("requires dask")
da = construct_dataarray(3, dtype, contains_nan=True, dask=dask)

actual = getattr(da, func)(("x", "y"))
expected = getattr(getattr(da, func)("x"), func)("y")
actual = getattr(da, func)(("x", "y"), skipna=skipna)
expected = getattr(getattr(da, func)("x", skipna=skipna), func)("y", skipna=skipna)
assert_allclose(actual, expected)


Expand Down

0 comments on commit 43a2a4b

Please sign in to comment.