Skip to content

Commit

Permalink
Clean up implementation of deepcopy and finished signal definition
Browse files Browse the repository at this point in the history
  • Loading branch information
RealOrangeOne committed Aug 30, 2024
1 parent 17c4a19 commit 57f7ce7
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 41 deletions.
70 changes: 32 additions & 38 deletions django_tasks/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from copy import deepcopy
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta
from functools import lru_cache
from inspect import iscoroutinefunction
from typing import (
TYPE_CHECKING,
Expand All @@ -24,7 +24,11 @@
from typing_extensions import ParamSpec, Self

from .exceptions import ResultDoesNotExist
from .utils import SerializedExceptionDict, exception_from_dict, get_module_path
from .utils import (
SerializedExceptionDict,
exception_from_dict,
get_module_path,
)

if TYPE_CHECKING:
from .backends.base import BaseTaskBackend
Expand All @@ -51,11 +55,22 @@ class ResultStatus(TextChoices):
COMPLETE = ("COMPLETE", _("Complete"))


# Use a large number to ensure `lru_cache` still uses a lock
@lru_cache(maxsize=100000)
def _get_task_signal(task: "Task", name: str) -> Signal:
"""
Allow a Task to have a signal, without storing it on the task itself.
This allows the Task to still be hashable, picklable and deepcopyable.
"""
return Signal()


T = TypeVar("T")
P = ParamSpec("P")


@dataclass
@dataclass(frozen=True)
class Task(Generic[P, T]):
priority: int
"""The priority of the task"""
Expand All @@ -78,9 +93,6 @@ class Task(Generic[P, T]):
immediately, or whatever the backend decides
"""

finished: Signal = field(init=False, default_factory=Signal)
"""A signal, fired when the task finished"""

def __post_init__(self) -> None:
self.get_backend().validate_task(self)

Expand All @@ -103,44 +115,21 @@ def using(
Create a new task with modified defaults
"""

task = deepcopy(self)
changes: Dict[str, Any] = {}

if priority is not None:
task.priority = priority
changes["priority"] = priority
if queue_name is not None:
task.queue_name = queue_name
changes["queue_name"] = queue_name
if run_after is not None:
if isinstance(run_after, timedelta):
task.run_after = timezone.now() + run_after
changes["run_after"] = timezone.now() + run_after
else:
task.run_after = run_after
changes["run_after"] = run_after
if backend is not None:
task.backend = backend

task.get_backend().validate_task(task)
changes["backend"] = backend

return task

def __deepcopy__(self, memo: Dict) -> Self:
"""
Copy a task, transplanting the `finished` signal.
Signals can't be deepcopied, so it needs to bypass the copy.
"""
finished_signal = self.finished
deepcopy_method = self.__deepcopy__

try:
self.finished = None # type: ignore[assignment]
self.__deepcopy__ = None # type: ignore[assignment]
task = deepcopy(self, memo)
finally:
self.__deepcopy__ = deepcopy_method # type: ignore[method-assign]
self.finished = finished_signal

task.finished = finished_signal
task.__deepcopy__ = deepcopy_method # type: ignore[method-assign]

return task
return replace(self, **changes)

def enqueue(self, *args: P.args, **kwargs: P.kwargs) -> "TaskResult[T]":
"""
Expand Down Expand Up @@ -208,6 +197,11 @@ def is_modified(self) -> bool:
"""
return self != self.original

@property
def finished(self) -> Signal:
"""A signal, fired when the task finished"""
return _get_task_signal(self, "finished")


# Bare decorator usage
# e.g. @task
Expand Down
22 changes: 19 additions & 3 deletions tests/tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from copy import deepcopy
from datetime import datetime, timedelta

Expand Down Expand Up @@ -114,18 +115,30 @@ def test_using_creates_new_instance(self) -> None:
self.assertEqual(new_task, test_tasks.noop_task)
self.assertIsNot(new_task, test_tasks.noop_task)

def test_chained_using(self) -> None:
now = timezone.now()

run_after_task = test_tasks.noop_task.using(run_after=now)
self.assertEqual(run_after_task.run_after, now)

priority_task = run_after_task.using(priority=10)
self.assertEqual(priority_task.priority, 10)
self.assertEqual(priority_task.run_after, now)

self.assertEqual(run_after_task.priority, 0)

async def test_refresh_result(self) -> None:
result = await test_tasks.noop_task.aenqueue()

original_result = deepcopy(result)
result_data = dataclasses.asdict(result)

result.refresh()

self.assertEqual(result, original_result)
self.assertEqual(dataclasses.asdict(result), result_data)

await result.arefresh()

self.assertEqual(result, original_result)
self.assertEqual(dataclasses.asdict(result), result_data)

def test_naive_datetime(self) -> None:
with self.assertRaisesMessage(
Expand Down Expand Up @@ -252,3 +265,6 @@ def test_finished_signal(self) -> None:
self.assertIs(
test_tasks.noop_task.using().finished, test_tasks.noop_task.finished
)
self.assertIs(
deepcopy(test_tasks.noop_task).finished, test_tasks.noop_task.finished
)

0 comments on commit 57f7ce7

Please sign in to comment.