Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 88 additions & 47 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,16 @@
from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable
import functools
import inspect
import logging
import time
import warnings
from collections.abc import Awaitable, Callable
from types import TracebackType
from typing import Any, Optional, Type
import warnings

from . import compat
from . import connection
from . import exceptions
from . import protocol

from . import compat, connection, exceptions, protocol

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -338,27 +334,46 @@ class Pool:
"""

__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
'_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
'_holders', '_initialized', '_initializing', '_closing',
'_closed', '_connection_class', '_record_class', '_generation',
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
"_queue",
"_loop",
"_minsize",
"_maxsize",
"_init",
"_connect",
"_reset",
"_connect_args",
"_connect_kwargs",
"_holders",
"_initialized",
"_initializing",
"_closing",
"_closed",
"_connection_class",
"_record_class",
"_generation",
"_setup",
"_max_queries",
"_max_inactive_connection_lifetime",
"_pool_timeout",
)

def __init__(self, *connect_args,
min_size,
max_size,
max_queries,
max_inactive_connection_lifetime,
connect=None,
setup=None,
init=None,
reset=None,
loop,
connection_class,
record_class,
**connect_kwargs):

def __init__(
self,
*connect_args,
min_size,
max_size,
max_queries,
max_inactive_connection_lifetime,
pool_timeout=None,
connect=None,
setup=None,
init=None,
reset=None,
loop,
connection_class,
record_class,
**connect_kwargs,
):
if len(connect_args) > 1:
warnings.warn(
"Passing multiple positional arguments to asyncpg.Pool "
Expand Down Expand Up @@ -389,6 +404,11 @@ def __init__(self, *connect_args,
'max_inactive_connection_lifetime is expected to be greater '
'or equal to zero')

if pool_timeout is not None and pool_timeout <= 0:
raise ValueError(
"pool_timeout is expected to be greater than zero or None"
)

if not issubclass(connection_class, connection.Connection):
raise TypeError(
'connection_class is expected to be a subclass of '
Expand Down Expand Up @@ -423,8 +443,10 @@ def __init__(self, *connect_args,
self._reset = reset

self._max_queries = max_queries
self._max_inactive_connection_lifetime = \
self._max_inactive_connection_lifetime = (
max_inactive_connection_lifetime
)
self._pool_timeout = pool_timeout

async def _async__init__(self):
if self._initialized:
Expand Down Expand Up @@ -578,7 +600,7 @@ async def execute(
self,
query: str,
*args,
timeout: Optional[float]=None,
timeout: Optional[float] = None,
) -> str:
"""Execute an SQL command (or commands).

Expand All @@ -596,7 +618,7 @@ async def executemany(
command: str,
args,
*,
timeout: Optional[float]=None,
timeout: Optional[float] = None,
):
"""Execute an SQL *command* for each sequence of arguments in *args*.

Expand Down Expand Up @@ -853,6 +875,7 @@ def acquire(self, *, timeout=None):
"""Acquire a database connection from the pool.

:param float timeout: A timeout for acquiring a Connection.
If not specified, defaults to the pool's *pool_timeout*.
:return: An instance of :class:`~asyncpg.connection.Connection`.

Can be used in an ``await`` expression or with an ``async with`` block.
Expand Down Expand Up @@ -892,11 +915,16 @@ async def _acquire_impl():
raise exceptions.InterfaceError('pool is closing')
self._check_init()

if timeout is None:
# Use pool_timeout as fallback if no timeout specified
effective_timeout = timeout or self._pool_timeout

if effective_timeout is None:
return await _acquire_impl()
else:
return await compat.wait_for(
_acquire_impl(), timeout=timeout)
_acquire_impl(),
timeout=effective_timeout
)

