Skip to content

Commit

Permalink
Merge branch 'main' into feature/recommender_v1
Browse files Browse the repository at this point in the history
  • Loading branch information
ahosler authored Jul 2, 2024
2 parents caf475b + debfe43 commit 817fe1f
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 15 deletions.
1 change: 1 addition & 0 deletions ads/opctl/operator/common/operator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class InputData(DataClassSerializable):
limit: int = None
sql: str = None
table_name: str = None
vault_secret_id: str = None


@dataclass(repr=True)
Expand Down
6 changes: 6 additions & 0 deletions ads/opctl/operator/lowcode/anomaly/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

validation_data:
required: false
Expand Down Expand Up @@ -130,6 +133,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

datetime_column:
type: dict
Expand Down
42 changes: 32 additions & 10 deletions ads/opctl/operator/lowcode/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import argparse
import logging
import os
import shutil
import sys
import tempfile
import time
from string import Template
from typing import Any, Dict, List, Tuple
Expand All @@ -28,6 +30,7 @@
)
from ads.opctl.operator.common.operator_config import OutputDirectory
from ads.common.object_storage_details import ObjectStorageDetails
from ads.secrets import ADBSecretKeeper


def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
Expand All @@ -53,10 +56,12 @@ def load_data(data_spec, storage_options=None, **kwargs):
sql = data_spec.sql
table_name = data_spec.table_name
limit = data_spec.limit

vault_secret_id = data_spec.vault_secret_id
storage_options = storage_options or (
default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
)
if vault_secret_id is not None and connect_args is None:
connect_args = dict()

if filename is not None:
if not format:
Expand All @@ -76,15 +81,32 @@ def load_data(data_spec, storage_options=None, **kwargs):
f"The format {format} is not currently supported for reading data. Please reformat the data source: {filename} ."
)
elif connect_args is not None:
con = oracledb.connect(**connect_args)
if table_name is not None:
data = pd.read_sql_table(table_name, con)
elif sql is not None:
data = pd.read_sql(sql, con)
else:
raise InvalidParameterError(
f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
)
with tempfile.TemporaryDirectory() as temp_dir:
if vault_secret_id is not None:
try:
with ADBSecretKeeper.load_secret(vault_secret_id, wallet_dir=temp_dir) as adwsecret:
if 'wallet_location' in adwsecret and 'wallet_location' not in connect_args:
shutil.unpack_archive(adwsecret["wallet_location"], temp_dir)
connect_args['wallet_location'] = temp_dir
if 'user_name' in adwsecret and 'user' not in connect_args:
connect_args['user'] = adwsecret['user_name']
if 'password' in adwsecret and 'password' not in connect_args:
connect_args['password'] = adwsecret['password']
if 'service_name' in adwsecret and 'service_name' not in connect_args:
connect_args['service_name'] = adwsecret['service_name']

except Exception as e:
raise Exception(f"Could not retrieve database credentials from vault {vault_secret_id}: {e}")

