Skip to content

Commit

Permalink
Add configurable transaction management (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
RealOrangeOne authored Aug 2, 2024
1 parent 87e1fdc commit 4c5f63e
Show file tree
Hide file tree
Showing 11 changed files with 419 additions and 59 deletions.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ The task decorator accepts a few arguments to customize the task:
- `priority`: The priority of the task (between -100 and 100. Larger numbers are higher priority. 0 by default)
- `queue_name`: Whether to run the task on a specific queue
- `backend`: Name of the backend for this task to use (as defined in `TASKS`)
- `enqueue_on_commit`: Whether the task is enqueued when the current transaction commits successfully, or enqueued immediately. By default, this is handled by the backend (see below). `enqueue_on_commit` may not be modified with `.using`.

These attributes can also be modified at run-time with `.using`:
These attributes (besides `enqueue_on_commit`) can also be modified at run-time with `.using`:

```python
modified_task = calculate_meaning_of_life.using(priority=10)
Expand All @@ -88,6 +89,23 @@ The returned `TaskResult` can be interrogated to query the current state of the

If the task takes arguments, these can be passed as-is to `enqueue`.

#### Transactions

By default, tasks are enqueued after the current transaction (if there is one) commits successfully (using Django's `transaction.on_commit` method), rather than enqueueing immediately.

This can be configured using the `ENQUEUE_ON_COMMIT` setting. `True` and `False` force the behaviour.

```python
TASKS = {
"default": {
"BACKEND": "django_tasks.backends.immediate.ImmediateBackend",
"ENQUEUE_ON_COMMIT": False
}
}
```

This can also be configured per-task by passing `enqueue_on_commit` to the `task` decorator.

### Queue names

By default, tasks are enqueued onto the "default" queue. When using multiple queues, it can be useful to constrain the allowed names, so tasks aren't missed.
Expand Down
40 changes: 34 additions & 6 deletions django_tasks/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from typing import Any, Iterable, TypeVar

from asgiref.sync import sync_to_async
from django.core.checks.messages import CheckMessage
from django.core.checks import messages
from django.db import connections
from django.test.testcases import _DatabaseFailure
from django.utils import timezone
from typing_extensions import ParamSpec

Expand All @@ -16,6 +18,9 @@


class BaseTaskBackend(metaclass=ABCMeta):
alias: str
enqueue_on_commit: bool

task_class = Task

supports_defer = False
Expand All @@ -32,6 +37,27 @@ def __init__(self, options: dict) -> None:

self.alias = options["ALIAS"]
self.queues = set(options.get("QUEUES", [DEFAULT_QUEUE_NAME]))
self.enqueue_on_commit = bool(options.get("ENQUEUE_ON_COMMIT", True))

def _get_enqueue_on_commit_for_task(self, task: Task) -> bool:
"""
Determine the correct `enqueue_on_commit` setting to use for a given task.
If the task defines it, use that, otherwise, fall back to the backend.
"""
# If this project doesn't use a database, there's nothing to commit to
if not connections.settings:
return False

# If connections are disabled during tests, there's nothing to commit to
for conn in connections.all():
if isinstance(conn.connect, _DatabaseFailure):
return False

if isinstance(task.enqueue_on_commit, bool):
return task.enqueue_on_commit

return self.enqueue_on_commit

def validate_task(self, task: Task) -> None:
"""
Expand Down Expand Up @@ -101,8 +127,10 @@ async def aget_result(self, result_id: str) -> TaskResult:
result_id=result_id
)

def check(self, **kwargs: Any) -> Iterable[CheckMessage]:
raise NotImplementedError(
"subclasses may provide a check() method to verify that task "
"backend is configured correctly."
)
def check(self, **kwargs: Any) -> Iterable[messages.CheckMessage]:
if self.enqueue_on_commit and not connections.settings:
yield messages.CheckMessage(
messages.ERROR,
"`ENQUEUE_ON_COMMIT` cannot be used when no databases are configured",
hint="Set `ENQUEUE_ON_COMMIT` to False",
)
32 changes: 13 additions & 19 deletions django_tasks/backends/database/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import TYPE_CHECKING, Any, Iterable, TypeVar

from django.apps import apps
from django.core.checks import ERROR, CheckMessage
from django.core.checks import messages
from django.core.exceptions import ValidationError
from django.db import connections, router
from django.db import connections, router, transaction
from typing_extensions import ParamSpec

from django_tasks.backends.base import BaseTaskBackend
Expand Down Expand Up @@ -51,18 +51,10 @@ def enqueue(

db_result = self._task_to_db_task(task, args, kwargs)

db_result.save()

return db_result.task_result

async def aenqueue(
self, task: Task[P, T], args: P.args, kwargs: P.kwargs
) -> TaskResult[T]:
self.validate_task(task)

db_result = self._task_to_db_task(task, args, kwargs)

await db_result.asave()
if self._get_enqueue_on_commit_for_task(task):
transaction.on_commit(db_result.save)
else:
db_result.save()

return db_result.task_result

Expand All @@ -82,14 +74,16 @@ async def aget_result(self, result_id: str) -> TaskResult:
except (DBTaskResult.DoesNotExist, ValidationError) as e:
raise ResultDoesNotExist(result_id) from e

def check(self, **kwargs: Any) -> Iterable[CheckMessage]:
def check(self, **kwargs: Any) -> Iterable[messages.CheckMessage]:
from .models import DBTaskResult

yield from super().check(**kwargs)

backend_name = self.__class__.__name__

if not apps.is_installed("django_tasks.backends.database"):
yield CheckMessage(
ERROR,
yield messages.CheckMessage(
messages.ERROR,
f"{backend_name} configured as django_tasks backend, but database app not installed",
"Insert 'django_tasks.backends.database' in INSTALLED_APPS",
)
Expand All @@ -100,8 +94,8 @@ def check(self, **kwargs: Any) -> Iterable[CheckMessage]:
and hasattr(db_connection, "transaction_mode")
and db_connection.transaction_mode != "EXCLUSIVE"
):
yield CheckMessage(
ERROR,
yield messages.CheckMessage(
messages.ERROR,
f"{backend_name} is using SQLite non-exclusive transactions",
f"Set settings.DATABASES[{db_connection.alias!r}]['OPTIONS']['transaction_mode'] to 'EXCLUSIVE'",
)
9 changes: 7 additions & 2 deletions django_tasks/backends/dummy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from copy import deepcopy
from functools import partial
from typing import List, TypeVar
from uuid import uuid4

from django.db import transaction
from django.utils import timezone
from typing_extensions import ParamSpec

Expand Down Expand Up @@ -42,8 +44,11 @@ def enqueue(
backend=self.alias,
)

# Copy the task to prevent mutation issues
self.results.append(deepcopy(result))
if self._get_enqueue_on_commit_for_task(task) is not False:
# Copy the task to prevent mutation issues
transaction.on_commit(partial(self.results.append, deepcopy(result)))
else:
self.results.append(deepcopy(result))

return result

Expand Down
54 changes: 34 additions & 20 deletions django_tasks/backends/immediate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from functools import partial
from inspect import iscoroutinefunction
from typing import TypeVar
from uuid import uuid4

from asgiref.sync import async_to_sync
from django.db import transaction
from django.utils import timezone
from typing_extensions import ParamSpec

Expand All @@ -22,53 +24,65 @@
class ImmediateBackend(BaseTaskBackend):
supports_async_task = True

def enqueue(
self, task: Task[P, T], args: P.args, kwargs: P.kwargs
) -> TaskResult[T]:
self.validate_task(task)
def _execute_task(self, task_result: TaskResult) -> None:
"""
Execute the task for the given `TaskResult`, mutating it with the outcome
"""
task = task_result.task

calling_task_func = (
async_to_sync(task.func) if iscoroutinefunction(task.func) else task.func
)

enqueued_at = timezone.now()
started_at = timezone.now()
result_id = str(uuid4())
task_result.started_at = timezone.now()
try:
result = json_normalize(calling_task_func(*args, **kwargs))
status = ResultStatus.COMPLETE
task_result._result = json_normalize(
calling_task_func(*task_result.args, **task_result.kwargs)
)
except BaseException as e:
task_result.finished_at = timezone.now()
try:
result = exception_to_dict(e)
task_result._result = exception_to_dict(e)
except Exception:
logger.exception("Task id=%s unable to save exception", result_id)
result = None
logger.exception("Task id=%s unable to save exception", task_result.id)
task_result._result = None

# Use `.exception` to integrate with error monitoring tools (eg Sentry)
logger.exception(
"Task id=%s path=%s state=%s",
result_id,
task_result.id,
task.module_path,
ResultStatus.FAILED,
)
status = ResultStatus.FAILED
task_result.status = ResultStatus.FAILED

# If the user tried to terminate, let them
if isinstance(e, KeyboardInterrupt):
raise
else:
task_result.finished_at = timezone.now()
task_result.status = ResultStatus.COMPLETE

def enqueue(
self, task: Task[P, T], args: P.args, kwargs: P.kwargs
) -> TaskResult[T]:
self.validate_task(task)

task_result = TaskResult[T](
task=task,
id=result_id,
status=status,
enqueued_at=enqueued_at,
started_at=started_at,
finished_at=timezone.now(),
id=str(uuid4()),
status=ResultStatus.NEW,
enqueued_at=timezone.now(),
started_at=None,
finished_at=None,
args=json_normalize(args),
kwargs=json_normalize(kwargs),
backend=self.alias,
)

task_result._result = result
if self._get_enqueue_on_commit_for_task(task) is not False:
transaction.on_commit(partial(self._execute_task, task_result))
else:
self._execute_task(task_result)

return task_result
16 changes: 14 additions & 2 deletions django_tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,17 @@ class Task(Generic[P, T]):
"""The name of the backend the task will run on"""

queue_name: str = DEFAULT_QUEUE_NAME
"""The name of the queue the task will run on """
"""The name of the queue the task will run on"""

run_after: Optional[datetime] = None
"""The earliest this task will run"""

enqueue_on_commit: Optional[bool] = None
"""
Whether the task will be enqueued when the current transaction commits,
immediately, or whatever the backend decides
"""

def __post_init__(self) -> None:
self.get_backend().validate_task(self)

Expand Down Expand Up @@ -170,6 +176,7 @@ def task(
priority: int = DEFAULT_PRIORITY,
queue_name: str = DEFAULT_QUEUE_NAME,
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
enqueue_on_commit: Optional[bool] = None,
) -> Callable[[Callable[P, T]], Task[P, T]]: ...


Expand All @@ -180,6 +187,7 @@ def task(
priority: int = DEFAULT_PRIORITY,
queue_name: str = DEFAULT_QUEUE_NAME,
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
enqueue_on_commit: Optional[bool] = None,
) -> Union[Task[P, T], Callable[[Callable[P, T]], Task[P, T]]]:
"""
A decorator used to create a task.
Expand All @@ -188,7 +196,11 @@ def task(

def wrapper(f: Callable[P, T]) -> Task[P, T]:
return tasks[backend].task_class(
priority=priority, func=f, queue_name=queue_name, backend=backend
priority=priority,
func=f,
queue_name=queue_name,
backend=backend,
enqueue_on_commit=enqueue_on_commit,
)

if function:
Expand Down
10 changes: 10 additions & 0 deletions tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,13 @@ def complex_exception() -> None:
@task()
def exit_task() -> None:
exit(1)


@task(enqueue_on_commit=True)
def enqueue_on_commit_task() -> None:
pass


@task(enqueue_on_commit=False)
def never_enqueue_on_commit_task() -> None:
pass
Loading

0 comments on commit 4c5f63e

Please sign in to comment.