Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP:worker and runner cancel flowruns on CANCEL_NEW concurrency strategy #15440

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/prefect/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def fast_flow():
)
from prefect.states import (
AwaitingConcurrencySlot,
Cancelled,
Crashed,
Pending,
exception_to_failed_state,
Expand Down Expand Up @@ -797,6 +798,20 @@ async def _check_for_cancelled_flow_runs(

return cancelling_flow_runs

async def _cancel_pending(self, flow_run: "FlowRun", deployment: "Deployment"):
run_logger = self._get_flow_run_logger(flow_run)
new_state = Cancelled(message="Flow run was cancelled by concurrency limit.")
try:
await self._client.set_flow_run_state(flow_run.id, new_state)
except Exception:
run_logger.exception("Failed to cancel flow run %s", flow_run.id)
try:
flow = await self._client.read_flow(flow_run.flow_id)
except ObjectNotFound:
flow = None
self._emit_flow_run_cancelled_event(flow_run, flow, deployment)
run_logger.info(f"Cancelled flow run '{flow_run.name}'!")

async def _cancel_run(self, flow_run: "FlowRun", state_msg: Optional[str] = None):
run_logger = self._get_flow_run_logger(flow_run)

Expand Down Expand Up @@ -843,6 +858,7 @@ async def _cancel_run(self, flow_run: "FlowRun", state_msg: Optional[str] = None
flow = await self._client.read_flow(flow_run.flow_id)
except ObjectNotFound:
flow = None

self._emit_flow_run_cancelled_event(
flow_run=flow_run, flow=flow, deployment=deployment
)
Expand Down Expand Up @@ -1074,7 +1090,16 @@ async def _submit_run_and_capture_errors(
flow_run.deployment_id,
flow_run.name,
)
await self._propose_scheduled_state(flow_run)
if deployment.concurrency_options:
if deployment.concurrency_options.collision_strategy == "CANCEL_NEW":
self._cancelling_flow_run_ids.add(flow_run.id)
self._runs_task_group.start_soon(
self._cancel_pending, flow_run, deployment
)
else:
await self._propose_scheduled_state(flow_run)
else:
await self._propose_scheduled_state(flow_run)

if not task_status._future.done():
task_status.started(exc)
Expand Down
1 change: 1 addition & 0 deletions src/prefect/utilities/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ async def set_state_and_handle_waits(set_state_func) -> OrchestrationResult:
set_state = partial(client.set_task_run_state, task_run_id, state, force=force)
response = await set_state_and_handle_waits(set_state)
elif flow_run_id:
print("setting flow run state", state)
set_state = partial(client.set_flow_run_state, flow_run_id, state, force=force)
response = await set_state_and_handle_waits(set_state)
else:
Expand Down
85 changes: 84 additions & 1 deletion src/prefect/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from prefect.states import (
AwaitingConcurrencySlot,
Cancelled,
Crashed,
Pending,
exception_to_failed_state,
Expand Down Expand Up @@ -784,6 +785,7 @@ async def _submit_scheduled_flow_runs(
f"Worker '{self.name}' submitting flow run '{flow_run.id}'"
)
self._submitting_flow_run_ids.add(flow_run.id)

self._runs_task_group.start_soon(
self._submit_run,
flow_run,
Expand Down Expand Up @@ -896,7 +898,15 @@ async def _submit_run_and_capture_errors(
flow_run.deployment_id,
flow_run.name,
)
await self._propose_scheduled_state(flow_run)
if deployment.concurrency_options:
if deployment.concurrency_options.collision_strategy == "CANCEL_NEW":
self._runs_task_group.start_soon(
self._cancel_pending, flow_run, deployment
)
else:
await self._propose_scheduled_state(flow_run)
else:
await self._propose_scheduled_state(flow_run)

if not task_status._future.done():
task_status.started(exc)
Expand Down Expand Up @@ -997,6 +1007,7 @@ async def _get_configuration(
async def _propose_pending_state(self, flow_run: "FlowRun") -> bool:
run_logger = self.get_flow_run_logger(flow_run)
state = flow_run.state

try:
state = await propose_state(
self._client, Pending(), flow_run_id=flow_run.id
Expand Down Expand Up @@ -1078,6 +1089,22 @@ async def _propose_crashed_state(self, flow_run: "FlowRun", message: str) -> Non
f"Reported flow run '{flow_run.id}' as crashed: {message}"
)

async def _cancel_pending(
self, flow_run: "FlowRun", deployment: "DeploymentResponse"
):
run_logger = self.get_flow_run_logger(flow_run)
new_state = Cancelled(message="Flow run was cancelled by concurrency limit.")
try:
await self._client.set_flow_run_state(flow_run.id, new_state)
except Exception:
run_logger.exception("Failed to cancel flow run %s", flow_run.id)
try:
flow = await self._client.read_flow(flow_run.flow_id)
except ObjectNotFound:
flow = None
self._emit_flow_run_cancelled_event(flow_run, flow, deployment)
run_logger.info(f"Cancelled flow run '{flow_run.name}'!")

async def _mark_flow_run_as_cancelled(
self, flow_run: "FlowRun", state_updates: Optional[dict] = None
) -> None:
Expand All @@ -1104,6 +1131,19 @@ async def _set_work_pool_template(self, work_pool, job_template):
),
)

async def _cancel_run(self, flow_run: "FlowRun", state_msg: Optional[str] = None):
run_logger = self.get_flow_run_logger(flow_run)

await self._mark_flow_run_as_cancelled(
flow_run,
state_updates={
"message": state_msg or "Flow run was cancelled successfully."
},
)

self._emit_flow_run_cancelled_event()
run_logger.info(f"Cancelled flow run '{flow_run.name}'!")

async def _schedule_task(self, __in_seconds: int, fn, *args, **kwargs):
"""
Schedule a background task to start after some time.
Expand Down Expand Up @@ -1219,3 +1259,46 @@ async def _emit_worker_stopped_event(self, started_event: Event):
related=self._event_related_resources(),
follows=started_event,
)

def _emit_flow_run_cancelled_event(
self,
flow_run: "FlowRun",
flow: "Optional[Flow]",
deployment: "Optional[DeploymentResponse]",
):
related = []
tags = []
if deployment:
related.append(
{
"prefect.resource.id": f"prefect.deployment.{deployment.id}",
"prefect.resource.role": "deployment",
"prefect.resource.name": deployment.name,
}
)
tags.extend(deployment.tags)
if flow:
related.append(
{
"prefect.resource.id": f"prefect.flow.{flow.id}",
"prefect.resource.role": "flow",
"prefect.resource.name": flow.name,
}
)
related.append(
{
"prefect.resource.id": f"prefect.flow-run.{flow_run.id}",
"prefect.resource.role": "flow-run",
"prefect.resource.name": flow_run.name,
}
)
tags.extend(flow_run.tags)

related = [RelatedResource.model_validate(r) for r in related]
related += tags_as_related_resources(set(tags))

emit_event(
event="prefect.runner.cancelled-flow-run",
resource=self._event_resource(),
related=related,
)
21 changes: 21 additions & 0 deletions testflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from prefect import flow, task
from prefect.deployments.runner import ConcurrencyLimitConfig


@task
def slow_task(time_to_wait: int):
import time

time.sleep(time_to_wait)


@flow
def slow_flow():
# tasks = slow_task.map([300, 360, 180])
# tasks.wait()
slow_task(300)


if __name__ == "__main__":
limit = ConcurrencyLimitConfig(limit=2, collision_strategy="CANCEL_NEW")
slow_flow.serve(name="slow-flow", global_limit=limit)
38 changes: 38 additions & 0 deletions tests/fixtures/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,44 @@ def hello(name: str = "world"):
return deployment


@pytest.fixture
async def worker_deployment_wq1_cl1_cancel_new(
session,
flow,
flow_function,
work_queue_1,
):
def hello(name: str = "world"):
pass

deployment = await models.deployments.create_deployment(
session=session,
deployment=schemas.core.Deployment(
name="My Deployment 1",
tags=["test"],
flow_id=flow.id,
schedules=[
schemas.actions.DeploymentScheduleCreate(
schedule=schemas.schedules.IntervalSchedule(
interval=datetime.timedelta(days=1),
anchor_date=pendulum.datetime(2020, 1, 1),
)
)
],
concurrency_limit=2,
concurrency_options=schemas.core.ConcurrencyOptions(
collision_strategy="CANCEL_NEW"
),
path="./subdir",
entrypoint="/file.py:flow",
parameter_openapi_schema=parameter_schema(hello).model_dump_for_openapi(),
work_queue_id=work_queue_1.id,
),
)
await session.commit()
return deployment


@pytest.fixture
async def worker_deployment_infra_wq1(session, flow, flow_function, work_queue_1):
def hello(name: str = "world"):
Expand Down
47 changes: 47 additions & 0 deletions tests/runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,53 @@ async def test(*args, **kwargs):
flow_run = await prefect_client.read_flow_run(flow_run.id)
assert flow_run.state.name == "AwaitingConcurrencySlot"

@pytest.mark.usefixtures("use_hosted_api_server")
async def test_runner_cancels_new_flow_run_when_collision_strategy_is_cancel_new(
self, prefect_client: PrefectClient, caplog
):
async def test(*args, **kwargs):
return 0

with mock.patch(
"prefect.concurrency.asyncio._acquire_concurrency_slots",
wraps=_acquire_concurrency_slots,
) as acquire_spy:
# Simulate a Locked response from the API
acquire_spy.side_effect = AcquireConcurrencySlotTimeoutError

async with Runner(pause_on_shutdown=False) as runner:
deployment = RunnerDeployment.from_flow(
flow=dummy_flow_1,
name=__file__,
concurrency_limit=ConcurrencyLimitConfig(
limit=2, collision_strategy="CANCEL_NEW"
),
)

deployment_id = await runner.add_deployment(deployment)

flow_run = await prefect_client.create_flow_run_from_deployment(
deployment_id=deployment_id
)

assert flow_run.state.is_scheduled()

runner.run = test # simulate running a flow

await runner._get_and_submit_flow_runs()

acquire_spy.assert_called_once_with(
[f"deployment:{deployment_id}"],
1,
timeout_seconds=None,
create_if_missing=None,
max_retries=0,
strict=True,
)

flow_run = await prefect_client.read_flow_run(flow_run.id)
assert flow_run.state.name == "Cancelled"

@pytest.mark.usefixtures("use_hosted_api_server")
async def test_runner_does_not_attempt_to_acquire_limit_if_deployment_has_no_concurrency_limit(
self, prefect_client: PrefectClient, caplog
Expand Down
35 changes: 35 additions & 0 deletions tests/workers/test_base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,41 @@ async def test(*args, **kwargs):
assert occupy_seconds > 0


async def test_worker_with_deployment_concurrency_options_cancel_new(
prefect_client: PrefectClient, worker_deployment_wq1_cl1_cancel_new, work_pool
):
async def test(*args, **kwargs):
return BaseWorkerResult(status_code=0, identifier="123")

flow_run = await prefect_client.create_flow_run_from_deployment(
worker_deployment_wq1_cl1_cancel_new.id, name="test"
)

with mock.patch(
"prefect.concurrency.asyncio._acquire_concurrency_slots",
wraps=_acquire_concurrency_slots,
) as acquire_spy:
# Simulate a Locked response from the API
acquire_spy.side_effect = AcquireConcurrencySlotTimeoutError

async with WorkerTestImpl(work_pool_name=work_pool.name) as worker:
worker._work_pool = work_pool
worker.run = test # simulate running a flow
await worker.get_and_submit_flow_runs()

acquire_spy.assert_called_once_with(
[f"deployment:{worker_deployment_wq1_cl1_cancel_new.id}"],
1,
timeout_seconds=None,
create_if_missing=None,
max_retries=0,
strict=True,
)

flow_run = await prefect_client.read_flow_run(flow_run.id)
assert flow_run.state.name == "Cancelled"


async def test_worker_with_deployment_concurrency_limit_proposes_awaiting_limit_state_name(
prefect_client: PrefectClient, worker_deployment_wq1_cl1, work_pool
):
Expand Down