diff --git a/pyrevolut/api/payout_links/resources/payout_link.py b/pyrevolut/api/payout_links/resources/payout_link.py index bf58934..5d083d7 100644 --- a/pyrevolut/api/payout_links/resources/payout_link.py +++ b/pyrevolut/api/payout_links/resources/payout_link.py @@ -89,7 +89,7 @@ class ResourcePayoutLink(BaseModel): request_id: Annotated[ str, Field( - description="The ID of the request, provided by the sender.", max_length=40 + description="The ID of the request, provided by the sender.", ), ] expiry_date: Annotated[ diff --git a/pyrevolut/api/transactions/resources/transaction.py b/pyrevolut/api/transactions/resources/transaction.py index 8ea3cb2..9874ba0 100644 --- a/pyrevolut/api/transactions/resources/transaction.py +++ b/pyrevolut/api/transactions/resources/transaction.py @@ -132,7 +132,7 @@ class ModelCard(BaseModel): request_id: Annotated[ str | None, Field( - description="The request ID that you provided previously.", max_length=40 + description="The request ID that you provided previously.", ), ] = None state: Annotated[ diff --git a/pyrevolut/api/webhooks/resources/transaction_created.py b/pyrevolut/api/webhooks/resources/transaction_created.py index ed74e2e..6a58e8e 100644 --- a/pyrevolut/api/webhooks/resources/transaction_created.py +++ b/pyrevolut/api/webhooks/resources/transaction_created.py @@ -96,7 +96,7 @@ class ModelCounterparty(BaseModel): request_id: Annotated[ str | None, Field( - description="The request ID that you provided previously.", max_length=40 + description="The request ID that you provided previously.", ), ] = None reason_code: Annotated[ diff --git a/pyrevolut/client/asynchronous.py b/pyrevolut/client/asynchronous.py index a5247ba..eb83956 100644 --- a/pyrevolut/client/asynchronous.py +++ b/pyrevolut/client/asynchronous.py @@ -25,17 +25,17 @@ class AsyncClient(BaseClient): """The asynchronous client for the Revolut API""" - Accounts: EndpointAccountsAsync | None = None - Cards: EndpointCardsAsync | None = None - Counterparties: EndpointCounterpartiesAsync | None = None - ForeignExchange: EndpointForeignExchangeAsync | None = None - PaymentDrafts: EndpointPaymentDraftsAsync | None = None - PayoutLinks: EndpointPayoutLinksAsync | None = None - Simulations: EndpointSimulationsAsync | None = None - TeamMembers: EndpointTeamMembersAsync | None = None - Transactions: EndpointTransactionsAsync | None = None - Transfers: EndpointTransfersAsync | None = None - Webhooks: EndpointWebhooksAsync | None = None + Accounts: EndpointAccountsAsync + Cards: EndpointCardsAsync + Counterparties: EndpointCounterpartiesAsync + ForeignExchange: EndpointForeignExchangeAsync + PaymentDrafts: EndpointPaymentDraftsAsync + PayoutLinks: EndpointPayoutLinksAsync + Simulations: EndpointSimulationsAsync + TeamMembers: EndpointTeamMembersAsync + Transactions: EndpointTransactionsAsync + Transfers: EndpointTransfersAsync + Webhooks: EndpointWebhooksAsync async def open(self): """Opens the client connection""" @@ -43,7 +43,6 @@ async def open(self): return self.client = HTTPClient() - self.load_endpoints() async def close(self): """Closes the client connection""" diff --git a/pyrevolut/client/base.py b/pyrevolut/client/base.py index cd95d3e..59931e7 100644 --- a/pyrevolut/client/base.py +++ b/pyrevolut/client/base.py @@ -1,6 +1,7 @@ from typing import Type, TypeVar, Literal, Annotated import logging import json +import base64 from pydantic import BaseModel, Field import pendulum @@ -53,6 +54,7 @@ class BaseClient: """The base client for the Revolut API""" creds_loc: str + creds: str | dict | None = None credentials: ModelCreds domain: str sandbox: bool @@ -63,6 +65,7 @@ class BaseClient: def __init__( self, creds_loc: str = "credentials/creds.json", + creds: str | dict | None = None, sandbox: bool = True, return_type: Literal["raw", "dict", "model"] = "dict", error_response: Literal["raw", "raise", "dict", "model"] = "raise", @@ -72,7 +75,12 @@ def __init__( Parameters ---------- creds_loc : str, optional - The location of the credentials file, by default "credentials/creds.json" + The location of the credentials file, by default "credentials/creds.json". + If the creds input is not provided, will load the credentials from this file. + creds : str | dict, optional + The credentials to use for the client, by default None. If not provided, will + load the credentials from the creds_loc file. + Can be a dictionary of the credentials or a base64 encoded string of the credentials dictionary. sandbox : bool, optional Whether to use the sandbox environment, by default True return_type : Literal["raw", "dict", "model"], optional @@ -95,19 +103,21 @@ def __init__( If "model": The client will return a Pydantic model of the error response """ - self.creds_loc = creds_loc - self.sandbox = sandbox assert return_type in [ "raw", "dict", "model", ], "return_type must be 'raw', 'dict', or 'model'" - self.return_type = return_type assert error_response in [ "raise", "dict", "model", ], "error_response must be 'raise', 'dict', or 'model'" + + self.creds_loc = creds_loc + self.creds = creds + self.sandbox = sandbox + self.return_type = return_type self.error_response = error_response # Set domain based on environment @@ -119,6 +129,9 @@ def __init__( # Load the credentials self.load_credentials() + # Load the endpoints + self.load_endpoints() + def process_response( self, response: Response, @@ -475,6 +488,10 @@ def _prep_put( **kwargs, } + def load_endpoints(self): + """Loads all the endpoints from the api directory""" + raise NotImplementedError("load_endpoints method must be implemented") + @property def required_headers(self) -> dict[str, str]: """The headers to be attached to each request @@ -514,7 +531,9 @@ def __check_client(self): If the client is not open or if the long-term credentials have expired """ if self.client is None: - raise ValueError("Client is not open") + raise RuntimeError( + "Client is not open. Use .open() or the contextmanager to open the client." + ) if self.credentials.credentials_expired: raise ValueError( @@ -604,8 +623,9 @@ def __replace_null_with_none(self, data: D) -> D: return data def load_credentials(self): - """Load the credentials from the credentials file. + """Load the credentials from the credentials inputs. + - If credentials are not provided, will load them from the credentials file. - If the credentials file does not exist, raise an error. - If the credentials file is invalid, raise an error. - If the credentials are expired, raise an error. @@ -616,14 +636,28 @@ def load_credentials(self): "\n\nPlease reauthenticate using the `pyrevolut auth-manual` command." ) - try: - self.credentials = load_creds(location=self.creds_loc) - except FileNotFoundError as exc: - raise ValueError( - f"Credentials file not found: {exc}. {solution_msg}" - ) from exc - except Exception as exc: - raise ValueError(f"Error loading credentials: {exc}.") from exc + if self.creds is not None: + if isinstance(self.creds, str): + _creds = json.loads(base64.b64decode(self.creds).decode("utf-8")) + else: + _creds = self.creds + try: + self.credentials = ModelCreds(**_creds) + except Exception as exc: + raise ValueError( + f"Error loading credentials: {exc}. {solution_msg}" + ) from exc + else: + try: + self.credentials = load_creds(location=self.creds_loc) + except FileNotFoundError as exc: + raise ValueError( + f"Credentials file not found: {exc}. {solution_msg}" + ) from exc + except Exception as exc: + raise ValueError( + f"Error loading credentials: {exc}. {solution_msg}" + ) from exc # Check if the credentials are still valid if self.credentials.credentials_expired: diff --git a/pyrevolut/client/synchronous.py b/pyrevolut/client/synchronous.py index da0e56f..a8b2b4c 100644 --- a/pyrevolut/client/synchronous.py +++ b/pyrevolut/client/synchronous.py @@ -24,17 +24,17 @@ class Client(BaseClient): """The synchronous client for the Revolut API""" - Accounts: EndpointAccountsSync | None = None - Cards: EndpointCardsSync | None = None - Counterparties: EndpointCounterpartiesSync | None = None - ForeignExchange: EndpointForeignExchangeSync | None = None - PaymentDrafts: EndpointPaymentDraftsSync | None = None - PayoutLinks: EndpointPayoutLinksSync | None = None - Simulations: EndpointSimulationsSync | None = None - TeamMembers: EndpointTeamMembersSync | None = None - Transactions: EndpointTransactionsSync | None = None - Transfers: EndpointTransfersSync | None = None - Webhooks: EndpointWebhooksSync | None = None + Accounts: EndpointAccountsSync + Cards: EndpointCardsSync + Counterparties: EndpointCounterpartiesSync + ForeignExchange: EndpointForeignExchangeSync + PaymentDrafts: EndpointPaymentDraftsSync + PayoutLinks: EndpointPayoutLinksSync + Simulations: EndpointSimulationsSync + TeamMembers: EndpointTeamMembersSync + Transactions: EndpointTransactionsSync + Transfers: EndpointTransfersSync + Webhooks: EndpointWebhooksSync def open(self): """Opens the client connection""" @@ -42,7 +42,6 @@ def open(self): return self.client = HTTPClient() - self.load_endpoints() def close(self): """Closes the client connection""" diff --git a/pyrevolut/utils/date.py b/pyrevolut/utils/date.py index d5e62b0..80687ce 100644 --- a/pyrevolut/utils/date.py +++ b/pyrevolut/utils/date.py @@ -33,7 +33,7 @@ def __get_pydantic_core_schema__( A Pydantic CoreSchema with the Date validation. """ return core_schema.no_info_wrap_validator_function( - cls._validate, core_schema.datetime_schema() + cls._validate, core_schema.date_schema() ) @classmethod @@ -120,7 +120,8 @@ def string_to_date(string: str) -> Date: ] for format in formats: try: - return pendulum.from_format(string=string, fmt=format, tz="UTC") + dt = pendulum.from_format(string=string, fmt=format, tz="UTC") + return pendulum.Date(year=dt.year, month=dt.month, day=dt.day) except Exception: pass raise PydanticCustomError(f"Error converting string to pendulum Date: {string}") diff --git a/pyrevolut/utils/datetime.py b/pyrevolut/utils/datetime.py index d6db033..e840e8b 100644 --- a/pyrevolut/utils/datetime.py +++ b/pyrevolut/utils/datetime.py @@ -59,11 +59,6 @@ def _validate( Any The validated value or raises a PydanticCustomError. """ - # if we are passed an existing instance, pass it straight through. - if isinstance(value, _DateTime): - return handler(value) - - # otherwise, parse it. try: data = to_datetime(value) except Exception as exc: diff --git a/tests/conftest.py b/tests/conftest.py index 1ebfd77..b1bbf29 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,8 @@ import time import random import itertools +import json +import base64 import pytest import pytest_asyncio @@ -42,6 +44,7 @@ # All the JSON files in the credentials folder CREDENTIALS_LOC = glob.glob(os.path.join("tests/credentials", "*.json")) CREDENTIALS_LOC_ITER = itertools.cycle(CREDENTIALS_LOC) +CREDENTIALS_CHOICE_ITER = itertools.cycle(["creds_loc", "creds", "creds_base64"]) @pytest.fixture(scope="session", autouse=True) @@ -55,23 +58,56 @@ def event_loop(): @pytest.fixture(scope="function") -def random_creds_file(): +def random_creds(): """Context manager that selects a random credentials file Yields ------ - str - The path to the credentials file + dict[str, str | dict] + A dictionary containing the credentials location, credentials, or base64 encoded credentials. + Will randomly choose between providing the creds_loc, creds, or creds_base64. + For example: { + "creds_loc": "tests/credentials/creds_1.json", + "creds": None, + } + or { + "creds_loc": None, + "creds": { + "client_id": "12345678-abcd-1234-abcd-1234567890ab", + "client_secret": "123" + } + } + or { + "creds_loc": None, + "creds": "eyJjbGllbnRfa" # Base64 encoded credentials + } """ # Select a random credentials file creds_loc = next(CREDENTIALS_LOC_ITER) - # Yield for test - yield creds_loc + # Select a random credentials loading choice + choice = next(CREDENTIALS_CHOICE_ITER) + + # Load the credentials file + with open(creds_loc, "r") as file: + creds = json.load(file) + + # Base64 encode the credentials dict + creds_base64 = base64.b64encode(json.dumps(creds).encode(encoding="utf-8")).decode( + encoding="utf-8" + ) + + # Randomly choose between providing the creds_loc, creds, or creds_base64 + if choice == "creds_loc": + yield {"creds_loc": creds_loc, "creds": None} + elif choice == "creds": + yield {"creds_loc": creds_loc, "creds": creds} + else: + yield {"creds_loc": creds_loc, "creds": creds_base64} @pytest.fixture(scope="function") -def base_sync_client(random_creds_file: str): +def base_sync_client(random_creds: dict[str, str | dict]): """Context manager that initializes the sync client Yields @@ -80,7 +116,8 @@ def base_sync_client(random_creds_file: str): """ # Initialize the client client = Client( - creds_loc=random_creds_file, + creds_loc=random_creds["creds_loc"], + creds=random_creds["creds"], sandbox=True, return_type="dict", ) @@ -90,16 +127,18 @@ def base_sync_client(random_creds_file: str): @pytest.fixture(scope="function") -def base_async_client(random_creds_file: str): +def base_async_client(random_creds: dict[str, str | dict]): """Context manager that initializes the async client Yields ------ None """ + # Initialize the client client = AsyncClient( - creds_loc=random_creds_file, + creds_loc=random_creds["creds_loc"], + creds=random_creds["creds"], sandbox=True, return_type="dict", ) diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py index bbab0c0..b622f22 100644 --- a/tests/test_webhooks.py +++ b/tests/test_webhooks.py @@ -348,9 +348,9 @@ async def test_webhook_app(async_client: Client, litestar_client_url: str): ### Create a transaction ### # Get all accounts - accounts: list[ - RetrieveAllAccounts.Response - ] = await async_client.Accounts.get_all_accounts() + accounts: list[RetrieveAllAccounts.Response] = ( + await async_client.Accounts.get_all_accounts() + ) await asyncio.sleep(random.randint(1, 3)) # Get EUR account @@ -378,9 +378,9 @@ async def test_webhook_app(async_client: Client, litestar_client_url: str): assert response.state == EnumTransactionState.COMPLETED # Get all counterparties - counterparties: list[ - RetrieveListOfCounterparties.Response - ] = await async_client.Counterparties.get_all_counterparties() + counterparties: list[RetrieveListOfCounterparties.Response] = ( + await async_client.Counterparties.get_all_counterparties() + ) # Get a EUR counterparty with an IBAN eur_counterparties: list[RetrieveListOfCounterparties.Response] = []