From b14b2e9c1468d05acc4be689f8387ac0954b898d Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Sat, 26 Oct 2024 11:31:23 -0700 Subject: [PATCH] Close connection automatically if dependency injected --- iceaxe/__tests__/mountaineer/__init__.py | 0 .../mountaineer/dependencies/__init__.py | 0 .../mountaineer/dependencies/test_core.py | 73 +++++++++++++++++++ iceaxe/mountaineer/dependencies/core.py | 9 ++- 4 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 iceaxe/__tests__/mountaineer/__init__.py create mode 100644 iceaxe/__tests__/mountaineer/dependencies/__init__.py create mode 100644 iceaxe/__tests__/mountaineer/dependencies/test_core.py diff --git a/iceaxe/__tests__/mountaineer/__init__.py b/iceaxe/__tests__/mountaineer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/iceaxe/__tests__/mountaineer/dependencies/__init__.py b/iceaxe/__tests__/mountaineer/dependencies/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/iceaxe/__tests__/mountaineer/dependencies/test_core.py b/iceaxe/__tests__/mountaineer/dependencies/test_core.py new file mode 100644 index 0000000..87b471c --- /dev/null +++ b/iceaxe/__tests__/mountaineer/dependencies/test_core.py @@ -0,0 +1,73 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import asyncpg +import pytest +from mountaineer import CoreDependencies + +from iceaxe.mountaineer.config import DatabaseConfig +from iceaxe.mountaineer.dependencies.core import get_db_connection +from iceaxe.session import DBConnection + + +@pytest.fixture(autouse=True) +def mock_db_connect(): + conn = AsyncMock(spec=asyncpg.Connection) + conn.close = AsyncMock() + + with patch("asyncpg.connect", new_callable=AsyncMock) as mock: + mock.return_value = conn + yield mock + + +@pytest.fixture +def mock_config(): + return DatabaseConfig( + POSTGRES_HOST="test-host", + POSTGRES_PORT=5432, + POSTGRES_USER="test-user", + POSTGRES_PASSWORD="test-pass", + POSTGRES_DB="test-db", + ) + + +@pytest.fixture +def mock_connection(): + conn = AsyncMock(spec=asyncpg.Connection) + conn.close = AsyncMock() + return conn + + +@pytest.mark.asyncio +async def test_get_db_connection_closes_after_yield( + mock_config: DatabaseConfig, + mock_connection: AsyncMock, + mock_db_connect: AsyncMock, +): + mock_get_config = MagicMock(return_value=mock_config) + CoreDependencies.get_config_with_type = mock_get_config + + mock_db_connect.return_value = mock_connection + + # Get the generator + db_gen = get_db_connection(mock_config) + + # Get the connection + connection = await anext(db_gen) # noqa: F821 + + assert isinstance(connection, DBConnection) + assert connection.conn == mock_connection + mock_db_connect.assert_called_once_with( + host=mock_config.POSTGRES_HOST, + port=mock_config.POSTGRES_PORT, + user=mock_config.POSTGRES_USER, + password=mock_config.POSTGRES_PASSWORD, + database=mock_config.POSTGRES_DB, + ) + + # Simulate the end of the generator's scope + try: + await db_gen.aclose() + except StopAsyncIteration: + pass + + mock_connection.close.assert_called_once() diff --git a/iceaxe/mountaineer/dependencies/core.py b/iceaxe/mountaineer/dependencies/core.py index 2c04b99..524ea52 100644 --- a/iceaxe/mountaineer/dependencies/core.py +++ b/iceaxe/mountaineer/dependencies/core.py @@ -3,6 +3,8 @@ """ +from typing import AsyncGenerator + import asyncpg from mountaineer import CoreDependencies, Depends @@ -14,7 +16,7 @@ async def get_db_connection( config: DatabaseConfig = Depends( CoreDependencies.get_config_with_type(DatabaseConfig) ), -) -> DBConnection: +) -> AsyncGenerator[DBConnection, None]: conn = await asyncpg.connect( host=config.POSTGRES_HOST, port=config.POSTGRES_PORT, @@ -22,4 +24,7 @@ async def get_db_connection( password=config.POSTGRES_PASSWORD, database=config.POSTGRES_DB, ) - return DBConnection(conn) + try: + yield DBConnection(conn) + finally: + await conn.close()