Skip to content

Commit

Permalink
MINOR: Apply multi threading in View Lineage Processing (#17922)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulixius9 committed Sep 19, 2024
1 parent 60434fe commit 1437bab
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
"""
Generic source to build SQL connectors.
"""
import math
import time
import traceback
from abc import ABC
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import Any, Iterable, List, Optional, Tuple, Union, cast

Expand Down Expand Up @@ -60,6 +63,7 @@
from metadata.ingestion.connections.session import create_and_bind_thread_safe_session
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.models.ometa_lineage import OMetaLineageRequest
from metadata.ingestion.models.topology import Queue
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import (
get_connection,
Expand All @@ -73,6 +77,7 @@
from metadata.utils import fqn
from metadata.utils.db_utils import get_view_lineage
from metadata.utils.execution_time_tracker import (
ExecutionTimeTrackerContextMap,
calculate_execution_time,
calculate_execution_time_generator,
)
Expand Down Expand Up @@ -583,9 +588,78 @@ def yield_table(
)
)

@calculate_execution_time_generator()
def yield_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]:
logger.info("Processing Lineage for Views")
def multithread_process_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]:
"""Multithread Processing of a Node"""

views_list = list(self.context.get().table_views or [])
views_length = len(views_list)

if views_length != 0:
chunksize = int(math.ceil(views_length / self.source_config.threads))
chunks = [
views_list[i : i + chunksize] for i in range(0, views_length, chunksize)
]

thread_pool = ThreadPoolExecutor(max_workers=self.source_config.threads)
queue = Queue()

futures = [
thread_pool.submit(
self._process_view_def_chunk,
chunk,
queue,
self.context.get_current_thread_id(),
)
for chunk in chunks
]

while True:
if queue.has_tasks():
yield from queue.process()

else:
if not futures:
break

for i, future in enumerate(futures):
if future.done():
future.result()
futures.pop(i)

time.sleep(0.01)

def _process_view_def_chunk(
self, chunk: List[TableView], queue: Queue, thread_id: int
) -> None:
"""
Process a chunk of view definitions
"""
self.context.copy_from(thread_id)
ExecutionTimeTrackerContextMap().copy_from_parent(thread_id)
for view in [v for v in chunk if v.view_definition is not None]:
for lineage in get_view_lineage(
view=view,
metadata=self.metadata,
service_name=self.context.get().database_service,
connection_type=self.service_connection.type.value,
timeout_seconds=self.source_config.queryParsingTimeoutLimit,
):
if lineage.right is not None:
queue.put(
Either(
right=OMetaLineageRequest(
lineage_request=lineage.right,
override_lineage=self.source_config.overrideViewLineage,
)
)
)
else:
queue.put(lineage)

def _process_view_def_serial(self) -> Iterable[Either[OMetaLineageRequest]]:
"""
Process view definitions serially
"""
for view in [
v for v in self.context.get().table_views if v.view_definition is not None
]:
Expand All @@ -606,6 +680,14 @@ def yield_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]:
else:
yield lineage

@calculate_execution_time_generator()
def yield_view_lineage(self) -> Iterable[Either[OMetaLineageRequest]]:
logger.info("Processing Lineage for Views")
if self.source_config.threads > 1:
yield from self.multithread_process_view_lineage()
else:
yield from self._process_view_def_serial()

def _get_foreign_constraints(self, foreign_columns) -> List[TableConstraint]:
"""
Search the referred table for foreign constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@
from metadata.ingestion.source.database.incremental_metadata_extraction import (
IncrementalConfig,
)
from metadata.ingestion.source.database.life_cycle_query_mixin import (
LifeCycleQueryMixin,
)
from metadata.ingestion.source.database.multi_db_source import MultiDBSource
from metadata.ingestion.source.database.snowflake.models import (
STORED_PROC_LANGUAGE_MAP,
Expand Down Expand Up @@ -148,7 +145,6 @@


class SnowflakeSource(
LifeCycleQueryMixin,
StoredProcedureMixin,
ExternalTableLineageMixin,
CommonDbSourceService,
Expand Down

0 comments on commit 1437bab

Please sign in to comment.