Skip to content

Commit

Permalink
Merge pull request #155 from tobymao/vchan/fix-queue-abort
Browse files Browse the repository at this point in the history
Fix: Handle aborting a job before it has started processing in Postgres
  • Loading branch information
tobymao authored Sep 16, 2024
2 parents 6276317 + 02e1ff3 commit 9a59770
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 24 deletions.
3 changes: 2 additions & 1 deletion saq/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ async def finish(
*,
result: t.Any = None,
error: str | None = None,
**kwargs: t.Any,
) -> None:
job.status = status
job.result = result
Expand All @@ -237,7 +238,7 @@ async def finish(
if status == Status.COMPLETE:
job.progress = 1.0

await self._finish(job=job, status=status, result=result, error=error)
await self._finish(job=job, status=status, result=result, error=error, **kwargs)
logger.info("Finished %s", job.info(logger.isEnabledFor(logging.DEBUG)))

if status == Status.COMPLETE:
Expand Down
44 changes: 41 additions & 3 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,12 @@ async def iter_jobs(
break

async def abort(self, job: Job, error: str, ttl: float = 5) -> None:
job.error = error
await self.update(job, status=Status.ABORTING)
async with self.pool.connection() as conn:
status = await self.get_job_status(job.key, for_update=True, connection=conn)
if status == Status.QUEUED:
await self.finish(job, Status.ABORTED, error=error, connection=conn)
else:
await self.update(job, status=Status.ABORTING, error=error, connection=conn)

async def dequeue(self, timeout: float = 0) -> Job | None:
"""Wait on `self.cond` to dequeue.
Expand Down Expand Up @@ -580,6 +584,37 @@ async def listen_for_enqueues(self, timeout: float | None = None) -> None:
async with self.cond:
self.cond.notify(1)

async def get_job_status(
self,
key: str,
for_update: bool = False,
connection: AsyncConnection | None = None,
) -> Status:
async with self.nullcontext(
connection
) if connection else self.pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute(
SQL(
dedent(
"""
SELECT status
FROM {jobs_table}
WHERE key = %(key)s
{for_update}
"""
)
).format(
jobs_table=self.jobs_table,
for_update=SQL("FOR UPDATE" if for_update else ""),
),
{
"key": key,
},
)
result = await cursor.fetchone()
assert result
return result[0]

async def _retry(self, job: Job, error: str | None) -> None:
next_retry_delay = job.next_retry_delay()
if next_retry_delay:
Expand All @@ -597,10 +632,13 @@ async def _finish(
*,
result: t.Any = None,
error: str | None = None,
connection: AsyncConnection | None = None,
) -> None:
key = job.key

async with self.pool.connection() as conn, conn.cursor() as cursor:
async with self.nullcontext(
connection
) if connection else self.pool.connection() as conn, conn.cursor() as cursor:
if job.ttl >= 0:
expire_at = seconds(now()) + job.ttl if job.ttl > 0 else None
await self.update(job, status=status, expire_at=expire_at, connection=conn)
Expand Down
24 changes: 4 additions & 20 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,6 @@ async def test_job_key(self) -> None:
async def test_schedule(self, mock_time: MagicMock) -> None:
pass

async def test_batch(self) -> None:
with contextlib.suppress(ValueError):
async with self.queue.batch():
job = await self.enqueue("echo", a=1)
raise ValueError()

self.assertEqual(job.status, Status.ABORTING)

async def test_enqueue_dup(self) -> None:
job = await self.enqueue("test", key="1")
self.assertEqual(job.id, "1")
Expand All @@ -482,6 +474,9 @@ async def test_abort(self) -> None:
await self.queue.abort(job, "test")
self.assertEqual(await self.count("queued"), 0)
self.assertEqual(await self.count("incomplete"), 0)
await job.refresh()
self.assertEqual(job.status, Status.ABORTED)
self.assertEqual(await self.queue.get_job_status(job.key), Status.ABORTED)

job = await self.enqueue("test", retries=2)
await self.dequeue()
Expand All @@ -492,18 +487,7 @@ async def test_abort(self) -> None:
self.assertEqual(await self.count("queued"), 0)
self.assertEqual(await self.count("incomplete"), 0)
self.assertEqual(await self.count("active"), 0)
async with self.queue.pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute(
SQL(
"""
SELECT status
FROM {}
WHERE key = %s
"""
).format(self.queue.jobs_table),
(job.key,),
)
self.assertEqual(await cursor.fetchone(), (Status.ABORTING,))
self.assertEqual(await self.queue.get_job_status(job.key), Status.ABORTING)

@mock.patch("saq.utils.time")
async def test_sweep(self, mock_time: MagicMock) -> None:
Expand Down

0 comments on commit 9a59770

Please sign in to comment.