Skip to content

Commit

Permalink
SNOW-1011774: Adding validation that compute-pool scope is only used …
Browse files Browse the repository at this point in the history
…with services. Test cleanup.
  • Loading branch information
sfc-gh-davwang committed Jan 31, 2024
1 parent 87c88bc commit bab8cf8
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 26 deletions.
9 changes: 5 additions & 4 deletions src/snowflake/cli/plugins/object/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,20 @@
)


def _scope_callback(scope: Tuple[str, str]):
def _scope_validate(object_type: str, scope: Tuple[str, str]):
if scope[1] is not None and not is_valid_identifier(scope[1]):
raise ClickException("scope name must be a valid identifier")
if scope[0] is not None and scope[0].lower() not in VALID_SCOPES:
raise ClickException(
f'scope must be one of the following: {", ".join(VALID_SCOPES)}'
f"scope must be one of the following: {', '.join(VALID_SCOPES)}"
)
return scope
if scope[0] == "compute-pool" and object_type != "service":
raise ClickException("compute-pool scope is only supported for listing service")


ScopeOption = typer.Option(
(None, None),
"--in",
callback=_scope_callback,
help="Specifies the scope of this command using '--in <scope> <name>' (e.g. list tables --in database my_db). Some object types have specialized scopes (e.g. list service --in compute-pool my_pool)",
)

Expand All @@ -69,6 +69,7 @@ def list_(
scope: Tuple[str, str] = ScopeOption,
**options,
):
_scope_validate(object_type, scope)
return QueryResult(
ObjectManager().show(object_type=object_type, like=like, scope=scope)
)
Expand Down
33 changes: 19 additions & 14 deletions tests/object/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from snowflake.cli.api.constants import SUPPORTED_OBJECTS, OBJECT_TO_NAMES
from snowflake.cli.plugins.object.commands import _scope_callback
from snowflake.cli.plugins.object.commands import _scope_validate
from click import ClickException


Expand Down Expand Up @@ -119,29 +119,34 @@ def test_show_with_invalid_scope(


@pytest.mark.parametrize(
"input_scope, input_name",
"object_type, input_scope, input_name",
[
(None, None),
("database", "test_db"),
("schema", "test_schema"),
("compute-pool", "test_pool"),
("user", None, None),
("schema", "database", "test_db"),
("table", "schema", "test_schema"),
("service", "compute-pool", "test_pool"),
],
)
def test_scope_callback(input_scope, input_name):
assert (input_scope, input_name) == _scope_callback((input_scope, input_name))
def test_scope_validate(object_type, input_scope, input_name):
_scope_validate(object_type, (input_scope, input_name))


@pytest.mark.parametrize(
"input_scope, input_name",
"object_type, input_scope, input_name",
[
("database", "invalid identifier"),
("invalid_scope", "identifier"),
("invalid_scope", "invalid identifier"),
("table", "database", "invalid identifier"),
("table", "invalid-scope", "identifier"),
("table", "invalid-scope", "invalid identifier"),
(
"table",
"compute-pool",
"test_pool",
), # 'compute-pool' scope can only be used with 'service'
],
)
def test_invalid_scope_callback(input_scope, input_name):
def test_invalid_scope_validate(object_type, input_scope, input_name):
with pytest.raises(ClickException):
_scope_callback((input_scope, input_name))
_scope_validate(object_type, (input_scope, input_name))


@mock.patch("snowflake.connector")
Expand Down
15 changes: 7 additions & 8 deletions tests_integration/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,25 @@ def test_object_table(runner, test_database, snowflake_session):
def test_list_with_scope(runner, test_database, snowflake_session):
# create a table in a schema other than schema of the current connection
other_schema = "other_schema"
public_table = ObjectNameProvider(
"Test_Object_Table"
).create_and_get_next_object_name()
other_table = object_name = ObjectNameProvider(
"Test_Object_Table"
).create_and_get_next_object_name()

public_table = ObjectNameProvider("Public_Table").create_and_get_next_object_name()

other_table = ObjectNameProvider("Other_Table").create_and_get_next_object_name()

snowflake_session.execute_string(
f"use schema public; create table {public_table} (some_number NUMBER);"
)
snowflake_session.execute_string(
f"create schema {other_schema}; create table {other_table} (some_number NUMBER);"
)
result_list_public = runner.invoke_with_connection_json(
["object", "list", "table", "--schema", "public"]
["object", "list", "table", "--in", "schema", "public"]
)
assert result_list_public.exit_code == 0, result_list_public.output
assert result_list_public.json[0]["name"].lower() == public_table.lower()

result_list_other = runner.invoke_with_connection_json(
["object", "list", "table", "--schema", "other_schema"]
["object", "list", "table", "--in", "schema", "other_schema"]
)
assert result_list_other.exit_code == 0, result_list_other.output
assert result_list_other.json[0]["name"].lower() == other_table.lower()

0 comments on commit bab8cf8

Please sign in to comment.