diff --git a/.github/workflows/playwright-mysql-e2e.yml b/.github/workflows/playwright-mysql-e2e.yml index 86171a51fda6..dd97f72597a6 100644 --- a/.github/workflows/playwright-mysql-e2e.yml +++ b/.github/workflows/playwright-mysql-e2e.yml @@ -95,7 +95,7 @@ jobs: - name: Install dependencies working-directory: openmetadata-ui/src/main/resources/ui/ - run: yarn --frozen-lockfile + run: yarn --ignore-scripts --frozen-lockfile - name: Install Playwright Browsers run: npx playwright@1.44.1 install --with-deps - name: Run Playwright tests diff --git a/.github/workflows/playwright-postgresql-e2e.yml b/.github/workflows/playwright-postgresql-e2e.yml index c67020115e11..b3ea6c978623 100644 --- a/.github/workflows/playwright-postgresql-e2e.yml +++ b/.github/workflows/playwright-postgresql-e2e.yml @@ -95,7 +95,7 @@ jobs: - name: Install dependencies working-directory: openmetadata-ui/src/main/resources/ui/ - run: yarn --frozen-lockfile + run: yarn --ignore-scripts --frozen-lockfile - name: Install Playwright Browsers run: npx playwright@1.44.1 install --with-deps - name: Run Playwright tests diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py index 86a09c17dd43..ec373acc7452 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py @@ -27,6 +27,7 @@ from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.tests.testCase import TestCase from metadata.ingestion.source.connections import get_connection +from metadata.profiler.orm.registry import Dialects from metadata.utils import fqn @@ -168,7 +169,7 @@ def get_data_diff_url(service_url: str, table_fqn) -> str: table_fqn ) # path needs to include the database AND schema in some of the connectors - if kwargs["scheme"] in ["mssql"]: + if kwargs["scheme"] in {Dialects.MSSQL, Dialects.Snowflake}: kwargs["path"] = f"/{database}/{schema}" return url._replace(**kwargs).geturl() diff --git a/ingestion/src/metadata/ingestion/source/database/common_db_source.py b/ingestion/src/metadata/ingestion/source/database/common_db_source.py index f0589adb8299..08cddc278358 100644 --- a/ingestion/src/metadata/ingestion/source/database/common_db_source.py +++ b/ingestion/src/metadata/ingestion/source/database/common_db_source.py @@ -11,8 +11,11 @@ """ Generic source to build SQL connectors. """ +import math +import time import traceback from abc import ABC +from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from typing import Any, Iterable, List, Optional, Tuple, Union, cast @@ -60,6 +63,7 @@ from metadata.ingestion.connections.session import create_and_bind_thread_safe_session from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification from metadata.ingestion.models.ometa_lineage import OMetaLineageRequest +from metadata.ingestion.models.topology import Queue from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.connections import ( get_connection, @@ -73,6 +77,7 @@ from metadata.utils import fqn from metadata.utils.db_utils import get_view_lineage from metadata.utils.execution_time_tracker import ( + ExecutionTimeTrackerContextMap, calculate_execution_time, calculate_execution_time_generator, ) @@ -583,9 +588,78 @@ def yield_table( ) ) - @calculate_execution_time_generator() - def yield_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]: - logger.info("Processing Lineage for Views") + def multithread_process_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]: + """Multithread Processing of a Node""" + + views_list = list(self.context.get().table_views or []) + views_length = len(views_list) + + if views_length != 0: + chunksize = int(math.ceil(views_length / self.source_config.threads)) + chunks = [ + views_list[i : i + chunksize] for i in range(0, views_length, chunksize) + ] + + thread_pool = ThreadPoolExecutor(max_workers=self.source_config.threads) + queue = Queue() + + futures = [ + thread_pool.submit( + self._process_view_def_chunk, + chunk, + queue, + self.context.get_current_thread_id(), + ) + for chunk in chunks + ] + + while True: + if queue.has_tasks(): + yield from queue.process() + + else: + if not futures: + break + + for i, future in enumerate(futures): + if future.done(): + future.result() + futures.pop(i) + + time.sleep(0.01) + + def _process_view_def_chunk( + self, chunk: List[TableView], queue: Queue, thread_id: int + ) -> None: + """ + Process a chunk of view definitions + """ + self.context.copy_from(thread_id) + ExecutionTimeTrackerContextMap().copy_from_parent(thread_id) + for view in [v for v in chunk if v.view_definition is not None]: + for lineage in get_view_lineage( + view=view, + metadata=self.metadata, + service_name=self.context.get().database_service, + connection_type=self.service_connection.type.value, + timeout_seconds=self.source_config.queryParsingTimeoutLimit, + ): + if lineage.right is not None: + queue.put( + Either( + right=OMetaLineageRequest( + lineage_request=lineage.right, + override_lineage=self.source_config.overrideViewLineage, + ) + ) + ) + else: + queue.put(lineage) + + def _process_view_def_serial(self) -> Iterable[Either[OMetaLineageRequest]]: + """ + Process view definitions serially + """ for view in [ v for v in self.context.get().table_views if v.view_definition is not None ]: @@ -606,6 +680,14 @@ def yield_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]: else: yield lineage + @calculate_execution_time_generator() + def yield_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]: + logger.info("Processing Lineage for Views") + if self.source_config.threads > 1: + yield from self.multithread_process_view_lineage() + else: + yield from self._process_view_def_serial() + def _get_foreign_constraints(self, foreign_columns) -> List[TableConstraint]: """ Search the referred table for foreign constraints diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py index d0500c20c171..6e0ff5101a56 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py @@ -65,9 +65,6 @@ from metadata.ingestion.source.database.incremental_metadata_extraction import ( IncrementalConfig, ) -from metadata.ingestion.source.database.life_cycle_query_mixin import ( - LifeCycleQueryMixin, -) from metadata.ingestion.source.database.multi_db_source import MultiDBSource from metadata.ingestion.source.database.snowflake.models import ( STORED_PROC_LANGUAGE_MAP, @@ -148,7 +145,6 @@ class SnowflakeSource( - LifeCycleQueryMixin, StoredProcedureMixin, ExternalTableLineageMixin, CommonDbSourceService, diff --git a/ingestion/tests/cli_e2e/base/config_builders/builders.py b/ingestion/tests/cli_e2e/base/config_builders/builders.py index 1da3ce0fefce..024f287214c5 100644 --- a/ingestion/tests/cli_e2e/base/config_builders/builders.py +++ b/ingestion/tests/cli_e2e/base/config_builders/builders.py @@ -16,6 +16,10 @@ from copy import deepcopy +from metadata.generated.schema.metadataIngestion.testSuitePipeline import ( + TestSuiteConfigType, +) + from ..e2e_types import E2EType @@ -70,6 +74,34 @@ def build(self) -> dict: return self.config +class DataQualityConfigBuilder(BaseBuilder): + """Builder class for the data quality config""" + + # pylint: disable=invalid-name + def __init__(self, config: dict, config_args: dict) -> None: + super().__init__(config, config_args) + self.test_case_defintions = self.config_args.get("test_case_definitions", []) + self.entity_fqn = self.config_args.get("entity_fqn", []) + + # pylint: enable=invalid-name + + def build(self) -> dict: + """build profiler config""" + del self.config["source"]["sourceConfig"]["config"] + self.config["source"]["sourceConfig"] = { + "config": { + "type": TestSuiteConfigType.TestSuite.value, + "entityFullyQualifiedName": self.entity_fqn, + }, + } + + self.config["processor"] = { + "type": "orm-test-runner", + "config": {"testCases": self.test_case_defintions}, + } + return self.config + + class SchemaConfigBuilder(BaseBuilder): """Builder for schema filter config""" @@ -147,6 +179,7 @@ def builder_factory(builder, config: dict, config_args: dict): """Factory method to return the builder class""" builder_classes = { E2EType.PROFILER.value: ProfilerConfigBuilder, + E2EType.DATA_QUALITY.value: DataQualityConfigBuilder, E2EType.INGEST_DB_FILTER_SCHEMA.value: SchemaConfigBuilder, E2EType.INGEST_DB_FILTER_TABLE.value: TableConfigBuilder, E2EType.INGEST_DB_FILTER_MIX.value: MixConfigBuilder, diff --git a/ingestion/tests/cli_e2e/base/e2e_types.py b/ingestion/tests/cli_e2e/base/e2e_types.py index 81b7eb14890f..442c5c27b884 100644 --- a/ingestion/tests/cli_e2e/base/e2e_types.py +++ b/ingestion/tests/cli_e2e/base/e2e_types.py @@ -24,6 +24,7 @@ class E2EType(Enum): INGEST = "ingest" PROFILER = "profiler" PROFILER_PROCESSOR = "profiler-processor" + DATA_QUALITY = "test" INGEST_DB_FILTER_SCHEMA = "ingest-db-filter-schema" INGEST_DB_FILTER_TABLE = "ingest-db-filter-table" INGEST_DB_FILTER_MIX = "ingest-db-filter-mix" diff --git a/ingestion/tests/cli_e2e/base/test_cli_db.py b/ingestion/tests/cli_e2e/base/test_cli_db.py index 02ebfa6d40ed..0bc5eb23f0fd 100644 --- a/ingestion/tests/cli_e2e/base/test_cli_db.py +++ b/ingestion/tests/cli_e2e/base/test_cli_db.py @@ -17,8 +17,13 @@ from unittest import TestCase import pytest +from pydantic import TypeAdapter +from _openmetadata_testutils.pydantic.test_utils import assert_equal_pydantic_objects +from metadata.data_quality.api.models import TestCaseDefinition from metadata.generated.schema.entity.data.table import Table +from metadata.generated.schema.tests.basic import TestCaseResult +from metadata.generated.schema.tests.testCase import TestCase as OMTestCase from metadata.ingestion.api.status import Status from .e2e_types import E2EType @@ -208,6 +213,50 @@ def test_profiler_with_time_partition(self) -> None: sink_status, ) + @pytest.mark.order(12) + def test_data_quality(self) -> None: + """12. Test data quality for the connector""" + if self.get_data_quality_table() is None: + return + self.delete_table_and_view() + self.create_table_and_view() + table: Table = self.openmetadata.get_by_name( + Table, self.get_data_quality_table(), nullable=False + ) + self.build_config_file() + self.run_command() + test_case_definitions = self.get_test_case_definitions() + self.build_config_file( + E2EType.DATA_QUALITY, + { + "entity_fqn": table.fullyQualifiedName.root, + "test_case_definitions": TypeAdapter( + List[TestCaseDefinition] + ).dump_python(test_case_definitions), + }, + ) + result = self.run_command("test") + sink_status, source_status = self.retrieve_statuses(result) + self.assert_status_for_data_quality(source_status, sink_status) + test_case_entities = [ + self.openmetadata.get_by_name( + OMTestCase, + ".".join([table.fullyQualifiedName.root, tcd.name]), + fields=["*"], + nullable=False, + ) + for tcd in test_case_definitions + ] + expected = self.get_expected_test_case_results() + try: + for test_case, expected in zip(test_case_entities, expected): + assert_equal_pydantic_objects(expected, test_case.testCaseResult) + finally: + for tc in test_case_entities: + self.openmetadata.delete( + OMTestCase, tc.id, recursive=True, hard_delete=True + ) + def retrieve_table(self, table_name_fqn: str) -> Table: return self.openmetadata.get_by_name(entity=Table, fqn=table_name_fqn) @@ -346,3 +395,15 @@ def get_profiler_processor_config(self, config: dict) -> dict: "config": {"tableConfig": [config]}, } } + + def get_data_quality_table(self): + return None + + def get_test_case_definitions(self) -> List[TestCaseDefinition]: + pass + + def get_expected_test_case_results(self) -> List[TestCaseResult]: + pass + + def assert_status_for_data_quality(self, source_status, sink_status): + pass diff --git a/ingestion/tests/cli_e2e/test_cli_snowflake.py b/ingestion/tests/cli_e2e/test_cli_snowflake.py index dc27e2fee9d8..f4467d9d8de6 100644 --- a/ingestion/tests/cli_e2e/test_cli_snowflake.py +++ b/ingestion/tests/cli_e2e/test_cli_snowflake.py @@ -20,9 +20,12 @@ from _openmetadata_testutils.pydantic.test_utils import assert_equal_pydantic_objects from metadata.generated.schema.entity.data.table import DmlOperationType, SystemProfile +from metadata.generated.schema.tests.basic import TestCaseResult, TestCaseStatus +from metadata.generated.schema.tests.testCase import TestCaseParameterValue from metadata.generated.schema.type.basic import Timestamp from metadata.ingestion.api.status import Status +from ...src.metadata.data_quality.api.models import TestCaseDefinition from .base.e2e_types import E2EType from .common.test_cli_db import CliCommonDB from .common_e2e_sqa_mixins import SQACommonMethods @@ -65,7 +68,7 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods): insert_data_queries: List[str] = [ "INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1,'Peter Parker');", - "INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1, 'Clark Kent');", + "INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (2, 'Clark Kent');", "INSERT INTO e2e_test.e2e_table (varchar_column, int_column) VALUES ('e2e_test.e2e_table', 1);", "INSERT INTO public.e2e_table (varchar_column, int_column) VALUES ('public.e2e_table', 1);", "INSERT INTO e2e_table (varchar_column, int_column) VALUES ('e2e_table', 1);", @@ -316,3 +319,28 @@ def wait_for_query_log(cls, timeout=600): ) if (datetime.now().timestamp() - start) > timeout: raise TimeoutError(f"Query log not updated for {timeout} seconds") + + def get_data_quality_table(self): + return "e2e_snowflake.E2E_DB.E2E_TEST.PERSONS" + + def get_test_case_definitions(self) -> List[TestCaseDefinition]: + return [ + TestCaseDefinition( + name="snowflake_data_diff", + testDefinitionName="tableDiff", + computePassedFailedRowCount=True, + parameterValues=[ + TestCaseParameterValue( + name="table2", + value=self.get_data_quality_table(), + ), + TestCaseParameterValue( + name="keyColumns", + value='["PERSON_ID"]', + ), + ], + ) + ] + + def get_expected_test_case_results(self): + return [TestCaseResult(testCaseStatus=TestCaseStatus.Success)] diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthenticationCodeFlowHandler.java b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthenticationCodeFlowHandler.java index 0e77094f39d3..23ae73f0e326 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthenticationCodeFlowHandler.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/security/AuthenticationCodeFlowHandler.java @@ -119,6 +119,8 @@ public class AuthenticationCodeFlowHandler { public static final String DEFAULT_PRINCIPAL_DOMAIN = "openmetadata.org"; public static final String OIDC_CREDENTIAL_PROFILE = "oidcCredentialProfile"; + public static final String SESSION_REDIRECT_URI = "sessionRedirectUri"; + public static final String REDIRECT_URI_KEY = "redirectUri"; private final OidcClient client; private final List claimsOrder; private final Map claimsMapping; @@ -247,11 +249,12 @@ private OidcClient buildOidcClient(OidcClientConfig clientConfig) { // Login public void handleLogin(HttpServletRequest req, HttpServletResponse resp) { try { + checkAndStoreRedirectUriInSession(req); LOG.debug("Performing Auth Login For User Session: {} ", req.getSession().getId()); Optional credentials = getUserCredentialsFromSession(req); if (credentials.isPresent()) { LOG.debug("Auth Tokens Located from Session: {} ", req.getSession().getId()); - sendRedirectWithToken(resp, credentials.get()); + sendRedirectWithToken(req, resp, credentials.get()); } else { LOG.debug("Performing Auth Code Flow to Idp: {} ", req.getSession().getId()); Map params = buildLoginParams(); @@ -278,6 +281,15 @@ public void handleLogin(HttpServletRequest req, HttpServletResponse resp) { } } + private void checkAndStoreRedirectUriInSession(HttpServletRequest request) { + String redirectUri = request.getParameter(REDIRECT_URI_KEY); + if (nullOrEmpty(redirectUri)) { + throw new TechnicalException("Redirect URI is required"); + } + + request.getSession().setAttribute(SESSION_REDIRECT_URI, redirectUri); + } + // Callback public void handleCallback(HttpServletRequest req, HttpServletResponse resp) { try { @@ -322,7 +334,7 @@ public void handleCallback(HttpServletRequest req, HttpServletResponse resp) { req.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials); // Redirect - sendRedirectWithToken(resp, credentials); + sendRedirectWithToken(req, resp, credentials); } catch (Exception e) { getErrorMessage(resp, e); } @@ -371,7 +383,7 @@ public void handleRefresh( LOG.debug( "Credentials Not Found For User Session: {}, Redirect to Logout ", httpServletRequest.getSession().getId()); - httpServletResponse.sendRedirect(String.format("%s/logout", serverUrl)); + this.handleLogout(httpServletRequest, httpServletResponse); } } catch (Exception e) { getErrorMessage(httpServletResponse, new TechnicalException(e)); @@ -638,7 +650,8 @@ public static void getErrorMessage(HttpServletResponse resp, Exception e) { "

