Skip to content

Commit 0949dbb

Browse files
author
Morty Dev
committed
refactor: auto-transaction implementation, mypy fixes
1 parent b4ba477 commit 0949dbb

File tree

9 files changed

+73
-82
lines changed

9 files changed

+73
-82
lines changed

app/features/content/api.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ...resources.providers import ConnectionProvider
66
from .models import Content, ContentIn
77
from .tables import CONTENT_TABLE
8-
from .tasks import fetch_content
8+
# from .tasks import fetch_content_async
99

1010
CONNECTION = AppContainer.resources.db.connection # type: ignore
1111
API = APIRouter()
@@ -31,12 +31,12 @@ async def create_content(
3131
async with connection.acquire() as cn, c1.inject(connection), c2.inject(
3232
connection
3333
):
34-
assert c1.injected == connection.current == cn
35-
assert c2.injected == connection.current == cn
36-
content = await connection.one(
37-
CONTENT_TABLE.insert(content_in.dict()).returning(CONTENT_TABLE),
38-
commit=True,
39-
)
40-
if content:
41-
fetch_content.apply_async((content["id"],))
34+
assert c1.current == connection.current == cn
35+
assert c2.current == connection.current == cn
36+
content = await connection.one(
37+
CONTENT_TABLE.insert(content_in.dict()).returning(CONTENT_TABLE),
38+
# commit=True,
39+
)
40+
# if content:
41+
# await fetch_content_async(content["id"], connection=connection)
4242
return content

app/features/content/tasks.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,21 @@
1515
@inject
1616
async def fetch_content_async(
1717
content_id,
18-
cursor: ConnectionProvider = Provide[AppContainer.resources.db],
18+
connection: ConnectionProvider = Provide[
19+
AppContainer.resources.db.connection
20+
],
1921
):
20-
content = await cursor.one(
21-
CONTENT_TABLE.select(CONTENT_TABLE.c.id == content_id)
22-
)
23-
async with httpx.AsyncClient() as cl:
24-
body = (await cl.get(content["url"])).text
25-
await cursor.execute(
26-
CONTENT_TABLE.update(CONTENT_TABLE.c.id == content_id).values(
27-
{"body": body}
28-
),
29-
commit=True,
30-
)
22+
async with connection.acquire():
23+
content = await connection.one(
24+
CONTENT_TABLE.select(CONTENT_TABLE.c.id == content_id)
25+
)
26+
async with httpx.AsyncClient() as cl:
27+
body = (await cl.get(content["url"])).text
28+
await connection.execute(
29+
CONTENT_TABLE.update(CONTENT_TABLE.c.id == content_id).values(
30+
{"body": body}
31+
),
32+
)
3133

3234

3335
@shared_task(autoretry_for=(Exception,), name="fetch_content")

app/migrations/script.py.mako

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ ${imports if imports else ""}
1111

1212
# revision identifiers, used by Alembic.
1313
revision = ${repr(up_revision)}
14-
down_revision = ${repr(down_revision)}
15-
branch_labels = ${repr(branch_labels)}
16-
depends_on = ${repr(depends_on)}
14+
down_revision: str | None = ${repr(down_revision)}
15+
branch_labels: str | None = ${repr(branch_labels)}
16+
depends_on: str | None = ${repr(depends_on)}
1717

1818

1919
def upgrade():

app/migrations/versions/0001_initial_migration.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
33
Revision ID: 0001
44
Revises:
5-
Created Date: 2022-05-08 16:55:48.639025
5+
Created Date: 2022-10-10 00:50:17.533906
66
77
"""
88
import sqlalchemy as sa
99
from alembic import op
1010

1111
# revision identifiers, used by Alembic.
1212
revision = "0001"
13-
down_revision = None
14-
branch_labels = None
15-
depends_on = None
13+
down_revision: str | None = None
14+
branch_labels: str | None = None
15+
depends_on: str | None = None
1616

1717

1818
def upgrade():

app/resources/providers.py

+17-34
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ async def somelogic(cn, dal1, dal2):
6666

6767
def __init__(self, engine):
6868
self.engine = engine
69-
self.current = False
70-
self.injected = None
69+
self.current: AsyncConnection | None = None
7170

7271
@asynccontextmanager
7372
async def inject(
@@ -77,52 +76,36 @@ async def inject(
7776
raise ValueError(
7877
"Connection already used, you cannot override current one"
7978
)
80-
if self.injected:
81-
raise ValueError("Already injected connection exists")
8279

83-
self.injected = connection.current
80+
self.current = connection.current
8481
yield self
85-
self.injected = None
82+
self.current = None
8683

8784
@asynccontextmanager
88-
async def acquire(
89-
self, tx=True
90-
) -> TP.AsyncGenerator[AsyncConnection, None]:
91-
if self.current and not self.injected:
92-
raise ValueError(
93-
"Connection already used, you cannot override current one"
94-
)
95-
96-
if self.injected:
97-
yield self.injected
85+
async def acquire(self) -> TP.AsyncGenerator[AsyncConnection, None]:
86+
if self.current:
87+
yield self.current
9888
else:
99-
manager = self.engine.begin if tx else self.engine.connect
100-
async with manager() as cn:
89+
async with self.engine.begin() as cn:
10190
self.current = cn
10291
yield cn
10392
self.current = None
10493

105-
async def scalar(self, query, commit: bool = False) -> TP.Any | None:
94+
async def execute(self, query) -> None:
95+
async with self.acquire() as cn:
96+
await cn.execute(query)
97+
98+
async def scalar(self, query) -> TP.Any | None:
10699
async with self.acquire() as cn:
107100
result = (await cn.execute(query)).scalar()
108-
if commit:
109-
await cn.commit()
110-
return result
101+
return result
111102

112-
async def one(
113-
self, query, commit: bool = False
114-
) -> dict[str, TP.Any] | None:
103+
async def one(self, query) -> dict[str, TP.Any]:
115104
async with self.acquire() as cn:
116105
row = (await cn.execute(query)).fetchone()
117-
if commit:
118-
await cn.commit()
119-
return dict(row) if row else None
106+
return dict(row) if row else dict()
120107

121-
async def many(
122-
self, query, commit: bool = False
123-
) -> list[dict[str, TP.Any]]:
108+
async def many(self, query) -> list[dict[str, TP.Any]]:
124109
async with self.acquire() as cn:
125110
rows = (await cn.execute(query)).fetchall()
126-
if commit:
127-
await cn.commit()
128-
return [dict(row) for row in rows]
111+
return [dict(row) for row in rows]

mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
[mypy]
22
python_version = 3.10
3+
4+
[mypy-asyncpg.*]
5+
ignore_missing_imports = True

poetry.lock

+19-19
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

poetry.toml

100755100644
File mode changed.

tests/test_content.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
@pytest.mark.asyncio
99
async def test_content_created(client, cursor):
10-
COUNT = 1000
10+
COUNT = 100
1111

1212
URLS = [f"https://www.google.com/{i}" for i in range(COUNT)]
1313
responses = await asyncio.gather(
@@ -28,3 +28,6 @@ async def test_content_created(client, cursor):
2828
assert resp.status_code == 409, resp.json()
2929

3030
assert len(URLS) == len(list(await cursor.many(CONTENT_TABLE.select())))
31+
32+
contents = await client.get("/contents")
33+
assert len(URLS) == len(contents.json())

0 commit comments

Comments
 (0)