Skip to content

Commit

Permalink
fix!: update defaults to not touching status
Browse files Browse the repository at this point in the history
prior to this commit, job.update() would also override status, this is
problematic when using update() for heartbeats since it would override
an abort. now, a job status only gets updated if status is passed in
directly.
  • Loading branch information
tobymao committed Feb 12, 2025
1 parent b867237 commit e7be16e
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 74 deletions.
6 changes: 3 additions & 3 deletions saq/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,10 @@ async def update(self, **kwargs: t.Any) -> None:
Updates the stored job in redis.
Set properties with passed in kwargs.
If status is not explicitly passed in, the status will not update as this is usually controlled by the workers.
"""
for k, v in kwargs.items():
setattr(self, k, v)
await self.get_queue().update(self)
await self.get_queue().update(self, **kwargs)

async def refresh(self, until_complete: float | None = None) -> None:
"""
Expand Down
9 changes: 8 additions & 1 deletion saq/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,15 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
async def notify(self, job: Job) -> None:
pass

async def update(self, job: Job, **kwargs: t.Any) -> None:
job.touched = now()
for k, v in kwargs.items():
if hasattr(job, k):
setattr(job, k, v)
await self._update(job, **kwargs)

@abstractmethod
async def update(self, job: Job) -> None:
async def _update(self, job: Job, status: Status | None = None, **kwargs: t.Any) -> None:
pass

@abstractmethod
Expand Down
58 changes: 16 additions & 42 deletions saq/queue/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,23 @@ def serialize(job: t.Optional[Job]) -> str | None:

async def process(self, body: str) -> str | None:
req = json.loads(body)
kind = req["kind"]
job = self.queue.deserialize(req.get("job"))
kind = req.pop("kind")
job = self.queue.deserialize(req.pop("job", None))

if job:
if kind == "enqueue":
return self.serialize(await self.queue.enqueue(job))
if kind == "update":
await self.queue.update(job)
await self.queue._update(job, **req)
return None
if kind == "finish":
await self.queue.finish(
job, status=req["status"], result=req["result"], error=req["error"]
)
await self.queue.finish(job, **req)
return None
if kind == "retry":
await self.queue.retry(job, error=req["error"])
await self.queue.retry(job, **req)
return None
if kind == "abort":
await self.queue.abort(job, error=req["error"], ttl=req["ttl"])
await self.queue.abort(job, **req)
return None
if kind == "finish_abort":
await self.queue.finish_abort(job)
Expand All @@ -69,48 +67,26 @@ async def process(self, body: str) -> str | None:
return None
else:
if kind == "dequeue":
return self.serialize(await self.queue.dequeue(req["timeout"]))
return self.serialize(await self.queue.dequeue(**req))
if kind == "job":
return self.serialize(await self.queue.job(req["job_key"]))
return self.serialize(await self.queue.job(**req))
if kind == "jobs":
return json.dumps(
[
job.to_dict() if job else None
for job in await self.queue.jobs(req["job_keys"])
]
[job.to_dict() if job else None for job in await self.queue.jobs(**req)]
)
if kind == "iter_jobs":
return json.dumps(
[
job.to_dict()
async for job in self.queue.iter_jobs(
statuses=req["statuses"], batch_size=req["batch_size"]
)
]
)
return json.dumps([job.to_dict() async for job in self.queue.iter_jobs(**req)])

if kind == "count":
return json.dumps(await self.queue.count(req["count_kind"]))
if kind == "schedule":
return json.dumps(await self.queue.schedule(req["lock"]))
if kind == "sweep":
return json.dumps(await self.queue.sweep(lock=req["lock"], abort=req["abort"]))
return json.dumps(await self.queue.sweep(**req))
if kind == "info":
return json.dumps(
await self.queue.info(
jobs=req["jobs"], offset=req["offset"], limit=req["limit"]
)
)
return json.dumps(await self.queue.info(**req))
if kind == "write_worker_info":
await self.queue.write_worker_info(
worker_id=req["worker_id"],
info={
"stats": req["stats"],
"queue_key": req["queue_key"],
"metadata": req["metadata"],
},
ttl=req["ttl"],
)
await self.queue.write_worker_info(**req)
return None
raise ValueError(f"Invalid request {body}")

Expand Down Expand Up @@ -179,8 +155,8 @@ async def _retry(self, job: Job, error: str | None) -> None:
async def notify(self, job: Job) -> None:
await self._send("notify", job=self.serialize(job))

async def update(self, job: Job) -> None:
await self._send("update", job=self.serialize(job))
async def _update(self, job: Job, status: Status | None = None, **kwargs: t.Any) -> None:
await self._send("update", job=self.serialize(job), status=status, **kwargs)

