Skip to content

Commit

Permalink
fix: prepare api call kwargs correctly for synchronous methods when u…
Browse files Browse the repository at this point in the history
…sing FivetranHookAsync (#115)

* fix: always use (username,password) tuple for sync auth and aiohttp.BasicAuth for async auth

This is related to issue #107

Changes:

Always return kwargs["auth"] as a tuple from FivetranHook._prepare_api_call_kwargs
Rename FivetranHookAsync._prepare_api_call_kwargs to FivetranHookAsync._prepare_api_call_kwargs_async
  • Loading branch information
JeremyDOwens authored Dec 30, 2024
1 parent 7f3a96c commit 9d42127
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 5 deletions.
8 changes: 4 additions & 4 deletions fivetran_provider_async/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _prepare_api_call_kwargs(self, method: str, endpoint: str, **kwargs: Any) ->

auth = (self.fivetran_conn.login, self.fivetran_conn.password)

kwargs.setdefault("auth", auth)
kwargs["auth"] = auth
kwargs.setdefault("headers", {})

kwargs["headers"].setdefault("User-Agent", self.api_user_agent + self._get_airflow_version())
Expand Down Expand Up @@ -634,8 +634,8 @@ class FivetranHookAsync(FivetranHook):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def _prepare_api_call_kwargs(self, method: str, endpoint: str, **kwargs: Any) -> dict[str, Any]:
kwargs = super()._prepare_api_call_kwargs(method, endpoint, **kwargs)
def _prepare_api_call_kwargs_async(self, method: str, endpoint: str, **kwargs: Any) -> dict[str, Any]:
kwargs = self._prepare_api_call_kwargs(method, endpoint, **kwargs)
auth = kwargs.get("auth")
if auth is not None and is_container(auth) and 2 <= len(auth) <= 3:
kwargs["auth"] = aiohttp.BasicAuth(*auth)
Expand Down Expand Up @@ -666,7 +666,7 @@ async def _do_api_call_async(

url = f"{self.api_protocol}://{self.api_host}/{endpoint}"

kwargs = self._prepare_api_call_kwargs(method, endpoint, **kwargs)
kwargs = self._prepare_api_call_kwargs_async(method, endpoint, **kwargs)

async with aiohttp.ClientSession() as session:
attempt_num = 1
Expand Down
66 changes: 65 additions & 1 deletion tests/hooks/test_fivetran.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pendulum
import pytest
import requests_mock
from aiohttp import ClientResponseError, RequestInfo
from aiohttp import BasicAuth, ClientResponseError, RequestInfo
from airflow.exceptions import AirflowException
from airflow.utils.helpers import is_container

from fivetran_provider_async.hooks import FivetranHook, FivetranHookAsync
from tests.common.static import (
Expand Down Expand Up @@ -596,6 +597,34 @@ async def mock_fun(arg1, arg2, arg3, arg4):
response = await hook._do_api_call_async(("POST", "v1/connectors/test"))
assert response == {"status": "success"}

@pytest.mark.asyncio
@mock.patch("fivetran_provider_async.hooks.aiohttp.ClientSession")
@mock.patch("fivetran_provider_async.hooks.FivetranHookAsync.get_connection")
async def test_do_api_call_async_verify_using_async_kwargs_preparation(
self, mock_get_connection, mock_session
):
"""Tests that _do_api_call_async calls _prepare_api_call_kwargs_async"""

async def mock_fun(arg1, arg2, arg3, arg4):
return {"status": "success"}

mock_session.return_value.__aexit__.return_value = mock_fun
mock_session.return_value.__aenter__.return_value.request.return_value.json.return_value = {
"status": "success"
}

hook = FivetranHookAsync(fivetran_conn_id="conn_fivetran")

hook.fivetran_conn = mock_get_connection
hook.fivetran_conn.login = LOGIN
hook.fivetran_conn.password = PASSWORD
with mock.patch(
"fivetran_provider_async.hooks.FivetranHookAsync._prepare_api_call_kwargs_async"
) as prep_func:
await hook._do_api_call_async(("POST", "v1/connectors/test"))

prep_func.assert_called_once_with("POST", "v1/connectors/test")

@pytest.mark.asyncio
@mock.patch("fivetran_provider_async.hooks.aiohttp.ClientSession")
@mock.patch("fivetran_provider_async.hooks.FivetranHookAsync.get_connection")
Expand Down Expand Up @@ -648,6 +677,24 @@ async def test_do_api_call_async_with_retryable_client_response_error(

assert str(exc.value) == "API requests to Fivetran failed 3 times. Giving up."

@pytest.mark.asyncio
@mock.patch("fivetran_provider_async.hooks.FivetranHookAsync.get_connection")
async def test_prepare_api_call_kwargs_async_returns_aiohttp_basicauth(self, mock_get_connection):
"""Tests to verify that the 'auth' value returned from kwarg preparation is
of type aiohttp.BasicAuth"""
hook = FivetranHookAsync(fivetran_conn_id="conn_fivetran")
hook.fivetran_conn = mock_get_connection
hook.fivetran_conn.login = LOGIN
hook.fivetran_conn.password = PASSWORD

# Test first without passing in an auth kwarg
kwargs = hook._prepare_api_call_kwargs_async("POST", "v1/connectors/test")
assert isinstance(kwargs["auth"], BasicAuth)

# Pass in auth kwarg of a different type (using a string for the test)
kwargs = hook._prepare_api_call_kwargs_async("POST", "v1/connectors/test", auth="BadAuth")
assert isinstance(kwargs["auth"], BasicAuth)


# Mock the `conn_fivetran` Airflow connection (note the `@` after `API_SECRET`)
@mock.patch.dict("os.environ", AIRFLOW_CONN_CONN_FIVETRAN="http://API_KEY:API_SECRET@")
Expand Down Expand Up @@ -794,3 +841,20 @@ def test_start_fivetran_sync(self, m):
)
result = hook.start_fivetran_sync(connector_id="interchangeable_revenge")
assert result is not None

def test_prepare_api_call_kwargs_always_returns_tuple(self):
"""Tests to verify that given a valid fivetran_conn _prepare_api_call_kwargs always returns
a username/password tuple"""
hook = FivetranHook(
fivetran_conn_id="conn_fivetran",
)

# Test first without passing in an auth kwarg
kwargs = hook._prepare_api_call_kwargs("POST", "v1/connectors/test")
assert not isinstance(kwargs["auth"], BasicAuth)
assert is_container(kwargs["auth"]) and len(kwargs["auth"]) == 2

# Pass in auth kwarg of a different type (using a string for the test)
kwargs = hook._prepare_api_call_kwargs("POST", "v1/connectors/test", auth="BadAuth")
assert not isinstance(kwargs["auth"], BasicAuth)
assert is_container(kwargs["auth"]) and len(kwargs["auth"]) == 2

0 comments on commit 9d42127

Please sign in to comment.