Skip to content

Commit

Permalink
fix #17
Browse files Browse the repository at this point in the history
  • Loading branch information
Rafael Marques committed Jul 22, 2024
1 parent 4932bd0 commit 75dfdfb
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 31 deletions.
26 changes: 13 additions & 13 deletions pydantic_settings_aws/aws.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, AnyStr, Dict, Optional, Union
from typing import Any, AnyStr, Dict, Optional, Type, Union

import boto3 # type: ignore[import-untyped]
from pydantic import ValidationError
Expand All @@ -10,7 +10,7 @@


def get_ssm_content(
settings: type[BaseSettings],
settings: Type[BaseSettings],
field_name: str,
ssm_info: Optional[Union[Dict[Any, AnyStr], AnyStr]] = None
) -> Optional[str]:
Expand All @@ -36,14 +36,14 @@ def get_ssm_content(
client = _get_ssm_boto3_client(settings)

logger.debug(f"Getting parameter {ssm_name} value with boto3 client")
ssm_response: dict[str, Any] = client.get_parameter( # type: ignore
ssm_response: Dict[str, Any] = client.get_parameter( # type: ignore
Name=ssm_name, WithDecryption=True
)

return ssm_response.get("Parameter", {}).get("Value", None)


def get_secrets_content(settings: type[BaseSettings]) -> dict[str, Any]:
def get_secrets_content(settings: Type[BaseSettings]) -> Dict[str, Any]:
client = _get_secrets_boto3_client(settings)
secrets_args: AwsSecretsArgs = _get_secrets_args(settings)

Expand All @@ -69,7 +69,7 @@ def get_secrets_content(settings: type[BaseSettings]) -> dict[str, Any]:
raise json_err


def _get_secrets_boto3_client( settings: type[BaseSettings]): # type: ignore[no-untyped-def]
def _get_secrets_boto3_client( settings: Type[BaseSettings]): # type: ignore[no-untyped-def]
logger.debug("Getting secrets manager content.")
client = settings.model_config.get("secrets_client", None)

Expand All @@ -80,7 +80,7 @@ def _get_secrets_boto3_client( settings: type[BaseSettings]): # type: ignore[no-
return _create_secrets_client(settings)


def _create_secrets_client(settings: type[BaseSettings]): # type: ignore[no-untyped-def]
def _create_secrets_client(settings: Type[BaseSettings]): # type: ignore[no-untyped-def]
"""Create a boto3 client for secrets manager.
Neither `boto3` nor `pydantic` exceptions will be handled.
Expand All @@ -92,7 +92,7 @@ def _create_secrets_client(settings: type[BaseSettings]): # type: ignore[no-unt
SecretsManagerClient: A secrets manager boto3 client.
"""
logger.debug("Extracting settings prefixed with aws_")
args: dict[str, Any] = {
args: Dict[str, Any] = {
k: v for k, v in settings.model_config.items() if k.startswith("aws_")
}

Expand All @@ -105,11 +105,11 @@ def _create_secrets_client(settings: type[BaseSettings]): # type: ignore[no-unt
return session.client("secretsmanager")


def _get_secrets_args(settings: type[BaseSettings]) -> AwsSecretsArgs:
def _get_secrets_args(settings: Type[BaseSettings]) -> AwsSecretsArgs:
logger.debug(
"Extracting settings prefixed with secrets_, except _client and _dir"
)
args: dict[str, Any] = {
args: Dict[str, Any] = {
k: v
for k, v in settings.model_config.items()
if k.startswith("secrets_")
Expand Down Expand Up @@ -139,7 +139,7 @@ def _get_secrets_content(
logger.debug(
"SecretString was not present. Getting content from SecretBinary."
)
secret_binary: bytes | None = secret.get("SecretBinary")
secret_binary: Optional[bytes] = secret.get("SecretBinary")

if secret_binary:
try:
Expand All @@ -152,7 +152,7 @@ def _get_secrets_content(
return secrets_content


def _get_ssm_boto3_client(settings: type[BaseSettings]): # type: ignore[no-untyped-def]
def _get_ssm_boto3_client(settings: Type[BaseSettings]): # type: ignore[no-untyped-def]
logger.debug("Getting secrets manager content.")
client = settings.model_config.get("ssm_client", None)

Expand All @@ -165,7 +165,7 @@ def _get_ssm_boto3_client(settings: type[BaseSettings]): # type: ignore[no-untyp
return _create_ssm_client(settings)


def _create_ssm_client(settings: type[BaseSettings]): # type: ignore[no-untyped-def]
def _create_ssm_client(settings: Type[BaseSettings]): # type: ignore[no-untyped-def]
"""Create a boto3 client for parameter store.
Neither `boto3` nor `pydantic` exceptions will be handled.
Expand All @@ -177,7 +177,7 @@ def _create_ssm_client(settings: type[BaseSettings]): # type: ignore[no-untyped-
SSMClient: A parameter ssm boto3 client.
"""
logger.debug("Extracting settings prefixed with aws_")
args: dict[str, Any] = {
args: Dict[str, Any] = {
k: v for k, v in settings.model_config.items() if k.startswith("aws_")
}

Expand Down
10 changes: 6 additions & 4 deletions pydantic_settings_aws/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple, Type

from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
Expand All @@ -10,12 +12,12 @@ class ParameterStoreBaseSettings(BaseSettings):
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
ParameterStoreSettingsSource(settings_cls),
Expand All @@ -29,12 +31,12 @@ class SecretsManagerBaseSettings(BaseSettings):
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
SecretsManagerSettingsSource(settings_cls),
Expand Down
18 changes: 9 additions & 9 deletions pydantic_settings_aws/sources.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict, Tuple, Type

from pydantic.fields import FieldInfo
from pydantic_settings import (
Expand All @@ -12,12 +12,12 @@
class ParameterStoreSettingsSource(PydanticBaseSettingsSource):
"""Source class for loading settings from AWS Parameter Store.
"""
def __init__(self, settings_cls: type[BaseSettings]):
def __init__(self, settings_cls: Type[BaseSettings]):
super().__init__(settings_cls)

def get_field_value(
self, field: FieldInfo, field_name: str
) -> tuple[Any, str, bool]:
) -> Tuple[Any, str, bool]:
ssm_info = utils.get_ssm_name_from_annotated_field(field.metadata)
field_value = aws.get_ssm_content(self.settings_cls, field_name, ssm_info)

Expand All @@ -32,8 +32,8 @@ def prepare_field_value(
) -> Any:
return value

def __call__(self) -> dict[str, Any]:
d: dict[str, Any] = {}
def __call__(self) -> Dict[str, Any]:
d: Dict[str, Any] = {}

for field_name, field in self.settings_cls.model_fields.items():
field_value, field_key, value_is_complex = self.get_field_value(
Expand All @@ -49,13 +49,13 @@ def __call__(self) -> dict[str, Any]:


class SecretsManagerSettingsSource(PydanticBaseSettingsSource):
def __init__(self, settings_cls: type[BaseSettings]):
def __init__(self, settings_cls: Type[BaseSettings]):
super().__init__(settings_cls)
self._json_content = aws.get_secrets_content(settings_cls)

def get_field_value(
self, field: FieldInfo, field_name: str
) -> tuple[Any, str, bool]:
) -> Tuple[Any, str, bool]:
field_value = self._json_content.get(field_name)
return field_value, field_name, False

Expand All @@ -68,8 +68,8 @@ def prepare_field_value(
) -> Any:
return value

def __call__(self) -> dict[str, Any]:
d: dict[str, Any] = {}
def __call__(self) -> Dict[str, Any]:
d: Dict[str, Any] = {}

for field_name, field in self.settings_cls.model_fields.items():
field_value, field_key, value_is_complex = self.get_field_value(
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ classifiers = [
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
Expand All @@ -34,7 +35,7 @@ classifiers = [
'Topic :: Software Development :: Libraries :: Python Modules',
'Topic :: Internet',
]
requires-python = '>=3.9'
requires-python = '>=3.8'
dependencies = [
'pydantic>=2.0.1',
'pydantic-settings>=2.0.2',
Expand Down Expand Up @@ -85,7 +86,7 @@ quote-style = 'double'
indent-style = 'space'

[tool.mypy]
python_version = '3.10'
python_version = '3.8'
show_error_codes = true
follow_imports = 'silent'
strict_optional = true
Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
black>=24.4.2
ruff>=0.5.1
mypy>=1.10.1
pre-commit>=3.7.1
pre-commit>=3.5.0
pytest>=8.2.2
pytest-cov>=5.0.0
5 changes: 3 additions & 2 deletions tests/settings_mocks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
from typing import Annotated, Optional
from typing import List, Optional

from pydantic import BaseModel
from pydantic_settings import SettingsConfigDict
from typing_extensions import Annotated

from pydantic_settings_aws import (
ParameterStoreBaseSettings,
Expand Down Expand Up @@ -45,7 +46,7 @@ class MySecretsWithClientConfig(SecretsManagerBaseSettings):


class NestedContent(BaseModel):
roles: list[str]
roles: List[str]


class SecretsWithNestedContent(SecretsManagerBaseSettings):
Expand Down

0 comments on commit 75dfdfb

Please sign in to comment.