Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create asyncpg pools from connector #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion google/cloud/sql/connector/asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""
import ssl
from typing import Any, TYPE_CHECKING
from typing import Any, Awaitable, Callable, TYPE_CHECKING, Union

SERVER_PROXY_PORT = 3307

Expand Down Expand Up @@ -62,3 +62,47 @@ async def connect(
direct_tls=True,
**kwargs,
)


async def create_pool(
ip_address: str,
ctx: Union[Callable[[], ssl.SSLContext], Callable[[], Awaitable[ssl.SSLContext]]],
**kwargs: Any
) -> "asyncpg.Connection":
"""Helper function to create an asyncpg DB-API connection pool object.

:type ip_address: str
:param ip_address: A string containing an IP address for the Cloud SQL
instance.

:type ctx: Callable[[], ssl.SSLContext]
:param ctx: A callable that returns an SSLContext object created from the
Cloud SQL server CA cert and ephemeral cert.

:type kwargs: Any
:param kwargs: Keyword arguments for establishing asyncpg connection
object to Cloud SQL instance.

:rtype: asyncpg.Connection
:returns: An asyncpg.Connection object to a Cloud SQL instance.
"""
try:
import asyncpg
except ImportError:
raise ImportError(
'Unable to import module "asyncpg." Please install and try again.'
)
user = kwargs.pop("user")
db = kwargs.pop("db")
passwd = kwargs.pop("password", None)

return await asyncpg.create_pool(
user=user,
database=db,
password=passwd,
host=ip_address,
port=SERVER_PROXY_PORT,
ssl=ctx,
direct_tls=True,
**kwargs,
)
115 changes: 115 additions & 0 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from threading import Thread
from typing import Any, Dict, Optional, Type
from functools import partial
import ssl

logger = logging.getLogger(name=__name__)

Expand Down Expand Up @@ -266,6 +267,120 @@ async def get_connection() -> Any:
instance.force_refresh()
raise

async def create_pool(
self, instance_connection_string: str, driver: str, **kwargs: Any
) -> Any:
"""Prepares and returns a database connection pool object and starts a
background task to refresh the certificates and metadata.

:type instance_connection_string: str
:param instance_connection_string:
A string containing the GCP project name, region name, and instance
name separated by colons.

Example: example-proj:example-region-us6:example-instance

:type driver: str
:param: driver:
A string representing the driver to connect with. Currently the only
supported driver is asyncpg.

:param kwargs:
Pass in any driver-specific arguments needed to connect to the Cloud
SQL instance.

:rtype: asyncpg.Pool
:returns:
A DB-API connection to the specified Cloud SQL instance.
"""
# Create an Instance object from the connection string.
# The Instance should verify arguments.
#
# Use the Instance to establish an SSL Connection.
#
# Return a DBAPI connection
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
if instance_connection_string in self._instances:
instance = self._instances[instance_connection_string]
if enable_iam_auth != instance._enable_iam_auth:
raise ValueError(
f"connect() called with `enable_iam_auth={enable_iam_auth}`, "
f"but previously used enable_iam_auth={instance._enable_iam_auth}`. "
"If you require both for your use case, please use a new "
"connector.Connector object."
)
else:
instance = Instance(
instance_connection_string,
driver,
self._keys,
self._loop,
self._credentials,
enable_iam_auth,
self._quota_project,
self._sqladmin_api_endpoint,
)
self._instances[instance_connection_string] = instance

connect_func = {
"asyncpg": asyncpg.create_pool,
}

# only accept supported database drivers
try:
connector = connect_func[driver]
except KeyError:
raise KeyError(f"Driver '{driver}' is not supported.")

ip_type = kwargs.pop("ip_type", self._ip_type)
timeout = kwargs.pop("timeout", self._timeout)
if "connect_timeout" in kwargs:
timeout = kwargs.pop("connect_timeout")

# Host and ssl options come from the certificates and metadata, so we don't
# want the user to specify them.
kwargs.pop("host", None)
kwargs.pop("ssl", None)
kwargs.pop("port", None)

# helper function to wrap in timeout
async def get_pool() -> Any:
instance_data, ip_address = await instance.connect_info(ip_type)

async def get_context() -> ssl.SSLContext:
instance_data, _ = await instance.connect_info(ip_type)
return instance_data.context

# format `user` param for automatic IAM database authn
if enable_iam_auth:
formatted_user = format_database_user(
instance_data.database_version, kwargs["user"]
)
if formatted_user != kwargs["user"]:
logger.debug(
f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}"
)
kwargs["user"] = formatted_user

