diff --git a/pyproject.toml b/pyproject.toml index 96ba16a6d..b1da5aa71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,8 +49,7 @@ dependencies = [ "coverage[toml]>=6.5", "pytest", "pytest-cov>=5.0.0,<6.0.0", -# disabling asyncio, see https://github.com/pytest-dev/pytest-asyncio/issues/1015 -# "pytest-asyncio>=0.24.0", + "pytest-asyncio>=0.24.0", "black>=23.1.0", "ruff>=0.0.243", "databricks-connect==15.1", @@ -65,7 +64,10 @@ dependencies = [ reconcile = "databricks.labs.remorph.reconcile.execute:main" [tool.hatch.envs.default.scripts] -test = "pytest --cov src --cov-report=xml tests/unit" +# temporarily disabling coverage in unit testing because collection of coverage results crashes the CI +# see https://github.com/pytest-dev/pytest-asyncio/issues/1015 +# test = "pytest --cov src --cov-report=xml tests/unit" +test = "pytest src tests/unit" coverage = "pytest --cov src tests/unit --cov-report=html" integration = "pytest --cov src tests/integration --durations 20" fmt = ["black .", diff --git a/src/databricks/labs/remorph/cli.py b/src/databricks/labs/remorph/cli.py index c2841c453..11cdc06bd 100644 --- a/src/databricks/labs/remorph/cli.py +++ b/src/databricks/labs/remorph/cli.py @@ -1,3 +1,4 @@ +import asyncio import json import os from pathlib import Path @@ -80,8 +81,10 @@ def transpile( mode=mode, sdk_config=sdk_config, ) + status, errors = asyncio.run(do_transpile(ctx.workspace_client, engine, config)) - status = do_transpile(ctx.workspace_client, engine, config) + for error in errors: + print(str(error)) print(json.dumps(status)) diff --git a/src/databricks/labs/remorph/transpiler/execute.py b/src/databricks/labs/remorph/transpiler/execute.py index 02986644e..671d0c2a2 100644 --- a/src/databricks/labs/remorph/transpiler/execute.py +++ b/src/databricks/labs/remorph/transpiler/execute.py @@ -2,6 +2,7 @@ import logging import os from pathlib import Path +from typing import Any from databricks.labs.remorph.__about__ import __version__ from databricks.labs.remorph.config import ( @@ -34,7 +35,7 @@ logger = logging.getLogger(__name__) -def _process_file( +async def _process_file( config: TranspileConfig, validator: Validator | None, transpiler: TranspileEngine, @@ -47,8 +48,8 @@ def _process_file( with input_path.open("r") as f: source_sql = remove_bom(f.read()) - transpile_result = asyncio.run( - _transpile(transpiler, config.source_dialect, config.target_dialect, source_sql, input_path) + transpile_result = await _transpile( + transpiler, config.source_dialect, config.target_dialect, source_sql, input_path ) error_list.extend(transpile_result.error_list) @@ -72,7 +73,7 @@ def _process_file( return transpile_result.success_count, error_list -def _process_directory( +async def _process_directory( config: TranspileConfig, validator: Validator | None, transpiler: TranspileEngine, @@ -94,14 +95,14 @@ def _process_directory( continue output_file_name = output_folder_base / file.name - success_count, error_list = _process_file(config, validator, transpiler, file, output_file_name) + success_count, error_list = await _process_file(config, validator, transpiler, file, output_file_name) counter = counter + success_count all_errors.extend(error_list) return counter, all_errors -def _process_input_dir(config: TranspileConfig, validator: Validator | None, transpiler: TranspileEngine): +async def _process_input_dir(config: TranspileConfig, validator: Validator | None, transpiler: TranspileEngine): error_list = [] file_list = [] counter = 0 @@ -113,13 +114,13 @@ def _process_input_dir(config: TranspileConfig, validator: Validator | None, tra msg = f"Processing for sqls under this folder: {folder}" logger.info(msg) file_list.extend(files) - no_of_sqls, errors = _process_directory(config, validator, transpiler, root, base_root, files) + no_of_sqls, errors = await _process_directory(config, validator, transpiler, root, base_root, files) counter = counter + no_of_sqls error_list.extend(errors) return TranspileStatus(file_list, counter, error_list) -def _process_input_file( +async def _process_input_file( config: TranspileConfig, validator: Validator | None, transpiler: TranspileEngine ) -> TranspileStatus: if not is_sql_file(config.input_path): @@ -136,12 +137,23 @@ def _process_input_file( make_dir(output_path) output_file = output_path / config.input_path.name - no_of_sqls, error_list = _process_file(config, validator, transpiler, config.input_path, output_file) + no_of_sqls, error_list = await _process_file(config, validator, transpiler, config.input_path, output_file) return TranspileStatus([config.input_path], no_of_sqls, error_list) @timeit -def transpile(workspace_client: WorkspaceClient, engine: TranspileEngine, config: TranspileConfig): +async def transpile( + workspace_client: WorkspaceClient, engine: TranspileEngine, config: TranspileConfig +) -> tuple[list[dict[str, Any]], list[TranspileError]]: + await engine.initialize(config) + status, errors = await _do_transpile(workspace_client, engine, config) + await engine.shutdown() + return status, errors + + +async def _do_transpile( + workspace_client: WorkspaceClient, engine: TranspileEngine, config: TranspileConfig +) -> tuple[list[dict[str, Any]], list[TranspileError]]: """ [Experimental] Transpiles the SQL queries from one dialect to another. @@ -163,9 +175,9 @@ def transpile(workspace_client: WorkspaceClient, engine: TranspileEngine, config if config.input_source is None: raise InvalidInputException("Missing input source!") if config.input_path.is_dir(): - result = _process_input_dir(config, validator, engine) + result = await _process_input_dir(config, validator, engine) elif config.input_path.is_file(): - result = _process_input_file(config, validator, engine) + result = await _process_input_file(config, validator, engine) else: msg = f"{config.input_source} does not exist." logger.error(msg) @@ -191,7 +203,7 @@ def transpile(workspace_client: WorkspaceClient, engine: TranspileEngine, config "error_log_file": str(error_log_file), } ) - return status + return status, result.error_list def verify_workspace_client(workspace_client: WorkspaceClient) -> WorkspaceClient: @@ -209,9 +221,9 @@ def verify_workspace_client(workspace_client: WorkspaceClient) -> WorkspaceClien async def _transpile( - transpiler: TranspileEngine, from_dialect: str, to_dialect: str, source_code: str, input_path: Path + engine: TranspileEngine, from_dialect: str, to_dialect: str, source_code: str, input_path: Path ) -> TranspileResult: - return await transpiler.transpile(from_dialect, to_dialect, source_code, input_path) + return await engine.transpile(from_dialect, to_dialect, source_code, input_path) def _validation( diff --git a/src/databricks/labs/remorph/transpiler/lsp/lsp_engine.py b/src/databricks/labs/remorph/transpiler/lsp/lsp_engine.py index 1bb20b0c3..6d0735bda 100644 --- a/src/databricks/labs/remorph/transpiler/lsp/lsp_engine.py +++ b/src/databricks/labs/remorph/transpiler/lsp/lsp_engine.py @@ -30,7 +30,10 @@ TextDocumentIdentifier, METHOD_TO_TYPES, LanguageKind, + Range as LSPRange, + Position as LSPPosition, _SPECIAL_PROPERTIES, + DiagnosticSeverity, ) from pygls.lsp.client import BaseLanguageClient from pygls.exceptions import FeatureRequestError @@ -40,6 +43,13 @@ from databricks.labs.remorph.config import TranspileConfig, TranspileResult from databricks.labs.remorph.errors.exceptions import IllegalStateException from databricks.labs.remorph.transpiler.transpile_engine import TranspileEngine +from databricks.labs.remorph.transpiler.transpile_status import ( + TranspileError, + ErrorKind, + ErrorSeverity, + CodeRange, + CodePosition, +) logger = logging.getLogger(__name__) @@ -231,6 +241,46 @@ def _apply(cls, lines: list[str], change: TextEdit) -> list[str]: return result +class DiagnosticConverter(abc.ABC): + + _KIND_NAMES = {e.name for e in ErrorKind} + + @classmethod + def apply(cls, file_path: Path, diagnostic: Diagnostic) -> TranspileError: + code = str(diagnostic.code) + kind = ErrorKind.INTERNAL + parts = code.split("-") + if len(parts) >= 2 and parts[0] in cls._KIND_NAMES: + kind = ErrorKind[parts[0]] + parts.pop(0) + code = "-".join(parts) + severity = cls._convert_severity(diagnostic.severity) + lsp_range = cls._convert_range(diagnostic.range) + return TranspileError( + code=code, kind=kind, severity=severity, path=file_path, message=diagnostic.message, range=lsp_range + ) + + @classmethod + def _convert_range(cls, lsp_range: LSPRange | None) -> CodeRange | None: + if not lsp_range: + return None + return CodeRange(cls._convert_position(lsp_range.start), cls._convert_position(lsp_range.end)) + + @classmethod + def _convert_position(cls, lsp_position: LSPPosition) -> CodePosition: + return CodePosition(lsp_position.line, lsp_position.character) + + @classmethod + def _convert_severity(cls, severity: DiagnosticSeverity | None) -> ErrorSeverity: + if severity == DiagnosticSeverity.Information: + return ErrorSeverity.INFO + if severity == DiagnosticSeverity.Warning: + return ErrorSeverity.WARNING + if severity == DiagnosticSeverity.Error: + return ErrorSeverity.ERROR + return ErrorSeverity.INFO + + class LSPEngine(TranspileEngine): @classmethod @@ -320,7 +370,8 @@ async def transpile( response = await self.transpile_document(file_path) self.close_document(file_path) transpiled_code = ChangeManager.apply(source_code, response.changes) - return TranspileResult(transpiled_code, 1, []) + transpile_errors = [DiagnosticConverter.apply(file_path, diagnostic) for diagnostic in response.diagnostics] + return TranspileResult(transpiled_code, 1, transpile_errors) def analyse_table_lineage( self, source_dialect: str, source_code: str, file_path: Path diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/sqlglot_engine.py b/src/databricks/labs/remorph/transpiler/sqlglot/sqlglot_engine.py index b0f52cc68..8e3094e5b 100644 --- a/src/databricks/labs/remorph/transpiler/sqlglot/sqlglot_engine.py +++ b/src/databricks/labs/remorph/transpiler/sqlglot/sqlglot_engine.py @@ -9,7 +9,7 @@ from sqlglot.expressions import Expression from sqlglot.tokens import Token, TokenType -from databricks.labs.remorph.config import TranspileResult +from databricks.labs.remorph.config import TranspileResult, TranspileConfig from databricks.labs.remorph.helpers.string_utils import format_error_message from databricks.labs.remorph.transpiler.sqlglot import lca_utils from databricks.labs.remorph.transpiler.sqlglot.dialect_utils import get_dialect @@ -69,6 +69,12 @@ def _partial_transpile( problem_list.append(ParserProblem(parsed_expression.original_sql, error)) return transpiled_sqls, problem_list + async def initialize(self, config: TranspileConfig) -> None: + pass + + async def shutdown(self) -> None: + pass + async def transpile( self, source_dialect: str, target_dialect: str, source_code: str, file_path: Path ) -> TranspileResult: diff --git a/src/databricks/labs/remorph/transpiler/transpile_engine.py b/src/databricks/labs/remorph/transpiler/transpile_engine.py index b2b8fbb71..0057aded8 100644 --- a/src/databricks/labs/remorph/transpiler/transpile_engine.py +++ b/src/databricks/labs/remorph/transpiler/transpile_engine.py @@ -3,7 +3,7 @@ from collections.abc import Iterable from pathlib import Path -from databricks.labs.remorph.config import TranspileResult +from databricks.labs.remorph.config import TranspileResult, TranspileConfig class TranspileEngine(abc.ABC): @@ -30,6 +30,12 @@ def analyse_table_lineage( self, source_dialect: str, source_code: str, file_path: Path ) -> Iterable[tuple[str, str]]: ... + @abc.abstractmethod + async def initialize(self, config: TranspileConfig) -> None: ... + + @abc.abstractmethod + async def shutdown(self) -> None: ... + @abc.abstractmethod async def transpile( self, source_dialect: str, target_dialect: str, source_code: str, file_path: Path diff --git a/tests/resources/lsp_transpiler/internal.sql b/tests/resources/lsp_transpiler/internal.sql new file mode 100644 index 000000000..fc563464e --- /dev/null +++ b/tests/resources/lsp_transpiler/internal.sql @@ -0,0 +1 @@ +create table stuff(name varchar(12)) diff --git a/tests/resources/lsp_transpiler/lsp_server.py b/tests/resources/lsp_transpiler/lsp_server.py index 488cbf64b..e5f02adc1 100644 --- a/tests/resources/lsp_transpiler/lsp_server.py +++ b/tests/resources/lsp_transpiler/lsp_server.py @@ -1,6 +1,7 @@ import os import sys from collections.abc import Sequence +from pathlib import Path from typing import Any, Literal from uuid import uuid4 @@ -23,6 +24,7 @@ Position, METHOD_TO_TYPES, _SPECIAL_PROPERTIES, + DiagnosticSeverity, ) from pygls.lsp.server import LanguageServer @@ -112,14 +114,39 @@ async def did_initialize(self, init_params: InitializeParams) -> None: def transpile_to_databricks(self, params: TranspileDocumentParams) -> TranspileDocumentResult: source_sql = self.workspace.get_text_document(params.uri).source source_lines = source_sql.split("\n") - transpiled_sql = source_sql.upper() - changes = [ - TextEdit( - range=Range(start=Position(0, 0), end=Position(len(source_lines), len(source_lines[-1]))), - new_text=transpiled_sql, + range = Range(start=Position(0, 0), end=Position(len(source_lines), len(source_lines[-1]))) + transpiled_sql, diagnostics = self._transpile(Path(params.uri).name, range, source_sql) + changes = [TextEdit(range=range, new_text=transpiled_sql)] + return TranspileDocumentResult(uri=params.uri, changes=changes, diagnostics=diagnostics) + + def _transpile(self, file_name: str, lsp_range: Range, source_sql: str) -> tuple[str, list[Diagnostic]]: + if file_name == "no_transpile.sql": + diagnostic = Diagnostic( + range=lsp_range, + message="No transpilation required", + severity=DiagnosticSeverity.Information, + code="GENERATION-NOT_REQUIRED", ) - ] - return TranspileDocumentResult(uri=params.uri, changes=changes, diagnostics=[]) + return source_sql, [diagnostic] + elif file_name == "unsupported_lca.sql": + diagnostic = Diagnostic( + range=lsp_range, + message="LCA conversion not supported", + severity=DiagnosticSeverity.Error, + code="ANALYSIS-UNSUPPORTED_LCA", + ) + return source_sql, [diagnostic] + elif file_name == "internal.sql": + diagnostic = Diagnostic( + range=lsp_range, + message="Something went wrong", + severity=DiagnosticSeverity.Warning, + code="SOME_ERROR_CODE", + ) + return source_sql, [diagnostic] + else: + # general test case + return source_sql.upper(), [] server = TestLspServer("test-lsp-server", "v0.1") diff --git a/tests/resources/lsp_transpiler/no_transpile.sql b/tests/resources/lsp_transpiler/no_transpile.sql new file mode 100644 index 000000000..fc563464e --- /dev/null +++ b/tests/resources/lsp_transpiler/no_transpile.sql @@ -0,0 +1 @@ +create table stuff(name varchar(12)) diff --git a/tests/resources/lsp_transpiler/unsupported_lca.sql b/tests/resources/lsp_transpiler/unsupported_lca.sql new file mode 100644 index 000000000..fc563464e --- /dev/null +++ b/tests/resources/lsp_transpiler/unsupported_lca.sql @@ -0,0 +1 @@ +create table stuff(name varchar(12)) diff --git a/tests/unit/test_cli_transpile.py b/tests/unit/test_cli_transpile.py index cacf84d62..ec4d0e5cf 100644 --- a/tests/unit/test_cli_transpile.py +++ b/tests/unit/test_cli_transpile.py @@ -1,4 +1,5 @@ -from unittest.mock import create_autospec, patch, PropertyMock, ANY +import asyncio +from unittest.mock import create_autospec, patch, PropertyMock, ANY, MagicMock import pytest @@ -8,6 +9,7 @@ from databricks.sdk import WorkspaceClient from databricks.labs.remorph.transpiler.transpile_engine import TranspileEngine +from tests.unit.conftest import path_to_resource def test_transpile_with_missing_installation(): @@ -31,11 +33,22 @@ def test_transpile_with_missing_installation(): ) +def patch_do_transpile(): + mock_transpile = MagicMock(return_value=({}, [])) + + @asyncio.coroutine + def patched_do_transpile(*args, **kwargs): + return mock_transpile(*args, **kwargs) + + return mock_transpile, patched_do_transpile + + def test_transpile_with_no_sdk_config(): workspace_client = create_autospec(WorkspaceClient) + mock_transpile, patched_do_transpile = patch_do_transpile() with ( patch("databricks.labs.remorph.cli.ApplicationContext", autospec=True) as mock_app_context, - patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile), patch("os.path.exists", return_value=True), ): default_config = TranspileConfig( @@ -81,10 +94,11 @@ def test_transpile_with_no_sdk_config(): def test_transpile_with_warehouse_id_in_sdk_config(): workspace_client = create_autospec(WorkspaceClient) + mock_transpile, patched_do_transpile = patch_do_transpile() with ( patch("databricks.labs.remorph.cli.ApplicationContext", autospec=True) as mock_app_context, patch("os.path.exists", return_value=True), - patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile), ): sdk_config = {"warehouse_id": "w_id"} default_config = TranspileConfig( @@ -130,10 +144,11 @@ def test_transpile_with_warehouse_id_in_sdk_config(): def test_transpile_with_cluster_id_in_sdk_config(): workspace_client = create_autospec(WorkspaceClient) + mock_transpile, patched_do_transpile = patch_do_transpile() with ( patch("databricks.labs.remorph.cli.ApplicationContext", autospec=True) as mock_app_context, patch("os.path.exists", return_value=True), - patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile), ): sdk_config = {"cluster_id": "c_id"} default_config = TranspileConfig( @@ -276,9 +291,10 @@ def test_transpile_with_valid_input(mock_workspace_client_cli): mode = "current" sdk_config = {'cluster_id': 'test_cluster'} + mock_transpile, patched_do_transpile = patch_do_transpile() with ( patch("os.path.exists", return_value=True), - patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile), ): cli.transpile( mock_workspace_client_cli, @@ -308,6 +324,47 @@ def test_transpile_with_valid_input(mock_workspace_client_cli): ) +def test_transpile_with_valid_transpiler(mock_workspace_client_cli): + transpiler_config_path = path_to_resource("lsp_transpiler", "lsp_config.yml") + source_dialect = "snowflake" + input_source = path_to_resource("functional", "snowflake", "aggregates", "least_1.sql") + output_folder = path_to_resource("lsp_transpiler") + skip_validation = "true" + catalog_name = "my_catalog" + schema_name = "my_schema" + mode = "current" + sdk_config = {'cluster_id': 'test_cluster'} + + mock_transpile, patched_do_transpile = patch_do_transpile() + with (patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile),): + cli.transpile( + mock_workspace_client_cli, + transpiler_config_path, + source_dialect, + input_source, + output_folder, + skip_validation, + catalog_name, + schema_name, + mode, + ) + mock_transpile.assert_called_once_with( + mock_workspace_client_cli, + ANY, + TranspileConfig( + transpiler_config_path=transpiler_config_path, + source_dialect=source_dialect, + input_source=input_source, + output_folder=output_folder, + sdk_config=sdk_config, + skip_validation=True, + catalog_name=catalog_name, + schema_name=schema_name, + mode=mode, + ), + ) + + def test_transpile_empty_output_folder(mock_workspace_client_cli): transpiler = "sqlglot" source_dialect = "snowflake" @@ -320,9 +377,10 @@ def test_transpile_empty_output_folder(mock_workspace_client_cli): mode = "current" sdk_config = {'cluster_id': 'test_cluster'} + mock_transpile, patched_do_transpile = patch_do_transpile() with ( patch("os.path.exists", return_value=True), - patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile), ): cli.transpile( mock_workspace_client_cli, @@ -377,3 +435,28 @@ def test_transpile_with_invalid_mode(mock_workspace_client_cli): schema_name, mode, ) + + +def test_transpile_prints_errors(capsys, tmp_path, mock_workspace_client_cli): + transpiler_config_path = path_to_resource("lsp_transpiler", "lsp_config.yml") + source_dialect = "snowflake" + input_source = path_to_resource("lsp_transpiler", "unsupported_lca.sql") + output_folder = str(tmp_path) + skip_validation = "true" + catalog_name = "my_catalog" + schema_name = "my_schema" + mode = "current" + cli.transpile( + mock_workspace_client_cli, + transpiler_config_path, + source_dialect, + input_source, + output_folder, + skip_validation, + catalog_name, + schema_name, + mode, + ) + captured = capsys.readouterr() + assert "TranspileError" in captured.out + assert "UNSUPPORTED_LCA" in captured.out diff --git a/tests/unit/transpiler/test_execute.py b/tests/unit/transpiler/test_execute.py index 93e7f71c9..3cdc46cc1 100644 --- a/tests/unit/transpiler/test_execute.py +++ b/tests/unit/transpiler/test_execute.py @@ -1,3 +1,4 @@ +import asyncio import re import shutil from pathlib import Path @@ -8,11 +9,13 @@ from databricks.connect import DatabricksSession from databricks.labs.lsql.backends import MockBackend from databricks.labs.lsql.core import Row +from databricks.sdk import WorkspaceClient + from databricks.labs.remorph.config import TranspileConfig, ValidationResult from databricks.labs.remorph.helpers.file_utils import make_dir from databricks.labs.remorph.helpers.validation import Validator from databricks.labs.remorph.transpiler.execute import ( - transpile, + transpile as do_transpile, transpile_column_exp, transpile_sql, ) @@ -24,6 +27,10 @@ # pylint: disable=unspecified-encoding +def transpile(workspace_client: WorkspaceClient, engine: SqlglotEngine, config: TranspileConfig): + return asyncio.run(do_transpile(workspace_client, engine, config)) + + def safe_remove_dir(dir_path: Path): if dir_path.exists(): shutil.rmtree(dir_path) @@ -163,7 +170,7 @@ def test_with_dir_skip_validation(initial_setup, mock_workspace_client): # call transpile with patch('databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend()): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by morph function is None" assert isinstance(status, list), "Status returned by morph function is not a list" @@ -222,7 +229,7 @@ def test_with_dir_with_output_folder_skip_validation(initial_setup, mock_workspa skip_validation=True, ) with patch('databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend()): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by morph function is None" assert isinstance(status, list), "Status returned by morph function is not a list" @@ -295,7 +302,7 @@ def test_with_file(initial_setup, mock_workspace_client): ), patch("databricks.labs.remorph.transpiler.execute.Validator", return_value=mock_validate), ): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by transpile function is None" @@ -344,7 +351,7 @@ def test_with_file_with_output_folder_skip_validation(initial_setup, mock_worksp 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend(), ): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by morph function is None" @@ -379,7 +386,7 @@ def test_with_not_a_sql_file_skip_validation(initial_setup, mock_workspace_clien 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend(), ): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by transpile function is None" @@ -497,7 +504,7 @@ def test_with_file_with_success(initial_setup, mock_workspace_client): ), patch("databricks.labs.remorph.transpiler.execute.Validator", return_value=mock_validate), ): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by morph function is None" assert isinstance(status, list), "Status returned by morph function is not a list" @@ -540,7 +547,7 @@ def test_parse_error_handling(initial_setup, mock_workspace_client): ) with patch('databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend()): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by morph function is None" @@ -597,7 +604,7 @@ def test_token_error_handling(initial_setup, mock_workspace_client): ) with patch('databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend()): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by morph function is None" assert isinstance(status, list), "Status returned by morph function is not a list" diff --git a/tests/unit/transpiler/test_lsp_engine.py b/tests/unit/transpiler/test_lsp_engine.py index a30185546..b7336e4a5 100644 --- a/tests/unit/transpiler/test_lsp_engine.py +++ b/tests/unit/transpiler/test_lsp_engine.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses from pathlib import Path from time import sleep @@ -10,6 +11,7 @@ LSPEngine, ChangeManager, ) +from databricks.labs.remorph.transpiler.transpile_status import TranspileError, ErrorSeverity, ErrorKind from tests.unit.conftest import path_to_resource @@ -134,3 +136,56 @@ async def test_server_transpiles_document(lsp_engine, transpile_config): def test_change_mgr_replaces_text(source, changes, result): transformed = ChangeManager.apply(source, changes) assert transformed == result + + +@pytest.mark.parametrize( + "resource, errors", + [ + ("source_stuff.sql", []), + ( + "no_transpile.sql", + [ + TranspileError( + "NOT_REQUIRED", + ErrorKind.GENERATION, + ErrorSeverity.INFO, + Path("no_transpile.sql"), + "No transpilation required", + ) + ], + ), + ( + "unsupported_lca.sql", + [ + TranspileError( + "UNSUPPORTED_LCA", + ErrorKind.ANALYSIS, + ErrorSeverity.ERROR, + Path("unsupported_lca.sql"), + "LCA conversion not supported", + ) + ], + ), + ( + "internal.sql", + [ + TranspileError( + "SOME_ERROR_CODE", + ErrorKind.INTERNAL, + ErrorSeverity.WARNING, + Path("internal.sql"), + "Something went wrong", + ) + ], + ), + ], +) +async def test_client_translates_diagnostics(lsp_engine, transpile_config, resource, errors): + sample_path = Path(path_to_resource("lsp_transpiler", resource)) + await lsp_engine.initialize(transpile_config) + result = await lsp_engine.transpile( + transpile_config.source_dialect, "databricks", sample_path.read_text(encoding="utf-8"), sample_path + ) + await lsp_engine.shutdown() + actual = [dataclasses.replace(error, path=Path(error.path.name), range=None) for error in result.error_list] + assert actual == errors