async def job(self, job_key: str) -> Job | None:
return self.deserialize(await self._send("job", job_key=job_key))
Expand Down Expand Up @@ -222,10 +198,8 @@ async def write_worker_info(
await self._send(
"write_worker_info",
worker_id=worker_id,
stats=info["stats"],
ttl=ttl,
queue_key=info["queue_key"],
metadata=info["metadata"],
info=info,
)

async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo:
Expand Down
32 changes: 15 additions & 17 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from saq.multiplexer import Multiplexer
from saq.queue.base import Queue, logger
from saq.queue.postgres_migrations import get_migrations
from saq.utils import now, now_seconds
from saq.utils import now_seconds

if t.TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -459,19 +459,16 @@ async def listen(
async def notify(self, job: Job, connection: AsyncConnection | None = None) -> None:
await self._notify(job.key, job.status, connection)

async def update(
self,
job: Job,
connection: AsyncConnection | None = None,
expire_at: float | None = -1,
**kwargs: t.Any,
) -> None:
job.touched = now()

for k, v in kwargs.items():
setattr(job, k, v)
async def _update(self, job: Job, status: Status | None = None, **kwargs: t.Any) -> None:
expire_at = kwargs.pop("expire_at", -1)
connection = kwargs.pop("connection", None)

async with self.nullcontext(connection) if connection else self.pool.connection() as conn:
if not status:
status = await self.get_job_status(job.key, connection=conn)

job.status = status or job.status

await conn.execute(
SQL(
dedent(
Expand Down Expand Up @@ -582,7 +579,7 @@ async def iter_jobs(
async def abort(self, job: Job, error: str, ttl: float = 5) -> None:
async with self.pool.connection() as conn:
status = await self.get_job_status(job.key, for_update=True, connection=conn)
if status == Status.QUEUED:
if not status or status == Status.QUEUED:
await self.finish(job, Status.ABORTED, error=error, connection=conn)
else:
await self.update(job, status=Status.ABORTING, error=error, connection=conn)
Expand Down Expand Up @@ -769,7 +766,7 @@ async def get_job_status(
key: str,
for_update: bool = False,
connection: AsyncConnection | None = None,
) -> Status:
) -> Status | None:
async with self.nullcontext(
connection
) if connection else self.pool.connection() as conn, conn.cursor() as cursor:
Expand All @@ -792,8 +789,9 @@ async def get_job_status(
},
)
result = await cursor.fetchone()
assert result
return result[0]
if result:
return result[0]
return None

async def _retry(self, job: Job, error: str | None) -> None:
next_retry_delay = job.next_retry_delay()
Expand All @@ -802,7 +800,7 @@ async def _retry(self, job: Job, error: str | None) -> None:
else:
scheduled = job.scheduled or now_seconds()

await self.update(job, scheduled=int(scheduled), expire_at=None)
await self.update(job, status=Status.QUEUED, scheduled=int(scheduled), expire_at=None)

async def _finish(
self,
Expand Down
7 changes: 5 additions & 2 deletions saq/queue/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,11 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
async def notify(self, job: Job) -> None:
await self.redis.publish(job.id, job.status)

async def update(self, job: Job) -> None:
job.touched = now()
async def _update(self, job: Job, status: Status | None = None, **kwargs: t.Any) -> None:
if not status:
stored = await self.job(job.key)
status = stored.status if stored else None
job.status = status or job.status
await self.redis.set(job.id, self.serialize(job))
await self.notify(job)

Expand Down
3 changes: 1 addition & 2 deletions saq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,8 @@ async def process(self) -> bool:
return False

job.started = now()
job.status = Status.ACTIVE
job.attempts += 1
await job.update()
await job.update(status=Status.ACTIVE)
context = {**self.context, "job": job}
await self._before_process(context)
logger.info("Processing %s", job.info(logger.isEnabledFor(logging.DEBUG)))
Expand Down
10 changes: 3 additions & 7 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,7 @@ async def test_iter_jobs(self) -> None:
await self.queue.enqueue("test")

async for job in self.queue.iter_jobs(batch_size=1):
job.status = Status.ACTIVE
await job.update()
await job.update(status=Status.ACTIVE)
break

self.assertEqual(9, len([job async for job in self.queue.iter_jobs()]))
Expand Down Expand Up @@ -435,9 +434,8 @@ async def test_sweep(self, mock_time: MagicMock) -> None:
job3 = await self.enqueue("test", timeout=1)
for _ in range(4):
job = await self.dequeue()
job.status = Status.ACTIVE
job.started = 1000
await self.queue.update(job)
await self.queue.update(job, status=Status.ACTIVE)
await self.dequeue()

# missing job
Expand Down Expand Up @@ -569,9 +567,7 @@ async def test_sweep_stuck(self, mock_time: MagicMock) -> None:
another_queue = await self.create_queue()
for _ in range(2):
job = await another_queue.dequeue()
job.status = Status.ACTIVE
job.started = 1000
await another_queue.update(job)
await another_queue.update(job, status=Status.ACTIVE, started=1000)

# Disconnect another_queue to simulate worker going down
await another_queue.disconnect()
Expand Down

0 comments on commit e7be16e

Please sign in to comment.