Skip to content

Commit

Permalink
feat: ibm-db2 integration (#262)
Browse files Browse the repository at this point in the history
* feat: ibm-db2 integration

* fix: custom regex query for mssql,mysql and db2
  • Loading branch information
Ryuk-me authored Oct 29, 2024
1 parent 2f5c115 commit 758ab53
Show file tree
Hide file tree
Showing 17 changed files with 928 additions and 285 deletions.
4 changes: 4 additions & 0 deletions dcs_core/core/common/models/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class DataSourceType(str, Enum):
DATABRICKS = "databricks"
SPARK_DF = "spark_df"
ORACLE = "oracle"
DB2 = "db2"


class DataSourceLanguageSupport(str, Enum):
Expand Down Expand Up @@ -79,6 +80,9 @@ class DataSourceConnectionConfiguration:

service_name: Optional[str] = None # Oracle specific configuration

security: Optional[str] = None # IBM DB2 specific configuration
protocol: Optional[str] = None # IBM DB2 specific configuration


@dataclass
class DataSourceConfiguration:
Expand Down
10 changes: 7 additions & 3 deletions dcs_core/core/configuration/configuration_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def _data_source_connection_config_parser(
warehouse=config["connection"].get("warehouse"),
role=config["connection"].get("role"),
service_name=config["connection"].get("service_name"),
security=config["connection"].get("security"),
protocol=config["connection"].get("protocol"),
)
return connection_config

Expand Down Expand Up @@ -138,9 +140,11 @@ def parse(self, config: Dict) -> Dict[str, ValidationConfigByDataset]:
validation_config = ValidationConfig(
name=validation_name,
on=value.get("on"),
threshold=self._parse_threshold_str(value.get("threshold"))
if value.get("threshold")
else None,
threshold=(
self._parse_threshold_str(value.get("threshold"))
if value.get("threshold")
else None
),
where=value.get("where"),
query=value.get("query"),
regex=value.get("regex"),
Expand Down
1 change: 1 addition & 0 deletions dcs_core/core/datasource/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class DataSourceManager:
"snowflake": "SnowFlakeDataSource",
"mssql": "MssqlDataSource",
"oracle": "OracleDataSource",
"db2": "DB2DataSource",
}

def __init__(self, config: Configuration):
Expand Down
142 changes: 72 additions & 70 deletions dcs_core/core/datasource/sql_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datetime import datetime
from typing import Dict, List, Tuple, Union

from loguru import logger
from sqlalchemy import inspect, text
from sqlalchemy.engine import Connection

Expand All @@ -32,6 +33,72 @@ def __init__(self, data_source_name: str, data_connection: Dict):
self.connection: Union[Connection, None] = None
self.database: str = data_connection.get("database")
self.use_sa_text_query = True
self.regex_patterns = {
"uuid": r"^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$",
"usa_phone": r"^(\+1[-.\s]?)?(\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}$",
"email": r"^(?!.*\.\.)(?!.*@.*@)[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$",
"usa_zip_code": r"^[0-9]{5}(?:-[0-9]{4})?$",
"ssn": r"^(?!000|666|9\d{2})\d{3}-(?!00)\d{2}-(?!0000)\d{4}$",
"sedol": r"[B-Db-dF-Hf-hJ-Nj-nP-Tp-tV-Xv-xYyZz\d]{6}\d",
"lei": r"^[A-Z0-9]{18}[0-9]{2}$",
"cusip": r"^[0-9A-Z]{9}$",
"figi": r"^BBG[A-Z0-9]{9}$",
"isin": r"^[A-Z]{2}[A-Z0-9]{9}[0-9]$",
"perm_id": r"^\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{3}$",
}

self.valid_state_codes = [
"AL",
"AK",
"AZ",
"AR",
"CA",
"CO",
"CT",
"DE",
"FL",
"GA",
"HI",
"ID",
"IL",
"IN",
"IA",
"KS",
"KY",
"LA",
"ME",
"MD",
"MA",
"MI",
"MN",
"MS",
"MO",
"MT",
"NE",
"NV",
"NH",
"NJ",
"NM",
"NY",
"NC",
"ND",
"OH",
"OK",
"OR",
"PA",
"RI",
"SC",
"SD",
"TN",
"TX",
"UT",
"VT",
"VA",
"WA",
"WV",
"WI",
"WY",
]

def is_connected(self) -> bool:
"""
Expand Down Expand Up @@ -401,27 +468,13 @@ def query_string_pattern_validity(
filters = f"WHERE {filters}" if filters else ""
qualified_table_name = self.qualified_table_name(table)

regex_patterns = {
"uuid": r"^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$",
"usa_phone": r"^(\+1[-.\s]?)?(\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}$",
"email": r"^(?!.*\.\.)(?!.*@.*@)[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$",
"usa_zip_code": r"^[0-9]{5}(?:-[0-9]{4})?$",
"ssn": r"^(?!000|666|9\d{2})\d{3}-(?!00)\d{2}-(?!0000)\d{4}$",
"sedol": r"[B-Db-dF-Hf-hJ-Nj-nP-Tp-tV-Xv-xYyZz\d]{6}\d",
"lei": r"^[A-Z0-9]{18}[0-9]{2}$",
"cusip": r"^[0-9A-Z]{9}$",
"figi": r"^BBG[A-Z0-9]{9}$",
"isin": r"^[A-Z]{2}[A-Z0-9]{9}[0-9]$",
"perm_id": r"^\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{3}$",
}

if not regex_pattern and not predefined_regex_pattern:
raise ValueError(
"Either regex_pattern or predefined_regex_pattern should be provided"
)

if predefined_regex_pattern:
regex_query = f"case when {field} ~ '{regex_patterns[predefined_regex_pattern]}' then 1 else 0 end"
regex_query = f"case when {field} ~ '{self.regex_patterns[predefined_regex_pattern]}' then 1 else 0 end"
else:
regex_query = f"case when {field} ~ '{regex_pattern}' then 1 else 0 end"

Expand Down Expand Up @@ -507,61 +560,10 @@ def query_get_usa_state_code_validity(
:param filters: filter condition
:return: count of valid state codes, count of total row count
"""
# List of valid state codes
valid_state_codes = [
"AL",
"AK",
"AZ",
"AR",
"CA",
"CO",
"CT",
"DE",
"FL",
"GA",
"HI",
"ID",
"IL",
"IN",
"IA",
"KS",
"KY",
"LA",
"ME",
"MD",
"MA",
"MI",
"MN",
"MS",
"MO",
"MT",
"NE",
"NV",
"NH",
"NJ",
"NM",
"NY",
"NC",
"ND",
"OH",
"OK",
"OR",
"PA",
"RI",
"SC",
"SD",
"TN",
"TX",
"UT",
"VT",
"VA",
"WA",
"WV",
"WI",
"WY",
]

valid_state_codes_str = ", ".join(f"'{code}'" for code in valid_state_codes)
valid_state_codes_str = ", ".join(
f"'{code}'" for code in self.valid_state_codes
)

filters = f"WHERE {filters}" if filters else ""

Expand Down Expand Up @@ -1002,5 +1004,5 @@ def query_timestamp_date_not_in_future_metric(
raise ValueError(f"Unknown operation: {operation}")

except Exception as e:
print(f"Error occurred: {e}")
logger.error(f"Error occurred: {e}")
return 0, 0
Loading

0 comments on commit 758ab53

Please sign in to comment.