con = oracledb.connect(**connect_args)
if table_name is not None:
data = pd.read_sql(f"SELECT * FROM {table_name}", con)
elif sql is not None:
data = pd.read_sql(sql, con)
else:
raise InvalidParameterError(
f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
)
else:
raise InvalidParameterError(
f"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
Expand Down
9 changes: 9 additions & 0 deletions ads/opctl/operator/lowcode/forecast/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

additional_data:
required: false
Expand Down Expand Up @@ -130,6 +133,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

test_data:
required: false
Expand Down Expand Up @@ -181,6 +187,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string
type: dict

output_directory:
Expand Down
15 changes: 11 additions & 4 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,38 @@
Release Notes
=============

2.11.14
-------
Release date: June 27, 2024

* Added compatibility with Python ``3.11``.
* Fixed the bug in model deployment tail logging.

2.11.13
------
-------
Release date: June 18, 2024

* Update langchain dependencies.
* Support adding and removing artifact in a multi-model setting for model created by reference.


2.11.12
------
-------
Release date: June 13, 2024

* Fixed bugs and introduced enhancements following our recent release.


2.11.11
------
-------
Release date: June 11, 2024

* Fixed the bug that led to timeout when loading config files during jupyterlab load.
* Fixed bugs and introduced enhancements following our recent release.


2.11.10
------
-------
Release date: June 5, 2024

* Support for Bring Your Own Model (BYOM) via AI Quick Actions.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ build-backend = "flit_core.buildapi"

# Required
name = "oracle_ads" # the install (PyPI) name; name for local build in [tool.flit.module] section below
version = "2.11.13"
version = "2.11.14"

# Optional
description = "Oracle Accelerated Data Science SDK"
Expand Down
103 changes: 103 additions & 0 deletions tests/operators/common/test_load_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python
from typing import Union

# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import pytest
from ads.opctl.operator.lowcode.common.utils import (
load_data,
)
from ads.opctl.operator.common.operator_config import InputData
from unittest.mock import patch, Mock, MagicMock
import unittest
import pandas as pd

mock_secret = {
'user_name': 'mock_user',
'password': 'mock_password',
'service_name': 'mock_service_name'
}

mock_connect_args = {
'user': 'mock_user',
'password': 'mock_password',
'service_name': 'mock_service_name',
'dsn': 'mock_dsn'
}

# Mock data for testing
mock_data = pd.DataFrame({
'id': [1, 2, 3],
'name': ['Alice', 'Bob', 'Charlie']
})

mock_db_connection = MagicMock()

load_secret_err_msg = "Vault exception message"
db_connect_err_msg = "Mocked DB connection error"


def mock_oracledb_connect_failure(*args, **kwargs):
raise Exception(db_connect_err_msg)


def mock_oracledb_connect(**kwargs):
assert kwargs == mock_connect_args, f"Expected connect_args {mock_connect_args}, but got {kwargs}"
return mock_db_connection


class MockADBSecretKeeper:
@staticmethod
def __enter__(*args, **kwargs):
return mock_secret

@staticmethod
def __exit__(*args, **kwargs):
pass

@staticmethod
def load_secret(vault_secret_id, wallet_dir):
return MockADBSecretKeeper()

@staticmethod
def load_secret_fail(*args, **kwargs):
raise Exception(load_secret_err_msg)


class TestDataLoad(unittest.TestCase):
def setUp(self):
self.data_spec = Mock(spec=InputData)
self.data_spec.connect_args = {
'dsn': 'mock_dsn'
}
self.data_spec.vault_secret_id = 'mock_secret_id'
self.data_spec.table_name = 'mock_table_name'
self.data_spec.url = None
self.data_spec.format = None
self.data_spec.columns = None
self.data_spec.limit = None

def testLoadSecretAndDBConnection(self):
with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret):
with patch('oracledb.connect', side_effect=mock_oracledb_connect):
with patch('pandas.read_sql', return_value=mock_data) as mock_read_sql:
data = load_data(self.data_spec)
mock_read_sql.assert_called_once_with(f"SELECT * FROM {self.data_spec.table_name}",
mock_db_connection)
pd.testing.assert_frame_equal(data, mock_data)

def testLoadVaultFailure(self):
with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret_fail):
with pytest.raises(Exception) as e:
load_data(self.data_spec)

expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {load_secret_err_msg}"
assert str(e.value) == expected_msg, f"Expected exception message '{expected_msg}', but got '{str(e)}'"

def testDBConnectionFailure(self):
with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret):
with patch('oracledb.connect', side_effect=mock_oracledb_connect_failure):
with pytest.raises(Exception) as e:
load_data(self.data_spec)

assert str(e.value) == db_connect_err_msg , f"Expected exception message '{db_connect_err_msg }', but got '{str(e)}'"

0 comments on commit 817fe1f

Please sign in to comment.