From f94e874c4d5e09eb4d910abec11a4ce3ecc72384 Mon Sep 17 00:00:00 2001 From: Adam Stus Date: Mon, 16 Sep 2024 12:28:05 +0200 Subject: [PATCH] Slash in git execute fix --- src/snowflake/cli/_plugins/git/manager.py | 40 +++++++-- src/snowflake/cli/_plugins/stage/manager.py | 12 ++- .../git/__snapshots__/test_git_commands.ambr | 48 +++++++++++ tests/git/test_git_commands.py | 86 +++++++++++++++++++ tests/stage/test_stage.py | 2 +- 5 files changed, 177 insertions(+), 11 deletions(-) diff --git a/src/snowflake/cli/_plugins/git/manager.py b/src/snowflake/cli/_plugins/git/manager.py index a411aa7193..d03a8d5eb5 100644 --- a/src/snowflake/cli/_plugins/git/manager.py +++ b/src/snowflake/cli/_plugins/git/manager.py @@ -27,16 +27,22 @@ from snowflake.cli.api.identifiers import FQN from snowflake.connector.cursor import SnowflakeCursor +OMIT_FIRST = slice(1, None) +OMIT_STAGE = slice(3, None) +ONLY_STAGE = slice(3) + class GitStagePathParts(StagePathParts): def __init__(self, stage_path: str): self.stage = GitManager.get_stage_from_path(stage_path) - stage_path_parts = Path(stage_path).parts + stage_path_parts = GitManager.split_git_path(stage_path) git_repo_name = stage_path_parts[0].split(".")[-1] if git_repo_name.startswith("@"): - git_repo_name = git_repo_name[1:] + git_repo_name = git_repo_name[OMIT_FIRST] self.stage_name = "/".join([git_repo_name, *stage_path_parts[1:3], ""]) - self.directory = "/".join(stage_path_parts[3:]) + self.directory = "/".join(stage_path_parts[OMIT_STAGE]) + if self.directory == "/": + self.directory = "" self.is_directory = True if stage_path.endswith("/") else False @property @@ -45,7 +51,7 @@ def path(self) -> str: @classmethod def get_directory(cls, stage_path: str) -> str: - return "/".join(Path(stage_path).parts[3:]) + return "/".join(GitManager.split_git_path(stage_path)[OMIT_STAGE]) @property def full_path(self) -> str: @@ -53,7 +59,7 @@ def full_path(self) -> str: def replace_stage_prefix(self, file_path: str) -> str: stage = Path(self.stage).parts[0] - file_path_without_prefix = Path(file_path).parts[1:] + file_path_without_prefix = Path(file_path).parts[OMIT_FIRST] return f"{stage}/{'/'.join(file_path_without_prefix)}" def add_stage_prefix(self, file_path: str) -> str: @@ -95,7 +101,8 @@ def get_stage_from_path(path: str): Returns stage name from potential path on stage. For example repo/branches/main/foo/bar -> repo/branches/main/ """ - return f"{'/'.join(Path(path).parts[0:3])}/" + path_parts = GitManager.split_git_path(path) + return f"{'/'.join(path_parts[ONLY_STAGE])}/" @staticmethod def _stage_path_part_factory(stage_path: str) -> StagePathParts: @@ -103,3 +110,24 @@ def _stage_path_part_factory(stage_path: str) -> StagePathParts: if stage_path.startswith(USER_STAGE_PREFIX): return UserStagePathParts(stage_path) return GitStagePathParts(stage_path) + + @staticmethod + def split_git_path(path: str): + # Check if path contains quotes and split it accordingly + if '/"' in path and '"/' in path: + if path.count('"') > 2: + raise ValueError('Too much " in path, expected 0 or 2.') + + path_parts = path.split('"') + + # Check if quoted part is third part of path + if len(Path(path_parts[0]).parts) != 2: + raise ValueError("Invalid path.") + + return [ + *Path(path_parts[0]).parts, + f'"{path_parts[1]}"', + *Path(path_parts[2]).parts, + ] + else: + return Path(path).parts diff --git a/src/snowflake/cli/_plugins/stage/manager.py b/src/snowflake/cli/_plugins/stage/manager.py index 90621e7011..05636b1617 100644 --- a/src/snowflake/cli/_plugins/stage/manager.py +++ b/src/snowflake/cli/_plugins/stage/manager.py @@ -57,6 +57,8 @@ ".py", ) # tuple to preserve order but it's a set +OMIT_FIRST = slice(1, None) + @dataclass class StagePathParts: @@ -67,7 +69,7 @@ class StagePathParts: @classmethod def get_directory(cls, stage_path: str) -> str: - return "/".join(Path(stage_path).parts[1:]) + return "/".join(Path(stage_path).parts[OMIT_FIRST]) @property def path(self) -> str: @@ -119,7 +121,9 @@ def __init__(self, stage_path: str): self.directory = self.get_directory(stage_path) self.stage = StageManager.get_stage_from_path(stage_path) stage_name = self.stage.split(".")[-1] - stage_name = stage_name[1:] if stage_name.startswith("@") else stage_name + stage_name = ( + stage_name[OMIT_FIRST] if stage_name.startswith("@") else stage_name + ) self.stage_name = stage_name self.is_directory = True if stage_path.endswith("/") else False @@ -133,7 +137,7 @@ def full_path(self) -> str: def replace_stage_prefix(self, file_path: str) -> str: stage = Path(self.stage).parts[0] - file_path_without_prefix = Path(file_path).parts[1:] + file_path_without_prefix = Path(file_path).parts[OMIT_FIRST] return f"{stage}/{'/'.join(file_path_without_prefix)}" def add_stage_prefix(self, file_path: str) -> str: @@ -461,7 +465,7 @@ def _call_execute_immediate( on_error: OnErrorType, ) -> Dict: try: - query = f"execute immediate from {file_stage_path}" + query = f"execute immediate from {self.quote_stage_name(file_stage_path)}" if variables: query += variables self._execute_query(query) diff --git a/tests/git/__snapshots__/test_git_commands.ambr b/tests/git/__snapshots__/test_git_commands.ambr index ab159786d5..724ab9245d 100644 --- a/tests/git/__snapshots__/test_git_commands.ambr +++ b/tests/git/__snapshots__/test_git_commands.ambr @@ -117,3 +117,51 @@ ''' # --- +# name: test_execute_slash_in_repository_name[@db.schema.repo/branches/"feature/commit"/-@db.schema.repo/branches/"feature/commit"/-expected_files3] + ''' + SUCCESS - @db.schema.repo/branches/"feature/commit"/s1.sql + SUCCESS - @db.schema.repo/branches/"feature/commit"/a/S3.sql + +----------------------------------------------------------------------+ + | File | Status | Error | + |----------------------------------------------------+---------+-------| + | @db.schema.repo/branches/"feature/commit"/s1.sql | SUCCESS | None | + | @db.schema.repo/branches/"feature/commit"/a/S3.sql | SUCCESS | None | + +----------------------------------------------------------------------+ + + ''' +# --- +# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/-@repo/branches/"feature/commit"/-expected_files0] + ''' + SUCCESS - @repo/branches/"feature/commit"/s1.sql + SUCCESS - @repo/branches/"feature/commit"/a/S3.sql + +------------------------------------------------------------+ + | File | Status | Error | + |------------------------------------------+---------+-------| + | @repo/branches/"feature/commit"/s1.sql | SUCCESS | None | + | @repo/branches/"feature/commit"/a/S3.sql | SUCCESS | None | + +------------------------------------------------------------+ + + ''' +# --- +# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/a/-@repo/branches/"feature/commit"/-expected_files2] + ''' + SUCCESS - @repo/branches/"feature/commit"/a/S3.sql + +------------------------------------------------------------+ + | File | Status | Error | + |------------------------------------------+---------+-------| + | @repo/branches/"feature/commit"/a/S3.sql | SUCCESS | None | + +------------------------------------------------------------+ + + ''' +# --- +# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/s1.sql-@repo/branches/"feature/commit"/-expected_files1] + ''' + SUCCESS - @repo/branches/"feature/commit"/s1.sql + +----------------------------------------------------------+ + | File | Status | Error | + |----------------------------------------+---------+-------| + | @repo/branches/"feature/commit"/s1.sql | SUCCESS | None | + +----------------------------------------------------------+ + + ''' +# --- diff --git a/tests/git/test_git_commands.py b/tests/git/test_git_commands.py index 1ce01516bb..5ffdff4422 100644 --- a/tests/git/test_git_commands.py +++ b/tests/git/test_git_commands.py @@ -664,6 +664,71 @@ def test_execute_new_git_repository_list_files( assert result.output == os_agnostic_snapshot +@pytest.mark.parametrize( + "repository_path, expected_stage, expected_files", + [ + ( + '@repo/branches/"feature/commit"/', + '@repo/branches/"feature/commit"/', + [ + '@repo/branches/"feature/commit"/s1.sql', + '@repo/branches/"feature/commit"/a/S3.sql', + ], + ), + ( + '@repo/branches/"feature/commit"/s1.sql', + '@repo/branches/"feature/commit"/', + [ + '@repo/branches/"feature/commit"/s1.sql', + ], + ), + ( + '@repo/branches/"feature/commit"/a/', + '@repo/branches/"feature/commit"/', + [ + '@repo/branches/"feature/commit"/a/S3.sql', + ], + ), + ( + '@db.schema.repo/branches/"feature/commit"/', + '@db.schema.repo/branches/"feature/commit"/', + [ + '@db.schema.repo/branches/"feature/commit"/s1.sql', + '@db.schema.repo/branches/"feature/commit"/a/S3.sql', + ], + ), + ], +) +@mock.patch(f"{STAGE_MANAGER}._execute_query") +def test_execute_slash_in_repository_name( + mock_execute, + mock_cursor, + runner, + repository_path, + expected_stage, + expected_files, + os_agnostic_snapshot, +): + mock_execute.return_value = mock_cursor( + [ + {"name": '/branches/"feature/commit"/a/S3.sql'}, + {"name": '/branches/"feature/commit"/s1.sql'}, + {"name": '/branches/"feature/commit"/s2'}, + ], + [], + ) + + result = runner.invoke(["git", "execute", repository_path]) + + assert result.exit_code == 0, result.output + ls_call, *execute_calls = mock_execute.mock_calls + assert ls_call == mock.call(f"ls '{expected_stage}'", cursor_class=DictCursor) + assert execute_calls == [ + mock.call(f"execute immediate from '{p}'") for p in expected_files + ] + assert result.output == os_agnostic_snapshot + + @mock.patch(f"{STAGE_MANAGER}._execute_query") def test_execute_with_variables(mock_execute, mock_cursor, runner): mock_execute.return_value = mock_cursor([{"name": "repo/branches/main/s1.sql"}], []) @@ -695,6 +760,27 @@ def test_execute_with_variables(mock_execute, mock_cursor, runner): ] +@mock.patch(f"{STAGE_MANAGER}._execute_query") +def test_execute_file_with_space_in_name(mock_execute, mock_cursor, runner): + mock_execute.return_value = mock_cursor( + [{"name": "repo/branches/main/Script 1.sql"}], [] + ) + + result = runner.invoke( + [ + "git", + "execute", + "@repo/branches/main/", + ] + ) + + assert result.exit_code == 0 + assert mock_execute.mock_calls == [ + mock.call("ls @repo/branches/main/", cursor_class=DictCursor), + mock.call(f"execute immediate from '@repo/branches/main/Script 1.sql'"), + ] + + @mock.patch("snowflake.connector.connect") @pytest.mark.parametrize( "command, parameters", diff --git a/tests/stage/test_stage.py b/tests/stage/test_stage.py index 2cb452a7c7..1e50fc4449 100644 --- a/tests/stage/test_stage.py +++ b/tests/stage/test_stage.py @@ -859,7 +859,7 @@ def test_execute_from_user_stage( ls_call, *execute_calls = mock_execute.mock_calls assert ls_call == mock.call(f"ls '@~'", cursor_class=DictCursor) assert execute_calls == [ - mock.call(f"execute immediate from {p}") for p in expected_files + mock.call(f"execute immediate from '{p}'") for p in expected_files ] assert result.output == snapshot