diff --git a/fivetran_provider_async/hooks.py b/fivetran_provider_async/hooks.py index cacf279..9368b78 100644 --- a/fivetran_provider_async/hooks.py +++ b/fivetran_provider_async/hooks.py @@ -106,7 +106,7 @@ def _prepare_api_call_kwargs(self, method: str, endpoint: str, **kwargs: Any) -> auth = (self.fivetran_conn.login, self.fivetran_conn.password) - kwargs.setdefault("auth", auth) + kwargs["auth"] = auth kwargs.setdefault("headers", {}) kwargs["headers"].setdefault("User-Agent", self.api_user_agent + self._get_airflow_version()) @@ -634,8 +634,8 @@ class FivetranHookAsync(FivetranHook): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - def _prepare_api_call_kwargs(self, method: str, endpoint: str, **kwargs: Any) -> dict[str, Any]: - kwargs = super()._prepare_api_call_kwargs(method, endpoint, **kwargs) + def _prepare_api_call_kwargs_async(self, method: str, endpoint: str, **kwargs: Any) -> dict[str, Any]: + kwargs = self._prepare_api_call_kwargs(method, endpoint, **kwargs) auth = kwargs.get("auth") if auth is not None and is_container(auth) and 2 <= len(auth) <= 3: kwargs["auth"] = aiohttp.BasicAuth(*auth) @@ -666,7 +666,7 @@ async def _do_api_call_async( url = f"{self.api_protocol}://{self.api_host}/{endpoint}" - kwargs = self._prepare_api_call_kwargs(method, endpoint, **kwargs) + kwargs = self._prepare_api_call_kwargs_async(method, endpoint, **kwargs) async with aiohttp.ClientSession() as session: attempt_num = 1 diff --git a/tests/hooks/test_fivetran.py b/tests/hooks/test_fivetran.py index 26dade7..1bf4518 100644 --- a/tests/hooks/test_fivetran.py +++ b/tests/hooks/test_fivetran.py @@ -5,8 +5,9 @@ import pendulum import pytest import requests_mock -from aiohttp import ClientResponseError, RequestInfo +from aiohttp import BasicAuth, ClientResponseError, RequestInfo from airflow.exceptions import AirflowException +from airflow.utils.helpers import is_container from fivetran_provider_async.hooks import FivetranHook, FivetranHookAsync from tests.common.static import ( @@ -596,6 +597,34 @@ async def mock_fun(arg1, arg2, arg3, arg4): response = await hook._do_api_call_async(("POST", "v1/connectors/test")) assert response == {"status": "success"} + @pytest.mark.asyncio + @mock.patch("fivetran_provider_async.hooks.aiohttp.ClientSession") + @mock.patch("fivetran_provider_async.hooks.FivetranHookAsync.get_connection") + async def test_do_api_call_async_verify_using_async_kwargs_preparation( + self, mock_get_connection, mock_session + ): + """Tests that _do_api_call_async calls _prepare_api_call_kwargs_async""" + + async def mock_fun(arg1, arg2, arg3, arg4): + return {"status": "success"} + + mock_session.return_value.__aexit__.return_value = mock_fun + mock_session.return_value.__aenter__.return_value.request.return_value.json.return_value = { + "status": "success" + } + + hook = FivetranHookAsync(fivetran_conn_id="conn_fivetran") + + hook.fivetran_conn = mock_get_connection + hook.fivetran_conn.login = LOGIN + hook.fivetran_conn.password = PASSWORD + with mock.patch( + "fivetran_provider_async.hooks.FivetranHookAsync._prepare_api_call_kwargs_async" + ) as prep_func: + await hook._do_api_call_async(("POST", "v1/connectors/test")) + + prep_func.assert_called_once_with("POST", "v1/connectors/test") + @pytest.mark.asyncio @mock.patch("fivetran_provider_async.hooks.aiohttp.ClientSession") @mock.patch("fivetran_provider_async.hooks.FivetranHookAsync.get_connection") @@ -648,6 +677,24 @@ async def test_do_api_call_async_with_retryable_client_response_error( assert str(exc.value) == "API requests to Fivetran failed 3 times. Giving up." + @pytest.mark.asyncio + @mock.patch("fivetran_provider_async.hooks.FivetranHookAsync.get_connection") + async def test_prepare_api_call_kwargs_async_returns_aiohttp_basicauth(self, mock_get_connection): + """Tests to verify that the 'auth' value returned from kwarg preparation is + of type aiohttp.BasicAuth""" + hook = FivetranHookAsync(fivetran_conn_id="conn_fivetran") + hook.fivetran_conn = mock_get_connection + hook.fivetran_conn.login = LOGIN + hook.fivetran_conn.password = PASSWORD + + # Test first without passing in an auth kwarg + kwargs = hook._prepare_api_call_kwargs_async("POST", "v1/connectors/test") + assert isinstance(kwargs["auth"], BasicAuth) + + # Pass in auth kwarg of a different type (using a string for the test) + kwargs = hook._prepare_api_call_kwargs_async("POST", "v1/connectors/test", auth="BadAuth") + assert isinstance(kwargs["auth"], BasicAuth) + # Mock the `conn_fivetran` Airflow connection (note the `@` after `API_SECRET`) @mock.patch.dict("os.environ", AIRFLOW_CONN_CONN_FIVETRAN="http://API_KEY:API_SECRET@") @@ -794,3 +841,20 @@ def test_start_fivetran_sync(self, m): ) result = hook.start_fivetran_sync(connector_id="interchangeable_revenge") assert result is not None + + def test_prepare_api_call_kwargs_always_returns_tuple(self): + """Tests to verify that given a valid fivetran_conn _prepare_api_call_kwargs always returns + a username/password tuple""" + hook = FivetranHook( + fivetran_conn_id="conn_fivetran", + ) + + # Test first without passing in an auth kwarg + kwargs = hook._prepare_api_call_kwargs("POST", "v1/connectors/test") + assert not isinstance(kwargs["auth"], BasicAuth) + assert is_container(kwargs["auth"]) and len(kwargs["auth"]) == 2 + + # Pass in auth kwarg of a different type (using a string for the test) + kwargs = hook._prepare_api_call_kwargs("POST", "v1/connectors/test", auth="BadAuth") + assert not isinstance(kwargs["auth"], BasicAuth) + assert is_container(kwargs["auth"]) and len(kwargs["auth"]) == 2