Skip to content

Commit

Permalink
✨ MapperAsyncIterDataPipe for applying custom async functions (#9)
Browse files Browse the repository at this point in the history
* ✨ MapperAsyncIterDataPipe for applying custom async functions

An asynchronous iterable-style DataPipe for applying a custom asynchronous function over an asynchronous iterable! Uses asyncio.TaskGroup from Python 3.11+ to run several tasks concurrently. Included a doctest, added a new section in the API docs under 'Mapping DataPipes', and set show_toc_level to 3 to make it show in the right sidebar.

* ✅ Add unit test for MapperAsyncIterDataPipe

Ensure that tasks are processed concurrently, included a timer to double check that all 3 tasks (made up of 2 sub-tasks) complete in 0.5 seconds instead of 1.5 seconds!

* 🥅 Use try-except* to catch task errors in ExceptionGroup

To better handle errors from tasks in a TaskGroup, wrap the TaskGroup context manager in a try-except* clause following PEP 654 Exception Groups. Based on the nice examples from https://github.com/jrfk/talk/tree/main/EuroPython2023. Added a unit test to ensure that a ValueError raised in 1 out of 3 tasks can be nicely captured and raised to attention.

* 📝 Document the raising of ExceptionGroup when a task errors out

Mention PEP0654 so that people know what an ExceptionGroup is, and how it could be handled.
  • Loading branch information
weiji14 authored Aug 3, 2023
1 parent 244c05e commit e17af34
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 0 deletions.
1 change: 1 addition & 0 deletions bambooflow/datapipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
AsyncIterDataPipe,
AsyncIterableWrapperAsyncIterDataPipe as AsyncIterableWrapper,
)
from bambooflow.datapipes.callable import MapperAsyncIterDataPipe as Mapper
80 changes: 80 additions & 0 deletions bambooflow/datapipes/callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Asynchronous Iterable DataPipes for asynchronous functions.
"""
import asyncio
from collections.abc import AsyncIterator, Callable, Coroutine
from typing import Any

from bambooflow.datapipes.aiter import AsyncIterDataPipe


class MapperAsyncIterDataPipe(AsyncIterDataPipe):
"""
Applies an asynchronous function over each item from the source DataPipe.
Parameters
----------
datapipe : AsyncIterDataPipe
The source asynchronous iterable-style DataPipe.
fn : Callable
Asynchronous function to be applied over each item.
Yields
------
awaitable : collections.abc.Awaitable
An :py-term:`awaitable` object from the
:py-term:`asynchronous iterator <asynchronous-iterator>`.
Raises
------
ExceptionGroup
If any one of the concurrent tasks raises an :py:class:`Exception`. See
`PEP654 <https://peps.python.org/pep-0654/#handling-exception-groups>`_
for general advice on how to handle exception groups.
Example
-------
>>> import asyncio
>>> from bambooflow.datapipes import AsyncIterableWrapper, Mapper
...
>>> # Apply an asynchronous multiply by two function
>>> async def times_two(x) -> float:
... await asyncio.sleep(delay=x)
... return x * 2
>>> dp = AsyncIterableWrapper(iterable=[0.1, 0.2, 0.3])
>>> dp_map = Mapper(datapipe=dp, fn=times_two)
...
>>> # Loop or iterate over the DataPipe stream
>>> it = aiter(dp_map)
>>> number = anext(it)
>>> asyncio.run(number)
0.2
>>> number = anext(it)
>>> asyncio.run(number)
0.4
>>> # Or if running in an interactive REPL with top-level `await` support
>>> number = anext(it)
>>> await number # doctest: +SKIP
0.6
"""

def __init__(
self, datapipe: AsyncIterDataPipe, fn: Callable[..., Coroutine[Any, Any, Any]]
):
super().__init__()
self._datapipe = datapipe
self._fn = fn

async def __aiter__(self) -> AsyncIterator:
try:
async with asyncio.TaskGroup() as task_group:
tasks: list[asyncio.Task] = [
task_group.create_task(coro=self._fn(data))
async for data in self._datapipe
]
except* BaseException as err:
raise ValueError(f"{err=}") from err

for task in tasks:
result = await task
yield result
94 changes: 94 additions & 0 deletions bambooflow/tests/test_datapipes_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Tests for callable datapipes.
"""
import asyncio
import re
import time
from collections.abc import Awaitable

import pytest

from bambooflow.datapipes import AsyncIterableWrapper, Mapper


# %%
@pytest.fixture(scope="function", name="times_two")
def fixture_times_two():
async def times_two(x) -> int:
await asyncio.sleep(0.2)
print(f"Multiplying {x} by 2")
result = x * 2
return result

return times_two


@pytest.fixture(scope="function", name="times_three")
def fixture_times_three():
async def times_three(x) -> int:
await asyncio.sleep(0.3)
print(f"Multiplying {x} by 3")
result = x * 3
return result

return times_three


@pytest.fixture(scope="function", name="error_four")
def fixture_error_four():
async def error_four(x):
await asyncio.sleep(0.1)
if x == 4:
raise ValueError(f"Some problem with {x}")

return error_four


# %%
async def test_mapper_concurrency(times_two, times_three):
"""
Ensure that MapperAsyncIterDataPipe works to process tasks concurrently,
such that three tasks taking 3*(0.2+0.3)=1.5 seconds in serial can be
completed in just (0.2+0.3)=0.5 seconds instead.
"""
dp = AsyncIterableWrapper(iterable=[0, 1, 2])
dp_map2 = Mapper(datapipe=dp, fn=times_two)
dp_map3 = Mapper(datapipe=dp_map2, fn=times_three)

i = 0
tic = time.perf_counter()
async for num in dp_map3:
# print("Number:", num)
assert num == i * 2 * 3
toc = time.perf_counter()
i += 1
# print(f"Ran in {toc - tic:0.4f} seconds")
print(f"Total: {toc - tic:0.4f} seconds")

assert toc - tic < 0.55 # Total time should be about 0.5 seconds
assert num == 12 # 2*2*3=12


async def test_mapper_exception_handling(error_four):
"""
Ensure that MapperAsyncIterDataPipe can capture exceptions when one of the
tasks raises an error.
"""
dp = AsyncIterableWrapper(iterable=[3, 4, 5])
dp_map = Mapper(datapipe=dp, fn=error_four)

it = aiter(dp_map)
number = anext(it)
# Checek that an ExceptionGroup is already raised on first access
with pytest.raises(
ValueError,
match=re.escape(
"err=ExceptionGroup('unhandled errors in a TaskGroup', [ValueError('Some problem with 4')])"
),
):
await number

# Subsequent access to iterator should raise StopAsyncIteration
number = anext(it)
with pytest.raises(StopAsyncIteration):
await number
3 changes: 3 additions & 0 deletions docs/_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ sphinx:
config:
myst_all_links_external: true
html_show_copyright: false
html_theme_options:
# https://sphinx-book-theme.readthedocs.io/en/stable/customize/sidebar-secondary.html
show_toc_level: 3
extlinks:
py-term:
- 'https://docs.python.org/3/glossary.html#term-%s'
Expand Down
10 changes: 10 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@
.. autoclass:: bambooflow.datapipes.aiter.AsyncIterableWrapperAsyncIterDataPipe
:show-inheritance:
```

### Mapping DataPipes

Datapipes which apply a custom asynchronous function to elements in a DataPipe.

```{eval-rst}
.. autoclass:: bambooflow.datapipes.Mapper
.. autoclass:: bambooflow.datapipes.callable.MapperAsyncIterDataPipe
:show-inheritance:
```

0 comments on commit e17af34

Please sign in to comment.