Skip to content

Commit

Permalink
Remove use_cache parameter (#11)
Browse files Browse the repository at this point in the history
* Remove `use_cache` parameter

* Update tests
  • Loading branch information
yakimka authored Apr 24, 2024
1 parent 91c8105 commit 4746a9c
Show file tree
Hide file tree
Showing 11 changed files with 390 additions and 284 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ We follow [Semantic Versions](https://semver.org/).
## Version 0.1.1

- Fix context manager error

## Version 0.2.0

- Removed `use_cache` parameter
- Added tests
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM python:3.11-slim-bullseye as builder

ARG WHEEL=picodi-0.1.1-py3-none-any.whl
ARG WHEEL=picodi-0.2.0-py3-none-any.whl
ENV VENV=/venv
ENV PATH="$VENV/bin:$PATH"

Expand Down
74 changes: 33 additions & 41 deletions picodi/picodi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@
import inspect
import threading
from collections.abc import Awaitable, Callable, Coroutine, Generator
from contextlib import (
AsyncExitStack,
ExitStack,
asynccontextmanager,
contextmanager,
nullcontext,
)
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -43,7 +37,7 @@
}


def Provide(dependency: Dependency, /, use_cache: bool = True) -> Any: # noqa: N802
def Provide(dependency: Dependency, /) -> Any: # noqa: N802
"""
Declare a provider.
It takes a single "dependency" callable (like a function).
Expand Down Expand Up @@ -75,7 +69,7 @@ def my_service(db=Provide(get_db), settings=Provide(get_settings)):
"""
if not getattr(dependency, "_scope_", None):
dependency._scope_ = "null" # type: ignore[attr-defined] # noqa: SF01
return Depends.from_dependency(dependency, use_cache)
return Depends.from_dependency(dependency)


def inject(fn: Callable[P, T]) -> Callable[P, T | Coroutine[Any, Any, T]]:
Expand All @@ -100,14 +94,10 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()
exit_stack = AsyncExitStack()
for depends, names, get_value in _resolve_depends(
for names, get_value in _arguments_to_getter(
bound, exit_stack, is_async=True
):
if depends.use_cache:
value = await get_value()
bound.arguments.update({name: value for name in names})
else:
bound.arguments.update({name: await get_value() for name in names})
bound.arguments.update({name: await get_value() for name in names})

async with exit_stack:
result = await fn(*bound.args, **bound.kwargs)
Expand All @@ -120,14 +110,10 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()
exit_stack = ExitStack()
for depends, names, get_value in _resolve_depends(
for names, get_value in _arguments_to_getter(
bound, exit_stack, is_async=False
):
if depends.use_cache:
value = get_value()
bound.arguments.update({name: value for name in names})
else:
bound.arguments.update({name: get_value() for name in names})
bound.arguments.update({name: get_value() for name in names})

with exit_stack:
result = fn(*bound.args, **bound.kwargs)
Expand Down Expand Up @@ -159,7 +145,7 @@ def get_db():
raise TypeError("Resource should be a generator function")
fn._scope_ = "singleton" # type: ignore[attr-defined] # noqa: SF01
with _lock:
_resources.append(Depends.from_dependency(fn, use_cache=True))
_resources.append(Depends.from_dependency(fn))
return fn


Expand Down Expand Up @@ -199,34 +185,28 @@ def shutdown_resources() -> Awaitable | None:
@dataclass(frozen=True)
class Depends:
dependency: Dependency
use_cache: bool
context_manager: CallableManager | None = field(compare=False)
is_async: bool = field(compare=False)

def get_scope_name(self) -> str:
return self.dependency._scope_ # type: ignore[attr-defined] # noqa: SF01

def value_as_context_manager(self) -> Any:
if self.context_manager:
return self.context_manager()
return nullcontext(self.dependency())

@classmethod
def from_dependency(cls, dependency: Dependency, use_cache: bool) -> Depends:
def from_dependency(cls, dependency: Dependency) -> Depends:
context_manager: Callable | None = None
is_async = False
is_async = inspect.iscoroutinefunction(dependency)
if inspect.isasyncgenfunction(dependency):
context_manager = asynccontextmanager(dependency)
is_async = True
elif inspect.isgeneratorfunction(dependency):
context_manager = contextmanager(dependency)

return cls(dependency, use_cache, context_manager, is_async)
return cls(dependency, context_manager, is_async)


def _resolve_depends(
def _arguments_to_getter(
bound: BoundArguments, exit_stack: AsyncExitStack | ExitStack, is_async: bool
) -> Generator[tuple[Depends, list[str], Callable[[], Any]], None, None]:
) -> Generator[tuple[list[str], Callable[[], Any]], None, None]:
dependencies: dict[Depends, list[str]] = {}
for name, value in bound.arguments.items():
if isinstance(value, Depends):
Expand All @@ -236,7 +216,7 @@ def _resolve_depends(

for depends, names in dependencies.items():
get_value = functools.partial(get_val, depends, exit_stack) # type: ignore
yield depends, names, get_value
yield names, get_value


def _get_value_from_depends(
Expand All @@ -248,18 +228,17 @@ def _get_value_from_depends(
try:
value = scope.get(depends.dependency)
except KeyError:
context_manager = depends.value_as_context_manager()
exit_stack = local_exit_stack
if scope_name == "singleton":
exit_stack = _exit_stack
if depends.is_async:
value = depends.dependency
value = depends.dependency()
else:
with _lock:
try:
value = scope.get(depends.dependency)
except KeyError:
value = exit_stack.enter_context(context_manager)
value = _call_value(depends, exit_stack)
scope.set(depends.dependency, value)
return value

Expand All @@ -273,22 +252,35 @@ async def _get_value_from_depends_async(
try:
value = scope.get(depends.dependency)
except KeyError:
context_manager = depends.value_as_context_manager()
exit_stack = local_exit_stack
if scope_name == "singleton":
exit_stack = _async_exit_stack
with _lock:
try:
value = scope.get(depends.dependency)
except KeyError:
value = _call_value(depends, exit_stack)
if depends.is_async:
value = await exit_stack.enter_async_context(context_manager)
else:
value = exit_stack.enter_context(context_manager)
value = await value
scope.set(depends.dependency, value)
return value


def _call_value(depends: Depends, exit_stack: ExitStack | AsyncExitStack) -> Any:
if depends.is_async:
if depends.context_manager:
return exit_stack.enter_async_context( # type: ignore[union-attr]
depends.context_manager() # type: ignore[arg-type]
)
return depends.dependency()
else:
if depends.context_manager:
return exit_stack.enter_context(
depends.context_manager() # type: ignore[arg-type]
)
return depends.dependency()


def _is_async_environment() -> bool:
try:
asyncio.get_running_loop()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "picodi"
description = "Simple Dependency Injection for Python"
version = "0.1.1"
version = "0.2.0"
license = "MIT"
authors = [
"yakimka"
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ exclude =
# Ignoring some errors in some files:
per-file-ignores =
# TC002 Move third-party import into a type-checking block
tests/*.py: TC002
tests/*.py: TC002, S311, DUO102

### Plugins
# flake8-bugbear
Expand Down
41 changes: 0 additions & 41 deletions tests/test_complex_logic.py

This file was deleted.

Loading

0 comments on commit 4746a9c

Please sign in to comment.