diff --git a/dltflow/quality/dlt_meta.py b/dltflow/quality/dlt_meta.py index 2a33d1a..02cc2cc 100644 --- a/dltflow/quality/dlt_meta.py +++ b/dltflow/quality/dlt_meta.py @@ -52,6 +52,7 @@ def orchestrate(self): import inspect import typing as t import logging +from copy import deepcopy from functools import wraps from warnings import warn @@ -363,12 +364,10 @@ def table_view_expectation_wrapper(self, child_function, execution_config, *args self._logger.debug(f'Expectations provided. Applying DLT expectations to {child_function.__name__}.') return execution_config.dlt_config.expectation_function( execution_config.dlt_config.dlt_expectations - )( - execution_config.table_or_view_func( + )(execution_config.table_or_view_func( child_function, **execution_config.dlt_config.write_opts.model_dump(exclude_none=True), ) - ) else: self._logger.debug(f'Expectations not provided. Applying DLT expectations to {child_function.__name__}.') @@ -437,8 +436,13 @@ def streaming_table_expectation_wrapper(self, child_function, execution_config): elif execution_config.dlt_config.apply_chg_config: dlt.create_streaming_table(name='target') + # rename view with a prefix of vw_ to separate out + # the user table target from the intermediate target + vw_execution_config = execution_config + vw_execution_config.dlt_config.apply_chg_config.target = f'vw_{execution_config.dlt_config.apply_chg_config.target}' + view_with_expectations = self.table_view_expectation_wrapper( - child_function, execution_config + child_function, vw_execution_config ) dlt.apply_changes( diff --git a/tests/unit/quality/test_dlt_meta.py b/tests/unit/quality/test_dlt_meta.py index 0216f82..1270931 100644 --- a/tests/unit/quality/test_dlt_meta.py +++ b/tests/unit/quality/test_dlt_meta.py @@ -203,7 +203,7 @@ def test_dlt_calls_streaming_table_append_flow( def test_dlt_calls_streaming_table_apply_changes( - pipeline_instance, pipeline_config + pipeline_instance, pipeline_config, sample_df ): # pragma: no cover """ This test checks that the DLT calls are being made correctly. @@ -223,11 +223,21 @@ def test_dlt_calls_streaming_table_apply_changes( with patch("dltflow.quality.dlt_meta.dlt.create_streaming_table") as mock_streaming_table: pipeline_instance = MyPipeline(init_conf=pipeline_config) - assert mock_streaming_table.call_count == 1 - - with patch("dltflow.quality.dlt_meta.dlt.apply_changes") as mock_apply_changes: - out_df = pipeline_instance.orchestrate() - mock_apply_changes.assert_called() + with patch.object( + pipeline_instance._execution_conf[0].dlt_config, + "_expectation_function", + autospec=True, + ) as mock_expect: + with patch.object( + pipeline_instance._execution_conf[0], "table_or_view_func", autospec=True + ) as mock_table: + mock_table.return_value = sample_df + mock_expect.return_value = lambda *args, **kwargs: sample_df + assert mock_streaming_table.call_count == 1 + + with patch("dltflow.quality.dlt_meta.dlt.apply_changes") as mock_apply_changes: + out_df = pipeline_instance.orchestrate() + mock_apply_changes.assert_called()