async def release(self, connection, *, timeout=None):
"""Release a database connection back to the pool.
Expand All @@ -906,7 +934,8 @@ async def release(self, connection, *, timeout=None):
:param float timeout:
A timeout for releasing the connection. If not specified, defaults
to the timeout provided in the corresponding call to the
:meth:`Pool.acquire() <asyncpg.pool.Pool.acquire>` method.
:meth:`Pool.acquire() <asyncpg.pool.Pool.acquire>` method, or
to the pool's *pool_timeout* if no acquire timeout was set.

.. versionchanged:: 0.14.0
Added the *timeout* parameter.
Expand All @@ -929,7 +958,7 @@ async def release(self, connection, *, timeout=None):

ch = connection._holder
if timeout is None:
timeout = ch._timeout
timeout = ch._timeout or self._pool_timeout

# Use asyncio.shield() to guarantee that task cancellation
# does not prevent the connection from being returned to the
Expand Down Expand Up @@ -1065,26 +1094,32 @@ async def __aexit__(
self.done = True
con = self.connection
self.connection = None
await self.pool.release(con)
# Use the acquire timeout if set, otherwise fall back to pool_timeout
release_timeout = self.timeout or self.pool._pool_timeout
await self.pool.release(con, timeout=release_timeout)

def __await__(self):
self.done = True
return self.pool._acquire(self.timeout).__await__()


def create_pool(dsn=None, *,
min_size=10,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300.0,
connect=None,
setup=None,
init=None,
reset=None,
loop=None,
connection_class=connection.Connection,
record_class=protocol.Record,
**connect_kwargs):
def create_pool(
dsn=None,
*,
min_size=10,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300.0,
pool_timeout=None,
connect=None,
setup=None,
init=None,
reset=None,
loop=None,
connection_class=connection.Connection,
record_class=protocol.Record,
**connect_kwargs,
):
r"""Create a connection pool.

Can be used either with an ``async with`` block:
Expand Down Expand Up @@ -1161,6 +1196,11 @@ def create_pool(dsn=None, *,
Number of seconds after which inactive connections in the
pool will be closed. Pass ``0`` to disable this mechanism.

:param float pool_timeout:
Default timeout for pool operations (connection acquire and release).
If not specified, pool operations may hang indefinitely. Individual
operations can override this with their own timeout parameters.

:param coroutine connect:
A coroutine that is called instead of
:func:`~asyncpg.connection.connect` whenever the pool needs to make a
Expand Down Expand Up @@ -1238,6 +1278,7 @@ def create_pool(dsn=None, *,
min_size=min_size,
max_size=max_size,
max_queries=max_queries,
pool_timeout=pool_timeout,
loop=loop,
connect=connect,
setup=setup,
Expand Down
44 changes: 42 additions & 2 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@

class SlowResetConnection(pg_connection.Connection):
"""Connection class to simulate races with Connection.reset()."""
async def reset(self, *, timeout=None):
async def _reset(self):
await asyncio.sleep(0.2)
return await super().reset(timeout=timeout)
return await super()._reset()


class SlowCancelConnection(pg_connection.Connection):
Expand Down Expand Up @@ -1004,6 +1004,46 @@ async def worker():
conn = await pool.acquire(timeout=0.1)
await pool.release(conn)

async def test_pool_timeout_acquire_timeout(self):
pool = await self.create_pool(
database='postgres',
min_size=1,
max_size=1, # Only 1 connection to force timeout
pool_timeout=0.1
)

# First acquire the only connection
conn1 = await pool.acquire()

# Now try to acquire another - should timeout due to pool_timeout
start_time = time.monotonic()
with self.assertRaises(asyncio.TimeoutError):
await pool.acquire()
end_time = time.monotonic()

self.assertLess(end_time - start_time, 0.2)

await pool.release(conn1)
await pool.close()

async def test_pool_timeout_release_with_slow_reset(self):
pool = await self.create_pool(
database='postgres',
min_size=1,
max_size=1,
pool_timeout=0.1,
connection_class=SlowResetConnection,
)

start_time = time.monotonic()
with self.assertRaises(asyncio.TimeoutError):
conn = await pool.acquire()
await pool.release(conn)
end_time = time.monotonic()

self.assertLess(end_time - start_time, 0.2)
await pool.close()


@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
class TestPoolReconnectWithTargetSessionAttrs(tb.ClusterTestCase):
Expand Down
Loading