From 05d5032f750c86c1b930a4ac45382c64e1a28931 Mon Sep 17 00:00:00 2001 From: hvrai Date: Tue, 19 Sep 2023 08:38:05 +0530 Subject: [PATCH] Fixing integration tests --- ads/feature_store/dataset.py | 5 +++-- .../spark/spark_execution.py | 20 +++++++++++-------- tests/integration/feature_store/test_base.py | 4 ++-- .../feature_store/test_dataset_complex.py | 4 +--- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/ads/feature_store/dataset.py b/ads/feature_store/dataset.py index 565d91ef1..4275d3e37 100644 --- a/ads/feature_store/dataset.py +++ b/ads/feature_store/dataset.py @@ -865,10 +865,11 @@ def _update_from_oci_dataset_model(self, oci_dataset: OCIDataset) -> "Dataset": features_list.append(output_feature) value = {self.CONST_ITEMS: features_list} - else: + elif infra_attr == self.CONST_FEATURE_GROUP: value = getattr(self.oci_dataset, dsc_attr) + else: + value = dataset_details[infra_attr] self.set_spec(infra_attr, value) - return self def materialise( diff --git a/ads/feature_store/execution_strategy/spark/spark_execution.py b/ads/feature_store/execution_strategy/spark/spark_execution.py index 687c8b496..caa74dd46 100644 --- a/ads/feature_store/execution_strategy/spark/spark_execution.py +++ b/ads/feature_store/execution_strategy/spark/spark_execution.py @@ -1,7 +1,5 @@ #!/usr/bin/env python # -*- coding: utf-8; -*- -import json - # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ @@ -29,8 +27,6 @@ raise from ads.feature_store.common.enums import ( - FeatureStoreJobType, - LifecycleState, EntityType, ExpectationType, ) @@ -47,6 +43,11 @@ from ads.feature_store.feature_statistics.statistics_service import StatisticsService from ads.feature_store.common.utils.utility import validate_input_feature_details +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ads.feature_store.feature_group import FeatureGroup + from ads.feature_store.dataset import Dataset logger = logging.getLogger(__name__) @@ -76,7 +77,10 @@ def __init__(self, metastore_id: str = None): self._jvm = self._spark_context._jvm def ingest_feature_definition( - self, feature_group, feature_group_job: FeatureGroupJob, dataframe + self, + feature_group: "FeatureGroup", + feature_group_job: FeatureGroupJob, + dataframe, ): try: self._save_offline_dataframe(dataframe, feature_group, feature_group_job) @@ -90,7 +94,7 @@ def ingest_dataset(self, dataset, dataset_job: DatasetJob): raise SparkExecutionException(e).with_traceback(e.__traceback__) def delete_feature_definition( - self, feature_group, feature_group_job: FeatureGroupJob + self, feature_group: "FeatureGroup", feature_group_job: FeatureGroupJob ): """ Deletes a feature definition from the system. @@ -122,7 +126,7 @@ def delete_feature_definition( output_details=output_details, ) - def delete_dataset(self, dataset, dataset_job: DatasetJob): + def delete_dataset(self, dataset: "Dataset", dataset_job: DatasetJob): """ Deletes a dataset from the system. @@ -154,7 +158,7 @@ def delete_dataset(self, dataset, dataset_job: DatasetJob): ) @staticmethod - def _validate_expectation(expectation_type, validation_output): + def _validate_expectation(expectation_type, validation_output: dict): """ Validates the expectation based on the given expectation type and the validation output. diff --git a/tests/integration/feature_store/test_base.py b/tests/integration/feature_store/test_base.py index 4055db46d..6ea257e86 100644 --- a/tests/integration/feature_store/test_base.py +++ b/tests/integration/feature_store/test_base.py @@ -22,8 +22,8 @@ client_kwargs = dict( - retry_strategy=oci.retry.NoneRetryStrategy, - service_endpoint=os.getenv("service_endpoint"), + retry_strategy=oci.retry.NoneRetryStrategy(), + fs_service_endpoint=os.getenv("service_endpoint"), ) ads.set_auth(client_kwargs=client_kwargs) diff --git a/tests/integration/feature_store/test_dataset_complex.py b/tests/integration/feature_store/test_dataset_complex.py index 26d1fe99b..315a87ffa 100644 --- a/tests/integration/feature_store/test_dataset_complex.py +++ b/tests/integration/feature_store/test_dataset_complex.py @@ -70,8 +70,6 @@ def test_manual_dataset( ).create() assert len(dataset_resource.feature_groups) == 1 assert dataset_resource.feature_groups[0].id == feature_group.id - assert dataset_resource.get_spec( - Dataset.CONST_FEATURE_GROUP - ).is_manual_association + assert dataset_resource.is_manual_association dataset_resource.delete() return dataset_resource