Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow customizing connection state reset #1191

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,11 +1515,10 @@ def terminate(self):
self._abort()
self._cleanup()

async def reset(self, *, timeout=None):
async def _reset(self):
self._check_open()
self._listeners.clear()
self._log_listeners.clear()
reset_query = self._get_reset_query()

if self._protocol.is_in_transaction() or self._top_xact is not None:
if self._top_xact is None or not self._top_xact._managed:
Expand All @@ -1531,10 +1530,36 @@ async def reset(self, *, timeout=None):
})

self._top_xact = None
reset_query = 'ROLLBACK;\n' + reset_query
await self.execute("ROLLBACK")

async def reset(self, *, timeout=None):
"""Reset the connection state.

Calling this will reset the connection session state to a state
resembling that of a newly obtained connection. Namely, an open
transaction (if any) is rolled back, open cursors are closed,
all `LISTEN <https://www.postgresql.org/docs/current/sql-listen.html>`_
registrations are removed, all session configuration
variables are reset to their default values, and all advisory locks
are released.

Note that the above describes the default query returned by
:meth:`Connection.get_reset_query`. If one overloads the method
by subclassing ``Connection``, then this method will do whatever
the overloaded method returns, except open transactions are always
terminated and any callbacks registered by
:meth:`Connection.add_listener` or :meth:`Connection.add_log_listener`
are removed.

if reset_query:
await self.execute(reset_query, timeout=timeout)
:param float timeout:
A timeout for resetting the connection. If not specified, defaults
to no timeout.
"""
async with compat.timeout(timeout):
await self._reset()
reset_query = self.get_reset_query()
if reset_query:
await self.execute(reset_query)

def _abort(self):
# Put the connection into the aborted state.
Expand Down Expand Up @@ -1695,7 +1720,15 @@ def _unwrap(self):
con_ref = self._proxy
return con_ref

def _get_reset_query(self):
def get_reset_query(self):
"""Return the query sent to server on connection release.

The query returned by this method is used by :meth:`Connection.reset`,
which is, in turn, used by :class:`~asyncpg.pool.Pool` before making
the connection available to another acquirer.

.. versionadded:: 0.30.0
"""
if self._reset_query is not None:
return self._reset_query

Expand Down
36 changes: 32 additions & 4 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,12 @@ async def release(self, timeout):
if budget is not None:
budget -= time.monotonic() - started

await self._con.reset(timeout=budget)
if self._pool._reset is not None:
async with compat.timeout(budget):
await self._con._reset()
await self._pool._reset(self._con)
else:
await self._con.reset(timeout=budget)
except (Exception, asyncio.CancelledError) as ex:
# If the `reset` call failed, terminate the connection.
# A new one will be created when `acquire` is called
Expand Down Expand Up @@ -313,7 +318,7 @@ class Pool:

__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
'_init', '_connect', '_connect_args', '_connect_kwargs',
'_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
'_holders', '_initialized', '_initializing', '_closing',
'_closed', '_connection_class', '_record_class', '_generation',
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
Expand All @@ -327,6 +332,7 @@ def __init__(self, *connect_args,
connect=None,
setup=None,
init=None,
reset=None,
loop,
connection_class,
record_class,
Expand Down Expand Up @@ -393,6 +399,7 @@ def __init__(self, *connect_args,

self._setup = setup
self._init = init
self._reset = reset

self._max_queries = max_queries
self._max_inactive_connection_lifetime = \
Expand Down Expand Up @@ -1036,6 +1043,7 @@ def create_pool(dsn=None, *,
connect=None,
setup=None,
init=None,
reset=None,
loop=None,
connection_class=connection.Connection,
record_class=protocol.Record,
Expand Down Expand Up @@ -1125,7 +1133,7 @@ def create_pool(dsn=None, *,

:param coroutine setup:
A coroutine to prepare a connection right before it is returned
from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use
from :meth:`Pool.acquire()`. An example use
case would be to automatically set up notifications listeners for
all connections of a pool.

Expand All @@ -1137,6 +1145,25 @@ def create_pool(dsn=None, *,
or :meth:`Connection.set_type_codec() <\
asyncpg.connection.Connection.set_type_codec>`.

:param coroutine reset:
A coroutine to reset a connection before it is returned to the pool by
:meth:`Pool.release()`. The function is supposed
to reset any changes made to the database session so that the next
acquirer gets the connection in a well-defined state.

The default implementation calls :meth:`Connection.reset() <\
asyncpg.connection.Connection.reset>`, which runs the following::

SELECT pg_advisory_unlock_all();
CLOSE ALL;
UNLISTEN *;
RESET ALL;

The exact reset query is determined by detected server capabilities,
and a custom *reset* implementation can obtain the default query
by calling :meth:`Connection.get_reset_query() <\
asyncpg.connection.Connection.get_reset_query>`.

:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
Expand Down Expand Up @@ -1165,7 +1192,7 @@ def create_pool(dsn=None, *,
Added the *record_class* parameter.

.. versionchanged:: 0.30.0
Added the *connect* parameter.
Added the *connect* and *reset* parameters.
"""
return Pool(
dsn,
Expand All @@ -1178,6 +1205,7 @@ def create_pool(dsn=None, *,
connect=connect,
setup=setup,
init=init,
reset=reset,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
**connect_kwargs,
)
17 changes: 16 additions & 1 deletion tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,31 @@ async def setup(con):
async def test_pool_07(self):
cons = set()
connect_called = 0
init_called = 0
setup_called = 0
reset_called = 0

async def connect(*args, **kwargs):
nonlocal connect_called
connect_called += 1
return await pg_connection.connect(*args, **kwargs)

async def setup(con):
nonlocal setup_called
if con._con not in cons: # `con` is `PoolConnectionProxy`.
raise RuntimeError('init was not called before setup')
setup_called += 1

async def init(con):
nonlocal init_called
if con in cons:
raise RuntimeError('init was called more than once')
cons.add(con)
init_called += 1

async def reset(con):
nonlocal reset_called
reset_called += 1

async def user(pool):
async with pool.acquire() as con:
Expand All @@ -162,12 +173,16 @@ async def user(pool):
max_size=5,
connect=connect,
init=init,
setup=setup) as pool:
setup=setup,
reset=reset) as pool:
users = asyncio.gather(*[user(pool) for _ in range(10)])
await users

self.assertEqual(len(cons), 5)
self.assertEqual(connect_called, 5)
self.assertEqual(init_called, 5)
self.assertEqual(setup_called, 10)
self.assertEqual(reset_called, 10)

async def bad_connect(*args, **kwargs):
return 1
Expand Down
Loading