Skip to content

Commit

Permalink
make task worker instantiable in sync context
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Aug 6, 2024
1 parent aea46a8 commit bd205c5
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions src/prefect/task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from prefect import Task
from prefect._internal.concurrency.api import create_call, from_sync
from prefect.cache_policies import DEFAULT, NONE
from prefect.client.orchestration import get_client
from prefect.client.orchestration import PrefectClient, get_client
from prefect.client.schemas.objects import TaskRun
from prefect.client.subscriptions import Subscription
from prefect.exceptions import Abort, PrefectHTTPStatusError
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
*tasks: Task,
limit: Optional[int] = 10,
):
self.tasks = []
self.tasks: list[Task] = []
for t in tasks:
if isinstance(t, Task):
if t.cache_policy in [None, NONE, NotSet]:
Expand All @@ -91,22 +91,21 @@ def __init__(
else:
self.tasks.append(t.with_options(persist_result=True))

self.task_keys = set(t.task_key for t in tasks if isinstance(t, Task))
self.task_keys: set[str] = set(t.task_key for t in tasks if isinstance(t, Task))
self.limit: Optional[int] = limit

self._started_at: Optional[pendulum.DateTime] = None
self.stopping: bool = False

self._client = get_client()
self._exit_stack = AsyncExitStack()
self._client: PrefectClient = get_client()
self._exit_stack: AsyncExitStack = AsyncExitStack()

if not asyncio.get_event_loop().is_running():
raise RuntimeError(
"TaskWorker must be initialized within an async context."
)
self._executor: ThreadPoolExecutor = ThreadPoolExecutor(
max_workers=limit if limit else None
)

self._runs_task_group: anyio.abc.TaskGroup = anyio.create_task_group()
self._executor = ThreadPoolExecutor(max_workers=limit if limit else None)
self._limiter = anyio.CapacityLimiter(limit) if limit else None
self._limiter: Optional[anyio.CapacityLimiter] = None
self._runs_task_group: Optional[anyio.abc.TaskGroup] = None

self.in_flight_task_runs: dict[str, dict[UUID, pendulum.DateTime]] = {
task_key: {} for task_key in self.task_keys
Expand All @@ -127,10 +126,6 @@ def started_at(self) -> Optional[pendulum.DateTime]:
def started(self) -> bool:
return self._started_at is not None

@property
def limit(self) -> Optional[int]:
return int(self._limiter.total_tokens) if self._limiter else None

@property
def current_tasks(self) -> Optional[int]:
return (
Expand Down Expand Up @@ -380,6 +375,12 @@ async def __aenter__(self):
if self._client._closed:
self._client = get_client()

if self._runs_task_group is None:
self._runs_task_group = anyio.create_task_group()

if self._limiter is None and self.limit:
self._limiter = anyio.CapacityLimiter(self.limit)

await self._exit_stack.enter_async_context(self._client)
await self._exit_stack.enter_async_context(self._runs_task_group)
self._exit_stack.enter_context(self._executor)
Expand Down

0 comments on commit bd205c5

Please sign in to comment.