Skip to content

Commit

Permalink
Slash in git execute fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-astus committed Sep 17, 2024
1 parent e4a7a13 commit f94e874
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 11 deletions.
40 changes: 34 additions & 6 deletions src/snowflake/cli/_plugins/git/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,22 @@
from snowflake.cli.api.identifiers import FQN
from snowflake.connector.cursor import SnowflakeCursor

OMIT_FIRST = slice(1, None)
OMIT_STAGE = slice(3, None)
ONLY_STAGE = slice(3)


class GitStagePathParts(StagePathParts):
def __init__(self, stage_path: str):
self.stage = GitManager.get_stage_from_path(stage_path)
stage_path_parts = Path(stage_path).parts
stage_path_parts = GitManager.split_git_path(stage_path)
git_repo_name = stage_path_parts[0].split(".")[-1]
if git_repo_name.startswith("@"):
git_repo_name = git_repo_name[1:]
git_repo_name = git_repo_name[OMIT_FIRST]
self.stage_name = "/".join([git_repo_name, *stage_path_parts[1:3], ""])
self.directory = "/".join(stage_path_parts[3:])
self.directory = "/".join(stage_path_parts[OMIT_STAGE])
if self.directory == "/":
self.directory = ""
self.is_directory = True if stage_path.endswith("/") else False

@property
Expand All @@ -45,15 +51,15 @@ def path(self) -> str:

@classmethod
def get_directory(cls, stage_path: str) -> str:
return "/".join(Path(stage_path).parts[3:])
return "/".join(GitManager.split_git_path(stage_path)[OMIT_STAGE])

@property
def full_path(self) -> str:
return f"{self.stage.rstrip('/')}/{self.directory}"

def replace_stage_prefix(self, file_path: str) -> str:
stage = Path(self.stage).parts[0]
file_path_without_prefix = Path(file_path).parts[1:]
file_path_without_prefix = Path(file_path).parts[OMIT_FIRST]
return f"{stage}/{'/'.join(file_path_without_prefix)}"

