diff --git a/src/snowflake/cli/plugins/object/commands.py b/src/snowflake/cli/plugins/object/commands.py index dcf0522f98..286d6ef382 100644 --- a/src/snowflake/cli/plugins/object/commands.py +++ b/src/snowflake/cli/plugins/object/commands.py @@ -37,20 +37,20 @@ ) -def _scope_callback(scope: Tuple[str, str]): +def _scope_validate(object_type: str, scope: Tuple[str, str]): if scope[1] is not None and not is_valid_identifier(scope[1]): raise ClickException("scope name must be a valid identifier") if scope[0] is not None and scope[0].lower() not in VALID_SCOPES: raise ClickException( - f'scope must be one of the following: {", ".join(VALID_SCOPES)}' + f"scope must be one of the following: {', '.join(VALID_SCOPES)}" ) - return scope + if scope[0] == "compute-pool" and object_type != "service": + raise ClickException("compute-pool scope is only supported for listing service") ScopeOption = typer.Option( (None, None), "--in", - callback=_scope_callback, help="Specifies the scope of this command using '--in ' (e.g. list tables --in database my_db). Some object types have specialized scopes (e.g. list service --in compute-pool my_pool)", ) @@ -69,6 +69,7 @@ def list_( scope: Tuple[str, str] = ScopeOption, **options, ): + _scope_validate(object_type, scope) return QueryResult( ObjectManager().show(object_type=object_type, like=like, scope=scope) ) diff --git a/tests/object/test_object.py b/tests/object/test_object.py index c1b1df95f0..2078d818e5 100644 --- a/tests/object/test_object.py +++ b/tests/object/test_object.py @@ -2,7 +2,7 @@ import pytest from snowflake.cli.api.constants import SUPPORTED_OBJECTS, OBJECT_TO_NAMES -from snowflake.cli.plugins.object.commands import _scope_callback +from snowflake.cli.plugins.object.commands import _scope_validate from click import ClickException @@ -119,29 +119,34 @@ def test_show_with_invalid_scope( @pytest.mark.parametrize( - "input_scope, input_name", + "object_type, input_scope, input_name", [ - (None, None), - ("database", "test_db"), - ("schema", "test_schema"), - ("compute-pool", "test_pool"), + ("user", None, None), + ("schema", "database", "test_db"), + ("table", "schema", "test_schema"), + ("service", "compute-pool", "test_pool"), ], ) -def test_scope_callback(input_scope, input_name): - assert (input_scope, input_name) == _scope_callback((input_scope, input_name)) +def test_scope_validate(object_type, input_scope, input_name): + _scope_validate(object_type, (input_scope, input_name)) @pytest.mark.parametrize( - "input_scope, input_name", + "object_type, input_scope, input_name", [ - ("database", "invalid identifier"), - ("invalid_scope", "identifier"), - ("invalid_scope", "invalid identifier"), + ("table", "database", "invalid identifier"), + ("table", "invalid-scope", "identifier"), + ("table", "invalid-scope", "invalid identifier"), + ( + "table", + "compute-pool", + "test_pool", + ), # 'compute-pool' scope can only be used with 'service' ], ) -def test_invalid_scope_callback(input_scope, input_name): +def test_invalid_scope_validate(object_type, input_scope, input_name): with pytest.raises(ClickException): - _scope_callback((input_scope, input_name)) + _scope_validate(object_type, (input_scope, input_name)) @mock.patch("snowflake.connector") diff --git a/tests_integration/test_object.py b/tests_integration/test_object.py index fbd7b4e436..aa17231db6 100644 --- a/tests_integration/test_object.py +++ b/tests_integration/test_object.py @@ -59,12 +59,11 @@ def test_object_table(runner, test_database, snowflake_session): def test_list_with_scope(runner, test_database, snowflake_session): # create a table in a schema other than schema of the current connection other_schema = "other_schema" - public_table = ObjectNameProvider( - "Test_Object_Table" - ).create_and_get_next_object_name() - other_table = object_name = ObjectNameProvider( - "Test_Object_Table" - ).create_and_get_next_object_name() + + public_table = ObjectNameProvider("Public_Table").create_and_get_next_object_name() + + other_table = ObjectNameProvider("Other_Table").create_and_get_next_object_name() + snowflake_session.execute_string( f"use schema public; create table {public_table} (some_number NUMBER);" ) @@ -72,13 +71,13 @@ def test_list_with_scope(runner, test_database, snowflake_session): f"create schema {other_schema}; create table {other_table} (some_number NUMBER);" ) result_list_public = runner.invoke_with_connection_json( - ["object", "list", "table", "--schema", "public"] + ["object", "list", "table", "--in", "schema", "public"] ) assert result_list_public.exit_code == 0, result_list_public.output assert result_list_public.json[0]["name"].lower() == public_table.lower() result_list_other = runner.invoke_with_connection_json( - ["object", "list", "table", "--schema", "other_schema"] + ["object", "list", "table", "--in", "schema", "other_schema"] ) assert result_list_other.exit_code == 0, result_list_other.output assert result_list_other.json[0]["name"].lower() == other_table.lower()