Skip to content

Commit

Permalink
code refactor for streaming dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
yogesh266 committed Oct 4, 2023
1 parent dc4c9de commit 6bbb58a
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 92 deletions.
20 changes: 18 additions & 2 deletions ads/feature_store/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,20 @@ class DatasetIngestionMode(Enum):
SQL = "SQL"


class IngestionMode(Enum):
class IngestionType(Enum):
"""
The type of ingestion that can be performed.
Possible values:
* STREAMING: The data is ingested in real time.
* BATCH: The data is ingested in batches.
"""

STREAMING = "STREAMING"
BATCH = "BATCH"


class BatchIngestionMode(Enum):
"""
An enumeration that represents the supported Ingestion Mode in feature store.
Expand All @@ -67,18 +80,21 @@ class IngestionMode(Enum):
DEFAULT = "DEFAULT"
UPSERT = "UPSERT"

class StreamIngestionMode(Enum):

class StreamingIngestionMode(Enum):
"""
Enumeration for stream ingestion modes.
- `COMPLETE`: Represents complete stream ingestion where the entire dataset is replaced.
- `APPEND`: Represents appending new data to the existing dataset.
- `UPDATE`: Represents updating existing data in the dataset.
"""

COMPLETE = "COMPLETE"
APPEND = "APPEND"
UPDATE = "UPDATE"


class JoinType(Enum):
"""Enumeration of supported SQL join types.
Expand Down
5 changes: 4 additions & 1 deletion ads/feature_store/common/utils/transformation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def apply_transformation(
temporary_table_view, **transformation_kwargs_dict
)
)
elif transformation.transformation_mode in [TransformationMode.PANDAS.value, TransformationMode.SPARK.value]:
elif transformation.transformation_mode in [
TransformationMode.PANDAS.value,
TransformationMode.SPARK.value,
]:
transformed_data = transformation_function_caller(
dataframe, **transformation_kwargs_dict
)
Expand Down
7 changes: 4 additions & 3 deletions ads/feature_store/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
ExecutionEngine,
ExpectationType,
EntityType,
BatchIngestionMode,
)
from ads.feature_store.common.exceptions import NotMaterializedError
from ads.feature_store.common.utils.utility import (
get_metastore_id,
validate_delta_format_parameters,
convert_expectation_suite_to_expectation,
)
from ads.feature_store.dataset_job import DatasetJob, IngestionMode
from ads.feature_store.dataset_job import DatasetJob
from ads.feature_store.execution_strategy.engine.spark_engine import SparkEngine
from ads.feature_store.execution_strategy.execution_strategy_provider import (
OciExecutionStrategyProvider,
Expand Down Expand Up @@ -779,7 +780,7 @@ def delete(self):
None
"""
# Create DataSet Job and persist it
dataset_job = self._build_dataset_job(IngestionMode.DEFAULT)
dataset_job = self._build_dataset_job(BatchIngestionMode.DEFAULT)

# Create the Job
dataset_job.create()
Expand Down Expand Up @@ -874,7 +875,7 @@ def _update_from_oci_dataset_model(self, oci_dataset: OCIDataset) -> "Dataset":

def materialise(
self,
ingestion_mode: IngestionMode = IngestionMode.OVERWRITE,
ingestion_mode: BatchIngestionMode = BatchIngestionMode.OVERWRITE,
feature_option_details: FeatureOptionDetails = None,
):
"""Creates a dataset job.
Expand Down
16 changes: 12 additions & 4 deletions ads/feature_store/dataset_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import logging
from copy import deepcopy
from typing import Dict, List, Any
from typing import Dict, List, Any, Union

import pandas

from ads.common import utils
from ads.feature_store.common.enums import (
JobConfigurationType,
BatchIngestionMode,
StreamingIngestionMode,
)
from ads.feature_store.feature_option_details import FeatureOptionDetails
from ads.feature_store.common.enums import IngestionMode, JobConfigurationType
from ads.feature_store.service.oci_dataset_job import OCIDatasetJob
from ads.jobs.builders.base import Builder

Expand Down Expand Up @@ -225,10 +229,14 @@ def ingestion_mode(self) -> str:
return self.get_spec(self.CONST_INGESTION_MODE)

@ingestion_mode.setter
def ingestion_mode(self, ingestion_mode: IngestionMode) -> "DatasetJob":
def ingestion_mode(
self, ingestion_mode: Union[BatchIngestionMode, StreamingIngestionMode]
) -> "DatasetJob":
return self.with_ingestion_mode(ingestion_mode)

def with_ingestion_mode(self, ingestion_mode: IngestionMode) -> "DatasetJob":
def with_ingestion_mode(
self, ingestion_mode: Union[BatchIngestionMode, StreamingIngestionMode]
) -> "DatasetJob":
"""Sets the mode of the dataset ingestion mode.
Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging

from ads.common.decorator.runtime_dependency import OptionalDependency
from ads.feature_store.common.enums import IngestionMode
from ads.feature_store.common.enums import BatchIngestionMode
from ads.feature_store.execution_strategy.engine.spark_engine import SparkEngine

try:
Expand Down Expand Up @@ -57,21 +57,10 @@ def write_dataframe_to_delta_lake(
None.
"""
logger.info(f"target table name {target_table_name}")
# query = (
# dataflow_output.writeStream.outputMode("append")
# .format("delta")
# .option(
# "checkpointLocation",
# "/Users/yogeshkumawat/Desktop/Github-Oracle/accelerated-data-science/TestYogi/streaming",
# )
# .toTable(target_table_name)
# )
#
# query.awaitTermination()

if (
self.spark_engine.is_delta_table_exists(target_table_name)
and ingestion_mode.upper() == IngestionMode.UPSERT.value
and ingestion_mode.upper() == BatchIngestionMode.UPSERT.value
):
logger.info(f"Upsert ops for target table {target_table_name} begin")

Expand Down Expand Up @@ -365,16 +354,17 @@ def write_stream_dataframe_to_delta_lake(
checkpoint_dir,
feature_option_details,
):
if query_name is None:
query_name = "insert_stream_" + target_table.split(".")[1]

query = (
stream_dataframe
.writeStream.
outputMode(output_mode)
stream_dataframe.writeStream.outputMode(output_mode)
.format("delta")
.option(
"checkpointLocation",
checkpoint_dir,
)
.options(self.get_delta_write_config(feature_option_details))
.options(**self.get_delta_write_config(feature_option_details))
.queryName(query_name)
.toTable(target_table)
)
Expand Down
25 changes: 14 additions & 11 deletions ads/feature_store/execution_strategy/engine/spark_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from datetime import datetime

from ads.common.decorator.runtime_dependency import OptionalDependency
from ads.feature_store.common.utils.utility import get_schema_from_spark_dataframe, get_schema_from_spark_df

try:
from pyspark.sql import SparkSession
Expand Down Expand Up @@ -43,10 +42,10 @@ def __init__(self, metastore_id: str = None, spark_session: SparkSession = None)
)