# async drivers are unblocking and can be awaited directly
if driver in ASYNC_DRIVERS:
return await connector(ip_address, get_context, **kwargs)
# synchronous drivers are blocking and run using executor
connect_partial = partial(
connector, ip_address, instance_data.context, **kwargs
)
return await self._loop.run_in_executor(None, connect_partial)

# attempt to make connection to Cloud SQL instance for given timeout
try:
return await asyncio.wait_for(get_pool(), timeout)
except asyncio.TimeoutError:
raise TimeoutError(f"Connection timed out after {timeout}s")
except Exception:
# with any other exception, we attempt a force refresh, then throw the error
instance.force_refresh()
raise

def __enter__(self) -> Any:
"""Enter context manager by returning Connector object"""
return self
Expand Down
42 changes: 41 additions & 1 deletion tests/system/test_asyncpg_iam_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


@pytest.fixture(name="conn")
async def setup() -> AsyncGenerator:
async def setup_conn() -> AsyncGenerator:
# initialize Cloud SQL Python Connector object
connector = await create_async_connector()
conn: asyncpg.Connection = await connector.connect_async(
Expand All @@ -49,6 +49,31 @@ async def setup() -> AsyncGenerator:
await connector.close_async()


@pytest.fixture(name="pool")
async def setup_pool() -> AsyncGenerator:
# initialize Cloud SQL Python Connector object
connector = await create_async_connector()
pool: asyncpg.Pool = await connector.create_pool(
os.environ["POSTGRES_IAM_CONNECTION_NAME"],
"asyncpg",
user=os.environ["POSTGRES_IAM_USER"],
db=os.environ["POSTGRES_DB"],
enable_iam_auth=True
)
await pool.execute(
f"CREATE TABLE IF NOT EXISTS {table_name}"
" ( id CHAR(20) NOT NULL, title TEXT NOT NULL );"
)

yield pool

await pool.execute(f"DROP TABLE IF EXISTS {table_name}")
# close asyncpg connection
await pool.close()
# cleanup Connector object
await connector.close_async()


@pytest.mark.asyncio
async def test_connection_with_asyncpg_iam_auth(conn: asyncpg.Connection) -> None:
await conn.execute(
Expand All @@ -62,3 +87,18 @@ async def test_connection_with_asyncpg_iam_auth(conn: asyncpg.Connection) -> Non
titles = [row[0] for row in rows]

assert titles == ["Book One", "Book Two"]


@pytest.mark.asyncio
async def test_connection_pooling_with_asyncpg_iam_auth(pool: asyncpg.Pool) -> None:
await pool.execute(
f"INSERT INTO {table_name} (id, title) VALUES ('book1', 'Book One')"
)
await pool.execute(
f"INSERT INTO {table_name} (id, title) VALUES ('book2', 'Book Two')"
)

rows = await pool.fetch(f"SELECT title FROM {table_name} ORDER BY ID")
titles = [row[0] for row in rows]

assert titles == ["Book One", "Book Two"]
20 changes: 18 additions & 2 deletions tests/unit/test_asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from typing import Any
from mock import patch, AsyncMock

from google.cloud.sql.connector.asyncpg import connect
from google.cloud.sql.connector.asyncpg import connect, create_pool


@pytest.mark.asyncio
@patch("asyncpg.connect", new_callable=AsyncMock)
async def test_asyncpg(mock_connect: AsyncMock, kwargs: Any) -> None:
async def test_asyncpg_connect(mock_connect: AsyncMock, kwargs: Any) -> None:
"""Test to verify that asyncpg gets to proper connection call."""
ip_addr = "0.0.0.0"
context = ssl.create_default_context()
Expand All @@ -32,3 +32,19 @@ async def test_asyncpg(mock_connect: AsyncMock, kwargs: Any) -> None:
assert connection is True
# verify that driver connection call would be made
assert mock_connect.assert_called_once


@pytest.mark.asyncio
@patch("asyncpg.create_pool", new_callable=AsyncMock)
async def test_asyncpg_create_pool(mock_create_pool: AsyncMock, kwargs: Any) -> None:
"""Test to verify that asyncpg gets to proper pool creation call."""
ip_addr = "0.0.0.0"

async def get_context() -> ssl.SSLContext:
return ssl.create_default_context()

mock_create_pool.return_value = True
connection = await create_pool(ip_addr, get_context, **kwargs)
assert connection is True
# verify that driver pool creation call would be made
assert mock_create_pool.assert_called_once