def add_stage_prefix(self, file_path: str) -> str:
Expand Down Expand Up @@ -95,11 +101,33 @@ def get_stage_from_path(path: str):
Returns stage name from potential path on stage. For example
repo/branches/main/foo/bar -> repo/branches/main/
"""
return f"{'/'.join(Path(path).parts[0:3])}/"
path_parts = GitManager.split_git_path(path)
return f"{'/'.join(path_parts[ONLY_STAGE])}/"

@staticmethod
def _stage_path_part_factory(stage_path: str) -> StagePathParts:
stage_path = StageManager.get_standard_stage_prefix(stage_path)
if stage_path.startswith(USER_STAGE_PREFIX):
return UserStagePathParts(stage_path)
return GitStagePathParts(stage_path)

@staticmethod
def split_git_path(path: str):
# Check if path contains quotes and split it accordingly
if '/"' in path and '"/' in path:
if path.count('"') > 2:
raise ValueError('Too much " in path, expected 0 or 2.')

path_parts = path.split('"')

# Check if quoted part is third part of path
if len(Path(path_parts[0]).parts) != 2:
raise ValueError("Invalid path.")

return [
*Path(path_parts[0]).parts,
f'"{path_parts[1]}"',
*Path(path_parts[2]).parts,
]
else:
return Path(path).parts
12 changes: 8 additions & 4 deletions src/snowflake/cli/_plugins/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
".py",
) # tuple to preserve order but it's a set

OMIT_FIRST = slice(1, None)


@dataclass
class StagePathParts:
Expand All @@ -67,7 +69,7 @@ class StagePathParts:

@classmethod
def get_directory(cls, stage_path: str) -> str:
return "/".join(Path(stage_path).parts[1:])
return "/".join(Path(stage_path).parts[OMIT_FIRST])

@property
def path(self) -> str:
Expand Down Expand Up @@ -119,7 +121,9 @@ def __init__(self, stage_path: str):
self.directory = self.get_directory(stage_path)
self.stage = StageManager.get_stage_from_path(stage_path)
stage_name = self.stage.split(".")[-1]
stage_name = stage_name[1:] if stage_name.startswith("@") else stage_name
stage_name = (
stage_name[OMIT_FIRST] if stage_name.startswith("@") else stage_name
)
self.stage_name = stage_name
self.is_directory = True if stage_path.endswith("/") else False

Expand All @@ -133,7 +137,7 @@ def full_path(self) -> str:

def replace_stage_prefix(self, file_path: str) -> str:
stage = Path(self.stage).parts[0]
file_path_without_prefix = Path(file_path).parts[1:]
file_path_without_prefix = Path(file_path).parts[OMIT_FIRST]
return f"{stage}/{'/'.join(file_path_without_prefix)}"

def add_stage_prefix(self, file_path: str) -> str:
Expand Down Expand Up @@ -461,7 +465,7 @@ def _call_execute_immediate(
on_error: OnErrorType,
) -> Dict:
try:
query = f"execute immediate from {file_stage_path}"
query = f"execute immediate from {self.quote_stage_name(file_stage_path)}"
if variables:
query += variables
self._execute_query(query)
Expand Down
48 changes: 48 additions & 0 deletions tests/git/__snapshots__/test_git_commands.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,51 @@

'''
# ---
# name: test_execute_slash_in_repository_name[@db.schema.repo/branches/"feature/commit"/-@db.schema.repo/branches/"feature/commit"/-expected_files3]
'''
SUCCESS - @db.schema.repo/branches/"feature/commit"/s1.sql
SUCCESS - @db.schema.repo/branches/"feature/commit"/a/S3.sql
+----------------------------------------------------------------------+
| File | Status | Error |
|----------------------------------------------------+---------+-------|
| @db.schema.repo/branches/"feature/commit"/s1.sql | SUCCESS | None |
| @db.schema.repo/branches/"feature/commit"/a/S3.sql | SUCCESS | None |
+----------------------------------------------------------------------+

'''
# ---
# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/-@repo/branches/"feature/commit"/-expected_files0]
'''
SUCCESS - @repo/branches/"feature/commit"/s1.sql
SUCCESS - @repo/branches/"feature/commit"/a/S3.sql
+------------------------------------------------------------+
| File | Status | Error |
|------------------------------------------+---------+-------|
| @repo/branches/"feature/commit"/s1.sql | SUCCESS | None |
| @repo/branches/"feature/commit"/a/S3.sql | SUCCESS | None |
+------------------------------------------------------------+

'''
# ---
# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/a/-@repo/branches/"feature/commit"/-expected_files2]
'''
SUCCESS - @repo/branches/"feature/commit"/a/S3.sql
+------------------------------------------------------------+
| File | Status | Error |
|------------------------------------------+---------+-------|
| @repo/branches/"feature/commit"/a/S3.sql | SUCCESS | None |
+------------------------------------------------------------+

'''
# ---
# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/s1.sql-@repo/branches/"feature/commit"/-expected_files1]
'''
SUCCESS - @repo/branches/"feature/commit"/s1.sql
+----------------------------------------------------------+
| File | Status | Error |
|----------------------------------------+---------+-------|
| @repo/branches/"feature/commit"/s1.sql | SUCCESS | None |
+----------------------------------------------------------+

'''
# ---
86 changes: 86 additions & 0 deletions tests/git/test_git_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,71 @@ def test_execute_new_git_repository_list_files(
assert result.output == os_agnostic_snapshot


@pytest.mark.parametrize(
"repository_path, expected_stage, expected_files",
[
(
'@repo/branches/"feature/commit"/',
'@repo/branches/"feature/commit"/',
[
'@repo/branches/"feature/commit"/s1.sql',
'@repo/branches/"feature/commit"/a/S3.sql',
],
),
(
'@repo/branches/"feature/commit"/s1.sql',
'@repo/branches/"feature/commit"/',
[
'@repo/branches/"feature/commit"/s1.sql',
],
),
(
'@repo/branches/"feature/commit"/a/',
'@repo/branches/"feature/commit"/',
[
'@repo/branches/"feature/commit"/a/S3.sql',
],
),
(
'@db.schema.repo/branches/"feature/commit"/',
'@db.schema.repo/branches/"feature/commit"/',
[
'@db.schema.repo/branches/"feature/commit"/s1.sql',
'@db.schema.repo/branches/"feature/commit"/a/S3.sql',
],
),
],
)
@mock.patch(f"{STAGE_MANAGER}._execute_query")
def test_execute_slash_in_repository_name(
mock_execute,
mock_cursor,
runner,
repository_path,
expected_stage,
expected_files,
os_agnostic_snapshot,
):
mock_execute.return_value = mock_cursor(
[
{"name": '/branches/"feature/commit"/a/S3.sql'},
{"name": '/branches/"feature/commit"/s1.sql'},
{"name": '/branches/"feature/commit"/s2'},
],
[],
)

result = runner.invoke(["git", "execute", repository_path])

assert result.exit_code == 0, result.output
ls_call, *execute_calls = mock_execute.mock_calls
assert ls_call == mock.call(f"ls '{expected_stage}'", cursor_class=DictCursor)
assert execute_calls == [
mock.call(f"execute immediate from '{p}'") for p in expected_files
]
assert result.output == os_agnostic_snapshot


@mock.patch(f"{STAGE_MANAGER}._execute_query")
def test_execute_with_variables(mock_execute, mock_cursor, runner):
mock_execute.return_value = mock_cursor([{"name": "repo/branches/main/s1.sql"}], [])
Expand Down Expand Up @@ -695,6 +760,27 @@ def test_execute_with_variables(mock_execute, mock_cursor, runner):
]


@mock.patch(f"{STAGE_MANAGER}._execute_query")
def test_execute_file_with_space_in_name(mock_execute, mock_cursor, runner):
mock_execute.return_value = mock_cursor(
[{"name": "repo/branches/main/Script 1.sql"}], []
)

result = runner.invoke(
[
"git",
"execute",
"@repo/branches/main/",
]
)

assert result.exit_code == 0
assert mock_execute.mock_calls == [
mock.call("ls @repo/branches/main/", cursor_class=DictCursor),
mock.call(f"execute immediate from '@repo/branches/main/Script 1.sql'"),
]


@mock.patch("snowflake.connector.connect")
@pytest.mark.parametrize(
"command, parameters",
Expand Down
2 changes: 1 addition & 1 deletion tests/stage/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def test_execute_from_user_stage(
ls_call, *execute_calls = mock_execute.mock_calls
assert ls_call == mock.call(f"ls '@~'", cursor_class=DictCursor)
assert execute_calls == [
mock.call(f"execute immediate from {p}") for p in expected_files
mock.call(f"execute immediate from '{p}'") for p in expected_files
]
assert result.output == snapshot

Expand Down

0 comments on commit f94e874

Please sign in to comment.