Skip to content

Commit

Permalink
kwargs support for the transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
yogesh266 committed Jul 25, 2023
1 parent 38e19d1 commit 033ce8d
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 4 deletions.
10 changes: 7 additions & 3 deletions ads/feature_store/common/utils/transformation_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8; -*-

import json
# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

Expand Down Expand Up @@ -28,6 +28,7 @@ def apply_transformation(
spark: SparkSession,
dataframe: Union[DataFrame, pd.DataFrame],
transformation: Transformation,
transformation_kwargs: str,
):
"""
Perform data transformation using either SQL or Pandas, depending on the specified transformation mode.
Expand All @@ -36,6 +37,7 @@ def apply_transformation(
spark: A SparkSession object.
transformation (Transformation): A transformation object containing details of transformation to be performed.
dataframe (DataFrame): The input dataframe to be transformed.
transformation_kwargs(str): The transformation parameters as json string.
Returns:
DataFrame: The resulting transformed data.
Expand All @@ -54,15 +56,17 @@ def apply_transformation(
)
transformed_data = None

transformation_kwargs_dict = json.loads(transformation_kwargs)

if transformation.transformation_mode == TransformationMode.SQL.value:
# Register the temporary table
temporary_table_view = "df_view"
dataframe.createOrReplaceTempView(temporary_table_view)

transformed_data = spark.sql(
transformation_function_caller(temporary_table_view)
transformation_function_caller(temporary_table_view, **transformation_kwargs_dict)
)
elif transformation.transformation_mode == TransformationMode.PANDAS.value:
transformed_data = transformation_function_caller(dataframe)
transformed_data = transformation_function_caller(dataframe, **transformation_kwargs_dict)

return transformed_data
6 changes: 6 additions & 0 deletions ads/feature_store/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def _build_feature_group(
name: str = None,
description: str = None,
compartment_id: str = None,
transformation_kwargs: Dict = None,
):
feature_group_resource = (
FeatureGroup()
Expand All @@ -304,6 +305,7 @@ def _build_feature_group(
.with_entity_id(self.id)
.with_transformation_id(transformation_id)
.with_partition_keys(partition_keys)
.with_transformation_kwargs(transformation_kwargs)
.with_primary_keys(primary_keys)
.with_input_feature_details(input_feature_details)
.with_statistics_config(statistics_config)
Expand All @@ -328,6 +330,7 @@ def create_feature_group(
name: str = None,
description: str = None,
compartment_id: str = None,
transformation_kwargs: Dict = None,
) -> "FeatureGroup":
"""Creates FeatureGroup resource.
Expand Down Expand Up @@ -355,6 +358,8 @@ def create_feature_group(
Description about the Resource.
compartment_id: str = None
compartment_id
transformation_kwargs: Dict
Arguments for the transformation.
Returns
Expand Down Expand Up @@ -391,6 +396,7 @@ def create_feature_group(
name,
description,
compartment_id,
transformation_kwargs,
)

return self.oci_feature_group.create()
Expand Down
12 changes: 11 additions & 1 deletion ads/feature_store/execution_strategy/spark/spark_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd

from ads.common.decorator.runtime_dependency import OptionalDependency
from ads.feature_store.common.utils.base64_encoder_decoder import Base64EncoderDecoder
from ads.feature_store.common.utils.utility import (
get_features,
show_ingestion_summary,
Expand Down Expand Up @@ -238,11 +239,20 @@ def _save_offline_dataframe(
# Apply the transformation
if feature_group.transformation_id:
logger.info("Dataframe is transformation enabled.")

# Get the Transformation Arguments if exists and pass to the transformation function.
transformation_kwargs = Base64EncoderDecoder.decode(
feature_group.transformation_kwargs
)

# Loads the transformation resource
transformation = Transformation.from_id(feature_group.transformation_id)

featured_data = TransformationUtils.apply_transformation(
self._spark_session, data_frame, transformation
self._spark_session,
data_frame,
transformation,
transformation_kwargs,
)
else:
logger.info("Transformation not defined.")
Expand Down
31 changes: 31 additions & 0 deletions ads/feature_store/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ads.feature_store.common.exceptions import (
NotMaterializedError,
)
from ads.feature_store.common.utils.base64_encoder_decoder import Base64EncoderDecoder
from ads.feature_store.common.utils.utility import (
get_metastore_id,
get_execution_engine_type,
Expand Down Expand Up @@ -137,6 +138,7 @@ class FeatureGroup(Builder):
CONST_LIFECYCLE_STATE = "lifecycleState"
CONST_LAST_JOB_ID = "jobId"
CONST_INFER_SCHEMA = "isInferSchema"
CONST_TRANSFORMATION_KWARGS = "transformationParameters"

attribute_map = {
CONST_ID: "id",
Expand All @@ -157,6 +159,7 @@ class FeatureGroup(Builder):
CONST_STATISTICS_CONFIG: "statistics_config",
CONST_INFER_SCHEMA: "is_infer_schema",
CONST_PARTITION_KEYS: "partition_keys",
CONST_TRANSFORMATION_KWARGS: "transformation_parameters",
}

def __init__(self, spec: Dict = None, **kwargs) -> None:
Expand Down Expand Up @@ -325,6 +328,34 @@ def with_primary_keys(self, primary_keys: List[str]) -> "FeatureGroup":
},
)

@property
def transformation_kwargs(self) -> str:
return self.get_spec(self.CONST_TRANSFORMATION_KWARGS)

@transformation_kwargs.setter
def transformation_kwargs(self, value: Dict):
self.with_transformation_kwargs(value)

def with_transformation_kwargs(
self, transformation_kwargs: Dict = {}
) -> "FeatureGroup":
"""Sets the primary keys of the feature group.
Parameters
----------
transformation_kwargs: Dict
Dictionary containing the transformation arguments.
Returns
-------
FeatureGroup
The FeatureGroup instance (self)
"""
return self.set_spec(
self.CONST_TRANSFORMATION_KWARGS,
Base64EncoderDecoder.encode(json.dumps(transformation_kwargs)),
)

@property
def partition_keys(self) -> List[str]:
return self.get_spec(self.CONST_PARTITION_KEYS)
Expand Down

0 comments on commit 033ce8d

Please sign in to comment.