def get_time_version_data(
self,
delta_table_name: str,
version_number: int = None,
timestamp: datetime = None,
self,
delta_table_name: str,
version_number: int = None,
timestamp: datetime = None,
):
split_db_name = delta_table_name.split(".")

Expand Down Expand Up @@ -104,10 +103,10 @@ def _read_delta_table(self, delta_table_path: str, read_options: Dict):
return df

def sql(
self,
query: str,
dataframe_type: DataFrameType = DataFrameType.SPARK,
is_online: bool = False,
self,
query: str,
dataframe_type: DataFrameType = DataFrameType.SPARK,
is_online: bool = False,
):
"""Execute SQL command on the offline or online feature store database
Expand Down Expand Up @@ -187,7 +186,9 @@ def get_tables_from_database(self, database):

return permanent_tables

def get_output_columns_from_table_or_dataframe(self, table_name: str = None, dataframe=None):
def get_output_columns_from_table_or_dataframe(
self, table_name: str = None, dataframe=None
):
"""Returns the column(features) along with type from the given table.
Args:
Expand All @@ -200,7 +201,9 @@ def get_output_columns_from_table_or_dataframe(self, table_name: str = None, dat
"""
if table_name is None and dataframe is None:
raise ValueError("Either 'table_name' or 'dataframe' must be provided to retrieve output columns.")
raise ValueError(
"Either 'table_name' or 'dataframe' must be provided to retrieve output columns."
)

if dataframe is not None:
feature_data_target = dataframe
Expand Down
Loading

0 comments on commit 6bbb58a

Please sign in to comment.