-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ MapperAsyncIterDataPipe for applying custom async functions (#9)
* ✨ MapperAsyncIterDataPipe for applying custom async functions An asynchronous iterable-style DataPipe for applying a custom asynchronous function over an asynchronous iterable! Uses asyncio.TaskGroup from Python 3.11+ to run several tasks concurrently. Included a doctest, added a new section in the API docs under 'Mapping DataPipes', and set show_toc_level to 3 to make it show in the right sidebar. * ✅ Add unit test for MapperAsyncIterDataPipe Ensure that tasks are processed concurrently, included a timer to double check that all 3 tasks (made up of 2 sub-tasks) complete in 0.5 seconds instead of 1.5 seconds! * 🥅 Use try-except* to catch task errors in ExceptionGroup To better handle errors from tasks in a TaskGroup, wrap the TaskGroup context manager in a try-except* clause following PEP 654 Exception Groups. Based on the nice examples from https://github.com/jrfk/talk/tree/main/EuroPython2023. Added a unit test to ensure that a ValueError raised in 1 out of 3 tasks can be nicely captured and raised to attention. * 📝 Document the raising of ExceptionGroup when a task errors out Mention PEP0654 so that people know what an ExceptionGroup is, and how it could be handled.
- Loading branch information
Showing
5 changed files
with
188 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
""" | ||
Asynchronous Iterable DataPipes for asynchronous functions. | ||
""" | ||
import asyncio | ||
from collections.abc import AsyncIterator, Callable, Coroutine | ||
from typing import Any | ||
|
||
from bambooflow.datapipes.aiter import AsyncIterDataPipe | ||
|
||
|
||
class MapperAsyncIterDataPipe(AsyncIterDataPipe): | ||
""" | ||
Applies an asynchronous function over each item from the source DataPipe. | ||
Parameters | ||
---------- | ||
datapipe : AsyncIterDataPipe | ||
The source asynchronous iterable-style DataPipe. | ||
fn : Callable | ||
Asynchronous function to be applied over each item. | ||
Yields | ||
------ | ||
awaitable : collections.abc.Awaitable | ||
An :py-term:`awaitable` object from the | ||
:py-term:`asynchronous iterator <asynchronous-iterator>`. | ||
Raises | ||
------ | ||
ExceptionGroup | ||
If any one of the concurrent tasks raises an :py:class:`Exception`. See | ||
`PEP654 <https://peps.python.org/pep-0654/#handling-exception-groups>`_ | ||
for general advice on how to handle exception groups. | ||
Example | ||
------- | ||
>>> import asyncio | ||
>>> from bambooflow.datapipes import AsyncIterableWrapper, Mapper | ||
... | ||
>>> # Apply an asynchronous multiply by two function | ||
>>> async def times_two(x) -> float: | ||
... await asyncio.sleep(delay=x) | ||
... return x * 2 | ||
>>> dp = AsyncIterableWrapper(iterable=[0.1, 0.2, 0.3]) | ||
>>> dp_map = Mapper(datapipe=dp, fn=times_two) | ||
... | ||
>>> # Loop or iterate over the DataPipe stream | ||
>>> it = aiter(dp_map) | ||
>>> number = anext(it) | ||
>>> asyncio.run(number) | ||
0.2 | ||
>>> number = anext(it) | ||
>>> asyncio.run(number) | ||
0.4 | ||
>>> # Or if running in an interactive REPL with top-level `await` support | ||
>>> number = anext(it) | ||
>>> await number # doctest: +SKIP | ||
0.6 | ||
""" | ||
|
||
def __init__( | ||
self, datapipe: AsyncIterDataPipe, fn: Callable[..., Coroutine[Any, Any, Any]] | ||
): | ||
super().__init__() | ||
self._datapipe = datapipe | ||
self._fn = fn | ||
|
||
async def __aiter__(self) -> AsyncIterator: | ||
try: | ||
async with asyncio.TaskGroup() as task_group: | ||
tasks: list[asyncio.Task] = [ | ||
task_group.create_task(coro=self._fn(data)) | ||
async for data in self._datapipe | ||
] | ||
except* BaseException as err: | ||
raise ValueError(f"{err=}") from err | ||
|
||
for task in tasks: | ||
result = await task | ||
yield result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
""" | ||
Tests for callable datapipes. | ||
""" | ||
import asyncio | ||
import re | ||
import time | ||
from collections.abc import Awaitable | ||
|
||
import pytest | ||
|
||
from bambooflow.datapipes import AsyncIterableWrapper, Mapper | ||
|
||
|
||
# %% | ||
@pytest.fixture(scope="function", name="times_two") | ||
def fixture_times_two(): | ||
async def times_two(x) -> int: | ||
await asyncio.sleep(0.2) | ||
print(f"Multiplying {x} by 2") | ||
result = x * 2 | ||
return result | ||
|
||
return times_two | ||
|
||
|
||
@pytest.fixture(scope="function", name="times_three") | ||
def fixture_times_three(): | ||
async def times_three(x) -> int: | ||
await asyncio.sleep(0.3) | ||
print(f"Multiplying {x} by 3") | ||
result = x * 3 | ||
return result | ||
|
||
return times_three | ||
|
||
|
||
@pytest.fixture(scope="function", name="error_four") | ||
def fixture_error_four(): | ||
async def error_four(x): | ||
await asyncio.sleep(0.1) | ||
if x == 4: | ||
raise ValueError(f"Some problem with {x}") | ||
|
||
return error_four | ||
|
||
|
||
# %% | ||
async def test_mapper_concurrency(times_two, times_three): | ||
""" | ||
Ensure that MapperAsyncIterDataPipe works to process tasks concurrently, | ||
such that three tasks taking 3*(0.2+0.3)=1.5 seconds in serial can be | ||
completed in just (0.2+0.3)=0.5 seconds instead. | ||
""" | ||
dp = AsyncIterableWrapper(iterable=[0, 1, 2]) | ||
dp_map2 = Mapper(datapipe=dp, fn=times_two) | ||
dp_map3 = Mapper(datapipe=dp_map2, fn=times_three) | ||
|
||
i = 0 | ||
tic = time.perf_counter() | ||
async for num in dp_map3: | ||
# print("Number:", num) | ||
assert num == i * 2 * 3 | ||
toc = time.perf_counter() | ||
i += 1 | ||
# print(f"Ran in {toc - tic:0.4f} seconds") | ||
print(f"Total: {toc - tic:0.4f} seconds") | ||
|
||
assert toc - tic < 0.55 # Total time should be about 0.5 seconds | ||
assert num == 12 # 2*2*3=12 | ||
|
||
|
||
async def test_mapper_exception_handling(error_four): | ||
""" | ||
Ensure that MapperAsyncIterDataPipe can capture exceptions when one of the | ||
tasks raises an error. | ||
""" | ||
dp = AsyncIterableWrapper(iterable=[3, 4, 5]) | ||
dp_map = Mapper(datapipe=dp, fn=error_four) | ||
|
||
it = aiter(dp_map) | ||
number = anext(it) | ||
# Checek that an ExceptionGroup is already raised on first access | ||
with pytest.raises( | ||
ValueError, | ||
match=re.escape( | ||
"err=ExceptionGroup('unhandled errors in a TaskGroup', [ValueError('Some problem with 4')])" | ||
), | ||
): | ||
await number | ||
|
||
# Subsequent access to iterator should raise StopAsyncIteration | ||
number = anext(it) | ||
with pytest.raises(StopAsyncIteration): | ||
await number |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters