Skip to content

Commit

Permalink
code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
yogesh266 committed Jul 8, 2023
1 parent 85ae918 commit 9e5ab20
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 369 deletions.
2 changes: 2 additions & 0 deletions ads/feature_store/common/spark_session_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, metastore_id: str = None):
)

if not developer_enabled() and metastore_id:
print("Not Developer Enabled")
# Get the authentication credentials for the OCI data catalog service
auth = copy.copy(ads.auth.default_signer())

Expand All @@ -79,6 +80,7 @@ def __init__(self, metastore_id: str = None):
.config("spark.driver.memory", "16G")

if developer_enabled():
print("Developer Enabled")
# Configure spark session with delta jars only in developer mode. In other cases,
# jars should be part of the conda pack
self.spark_session = configure_spark_with_delta_pip(
Expand Down
96 changes: 0 additions & 96 deletions ads/feature_store/common/utils/transformation_query_validator.py

This file was deleted.

27 changes: 14 additions & 13 deletions ads/feature_store/common/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@


def get_execution_engine_type(
data_frame: Union[DataFrame, pd.DataFrame]
data_frame: Union[DataFrame, pd.DataFrame]
) -> ExecutionEngine:
"""
Determines the execution engine type for a given DataFrame.
Expand Down Expand Up @@ -84,7 +84,7 @@ def get_metastore_id(feature_store_id: str):


def validate_delta_format_parameters(
timestamp: datetime = None, version_number: int = None, is_restore: bool = False
timestamp: datetime = None, version_number: int = None, is_restore: bool = False
):
"""
Validate the user input provided as part of preview, restore APIs for ingested data, Ingested data is
Expand Down Expand Up @@ -118,9 +118,9 @@ def validate_delta_format_parameters(


def get_features(
output_columns: List[dict],
parent_id: str,
entity_type: EntityType = EntityType.FEATURE_GROUP,
output_columns: List[dict],
parent_id: str,
entity_type: EntityType = EntityType.FEATURE_GROUP,
) -> List[Feature]:
"""
Returns a list of features, given a list of output_columns and a feature_group_id.
Expand Down Expand Up @@ -154,8 +154,8 @@ def get_features(
return features


def get_schema_from_pandas_df(df: pd.DataFrame):
spark = SparkSessionSingleton().get_spark_session()
def get_schema_from_pandas_df(df: pd.DataFrame, feature_store_id: str):
spark = SparkSessionSingleton(get_metastore_id(feature_store_id)).get_spark_session()
converted_df = spark.createDataFrame(df)
return get_schema_from_spark_df(converted_df)

Expand All @@ -174,27 +174,28 @@ def get_schema_from_spark_df(df: DataFrame):
return schema_details


def get_schema_from_df(data_frame: Union[DataFrame, pd.DataFrame]) -> List[dict]:
def get_schema_from_df(data_frame: Union[DataFrame, pd.DataFrame], feature_store_id: str) -> List[dict]:
"""
Given a DataFrame, returns a list of dictionaries that describe its schema.
If the DataFrame is a pandas DataFrame, it uses pandas methods to get the schema.
If it's a PySpark DataFrame, it uses PySpark methods to get the schema.
"""
if isinstance(data_frame, pd.DataFrame):
return get_schema_from_pandas_df(data_frame)
return get_schema_from_pandas_df(data_frame, feature_store_id)
else:
return get_schema_from_spark_df(data_frame)


def get_input_features_from_df(
data_frame: Union[DataFrame, pd.DataFrame]
data_frame: Union[DataFrame, pd.DataFrame],
feature_store_id: str
) -> List[FeatureDetail]:
"""
Given a DataFrame, returns a list of FeatureDetail objects that represent its input features.
Each FeatureDetail object contains information about a single input feature, such as its name, data type, and
whether it's categorical or numerical.
"""
schema_details = get_schema_from_df(data_frame)
schema_details = get_schema_from_df(data_frame, feature_store_id)
feature_details = []

for schema_detail in schema_details:
Expand All @@ -204,7 +205,7 @@ def get_input_features_from_df(


def convert_expectation_suite_to_expectation(
expectation_suite: ExpectationSuite, expectation_type: ExpectationType
expectation_suite: ExpectationSuite, expectation_type: ExpectationType
):
"""
Convert an ExpectationSuite object to an Expectation object with detailed rule information.
Expand Down Expand Up @@ -282,7 +283,7 @@ def convert_pandas_datatype_with_schema(
else:
logger.warning("column" + column + "doesn't exist in the input feature details")
columns_to_remove.append(column)
return input_df.drop(columns = columns_to_remove)
return input_df.drop(columns=columns_to_remove)


def convert_spark_dataframe_with_schema(
Expand Down
2 changes: 1 addition & 1 deletion ads/feature_store/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def create_feature_group(
raw_feature_details = (
input_feature_details
if input_feature_details
else get_input_features_from_df(schema_details_dataframe)
else get_input_features_from_df(schema_details_dataframe, self.feature_store_id)
)

self.oci_feature_group = self._build_feature_group(
Expand Down
6 changes: 5 additions & 1 deletion ads/feature_store/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,11 @@ def with_input_feature_details(
def with_schema_details_from_dataframe(
self, data_frame: Union[DataFrame, pd.DataFrame]
) -> "FeatureGroup":
schema_details = get_schema_from_df(data_frame)

if not self.feature_store_id:
raise ValueError("FeatureStore id must be set before calling `with_schema_details_from_dataframe`")

schema_details = get_schema_from_df(data_frame, self.feature_store_id)
feature_details = []

for schema_detail in schema_details:
Expand Down
15 changes: 4 additions & 11 deletions ads/feature_store/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,11 @@ def list_entities_df(

def _build_transformation(
self,
transformation_mode: TransformationMode,
source_code_func=None,
source_code_func,
transformation_mode,
display_name: str = None,
description: str = None,
compartment_id: str = None,
sql_query: str = None,
):
transformation = (
Transformation()
Expand All @@ -443,19 +442,17 @@ def _build_transformation(
compartment_id if compartment_id else self.compartment_id
)
.with_feature_store_id(self.id)
.with_transformation_query_input(sql_query)
)

return transformation

def create_transformation(
self,
source_code_func,
transformation_mode: TransformationMode,
source_code_func=None,
display_name: str = None,
description: str = None,
compartment_id: str = None,
sql_query: str = None,
) -> "Transformation":
"""Creates transformation resource from feature store.
Expand All @@ -471,9 +468,6 @@ def create_transformation(
description for the entity.
compartment_id: str
compartment_id for the entity.
sql_query: str
inline sql query to be passed for transformation creation,
Please ensure to use DATA_SOURCE_INPUT as FROM table name
Returns
-------
Expand All @@ -486,12 +480,11 @@ def create_transformation(
)

self.oci_transformation = self._build_transformation(
transformation_mode,
source_code_func,
transformation_mode,
display_name,
description,
compartment_id,
sql_query,
)

return self.oci_transformation.create()
Expand Down
Loading

0 comments on commit 9e5ab20

Please sign in to comment.