Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dsn attribute to ADWSecretKeeper #986

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions ads/secrets/adb.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import ads
from ads.secrets import SecretKeeper, Secret
import json
import os
import tempfile
import zipfile

from tqdm.auto import tqdm

import ads
from ads.secrets import Secret, SecretKeeper

logger = ads.getLogger("ads.secrets")

from dataclasses import dataclass, field
Expand All @@ -25,7 +26,7 @@ class ADBSecret(Secret):

user_name: str
password: str
service_name: str
service_name: str = field(default=None)
wallet_location: str = field(
default=None, metadata={"serializable": False}
) # Not saved in vault
Expand All @@ -40,6 +41,7 @@ class ADBSecret(Secret):
wallet_secret_ids: list = field(
repr=False, default_factory=list
) # Not exposed through environment or `to_dict` function
dsn: str = field(default=None)

def __post_init__(self):
self.wallet_file_name = (
Expand Down Expand Up @@ -76,6 +78,22 @@ class ADBSecretKeeper(SecretKeeper):
>>> print(adw_keeper.secret_id) # Prints the secret_id of the stored credentials
>>> adw_keeper.export_vault_details("adw_employee_att.json", format="json") # Save the secret id and vault info to a json file


>>> # Saving credentials for TLS connection
>>> from ads.secrets.adw import ADBSecretKeeper
>>> vault_id = "ocid1.vault.oc1..<unique_ID>"
>>> kid = "ocid1.ke..<unique_ID>"

>>> import ads
>>> ads.set_auth("resource_principal") # If using resource principal for authentication
>>> connection_parameters={
... "user_name":"admin",
... "password":"<your password>",
... "dsn":"<dsn string>"
... }
>>> adw_keeper = ADBSecretKeeper(vault_id=vault_id, key_id=kid, **connection_parameters)
>>> adw_keeper.save("adw_employee", "My DB credentials", freeform_tags={"schema":"emp"})

>>> # Loading credentails
>>> import ads
>>> ads.set_auth("resource_principal") # If using resource principal for authentication
Expand Down Expand Up @@ -133,6 +151,7 @@ def __init__(
wallet_dir: str = None,
repository_path: str = None,
repository_key: str = None,
dsn: str = None,
**kwargs,
):
"""
Expand All @@ -152,6 +171,8 @@ def __init__(
Path to credentials repository. For more details refer `ads.database.connection`
repository_key: (str, optional). Default None.
Configuration key for loading the right configuration from repository. For more details refer `ads.database.connection`
dsn: (str, optional). Default None.
dsn string copied from the OCI console for TLS connection
kwargs:
vault_id: str. OCID of the vault where the secret is stored. Required for saving secret.
key_id: str. OCID of the key used for encrypting the secret. Required for saving secret.
Expand Down Expand Up @@ -180,6 +201,7 @@ def __init__(
password=password,
service_name=service_name,
wallet_location=wallet_location,
dsn=dsn,
)
self.wallet_dir = wallet_dir

Expand Down Expand Up @@ -252,7 +274,7 @@ def decode(self) -> "ads.secrets.adb.ADBSecretKeeper":
logger.debug(f"Setting wallet file to {self.data.wallet_location}")
data.wallet_location = self.data.wallet_location
elif data.wallet_secret_ids and len(data.wallet_secret_ids) > 0:
logger.debug(f"Secret ids corresponding to the wallet files found.")
logger.debug("Secret ids corresponding to the wallet files found.")
# If the secret ids for wallet files are available in secret, then we
# can generate the wallet file.

Expand Down
74 changes: 70 additions & 4 deletions tests/unitary/default_setup/secret/test_secretkeeper_adw.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python

# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from ads.secrets import ADBSecretKeeper
Expand Down Expand Up @@ -36,6 +36,24 @@ def key_encoding():
)


@pytest.fixture
def key_encoding_dsn():
user_name = "myuser"
password = "this-is-not-the-secret"
dsn = "my long dsn string....................."
secret_dict = {
"user_name": user_name,
"password": password,
"dsn": dsn,
}
encoded = b64encode(json.dumps(secret_dict).encode("utf-8")).decode("utf-8")
return (
(user_name, password, dsn),
secret_dict,
encoded,
)


def generate_wallet_data(wallet_zip_path, wallet_dir_path):
files = 4
file_content = {}
Expand Down Expand Up @@ -133,9 +151,9 @@ def test_encode(mock_client, mock_signer, key_encoding):

@patch("ads.common.auth.default_signer")
@patch("ads.common.oci_client.OCIClientFactory")
def test_adw_save(mock_client, mock_signer, key_encoding, tmpdir):
def test_adw_tls_save(mock_client, mock_signer, key_encoding_dsn, tmpdir):
adwsecretkeeper = ADBSecretKeeper(
*key_encoding[0],
**key_encoding_dsn[1],
vault_id="ocid.vault",
key_id="ocid.key",
compartment_id="dummy",
Expand Down Expand Up @@ -211,6 +229,7 @@ def test_adw_context(mock_client, mock_signer, key_encoding):
assert adwsecretkeeper == {
**key_encoding[1],
"wallet_location": "/this/is/mywallet.zip",
"dsn": None,
}
assert os.environ.get("user_name") == key_encoding[0][0]
assert os.environ.get("password") == key_encoding[0][1]
Expand All @@ -221,13 +240,44 @@ def test_adw_context(mock_client, mock_signer, key_encoding):
"password": None,
"service_name": None,
"wallet_location": None,
"dsn": None,
}
assert os.environ.get("user_name") is None
assert os.environ.get("password") is None
assert os.environ.get("service_name") is None
assert os.environ.get("wallet_location") is None


@patch("ads.common.auth.default_signer")
@patch("ads.common.oci_client.OCIClientFactory")
def test_adw_context_tls(mock_client, mock_signer, key_encoding_dsn):
with mock.patch(
"ads.vault.Vault.get_secret", return_value=key_encoding_dsn[2]
) as mocked_getsecret:
with ADBSecretKeeper.load_secret(
source="ocid.secret.id",
export_env=True,
) as adwsecretkeeper:
assert adwsecretkeeper == {
**key_encoding_dsn[1],
"service_name": None,
"wallet_location": None,
}
assert os.environ.get("user_name") == key_encoding_dsn[0][0]
assert os.environ.get("password") == key_encoding_dsn[0][1]
assert os.environ.get("dsn") == key_encoding_dsn[0][2]
assert adwsecretkeeper == {
"user_name": None,
"password": None,
"service_name": None,
"wallet_location": None,
"dsn": None,
}
assert os.environ.get("user_name") is None
assert os.environ.get("password") is None
assert os.environ.get("dsn") is None


@patch("ads.common.auth.default_signer")
@patch("ads.common.oci_client.OCIClientFactory")
def test_adw_keeper_no_wallet(mock_client, mock_signer, key_encoding):
Expand All @@ -240,13 +290,14 @@ def test_adw_keeper_no_wallet(mock_client, mock_signer, key_encoding):
assert adwsecretkeeper == {
**key_encoding[1],
"wallet_location": None,
"dsn": None,
}


@patch("ads.common.auth.default_signer")
@patch("ads.common.oci_client.OCIClientFactory")
def test_adw_keeper_with_repository(mock_client, mock_signer, key_encoding, tmpdir):
expected = {**key_encoding[1], "wallet_location": key_encoding[3]}
expected = {**key_encoding[1], "wallet_location": key_encoding[3], "dsn": None}
os.makedirs(os.path.join(tmpdir, "testdb"))
with open(os.path.join(tmpdir, "testdb", "config.json"), "w") as conffile:
json.dump(expected, conffile)
Expand All @@ -270,6 +321,7 @@ def test_adw_context_namespace(mock_client, mock_signer, key_encoding):
assert adwsecretkeeper == {
**key_encoding[1],
"wallet_location": "/this/is/mywallet.zip",
"dsn": None,
}
assert os.environ.get("myapp.user_name") == key_encoding[0][0]
assert os.environ.get("myapp.password") == key_encoding[0][1]
Expand All @@ -280,6 +332,7 @@ def test_adw_context_namespace(mock_client, mock_signer, key_encoding):
"password": None,
"service_name": None,
"wallet_location": None,
"dsn": None,
}
assert os.environ.get("myapp.user_name") is None
assert os.environ.get("myapp.password") is None
Expand All @@ -300,6 +353,7 @@ def test_adw_context_noexport(mock_client, mock_signer, key_encoding):
assert adwsecretkeeper == {
**key_encoding[1],
"wallet_location": "/this/is/mywallet.zip",
"dsn": None,
}

assert os.environ.get("user_name") is None
Expand All @@ -312,6 +366,7 @@ def test_adw_context_noexport(mock_client, mock_signer, key_encoding):
"password": None,
"service_name": None,
"wallet_location": None,
"dsn": None,
}


Expand Down Expand Up @@ -413,6 +468,7 @@ def mock_get_secret_id(
"password": key_encoding_with_wallet.credentials.password,
"service_name": key_encoding_with_wallet.credentials.service_name,
"wallet_location": f"{os.path.join(wallet_dir,'wallet.zip')}",
"dsn": None,
}

# with open(key_encoding_with_wallet[3], "rb") as orgfile:
Expand Down Expand Up @@ -449,6 +505,7 @@ def mock_get_secret_id(
"password": key_encoding_with_wallet.credentials.password,
"service_name": key_encoding_with_wallet.credentials.service_name,
"wallet_location": f"{os.path.join(wallet_dir,'wallet.zip')}",
"dsn": None,
}
assert (
os.environ.get("user_name")
Expand All @@ -472,6 +529,7 @@ def mock_get_secret_id(
"password": None,
"service_name": None,
"wallet_location": None,
"dsn": None,
}
assert os.environ.get("user_name") is None
assert os.environ.get("password") is None
Expand Down Expand Up @@ -508,6 +566,7 @@ def mock_get_secret_id(
"password": key_encoding_with_wallet.credentials.password,
"service_name": key_encoding_with_wallet.credentials.service_name,
"wallet_location": f"{os.path.join(wallet_dir,'wallet.zip')}",
"dsn": None,
}
assert (
os.environ.get("myapp.user_name")
Expand All @@ -531,6 +590,7 @@ def mock_get_secret_id(
"password": None,
"service_name": None,
"wallet_location": None,
"dsn": None,
}
assert os.environ.get("myapp.user_name") is None
assert os.environ.get("myapp.password") is None
Expand Down Expand Up @@ -565,6 +625,7 @@ def mock_get_secret_id(
"password": key_encoding_with_wallet.credentials.password,
"service_name": key_encoding_with_wallet.credentials.service_name,
"wallet_location": f"{os.path.join(wallet_dir,'wallet.zip')}",
"dsn": None,
}
assert os.environ.get("user_name") is None
assert os.environ.get("password") is None
Expand All @@ -576,6 +637,7 @@ def mock_get_secret_id(
"password": None,
"service_name": None,
"wallet_location": None,
"dsn": None,
}


Expand Down Expand Up @@ -730,6 +792,7 @@ def mock_get_secret_id(
"password": key_encoding_with_wallet.credentials.password,
"service_name": key_encoding_with_wallet.credentials.service_name,
"wallet_location": f"{os.path.join(wallet_dir,'wallet.zip')}",
"dsn": None,
}
assert (
os.environ.get("user_name")
Expand All @@ -753,6 +816,7 @@ def mock_get_secret_id(
"password": None,
"service_name": None,
"wallet_location": None,
"dsn": None,
}
assert os.environ.get("user_name") is None
assert os.environ.get("password") is None
Expand Down Expand Up @@ -786,6 +850,7 @@ def mock_get_secret_id(
"password": key_encoding_with_wallet.credentials.password,
"service_name": key_encoding_with_wallet.credentials.service_name,
"wallet_location": f"{os.path.join(wallet_dir,'wallet.zip')}",
"dsn": None,
}
assert (
os.environ.get("user_name")
Expand All @@ -809,6 +874,7 @@ def mock_get_secret_id(
"password": None,
"service_name": None,
"wallet_location": None,
"dsn": None,
}
assert os.environ.get("user_name") is None
assert os.environ.get("password") is None
Expand Down