[Auth Callback Servlet] Failed in Auth Login : %s

", e.getMessage())); } - private void sendRedirectWithToken(HttpServletResponse response, OidcCredentials credentials) + private void sendRedirectWithToken( + HttpServletRequest request, HttpServletResponse response, OidcCredentials credentials) throws ParseException, IOException { JWT jwt = credentials.getIdToken(); Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); @@ -647,10 +660,12 @@ private void sendRedirectWithToken(HttpServletResponse response, OidcCredentials String userName = findUserNameFromClaims(claimsMapping, claimsOrder, claims); String email = findEmailFromClaims(claimsMapping, claimsOrder, claims, principalDomain); + String redirectUri = (String) request.getSession().getAttribute(SESSION_REDIRECT_URI); + String url = String.format( - "%s/auth/callback?id_token=%s&email=%s&name=%s", - serverUrl, credentials.getIdToken().getParsedString(), email, userName); + "%s?id_token=%s&email=%s&name=%s", + redirectUri, credentials.getIdToken().getParsedString(), email, userName); response.sendRedirect(url); } diff --git a/openmetadata-ui/src/main/resources/ui/src/components/Auth/AppAuthenticators/GenericAuthenticator.tsx b/openmetadata-ui/src/main/resources/ui/src/components/Auth/AppAuthenticators/GenericAuthenticator.tsx index d5264c8b6b27..f70c1757f732 100644 --- a/openmetadata-ui/src/main/resources/ui/src/components/Auth/AppAuthenticators/GenericAuthenticator.tsx +++ b/openmetadata-ui/src/main/resources/ui/src/components/Auth/AppAuthenticators/GenericAuthenticator.tsx @@ -34,7 +34,8 @@ export const GenericAuthenticator = forwardRef( const handleLogin = () => { setIsAuthenticated(false); setIsSigningUp(true); - window.location.assign('api/v1/auth/login'); + const redirectUri = `${window.location.origin}${ROUTES.AUTH_CALLBACK}`; + window.location.assign(`api/v1/auth/login?redirectUri=${redirectUri}`); }; const handleLogout = async () => {