Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #17

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyrevolut/api/payout_links/resources/payout_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ResourcePayoutLink(BaseModel):
request_id: Annotated[
str,
Field(
description="The ID of the request, provided by the sender.", max_length=40
description="The ID of the request, provided by the sender.",
),
]
expiry_date: Annotated[
Expand Down
2 changes: 1 addition & 1 deletion pyrevolut/api/transactions/resources/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class ModelCard(BaseModel):
request_id: Annotated[
str | None,
Field(
description="The request ID that you provided previously.", max_length=40
description="The request ID that you provided previously.",
),
] = None
state: Annotated[
Expand Down
2 changes: 1 addition & 1 deletion pyrevolut/api/webhooks/resources/transaction_created.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class ModelCounterparty(BaseModel):
request_id: Annotated[
str | None,
Field(
description="The request ID that you provided previously.", max_length=40
description="The request ID that you provided previously.",
),
] = None
reason_code: Annotated[
Expand Down
23 changes: 11 additions & 12 deletions pyrevolut/client/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,24 @@
class AsyncClient(BaseClient):
"""The asynchronous client for the Revolut API"""

Accounts: EndpointAccountsAsync | None = None
Cards: EndpointCardsAsync | None = None
Counterparties: EndpointCounterpartiesAsync | None = None
ForeignExchange: EndpointForeignExchangeAsync | None = None
PaymentDrafts: EndpointPaymentDraftsAsync | None = None
PayoutLinks: EndpointPayoutLinksAsync | None = None
Simulations: EndpointSimulationsAsync | None = None
TeamMembers: EndpointTeamMembersAsync | None = None
Transactions: EndpointTransactionsAsync | None = None
Transfers: EndpointTransfersAsync | None = None
Webhooks: EndpointWebhooksAsync | None = None
Accounts: EndpointAccountsAsync
Cards: EndpointCardsAsync
Counterparties: EndpointCounterpartiesAsync
ForeignExchange: EndpointForeignExchangeAsync
PaymentDrafts: EndpointPaymentDraftsAsync
PayoutLinks: EndpointPayoutLinksAsync
Simulations: EndpointSimulationsAsync
TeamMembers: EndpointTeamMembersAsync
Transactions: EndpointTransactionsAsync
Transfers: EndpointTransfersAsync
Webhooks: EndpointWebhooksAsync

async def open(self):
"""Opens the client connection"""
if self.client is not None:
return

self.client = HTTPClient()
self.load_endpoints()

async def close(self):
"""Closes the client connection"""
Expand Down
62 changes: 48 additions & 14 deletions pyrevolut/client/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Type, TypeVar, Literal, Annotated
import logging
import json
import base64

from pydantic import BaseModel, Field
import pendulum
Expand Down Expand Up @@ -53,6 +54,7 @@
"""The base client for the Revolut API"""

creds_loc: str
creds: str | dict | None = None
credentials: ModelCreds
domain: str
sandbox: bool
Expand All @@ -63,6 +65,7 @@
def __init__(
self,
creds_loc: str = "credentials/creds.json",
creds: str | dict | None = None,
sandbox: bool = True,
return_type: Literal["raw", "dict", "model"] = "dict",
error_response: Literal["raw", "raise", "dict", "model"] = "raise",
Expand All @@ -72,7 +75,12 @@
Parameters
----------
creds_loc : str, optional
The location of the credentials file, by default "credentials/creds.json"
The location of the credentials file, by default "credentials/creds.json".
If the creds input is not provided, will load the credentials from this file.
creds : str | dict, optional
The credentials to use for the client, by default None. If not provided, will
load the credentials from the creds_loc file.
Can be a dictionary of the credentials or a base64 encoded string of the credentials dictionary.
sandbox : bool, optional
Whether to use the sandbox environment, by default True
return_type : Literal["raw", "dict", "model"], optional
Expand All @@ -95,19 +103,21 @@
If "model":
The client will return a Pydantic model of the error response
"""
self.creds_loc = creds_loc
self.sandbox = sandbox
assert return_type in [
"raw",
"dict",
"model",
], "return_type must be 'raw', 'dict', or 'model'"
self.return_type = return_type
assert error_response in [
"raise",
"dict",
"model",
], "error_response must be 'raise', 'dict', or 'model'"

self.creds_loc = creds_loc
self.creds = creds
self.sandbox = sandbox
self.return_type = return_type
self.error_response = error_response

# Set domain based on environment
Expand All @@ -119,6 +129,9 @@
# Load the credentials
self.load_credentials()

# Load the endpoints
self.load_endpoints()

def process_response(
self,
response: Response,
Expand Down Expand Up @@ -475,6 +488,10 @@
**kwargs,
}

def load_endpoints(self):
"""Loads all the endpoints from the api directory"""
raise NotImplementedError("load_endpoints method must be implemented")

@property
def required_headers(self) -> dict[str, str]:
"""The headers to be attached to each request
Expand Down Expand Up @@ -514,7 +531,9 @@
If the client is not open or if the long-term credentials have expired
"""
if self.client is None:
raise ValueError("Client is not open")
raise RuntimeError(

Check warning on line 534 in pyrevolut/client/base.py

View check run for this annotation

Codecov / codecov/patch

pyrevolut/client/base.py#L534

Added line #L534 was not covered by tests
"Client is not open. Use .open() or the contextmanager to open the client."
)

if self.credentials.credentials_expired:
raise ValueError(
Expand Down Expand Up @@ -604,8 +623,9 @@
return data

def load_credentials(self):
"""Load the credentials from the credentials file.
"""Load the credentials from the credentials inputs.

- If credentials are not provided, will load them from the credentials file.
- If the credentials file does not exist, raise an error.
- If the credentials file is invalid, raise an error.
- If the credentials are expired, raise an error.
Expand All @@ -616,14 +636,28 @@
"\n\nPlease reauthenticate using the `pyrevolut auth-manual` command."
)

try:
self.credentials = load_creds(location=self.creds_loc)
except FileNotFoundError as exc:
raise ValueError(
f"Credentials file not found: {exc}. {solution_msg}"
) from exc
except Exception as exc:
raise ValueError(f"Error loading credentials: {exc}.") from exc
if self.creds is not None:
if isinstance(self.creds, str):
_creds = json.loads(base64.b64decode(self.creds).decode("utf-8"))
else:
_creds = self.creds
try:
self.credentials = ModelCreds(**_creds)
except Exception as exc:
raise ValueError(

Check warning on line 647 in pyrevolut/client/base.py

View check run for this annotation

Codecov / codecov/patch

pyrevolut/client/base.py#L646-L647

Added lines #L646 - L647 were not covered by tests
f"Error loading credentials: {exc}. {solution_msg}"
) from exc
else:
try:
self.credentials = load_creds(location=self.creds_loc)
except FileNotFoundError as exc:
raise ValueError(

Check warning on line 654 in pyrevolut/client/base.py

View check run for this annotation

Codecov / codecov/patch

pyrevolut/client/base.py#L653-L654

Added lines #L653 - L654 were not covered by tests
f"Credentials file not found: {exc}. {solution_msg}"
) from exc
except Exception as exc:
raise ValueError(

Check warning on line 658 in pyrevolut/client/base.py

View check run for this annotation

Codecov / codecov/patch

pyrevolut/client/base.py#L657-L658

Added lines #L657 - L658 were not covered by tests
f"Error loading credentials: {exc}. {solution_msg}"
) from exc

# Check if the credentials are still valid
if self.credentials.credentials_expired:
Expand Down
23 changes: 11 additions & 12 deletions pyrevolut/client/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,24 @@
class Client(BaseClient):
"""The synchronous client for the Revolut API"""

Accounts: EndpointAccountsSync | None = None
Cards: EndpointCardsSync | None = None
Counterparties: EndpointCounterpartiesSync | None = None
ForeignExchange: EndpointForeignExchangeSync | None = None
PaymentDrafts: EndpointPaymentDraftsSync | None = None
PayoutLinks: EndpointPayoutLinksSync | None = None
Simulations: EndpointSimulationsSync | None = None
TeamMembers: EndpointTeamMembersSync | None = None
Transactions: EndpointTransactionsSync | None = None
Transfers: EndpointTransfersSync | None = None
Webhooks: EndpointWebhooksSync | None = None
Accounts: EndpointAccountsSync
Cards: EndpointCardsSync
Counterparties: EndpointCounterpartiesSync
ForeignExchange: EndpointForeignExchangeSync
PaymentDrafts: EndpointPaymentDraftsSync
PayoutLinks: EndpointPayoutLinksSync
Simulations: EndpointSimulationsSync
TeamMembers: EndpointTeamMembersSync
Transactions: EndpointTransactionsSync
Transfers: EndpointTransfersSync
Webhooks: EndpointWebhooksSync

def open(self):
"""Opens the client connection"""
if self.client is not None:
return

self.client = HTTPClient()
self.load_endpoints()

def close(self):
"""Closes the client connection"""
Expand Down
5 changes: 3 additions & 2 deletions pyrevolut/utils/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __get_pydantic_core_schema__(
A Pydantic CoreSchema with the Date validation.
"""
return core_schema.no_info_wrap_validator_function(
cls._validate, core_schema.datetime_schema()
cls._validate, core_schema.date_schema()
)

@classmethod
Expand Down Expand Up @@ -120,7 +120,8 @@ def string_to_date(string: str) -> Date:
]
for format in formats:
try:
return pendulum.from_format(string=string, fmt=format, tz="UTC")
dt = pendulum.from_format(string=string, fmt=format, tz="UTC")
return pendulum.Date(year=dt.year, month=dt.month, day=dt.day)
except Exception:
pass
raise PydanticCustomError(f"Error converting string to pendulum Date: {string}")
5 changes: 0 additions & 5 deletions pyrevolut/utils/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ def _validate(
Any
The validated value or raises a PydanticCustomError.
"""
# if we are passed an existing instance, pass it straight through.
if isinstance(value, _DateTime):
return handler(value)

# otherwise, parse it.
try:
data = to_datetime(value)
except Exception as exc:
Expand Down
57 changes: 48 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import time
import random
import itertools
import json
import base64

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -42,6 +44,7 @@
# All the JSON files in the credentials folder
CREDENTIALS_LOC = glob.glob(os.path.join("tests/credentials", "*.json"))
CREDENTIALS_LOC_ITER = itertools.cycle(CREDENTIALS_LOC)
CREDENTIALS_CHOICE_ITER = itertools.cycle(["creds_loc", "creds", "creds_base64"])


@pytest.fixture(scope="session", autouse=True)
Expand All @@ -55,23 +58,56 @@ def event_loop():


@pytest.fixture(scope="function")
def random_creds_file():
def random_creds():
"""Context manager that selects a random credentials file

Yields
------
str
The path to the credentials file
dict[str, str | dict]
A dictionary containing the credentials location, credentials, or base64 encoded credentials.
Will randomly choose between providing the creds_loc, creds, or creds_base64.
For example: {
"creds_loc": "tests/credentials/creds_1.json",
"creds": None,
}
or {
"creds_loc": None,
"creds": {
"client_id": "12345678-abcd-1234-abcd-1234567890ab",
"client_secret": "123"
}
}
or {
"creds_loc": None,
"creds": "eyJjbGllbnRfa" # Base64 encoded credentials
}
"""
# Select a random credentials file
creds_loc = next(CREDENTIALS_LOC_ITER)

# Yield for test
yield creds_loc
# Select a random credentials loading choice
choice = next(CREDENTIALS_CHOICE_ITER)

# Load the credentials file
with open(creds_loc, "r") as file:
creds = json.load(file)

# Base64 encode the credentials dict
creds_base64 = base64.b64encode(json.dumps(creds).encode(encoding="utf-8")).decode(
encoding="utf-8"
)

# Randomly choose between providing the creds_loc, creds, or creds_base64
if choice == "creds_loc":
yield {"creds_loc": creds_loc, "creds": None}
elif choice == "creds":
yield {"creds_loc": creds_loc, "creds": creds}
else:
yield {"creds_loc": creds_loc, "creds": creds_base64}


@pytest.fixture(scope="function")
def base_sync_client(random_creds_file: str):
def base_sync_client(random_creds: dict[str, str | dict]):
"""Context manager that initializes the sync client

Yields
Expand All @@ -80,7 +116,8 @@ def base_sync_client(random_creds_file: str):
"""
# Initialize the client
client = Client(
creds_loc=random_creds_file,
creds_loc=random_creds["creds_loc"],
creds=random_creds["creds"],
sandbox=True,
return_type="dict",
)
Expand All @@ -90,16 +127,18 @@ def base_sync_client(random_creds_file: str):


@pytest.fixture(scope="function")
def base_async_client(random_creds_file: str):
def base_async_client(random_creds: dict[str, str | dict]):
"""Context manager that initializes the async client

Yields
------
None
"""

# Initialize the client
client = AsyncClient(
creds_loc=random_creds_file,
creds_loc=random_creds["creds_loc"],
creds=random_creds["creds"],
sandbox=True,
return_type="dict",
)
Expand Down
Loading
Loading