diff --git a/source-braintree-native/source_braintree_native/api.py b/source-braintree-native/source_braintree_native/api.py index a7b08acd5..7284630d4 100644 --- a/source-braintree-native/source_braintree_native/api.py +++ b/source-braintree-native/source_braintree_native/api.py @@ -1,3 +1,4 @@ +import asyncio from braintree import ( BraintreeGateway, AddOnGateway, @@ -7,8 +8,10 @@ CreditCardVerificationSearch, DisputeSearch, SubscriptionSearch, + ResourceCollection, ) from braintree.attribute_getter import AttributeGetter +from braintree.paginated_collection import PaginatedCollection from datetime import datetime, timedelta, UTC from logging import Logger from typing import AsyncGenerator @@ -75,14 +78,30 @@ def _braintree_object_to_dict(braintree_object): data.pop('_setattrs', None) return data -# TODO(bair): Refactor snapshot_ and fetch_ functions to make asynchronous API requests instead of synchronous requests. + +# Braintree's SDK makes synchronous API requests, which prevents multiple streams from +# sending concurrent API requests. asyncio.to_thread is used as a wrapper to run these +# synchronous API calls in a separate thread and avoid blocking the main thread's event loop. +async def _async_iterator_wrapper(collection: ResourceCollection | PaginatedCollection): + def _braintree_iterator(collection: ResourceCollection | PaginatedCollection): + for object in collection.items: + yield object + + it = await asyncio.to_thread(_braintree_iterator, collection) + + for object in it: + yield object + + async def snapshot_resources( braintree_gateway: BraintreeGateway, gateway_property: str, gateway_response_field: str | None, log: Logger, ) -> AsyncGenerator[FullRefreshResource, None]: - resources = getattr(braintree_gateway, gateway_property).all() + resources = await asyncio.to_thread( + getattr(braintree_gateway, gateway_property).all + ) iterator = getattr(resources, gateway_response_field) if gateway_response_field else resources for object in iterator: @@ -100,23 +119,21 @@ async def fetch_transactions( window_end = log_cursor + timedelta(hours=window_size) end = min(window_end, datetime.now(tz=UTC)) - collection = braintree_gateway.transaction.search( - TransactionSearch.created_at.between(log_cursor, end) + collection: ResourceCollection = await asyncio.to_thread( + braintree_gateway.transaction.search, + TransactionSearch.created_at.between(log_cursor, end), ) - count = 0 + if collection.maximum_size >= TRANSACTION_SEARCH_LIMIT: + raise RuntimeError(_search_limit_error_message(collection.maximum_size, "transactions")) - for object in collection.items: - count += 1 + async for object in _async_iterator_wrapper(collection): doc = IncrementalResource.model_validate(_braintree_object_to_dict(object)) if doc.created_at > log_cursor: yield doc most_recent_created_at = doc.created_at - if count >= TRANSACTION_SEARCH_LIMIT: - raise RuntimeError(_search_limit_error_message(count, "transactions")) - if end == window_end: yield window_end elif most_recent_created_at > log_cursor: @@ -134,23 +151,21 @@ async def fetch_customers( window_end = log_cursor + timedelta(hours=window_size) end = min(window_end, datetime.now(tz=UTC)) - collection = braintree_gateway.customer.search( - CustomerSearch.created_at.between(log_cursor, end) + collection: ResourceCollection = await asyncio.to_thread( + braintree_gateway.customer.search, + CustomerSearch.created_at.between(log_cursor, end), ) - count = 0 + if collection.maximum_size >= SEARCH_LIMIT: + raise RuntimeError(_search_limit_error_message(collection.maximum_size, "customers")) - for object in collection.items: - count += 1 + async for object in _async_iterator_wrapper(collection): doc = IncrementalResource.model_validate(_braintree_object_to_dict(object)) if doc.created_at > log_cursor: yield doc most_recent_created_at = doc.created_at - if count >= SEARCH_LIMIT: - raise RuntimeError(_search_limit_error_message(count, "customers")) - if end == window_end: yield window_end elif most_recent_created_at > log_cursor: @@ -168,29 +183,26 @@ async def fetch_credit_card_verifications( window_end = log_cursor + timedelta(hours=window_size) end = min(window_end, datetime.now(tz=UTC)) - collection = braintree_gateway.verification.search( - CreditCardVerificationSearch.created_at.between(log_cursor, end) + collection: ResourceCollection = await asyncio.to_thread( + braintree_gateway.verification.search, + CreditCardVerificationSearch.created_at.between(log_cursor, end), ) - count = 0 + if collection.maximum_size >= SEARCH_LIMIT: + raise RuntimeError(_search_limit_error_message(collection.maximum_size, "credit card verifications")) - for object in collection.items: - count += 1 + async for object in _async_iterator_wrapper(collection): doc = IncrementalResource.model_validate(_braintree_object_to_dict(object)) if doc.created_at > log_cursor: yield doc most_recent_created_at = doc.created_at - if count >= SEARCH_LIMIT: - raise RuntimeError(_search_limit_error_message(count, "credit card verifications")) - if end == window_end: yield window_end elif most_recent_created_at > log_cursor: yield most_recent_created_at - async def fetch_subscriptions( braintree_gateway: BraintreeGateway, window_size: int, @@ -202,23 +214,21 @@ async def fetch_subscriptions( window_end = log_cursor + timedelta(hours=window_size) end = min(window_end, datetime.now(tz=UTC)) - collection = braintree_gateway.subscription.search( - SubscriptionSearch.created_at.between(log_cursor, end) + collection: ResourceCollection = await asyncio.to_thread( + braintree_gateway.subscription.search, + SubscriptionSearch.created_at.between(log_cursor, end), ) - count = 0 + if collection.maximum_size >= SEARCH_LIMIT: + raise RuntimeError(_search_limit_error_message(collection.maximum_size, "subscriptions")) - for object in collection.items: - count += 1 + async for object in _async_iterator_wrapper(collection): doc = IncrementalResource.model_validate(_braintree_object_to_dict(object)) if doc.created_at > log_cursor: yield doc most_recent_created_at = doc.created_at - if count >= SEARCH_LIMIT: - raise RuntimeError(_search_limit_error_message(count, "subscriptions")) - if end == window_end: yield window_end elif most_recent_created_at > log_cursor: @@ -239,13 +249,14 @@ async def fetch_disputes( window_end = log_cursor + timedelta(hours=window_size) end = min(window_end, datetime.now(tz=UTC)) - collection = braintree_gateway.dispute.search( - DisputeSearch.received_date.between(start, end) + search_result = await asyncio.to_thread( + braintree_gateway.dispute.search, + DisputeSearch.received_date.between(start, end), ) count = 0 - for object in collection.disputes: + async for object in _async_iterator_wrapper(search_result.disputes): count += 1 doc = IncrementalResource.model_validate(_braintree_object_to_dict(object))