diff --git a/src/snowflake/cli/_app/telemetry.py b/src/snowflake/cli/_app/telemetry.py index 326e453ee8..106ae969c4 100644 --- a/src/snowflake/cli/_app/telemetry.py +++ b/src/snowflake/cli/_app/telemetry.py @@ -54,6 +54,8 @@ class CLITelemetryField(Enum): COMMAND_EXECUTION_TIME = "command_execution_time" # Configuration CONFIG_FEATURE_FLAGS = "config_feature_flags" + # Metrics + COUNTERS = "counters" # Information EVENT = "event" ERROR_MSG = "error_msg" @@ -72,6 +74,16 @@ class TelemetryEvent(Enum): TelemetryDict = Dict[Union[CLITelemetryField, TelemetryField], Any] +def _get_command_metrics() -> TelemetryDict: + cli_context = get_cli_context() + + return { + CLITelemetryField.COUNTERS: { + **cli_context.metrics.counters, + } + } + + def _find_command_info() -> TelemetryDict: ctx = click.get_current_context() command_path = ctx.command_path.split(" ")[1:] @@ -168,6 +180,7 @@ def log_command_result(execution: ExecutionMetadata): CLITelemetryField.COMMAND_EXECUTION_ID: execution.execution_id, CLITelemetryField.COMMAND_RESULT_STATUS: execution.status.value, CLITelemetryField.COMMAND_EXECUTION_TIME: execution.get_duration(), + **_get_command_metrics(), } ) @@ -183,6 +196,7 @@ def log_command_execution_error(exception: Exception, execution: ExecutionMetada CLITelemetryField.ERROR_TYPE: exception_type, CLITelemetryField.IS_CLI_EXCEPTION: is_cli_exception, CLITelemetryField.COMMAND_EXECUTION_TIME: execution.get_duration(), + **_get_command_metrics(), } ) diff --git a/src/snowflake/cli/_plugins/nativeapp/codegen/compiler.py b/src/snowflake/cli/_plugins/nativeapp/codegen/compiler.py index 8b254887a4..7f89588b69 100644 --- a/src/snowflake/cli/_plugins/nativeapp/codegen/compiler.py +++ b/src/snowflake/cli/_plugins/nativeapp/codegen/compiler.py @@ -34,7 +34,9 @@ TemplatesProcessor, ) from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag +from snowflake.cli.api.cli_global_context import get_cli_context from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.metrics import CLICounterField from snowflake.cli.api.project.schemas.v1.native_app.path_mapping import ( ProcessorMapping, ) @@ -72,6 +74,9 @@ def compile_artifacts(self): Go through every artifact object in the project definition of a native app, and execute processors in order of specification for each of the artifact object. May have side-effects on the filesystem by either directly editing source files or the deploy root. """ + metrics = get_cli_context().metrics + metrics.set_counter_default(CLICounterField.TEMPLATES_PROCESSOR, 0) + metrics.set_counter_default(CLICounterField.SNOWPARK_PROCESSOR, 0) if not self._should_invoke_processors(): return diff --git a/src/snowflake/cli/_plugins/nativeapp/codegen/snowpark/python_processor.py b/src/snowflake/cli/_plugins/nativeapp/codegen/snowpark/python_processor.py index f0bdf66a09..98f9ca1eae 100644 --- a/src/snowflake/cli/_plugins/nativeapp/codegen/snowpark/python_processor.py +++ b/src/snowflake/cli/_plugins/nativeapp/codegen/snowpark/python_processor.py @@ -48,7 +48,9 @@ NativeAppExtensionFunction, ) from snowflake.cli._plugins.stage.diff import to_stage_path +from snowflake.cli.api.cli_global_context import get_cli_context from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.metrics import CLICounterField from snowflake.cli.api.project.schemas.v1.native_app.path_mapping import ( PathMapping, ProcessorMapping, @@ -176,6 +178,8 @@ def process( setup script with generated SQL that registers these functions. """ + get_cli_context().metrics.set_counter(CLICounterField.SNOWPARK_PROCESSOR, 1) + bundle_map = BundleMap( project_root=self._bundle_ctx.project_root, deploy_root=self._bundle_ctx.deploy_root, diff --git a/src/snowflake/cli/_plugins/nativeapp/codegen/templates/templates_processor.py b/src/snowflake/cli/_plugins/nativeapp/codegen/templates/templates_processor.py index 779da8717e..4c04443653 100644 --- a/src/snowflake/cli/_plugins/nativeapp/codegen/templates/templates_processor.py +++ b/src/snowflake/cli/_plugins/nativeapp/codegen/templates/templates_processor.py @@ -25,6 +25,7 @@ from snowflake.cli._plugins.nativeapp.exceptions import InvalidTemplateInFileError from snowflake.cli.api.cli_global_context import get_cli_context from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.metrics import CLICounterField from snowflake.cli.api.project.schemas.v1.native_app.path_mapping import ( PathMapping, ProcessorMapping, @@ -98,6 +99,8 @@ def process( Process the artifact by executing the template expansion logic on it. """ + get_cli_context().metrics.set_counter(CLICounterField.TEMPLATES_PROCESSOR, 1) + bundle_map = BundleMap( project_root=self._bundle_ctx.project_root, deploy_root=self._bundle_ctx.deploy_root, diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application.py b/src/snowflake/cli/_plugins/nativeapp/entities/application.py index 5cf3e64468..0010cc6ccf 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application.py @@ -42,6 +42,7 @@ ) from snowflake.cli._plugins.nativeapp.utils import needs_confirmation from snowflake.cli._plugins.workspace.action_context import ActionContext +from snowflake.cli.api.cli_global_context import get_cli_context from snowflake.cli.api.console.abc import AbstractConsole from snowflake.cli.api.entities.common import EntityBase, get_sql_executor from snowflake.cli.api.entities.utils import ( @@ -57,6 +58,7 @@ NOT_SUPPORTED_ON_DEV_MODE_APPLICATIONS, ONLY_SUPPORTED_ON_DEV_MODE_APPLICATIONS, ) +from snowflake.cli.api.metrics import CLICounterField from snowflake.cli.api.project.schemas.entities.common import ( EntityModelBase, Identifier, @@ -81,7 +83,6 @@ APPLICATION_NO_LONGER_AVAILABLE, } - ApplicationOwnedObject = TypedDict("ApplicationOwnedObject", {"name": str, "type": str}) @@ -563,14 +564,13 @@ def create_or_upgrade_app( ) # hooks always executed after a create or upgrade - if post_deploy_hooks: - cls.execute_post_deploy_hooks( - console=console, - project_root=project_root, - post_deploy_hooks=post_deploy_hooks, - app_name=app_name, - app_warehouse=app_warehouse, - ) + cls.execute_post_deploy_hooks( + console=console, + project_root=project_root, + post_deploy_hooks=post_deploy_hooks, + app_name=app_name, + app_warehouse=app_warehouse, + ) return except ProgrammingError as err: @@ -622,14 +622,13 @@ def create_or_upgrade_app( print_messages(console, create_cursor) # hooks always executed after a create or upgrade - if post_deploy_hooks: - cls.execute_post_deploy_hooks( - console=console, - project_root=project_root, - post_deploy_hooks=post_deploy_hooks, - app_name=app_name, - app_warehouse=app_warehouse, - ) + cls.execute_post_deploy_hooks( + console=console, + project_root=project_root, + post_deploy_hooks=post_deploy_hooks, + app_name=app_name, + app_warehouse=app_warehouse, + ) except ProgrammingError as err: generic_sql_error_handler(err) @@ -643,14 +642,19 @@ def execute_post_deploy_hooks( app_name: str, app_warehouse: Optional[str], ): - with cls.use_application_warehouse(app_warehouse): - execute_post_deploy_hooks( - console=console, - project_root=project_root, - post_deploy_hooks=post_deploy_hooks, - deployed_object_type="application", - database_name=app_name, - ) + get_cli_context().metrics.set_counter_default( + CLICounterField.POST_DEPLOY_SCRIPTS, 0 + ) + + if post_deploy_hooks: + with cls.use_application_warehouse(app_warehouse): + execute_post_deploy_hooks( + console=console, + project_root=project_root, + post_deploy_hooks=post_deploy_hooks, + deployed_object_type="application", + database_name=app_name, + ) @staticmethod @contextmanager diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py index c3fbd35777..add6d81947 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py @@ -43,6 +43,7 @@ from snowflake.cli._plugins.stage.diff import DiffResult from snowflake.cli._plugins.stage.manager import StageManager from snowflake.cli._plugins.workspace.action_context import ActionContext +from snowflake.cli.api.cli_global_context import get_cli_context from snowflake.cli.api.console.abc import AbstractConsole from snowflake.cli.api.entities.common import EntityBase, get_sql_executor from snowflake.cli.api.entities.utils import ( @@ -55,6 +56,7 @@ ) from snowflake.cli.api.errno import DOES_NOT_EXIST_OR_NOT_AUTHORIZED from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError +from snowflake.cli.api.metrics import CLICounterField from snowflake.cli.api.project.schemas.entities.common import ( EntityModelBase, Identifier, @@ -401,15 +403,14 @@ def deploy( if not policy.should_proceed("Proceed with using this package?"): raise typer.Abort() from e with get_sql_executor().use_role(package_role): - if package_scripts: - cls.apply_package_scripts( - console=console, - package_scripts=package_scripts, - package_warehouse=package_warehouse, - project_root=project_root, - package_role=package_role, - package_name=package_name, - ) + cls.apply_package_scripts( + console=console, + package_scripts=package_scripts, + package_warehouse=package_warehouse, + project_root=project_root, + package_role=package_role, + package_name=package_name, + ) # 3. Upload files from deploy root local folder to the above stage stage_schema = extract_schema(stage_fqn) @@ -427,14 +428,13 @@ def deploy( print_diff=print_diff, ) - if post_deploy_hooks: - cls.execute_post_deploy_hooks( - console=console, - project_root=project_root, - post_deploy_hooks=post_deploy_hooks, - package_name=package_name, - package_warehouse=package_warehouse, - ) + cls.execute_post_deploy_hooks( + console=console, + project_root=project_root, + post_deploy_hooks=post_deploy_hooks, + package_name=package_name, + package_warehouse=package_warehouse, + ) if validate: cls.validate_setup_script( @@ -1037,10 +1037,17 @@ def apply_package_scripts( applies all package scripts in-order to the application package. """ - if package_scripts: - console.warning( - "WARNING: native_app.package.scripts is deprecated. Please migrate to using native_app.package.post_deploy." - ) + metrics = get_cli_context().metrics + metrics.set_counter_default(CLICounterField.PACKAGE_SCRIPTS, 0) + + if not package_scripts: + return + + metrics.set_counter(CLICounterField.PACKAGE_SCRIPTS, 1) + + console.warning( + "WARNING: native_app.package.scripts is deprecated. Please migrate to using native_app.package.post_deploy." + ) queued_queries = render_script_templates( project_root, @@ -1129,14 +1136,19 @@ def execute_post_deploy_hooks( package_name: str, package_warehouse: Optional[str], ): - with cls.use_package_warehouse(package_warehouse): - execute_post_deploy_hooks( - console=console, - project_root=project_root, - post_deploy_hooks=post_deploy_hooks, - deployed_object_type="application package", - database_name=package_name, - ) + get_cli_context().metrics.set_counter_default( + CLICounterField.POST_DEPLOY_SCRIPTS, 0 + ) + + if post_deploy_hooks: + with cls.use_package_warehouse(package_warehouse): + execute_post_deploy_hooks( + console=console, + project_root=project_root, + post_deploy_hooks=post_deploy_hooks, + deployed_object_type="application package", + database_name=package_name, + ) @classmethod def validate_setup_script( diff --git a/src/snowflake/cli/api/cli_global_context.py b/src/snowflake/cli/api/cli_global_context.py index e810da9d62..0b53495166 100644 --- a/src/snowflake/cli/api/cli_global_context.py +++ b/src/snowflake/cli/api/cli_global_context.py @@ -22,6 +22,7 @@ from snowflake.cli.api.connections import ConnectionContext, OpenConnectionCache from snowflake.cli.api.exceptions import MissingConfiguration +from snowflake.cli.api.metrics import CLIMetrics from snowflake.cli.api.output.formats import OutputFormat from snowflake.cli.api.rendering.jinja import CONTEXT_KEY from snowflake.connector import SnowflakeConnection @@ -46,6 +47,8 @@ class _CliGlobalContextManager: experimental: bool = False enable_tracebacks: bool = True + metrics: CLIMetrics = field(default_factory=CLIMetrics) + project_path_arg: str | None = None project_is_optional: bool = True project_env_overrides_args: dict[str, str] = field(default_factory=dict) @@ -152,6 +155,10 @@ def connection_context(self) -> ConnectionContext: def enable_tracebacks(self) -> bool: return self._manager.enable_tracebacks + @property + def metrics(self): + return self._manager.metrics + @property def output_format(self) -> OutputFormat: return self._manager.output_format diff --git a/src/snowflake/cli/api/entities/utils.py b/src/snowflake/cli/api/entities/utils.py index 9b514a9ba7..c60a2224b4 100644 --- a/src/snowflake/cli/api/entities/utils.py +++ b/src/snowflake/cli/api/entities/utils.py @@ -31,6 +31,7 @@ NO_WAREHOUSE_SELECTED_IN_SESSION, ) from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError +from snowflake.cli.api.metrics import CLICounterField from snowflake.cli.api.project.schemas.entities.common import PostDeployHook from snowflake.cli.api.rendering.sql_templates import ( choose_sql_jinja_env_based_on_template_syntax, @@ -249,6 +250,8 @@ def execute_post_deploy_hooks( if not post_deploy_hooks: return + get_cli_context().metrics.set_counter(CLICounterField.POST_DEPLOY_SCRIPTS, 1) + with console.phase(f"Executing {deployed_object_type} post-deploy actions"): sql_scripts_paths = [] for hook in post_deploy_hooks: diff --git a/src/snowflake/cli/api/metrics.py b/src/snowflake/cli/api/metrics.py new file mode 100644 index 0000000000..3ab4d70316 --- /dev/null +++ b/src/snowflake/cli/api/metrics.py @@ -0,0 +1,92 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + + +class _TypePrefix: + FEATURES = "features" + + +class _DomainPrefix: + GLOBAL = "global" + APP = "app" + SQL = "sql" + + +class CLICounterField: + """ + for each counter field we're adopting a convention of + .. + for example, if we're tracking a global feature, then the field name would be + features.global.feature_name + + The metrics API is implemented to be generic, but we are adopting a convention + for feature tracking with the following model for a given command execution: + * counter not present -> feature is not available + * counter == 0 -> feature is available, but not used + * counter == 1 -> feature is used + this makes it easy to compute percentages for feature dashboards in Snowsight + """ + + TEMPLATES_PROCESSOR = ( + f"{_TypePrefix.FEATURES}.{_DomainPrefix.GLOBAL}.templates_processor" + ) + SQL_TEMPLATES = f"{_TypePrefix.FEATURES}.{_DomainPrefix.SQL}.sql_templates" + PDF_TEMPLATES = f"{_TypePrefix.FEATURES}.{_DomainPrefix.GLOBAL}.pdf_templates" + SNOWPARK_PROCESSOR = ( + f"{_TypePrefix.FEATURES}.{_DomainPrefix.APP}.snowpark_processor" + ) + POST_DEPLOY_SCRIPTS = ( + f"{_TypePrefix.FEATURES}.{_DomainPrefix.APP}.post_deploy_scripts" + ) + PACKAGE_SCRIPTS = f"{_TypePrefix.FEATURES}.{_DomainPrefix.APP}.package_scripts" + + +class CLIMetrics: + """ + Class to track various metrics across the execution of a command + """ + + def __init__(self): + self._counters: Dict[str, int] = {} + + def __eq__(self, other): + if isinstance(other, CLIMetrics): + return self._counters == other._counters + return False + + def get_counter(self, name: str) -> Optional[int]: + return self._counters.get(name) + + def set_counter(self, name: str, value: int) -> None: + self._counters[name] = value + + def set_counter_default(self, name: str, value: int) -> None: + """ + sets the counter if it does not already exist + """ + if name not in self._counters: + self.set_counter(name, value) + + def increment_counter(self, name: str, value: int = 1) -> None: + if name not in self._counters: + self.set_counter(name, value) + else: + self._counters[name] += value + + @property + def counters(self) -> Dict[str, int]: + # return a copy of the original dict to avoid mutating the original + return self._counters.copy() diff --git a/src/snowflake/cli/api/rendering/sql_templates.py b/src/snowflake/cli/api/rendering/sql_templates.py index 5834d1cd44..1c400b07f8 100644 --- a/src/snowflake/cli/api/rendering/sql_templates.py +++ b/src/snowflake/cli/api/rendering/sql_templates.py @@ -21,6 +21,7 @@ from snowflake.cli.api.cli_global_context import get_cli_context from snowflake.cli.api.console.console import cli_console from snowflake.cli.api.exceptions import InvalidTemplate +from snowflake.cli.api.metrics import CLICounterField from snowflake.cli.api.rendering.jinja import ( CONTEXT_KEY, FUNCTION_KEY, @@ -96,4 +97,9 @@ def snowflake_sql_jinja_render(content: str, data: Dict | None = None) -> str: context_data = get_cli_context().template_context context_data.update(data) env = choose_sql_jinja_env_based_on_template_syntax(content) + + get_cli_context().metrics.set_counter( + CLICounterField.SQL_TEMPLATES, int(has_sql_templates(content)) + ) + return env.from_string(content).render(context_data) diff --git a/src/snowflake/cli/api/utils/definition_rendering.py b/src/snowflake/cli/api/utils/definition_rendering.py index 1755b31609..66f5a0b7a5 100644 --- a/src/snowflake/cli/api/utils/definition_rendering.py +++ b/src/snowflake/cli/api/utils/definition_rendering.py @@ -19,8 +19,10 @@ from jinja2 import Environment, TemplateSyntaxError, nodes from packaging.version import Version +from snowflake.cli.api.cli_global_context import get_cli_context from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.exceptions import CycleDetectedError, InvalidTemplate +from snowflake.cli.api.metrics import CLICounterField from snowflake.cli.api.project.schemas.project_definition import ( ProjectProperties, build_project_definition, @@ -266,6 +268,12 @@ def find_any_template_vars(element): return referenced_vars +def _has_referenced_vars_in_definition( + template_env: TemplatedEnvironment, definition: Definition +) -> bool: + return len(_get_referenced_vars_in_definition(template_env, definition)) > 0 + + def _template_version_warning(): cc.warning( "Ignoring template pattern in project definition file. " @@ -291,6 +299,17 @@ def _add_defaults_to_definition(original_definition: Definition) -> Definition: return definition_with_defaults +def _update_metrics(template_env: TemplatedEnvironment, definition: Definition): + metrics = get_cli_context().metrics + + # render_definition_template is invoked multiple times both by the user + # and by us so we should make sure we don't overwrite a 1 with a 0 here + metrics.set_counter_default(CLICounterField.PDF_TEMPLATES, 0) + + if _has_referenced_vars_in_definition(template_env, definition): + metrics.set_counter(CLICounterField.PDF_TEMPLATES, 1) + + def render_definition_template( original_definition: Optional[Definition], context_overrides: Context ) -> ProjectProperties: @@ -326,10 +345,7 @@ def render_definition_template( definition["definition_version"] ) < Version("1.1"): try: - referenced_vars = _get_referenced_vars_in_definition( - template_env, definition - ) - if referenced_vars: + if _has_referenced_vars_in_definition(template_env, definition): _template_version_warning() except Exception: # also warn on Exception, as it means the user is incorrectly attempting to use templating @@ -340,6 +356,10 @@ def render_definition_template( project_context[CONTEXT_KEY]["env"] = environment_overrides return ProjectProperties(project_definition, project_context) + # need to have the metrics added here since we add defaults to the + # definition that the user might not have added themselves later + _update_metrics(template_env, definition) + definition = _add_defaults_to_definition(definition) project_context = {CONTEXT_KEY: definition} diff --git a/tests/api/test_metrics.py b/tests/api/test_metrics.py new file mode 100644 index 0000000000..133269dda5 --- /dev/null +++ b/tests/api/test_metrics.py @@ -0,0 +1,101 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from snowflake.cli.api.metrics import CLIMetrics + + +def test_metrics_no_counters(): + # given + metrics = CLIMetrics() + + # when + + # then + assert metrics.counters == {} + assert metrics.get_counter("counter1") is None + + +def test_metrics_set_one_counter(): + # given + metrics = CLIMetrics() + + # when + metrics.set_counter("counter1", 1) + + # then + assert metrics.counters == {"counter1": 1} + assert metrics.get_counter("counter1") == 1 + + +def test_metrics_increment_new_counter(): + # given + metrics = CLIMetrics() + + # when + metrics.increment_counter("counter1") + + # then + assert metrics.counters == {"counter1": 1} + assert metrics.get_counter("counter1") == 1 + + +def test_metrics_increment_existing_counter(): + # given + metrics = CLIMetrics() + + # when + metrics.set_counter("counter1", 2) + metrics.increment_counter(name="counter1", value=2) + + # then + assert metrics.counters == {"counter1": 4} + assert metrics.get_counter("counter1") == 4 + + +def test_metrics_set_multiple_counters(): + # given + metrics = CLIMetrics() + + # when + metrics.set_counter("counter1", 1) + metrics.set_counter("counter2", 0) + metrics.set_counter(name="counter2", value=2) + + # then + assert metrics.counters == {"counter1": 1, "counter2": 2} + assert metrics.get_counter("counter1") == 1 + assert metrics.get_counter("counter2") == 2 + + +def test_metrics_set_default_new_counter(): + # given + metrics = CLIMetrics() + + # when + metrics.set_counter_default("c1", 3) + + # then + assert metrics.counters == {"c1": 3} + + +def test_metrics_set_default_existing_counter(): + # given + metrics = CLIMetrics() + + # when + metrics.set_counter("c2", 2) + metrics.set_counter_default("c2", 1) + + # then + assert metrics.counters == {"c2": 2} diff --git a/tests_integration/nativeapp/test_feature_metrics.py b/tests_integration/nativeapp/test_feature_metrics.py new file mode 100644 index 0000000000..33312f7037 --- /dev/null +++ b/tests_integration/nativeapp/test_feature_metrics.py @@ -0,0 +1,155 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from shlex import split +from typing import Dict, Any +from unittest import mock +from unittest.mock import MagicMock + +from snowflake.cli.api.metrics import CLICounterField +from tests.project.fixtures import * + + +def _extract_first_result_executing_command_telemetry_message( + mock_telemetry: MagicMock, +) -> Dict[str, Any]: + # The method is called with a TelemetryData type, so we cast it to dict for simpler comparison + return next( + args.args[0].to_dict()["message"] + for args in mock_telemetry.call_args_list + if args.args[0].to_dict().get("message").get("type") + == "result_executing_command" + ) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "command,expected_counter", + [ + ( + [ + "sql", + "-q", + "select '<% ctx.env.test %>'", + "--env", + "test=value_from_cli", + ], + 1, + ), + (["sql", "-q", "select 'string'"], 0), + ], +) +@mock.patch("snowflake.connector.telemetry.TelemetryClient.try_add_log_to_batch") +def test_sql_templating_emits_counter( + mock_telemetry, + command: List[str], + expected_counter, + runner, +): + result = runner.invoke_with_connection_json(command) + + assert result.exit_code == 0 + + message = _extract_first_result_executing_command_telemetry_message(mock_telemetry) + + assert message["counters"][CLICounterField.SQL_TEMPLATES] == expected_counter + + +@pytest.mark.integration +@pytest.mark.parametrize( + "command," "test_project," "expected_counters", + [ + # ensure that post deploy scripts are picked up for v1 + ( + "app deploy", + "napp_application_post_deploy_v1", + { + CLICounterField.SNOWPARK_PROCESSOR: 0, + CLICounterField.TEMPLATES_PROCESSOR: 0, + CLICounterField.PDF_TEMPLATES: 0, + CLICounterField.POST_DEPLOY_SCRIPTS: 1, + CLICounterField.PACKAGE_SCRIPTS: 0, + }, + ), + # post deploy scripts should not be available for bundling since there is no deploy + ( + "ws bundle --entity-id=pkg", + "napp_templates_processors_v2", + { + CLICounterField.SNOWPARK_PROCESSOR: 0, + CLICounterField.TEMPLATES_PROCESSOR: 1, + CLICounterField.PDF_TEMPLATES: 1, + }, + ), + # ensure that templates processor is picked up + ( + "app run", + "napp_templates_processors_v1", + { + CLICounterField.SNOWPARK_PROCESSOR: 0, + CLICounterField.TEMPLATES_PROCESSOR: 1, + CLICounterField.PDF_TEMPLATES: 0, + CLICounterField.POST_DEPLOY_SCRIPTS: 0, + CLICounterField.PACKAGE_SCRIPTS: 0, + }, + ), + # ensure that package scripts are picked up + ( + "app deploy", + "integration_external", + { + CLICounterField.SNOWPARK_PROCESSOR: 0, + CLICounterField.TEMPLATES_PROCESSOR: 0, + CLICounterField.POST_DEPLOY_SCRIPTS: 0, + CLICounterField.PACKAGE_SCRIPTS: 1, + }, + ), + # ensure post deploy scripts are picked up for v2 + ( + "app deploy", + "integration_external_v2", + { + CLICounterField.SNOWPARK_PROCESSOR: 0, + CLICounterField.TEMPLATES_PROCESSOR: 0, + CLICounterField.PDF_TEMPLATES: 1, + CLICounterField.POST_DEPLOY_SCRIPTS: 1, + CLICounterField.PACKAGE_SCRIPTS: 0, + }, + ), + ], +) +@mock.patch("snowflake.connector.telemetry.TelemetryClient.try_add_log_to_batch") +def test_nativeapp_feature_counter_has_expected_value( + mock_telemetry, + runner, + nativeapp_teardown, + nativeapp_project_directory, + command: str, + test_project: str, + expected_counters: Dict[str, int], +): + local_test_env = { + "APP_DIR": "app", + "schema_name": "test_schema", + "table_name": "test_table", + "value": "test_value", + } + + with nativeapp_project_directory(test_project): + runner.invoke_with_connection(split(command), env=local_test_env) + + message = _extract_first_result_executing_command_telemetry_message( + mock_telemetry + ) + + assert message["counters"] == expected_counters