From 1b1d4fc15bf0304d97351319595802784e3a9022 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Thu, 7 Sep 2023 14:19:03 +0200 Subject: [PATCH] Output types (#350) --- src/snowcli/app/cli_app.py | 9 +- src/snowcli/cli/connection/commands.py | 14 +- .../cli/snowpark/compute_pool/commands.py | 18 +- src/snowcli/cli/snowpark/function/commands.py | 39 ++-- src/snowcli/cli/snowpark/jobs/commands.py | 19 +- src/snowcli/cli/snowpark/package/commands.py | 16 +- .../cli/snowpark/procedure/commands.py | 41 ++-- .../snowpark/procedure_coverage/commands.py | 8 +- src/snowcli/cli/snowpark/registry/commands.py | 5 +- src/snowcli/cli/snowpark/registry/manager.py | 3 - src/snowcli/cli/snowpark/services/commands.py | 28 +-- src/snowcli/cli/sql/commands.py | 12 +- src/snowcli/cli/stage/commands.py | 26 +-- src/snowcli/cli/streamlit/commands.py | 40 ++-- src/snowcli/cli/warehouse/commands.py | 4 +- src/snowcli/output/decorators.py | 8 +- src/snowcli/output/printing.py | 201 +++++++----------- src/snowcli/output/types.py | 84 ++++++++ tests/__snapshots__/test_package.ambr | 28 +-- tests/__snapshots__/test_warehouse.ambr | 1 + tests/output/test_printing.py | 198 ++++++++++------- .../test_procedure_coverage.ambr | 24 +-- tests/test_main.py | 4 - tests/test_snow_connector.py | 2 +- tests/test_warehouse.py | 2 - tests/testing_utils/fixtures.py | 1 + tests_integration/conftest.py | 2 +- tests_integration/test_package.py | 8 +- .../assertions/test_result_assertions.py | 7 +- .../testing_utils/snowpark_utils.py | 7 +- 30 files changed, 467 insertions(+), 392 deletions(-) create mode 100644 src/snowcli/output/types.py diff --git a/src/snowcli/app/cli_app.py b/src/snowcli/app/cli_app.py index f099eb0a4..c56ad4e17 100644 --- a/src/snowcli/app/cli_app.py +++ b/src/snowcli/app/cli_app.py @@ -12,7 +12,8 @@ from snowcli.config import config_init, cli_config from snowcli.app.dev.docs.generator import generate_docs from snowcli.output.formats import OutputFormat -from snowcli.output.printing import OutputData +from snowcli.output.printing import print_result +from snowcli.output.types import CollectionResult from snowcli.app.dev.pycharm_remote_debug import ( setup_pycharm_remote_debugger_if_provided, ) @@ -36,13 +37,13 @@ def _version_callback(value: bool): def _info_callback(value: bool): if value: - OutputData.from_list( + result = CollectionResult( [ {"key": "version", "value": __about__.VERSION}, {"key": "default_config_file_path", "value": cli_config.file_path}, ], - format_=OutputFormat.JSON, - ).print() + ) + print_result(result, output_format=OutputFormat.JSON) raise typer.Exit() diff --git a/src/snowcli/cli/connection/commands.py b/src/snowcli/cli/connection/commands.py index 4b57db336..ab898f3e1 100644 --- a/src/snowcli/cli/connection/commands.py +++ b/src/snowcli/cli/connection/commands.py @@ -11,7 +11,7 @@ from snowcli.cli.common.flags import DEFAULT_CONTEXT_SETTINGS, ConnectionOption from snowcli.config import cli_config from snowcli.output.decorators import with_output -from snowcli.output.printing import OutputData +from snowcli.output.types import CollectionResult, CommandResult, MessageResult from snowcli.snow_connector import connect_to_snowflake app = typer.Typer( @@ -41,7 +41,7 @@ def _mask_password(connection_params: dict): @app.command(name="list") @with_output @global_options -def list_connections(**options) -> OutputData: +def list_connections(**options) -> CommandResult: """ List configured connections. """ @@ -50,7 +50,7 @@ def list_connections(**options) -> OutputData: {"connection_name": k, "parameters": _mask_password(v)} for k, v in connections.items() ) - return OutputData(stream=result) + return CollectionResult(result) def require_integer(field_name: str): @@ -159,7 +159,7 @@ def add( prompt="Snowflake region", help="Region name if not the default Snowflake deployment.", ), -) -> OutputData: +) -> CommandResult: """Add connection to configuration file.""" connection_entry = { "account": account, @@ -180,16 +180,16 @@ def add( except KeyAlreadyPresent: raise ClickException(f"Connection {connection_name} already exists") - return OutputData.from_string( + return MessageResult( f"Wrote new connection {connection_name} to {cli_config.file_path}" ) @app.command() @with_output -def test(connection: str = ConnectionOption) -> OutputData: +def test(connection: str = ConnectionOption) -> CommandResult: """ Tests connection to Snowflake. """ connect_to_snowflake(connection_name=connection) - return OutputData.from_string("OK") + return MessageResult("OK") diff --git a/src/snowcli/cli/snowpark/compute_pool/commands.py b/src/snowcli/cli/snowpark/compute_pool/commands.py index 760cb2c24..a4adc7386 100644 --- a/src/snowcli/cli/snowpark/compute_pool/commands.py +++ b/src/snowcli/cli/snowpark/compute_pool/commands.py @@ -5,7 +5,7 @@ from snowcli.cli.common.flags import DEFAULT_CONTEXT_SETTINGS from snowcli.cli.snowpark.compute_pool.manager import ComputePoolManager from snowcli.output.decorators import with_output -from snowcli.output.printing import OutputData +from snowcli.output.types import SingleQueryResult, QueryResult, CommandResult app = typer.Typer( context_settings=DEFAULT_CONTEXT_SETTINGS, @@ -22,25 +22,25 @@ def create( num_instances: int = typer.Option(..., "--num", "-d", help="Number of instances"), instance_family: str = typer.Option(..., "--family", "-f", help="Instance family"), **options, -) -> OutputData: +) -> CommandResult: """ Create compute pool """ cursor = ComputePoolManager().create( pool_name=name, num_instances=num_instances, instance_family=instance_family ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command() @with_output @global_options_with_connection -def list(**options) -> OutputData: +def list(**options) -> CommandResult: """ List compute pools """ cursor = ComputePoolManager().show() - return OutputData.from_cursor(cursor) + return QueryResult(cursor) @app.command() @@ -48,12 +48,12 @@ def list(**options) -> OutputData: @global_options_with_connection def drop( name: str = typer.Argument(..., help="Compute Pool Name"), **options -) -> OutputData: +) -> CommandResult: """ Drop compute pool """ cursor = ComputePoolManager().drop(pool_name=name) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command() @@ -61,12 +61,12 @@ def drop( @global_options_with_connection def stop( name: str = typer.Argument(..., help="Compute Pool Name"), **options -) -> OutputData: +) -> CommandResult: """ Stop and delete all services running on Compute Pool """ cursor = ComputePoolManager().stop(pool_name=name) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) app_cp = build_alias( diff --git a/src/snowcli/cli/snowpark/function/commands.py b/src/snowcli/cli/snowpark/function/commands.py index 7ec7accd6..2914ef45e 100644 --- a/src/snowcli/cli/snowpark/function/commands.py +++ b/src/snowcli/cli/snowpark/function/commands.py @@ -18,7 +18,12 @@ ) from snowcli.cli.stage.manager import StageManager from snowcli.output.decorators import with_output -from snowcli.output.printing import OutputData +from snowcli.output.types import ( + MessageResult, + SingleQueryResult, + QueryResult, + CommandResult, +) from snowcli.utils import ( prepare_app_zip, get_snowflake_packages, @@ -71,7 +76,7 @@ def function_init(): Initialize this directory with a sample set of files to create a function. """ create_project_template("default_function") - return OutputData.from_string("Done") + return MessageResult("Done") @app.command("create") @@ -104,7 +109,7 @@ def function_create( help="Replace if existing function", ), **options, -) -> OutputData: +) -> CommandResult: """Creates a python UDF/UDTF using local artifact.""" snowpark_package( pypi_download, # type: ignore[arg-type] @@ -133,7 +138,7 @@ def function_create( packages=packages, overwrite=overwrite, ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) def upload_snowpark_artifact( @@ -182,7 +187,7 @@ def function_update( help="Replace function, even if no detected changes to metadata", ), **options, -) -> OutputData: +) -> CommandResult: """Updates an existing python UDF/UDTF using local artifact.""" snowpark_package( pypi_download, # type: ignore[arg-type] @@ -241,9 +246,9 @@ def function_update( packages=packages, overwrite=True, ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) - return OutputData.from_string("No packages to update. Deployment complete!") + return MessageResult("No packages to update. Deployment complete!") @app.command("package") @@ -252,14 +257,14 @@ def function_package( pypi_download: str = PyPiDownloadOption, check_anaconda_for_pypi_deps: bool = CheckAnacondaForPyPiDependancies, package_native_libraries: str = PackageNativeLibrariesOption, -) -> OutputData: +) -> CommandResult: """Packages function code into zip file.""" snowpark_package( pypi_download, # type: ignore[arg-type] check_anaconda_for_pypi_deps, package_native_libraries, # type: ignore[arg-type] ) - return OutputData.from_string("Done") + return MessageResult("Done") @app.command("execute") @@ -273,10 +278,10 @@ def function_execute( help="Function with inputs. E.g. 'hello(int, string)'", ), **options, -) -> OutputData: +) -> CommandResult: """Executes a Snowflake function.""" cursor = FunctionManager().execute(expression=function) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command("describe") @@ -292,14 +297,14 @@ def function_describe( help="Function signature with inputs. E.g. 'hello(int, string)'", ), **options, -) -> OutputData: +) -> CommandResult: """Describes a Snowflake function.""" cursor = FunctionManager().describe( identifier=FunctionManager.identifier( name=name, signature=input_parameters, name_and_signature=function ) ) - return OutputData.from_cursor(cursor) + return QueryResult(cursor) @app.command("list") @@ -313,10 +318,10 @@ def function_list( help='Filter functions by name - e.g. "hello%"', ), **options, -) -> OutputData: +) -> CommandResult: """Lists Snowflake functions.""" cursor = FunctionManager().show(like=like) - return OutputData.from_cursor(cursor) + return QueryResult(cursor) @app.command("drop") @@ -332,11 +337,11 @@ def function_drop( help="Function signature with inputs. E.g. 'hello(int, string)'", ), **options, -) -> OutputData: +) -> CommandResult: """Drops a Snowflake function.""" cursor = FunctionManager().drop( identifier=FunctionManager.identifier( name=name, signature=input_parameters, name_and_signature=signature ) ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) diff --git a/src/snowcli/cli/snowpark/jobs/commands.py b/src/snowcli/cli/snowpark/jobs/commands.py index fb5129eb0..d74d4a2c8 100644 --- a/src/snowcli/cli/snowpark/jobs/commands.py +++ b/src/snowcli/cli/snowpark/jobs/commands.py @@ -9,7 +9,8 @@ from snowcli.cli.snowpark.jobs.manager import JobManager from snowcli.cli.stage.manager import StageManager from snowcli.output.decorators import with_output -from snowcli.output.printing import OutputData + +from snowcli.output.types import SingleQueryResult, CommandResult app = typer.Typer( context_settings=DEFAULT_CONTEXT_SETTINGS, name="jobs", help="Manage jobs" @@ -32,7 +33,7 @@ def create( ), stage: str = typer.Option("SOURCE_STAGE", "--stage", "-l", help="Stage name"), **options, -) -> OutputData: +) -> CommandResult: """ Create Job """ @@ -43,18 +44,18 @@ def create( cursor = JobManager().create( compute_pool=compute_pool, spec_path=spec_path, stage=stage ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command() @with_output @global_options_with_connection -def desc(id: str = typer.Argument(..., help="Job id"), **options) -> OutputData: +def desc(id: str = typer.Argument(..., help="Job id"), **options) -> CommandResult: """ Desc Service """ cursor = JobManager().desc(job_name=id) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command() @@ -78,20 +79,20 @@ def logs( @app.command() @with_output @global_options_with_connection -def status(id: str = typer.Argument(..., help="Job id"), **options) -> OutputData: +def status(id: str = typer.Argument(..., help="Job id"), **options) -> CommandResult: """ Returns status of a job. """ cursor = JobManager().status(job_name=id) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command() @with_output @global_options_with_connection -def drop(id: str = typer.Argument(..., help="Job id"), **options) -> OutputData: +def drop(id: str = typer.Argument(..., help="Job id"), **options) -> CommandResult: """ Drop Service """ cursor = JobManager().drop(job_name=id) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) diff --git a/src/snowcli/cli/snowpark/package/commands.py b/src/snowcli/cli/snowpark/package/commands.py index 95c47c976..b7196a664 100644 --- a/src/snowcli/cli/snowpark/package/commands.py +++ b/src/snowcli/cli/snowpark/package/commands.py @@ -14,14 +14,12 @@ upload, ) from snowcli.cli.snowpark.package.utils import ( - InAnaconda, NotInAnaconda, RequiresPackages, - NothingFound, CreatedSuccessfully, ) from snowcli.output.decorators import with_output -from snowcli.output.printing import OutputData +from snowcli.output.types import MessageResult, CommandResult app = typer.Typer( name="package", @@ -43,7 +41,7 @@ def package_lookup( help="Install packages that are not available on the Snowflake anaconda channel", ), **options, -) -> OutputData: +) -> CommandResult: """ Checks if a package is available on the Snowflake anaconda channel. In install_packages flag is set to True, command will check all the dependencies of the packages @@ -51,7 +49,7 @@ def package_lookup( """ lookup_result = lookup(name=name, install_packages=install_packages) cleanup_after_install() - return OutputData.from_string(lookup_result.message) + return MessageResult(lookup_result.message) @app.command("upload") @@ -78,11 +76,11 @@ def package_upload( help="Overwrite the file if it already exists", ), **options, -) -> OutputData: +) -> CommandResult: """ Upload a python package zip file to a Snowflake stage, so it can be referenced in the imports of a procedure or function. """ - return OutputData.from_string(upload(file=file, stage=stage, overwrite=overwrite)) + return MessageResult(upload(file=file, stage=stage, overwrite=overwrite)) @app.command("create") @@ -100,7 +98,7 @@ def package_create( help="Install packages that are not available on the Snowflake anaconda channel", ), **options, -) -> OutputData: +) -> CommandResult: """ Create a python package as a zip file that can be uploaded to a stage and imported for a Snowpark python app. """ @@ -120,4 +118,4 @@ def package_create( message = lookup_result.message cleanup_after_install() - return OutputData.from_string(message) + return MessageResult(message) diff --git a/src/snowcli/cli/snowpark/procedure/commands.py b/src/snowcli/cli/snowpark/procedure/commands.py index 0d457ea44..5739388de 100644 --- a/src/snowcli/cli/snowpark/procedure/commands.py +++ b/src/snowcli/cli/snowpark/procedure/commands.py @@ -23,7 +23,12 @@ ) from snowcli.cli.stage.manager import StageManager from snowcli.output.decorators import with_output -from snowcli.output.printing import OutputData +from snowcli.output.types import ( + MessageResult, + CommandResult, + SingleQueryResult, + QueryResult, +) from snowcli.utils import ( create_project_template, prepare_app_zip, @@ -45,12 +50,12 @@ @app.command("init") @with_output -def procedure_init() -> OutputData: +def procedure_init() -> CommandResult: """ Initialize this directory with a sample set of files to create a procedure. """ create_project_template("default_procedure") - return OutputData.from_string("Done") + return MessageResult("Done") @app.command("create") @@ -108,7 +113,7 @@ def procedure_create( help="Wraps the procedure with a code coverage measurement tool, so that a coverage report can be later retrieved.", ), **options, -) -> OutputData: +) -> CommandResult: """Creates a python procedure using local artifact.""" snowpark_package( pypi_download, # type: ignore[arg-type] @@ -142,7 +147,7 @@ def procedure_create( overwrite=overwrite, execute_as_caller=execute_as_caller, ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) def _upload_procedure_artifact( @@ -236,7 +241,7 @@ def procedure_update( help="Wraps the procedure with a code coverage measurement tool, so that a coverage report can be later retrieved.", ), **options, -) -> OutputData: +) -> CommandResult: """Updates an existing python procedure using local artifact.""" snowpark_package( pypi_download, # type: ignore[arg-type] @@ -308,9 +313,9 @@ def procedure_update( overwrite=True, execute_as_caller=execute_as_caller, ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) - return OutputData.from_string("No packages to update. Deployment complete!") + return MessageResult("No packages to update. Deployment complete!") @app.command("package") @@ -319,14 +324,14 @@ def procedure_package( pypi_download: str = PyPiDownloadOption, check_anaconda_for_pypi_deps: bool = CheckAnacondaForPyPiDependancies, package_native_libraries: str = PackageNativeLibrariesOption, -) -> OutputData: +) -> CommandResult: """Packages procedure code into zip file.""" snowpark_package( pypi_download, # type: ignore[arg-type] check_anaconda_for_pypi_deps, package_native_libraries, # type: ignore[arg-type] ) - return OutputData.from_string("Done") + return MessageResult("Done") @app.command("execute") @@ -340,10 +345,10 @@ def procedure_execute( help="Procedure with inputs. E.g. 'hello(int, string)'. Must exactly match those provided when creating the procedure.", ), **options, -) -> OutputData: +) -> CommandResult: """Executes a Snowflake procedure.""" cursor = ProcedureManager().execute(expression=signature) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command("describe") @@ -364,7 +369,7 @@ def procedure_describe( help="Procedure signature with inputs. E.g. 'hello(int, string)'", ), **options, -) -> OutputData: +) -> CommandResult: """Describes a Snowflake procedure.""" cursor = ProcedureManager().describe( ProcedureManager.identifier( @@ -373,7 +378,7 @@ def procedure_describe( name_and_signature=signature, ) ) - return OutputData.from_cursor(cursor) + return QueryResult(cursor) @app.command("list") @@ -387,10 +392,10 @@ def procedure_list( help='Filter procedures by name - e.g. "hello%"', ), **options, -) -> OutputData: +) -> CommandResult: """Lists Snowflake procedures.""" cursor = ProcedureManager().show(like=like) - return OutputData.from_cursor(cursor) + return QueryResult(cursor) @app.command("drop") @@ -411,7 +416,7 @@ def procedure_drop( help="Procedure signature with inputs. E.g. 'hello(int, string)'", ), **options, -) -> OutputData: +) -> CommandResult: """Drops a Snowflake procedure.""" cursor = ProcedureManager().drop( ProcedureManager.identifier( @@ -420,7 +425,7 @@ def procedure_drop( name_and_signature=signature, ) ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) def _replace_handler_in_zip( diff --git a/src/snowcli/cli/snowpark/procedure_coverage/commands.py b/src/snowcli/cli/snowpark/procedure_coverage/commands.py index 7559c449b..fd5d3e1d4 100644 --- a/src/snowcli/cli/snowpark/procedure_coverage/commands.py +++ b/src/snowcli/cli/snowpark/procedure_coverage/commands.py @@ -7,7 +7,7 @@ ReportOutputOptions, ) from snowcli.output.decorators import with_output -from snowcli.output.printing import OutputData +from snowcli.output.types import MessageResult, SingleQueryResult, CommandResult app: typer.Typer = typer.Typer( name="coverage", @@ -55,7 +55,7 @@ def procedure_coverage_report( store_as_comment=store_as_comment, ) - return OutputData.from_string(message) + return MessageResult(message) @app.command( @@ -78,8 +78,8 @@ def procedure_coverage_clear( help="Input parameters - such as (message string, count int). Must exactly match those provided when creating the procedure.", ), **options, -) -> OutputData: +) -> CommandResult: cursor = ProcedureCoverageManager().clear( name=name, input_parameters=input_parameters ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) diff --git a/src/snowcli/cli/snowpark/registry/commands.py b/src/snowcli/cli/snowpark/registry/commands.py index 105fac67d..7fcf19b9d 100644 --- a/src/snowcli/cli/snowpark/registry/commands.py +++ b/src/snowcli/cli/snowpark/registry/commands.py @@ -4,8 +4,7 @@ from snowcli.cli.common.flags import DEFAULT_CONTEXT_SETTINGS, ConnectionOption from snowcli.cli.snowpark.registry.manager import get_token from snowcli.output.decorators import with_output -from snowcli.output.formats import OutputFormat -from snowcli.output.printing import OutputData +from snowcli.output.types import ObjectResult app = typer.Typer( context_settings=DEFAULT_CONTEXT_SETTINGS, name="registry", help="Manage registry" @@ -19,4 +18,4 @@ def token(environment: str = ConnectionOption, **options): """ Get token to authenticate with registry. """ - return OutputData.from_list([get_token(environment)]) + return ObjectResult(get_token(environment)) diff --git a/src/snowcli/cli/snowpark/registry/manager.py b/src/snowcli/cli/snowpark/registry/manager.py index 4d04e604c..a130ce4e2 100644 --- a/src/snowcli/cli/snowpark/registry/manager.py +++ b/src/snowcli/cli/snowpark/registry/manager.py @@ -1,7 +1,4 @@ -import json - from snowcli.cli.common.flags import ConnectionOption -from snowcli.output.printing import OutputData from snowcli.snow_connector import connect_to_snowflake diff --git a/src/snowcli/cli/snowpark/services/commands.py b/src/snowcli/cli/snowpark/services/commands.py index 74cf4b1f7..b13770756 100644 --- a/src/snowcli/cli/snowpark/services/commands.py +++ b/src/snowcli/cli/snowpark/services/commands.py @@ -4,12 +4,12 @@ import typer from snowcli.cli.common.decorators import global_options_with_connection -from snowcli.cli.common.flags import ConnectionOption, DEFAULT_CONTEXT_SETTINGS +from snowcli.cli.common.flags import DEFAULT_CONTEXT_SETTINGS from snowcli.cli.snowpark.common import print_log_lines from snowcli.cli.snowpark.services.manager import ServiceManager from snowcli.cli.stage.manager import StageManager from snowcli.output.decorators import with_output -from snowcli.output.printing import OutputData +from snowcli.output.types import QueryResult, SingleQueryResult, CommandResult app = typer.Typer( context_settings=DEFAULT_CONTEXT_SETTINGS, name="services", help="Manage services" @@ -36,7 +36,7 @@ def create( ), stage: str = typer.Option("SOURCE_STAGE", "--stage", "-l", help="Stage name"), **options, -) -> OutputData: +) -> CommandResult: """ Create service """ @@ -51,18 +51,20 @@ def create( spec_path=spec_path, stage=stage, ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command() @with_output @global_options_with_connection -def desc(name: str = typer.Argument(..., help="Service Name"), **options) -> OutputData: +def desc( + name: str = typer.Argument(..., help="Service Name"), **options +) -> CommandResult: """ Desc Service """ cursor = ServiceManager().desc(service_name=name) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command() @@ -70,34 +72,36 @@ def desc(name: str = typer.Argument(..., help="Service Name"), **options) -> Out @global_options_with_connection def status( name: str = typer.Argument(..., help="Service Name"), **options -) -> OutputData: +) -> CommandResult: """ Logs Service """ cursor = ServiceManager().status(service_name=name) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command() @with_output @global_options_with_connection -def list(**options) -> OutputData: +def list(**options) -> CommandResult: """ List Service """ cursor = ServiceManager().show() - return OutputData.from_cursor(cursor) + return QueryResult(cursor) @app.command() @with_output @global_options_with_connection -def drop(name: str = typer.Argument(..., help="Service Name"), **options) -> OutputData: +def drop( + name: str = typer.Argument(..., help="Service Name"), **options +) -> CommandResult: """ Drop Service """ cursor = ServiceManager().drop(service_name=name) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command() diff --git a/src/snowcli/cli/sql/commands.py b/src/snowcli/cli/sql/commands.py index 512a15519..58ecf3689 100644 --- a/src/snowcli/cli/sql/commands.py +++ b/src/snowcli/cli/sql/commands.py @@ -6,7 +6,7 @@ from snowcli.cli.common.decorators import global_options_with_connection from snowcli.cli.sql.manager import SqlManager from snowcli.output.decorators import with_output -from snowcli.output.printing import OutputData +from snowcli.output.types import QueryResult, CommandResult, MultipleResults @with_output @@ -29,7 +29,7 @@ def execute_sql( help="File to execute.", ), **options -) -> OutputData: +) -> CommandResult: """ Executes Snowflake query. @@ -38,5 +38,9 @@ def execute_sql( """ cursors = SqlManager().execute(query, file) if len(cursors) > 1: - return OutputData(stream=(OutputData.from_cursor(cur) for cur in cursors)) - return OutputData.from_cursor(cursors[0]) + result = MultipleResults() + for curr in cursors: + result.add(QueryResult(curr)) + else: + result = QueryResult(cursors[0]) + return result diff --git a/src/snowcli/cli/stage/commands.py b/src/snowcli/cli/stage/commands.py index 799477382..bde37c3d5 100644 --- a/src/snowcli/cli/stage/commands.py +++ b/src/snowcli/cli/stage/commands.py @@ -7,7 +7,7 @@ from snowcli.cli.common.flags import DEFAULT_CONTEXT_SETTINGS from snowcli.cli.stage.manager import StageManager from snowcli.output.decorators import with_output -from snowcli.output.printing import OutputData +from snowcli.output.types import QueryResult, SingleQueryResult, CommandResult app = typer.Typer( name="stage", @@ -23,7 +23,7 @@ @global_options_with_connection def stage_list( stage_name: str = typer.Argument(None, help="Name of stage"), **options -) -> OutputData: +) -> CommandResult: """ List stage contents or shows available stages if stage name not provided. """ @@ -33,7 +33,7 @@ def stage_list( cursor = manager.list(stage_name=stage_name) else: cursor = manager.show() - return OutputData.from_cursor(cursor) + return QueryResult(cursor) @app.command("get") @@ -51,12 +51,12 @@ def stage_get( help="Directory location to store downloaded files", ), **options, -) -> OutputData: +) -> CommandResult: """ Download all files from a stage to a local directory. """ cursor = StageManager().get(stage_name=stage_name, dest_path=path) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command("put") @@ -86,7 +86,7 @@ def stage_put( help="Number of parallel threads to use for upload", ), **options, -) -> OutputData: +) -> CommandResult: """ Upload files to a stage from a local client """ @@ -96,29 +96,29 @@ def stage_put( cursor = manager.put( local_path=local_path, stage_path=name, overwrite=overwrite, parallel=parallel ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command("create") @with_output @global_options_with_connection -def stage_create(name: str = StageNameOption, **options) -> OutputData: +def stage_create(name: str = StageNameOption, **options) -> CommandResult: """ Create stage if not exists. """ cursor = StageManager().create(stage_name=name) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command("drop") @with_output @global_options_with_connection -def stage_drop(name: str = StageNameOption, **options) -> OutputData: +def stage_drop(name: str = StageNameOption, **options) -> CommandResult: """ Drop stage """ cursor = StageManager().drop(stage_name=name) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command("remove") @@ -128,10 +128,10 @@ def stage_remove( stage_name: str = StageNameOption, file_name: str = typer.Argument(..., help="File name"), **options, -) -> OutputData: +) -> CommandResult: """ Remove file from stage """ cursor = StageManager().remove(stage_name=stage_name, path=file_name) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) diff --git a/src/snowcli/cli/streamlit/commands.py b/src/snowcli/cli/streamlit/commands.py index 391de37da..72ddc79b5 100644 --- a/src/snowcli/cli/streamlit/commands.py +++ b/src/snowcli/cli/streamlit/commands.py @@ -12,7 +12,14 @@ PackageNativeLibrariesOption, PyPiDownloadOption, ) -from snowcli.output.printing import OutputData +from snowcli.output.types import ( + CommandResult, + QueryResult, + CollectionResult, + SingleQueryResult, + MessageResult, + MultipleResults, +) app = typer.Typer( context_settings=DEFAULT_CONTEXT_SETTINGS, @@ -25,12 +32,12 @@ @app.command("list") @with_output @global_options_with_connection -def streamlit_list(**options) -> OutputData: +def streamlit_list(**options) -> CommandResult: """ List streamlit apps. """ cursor = StreamlitManager().list() - return OutputData.from_cursor(cursor) + return QueryResult(cursor) @app.command("describe") @@ -39,14 +46,15 @@ def streamlit_list(**options) -> OutputData: def streamlit_describe( name: str = typer.Argument(..., help="Name of streamlit to be deployed."), **options, -) -> OutputData: +) -> CommandResult: """ Describe a streamlit app. """ description, url = StreamlitManager().describe(streamlit_name=name) - return OutputData.from_list( - [OutputData.from_cursor(description), OutputData.from_cursor(url)] - ) + result = MultipleResults() + result.add(QueryResult(description)) + result.add(SingleQueryResult(url)) + return result @app.command("create") @@ -71,7 +79,7 @@ def streamlit_create( + "This should be considered a temporary workaround until native support is available.", ), **options, -) -> OutputData: +) -> CommandResult: """ Create a streamlit app. """ @@ -81,7 +89,7 @@ def streamlit_create( from_stage=from_stage, use_packaging_workaround=use_packaging_workaround, ) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command("share") @@ -93,12 +101,12 @@ def streamlit_share( ..., help="Role that streamlit should be shared with." ), **options, -) -> OutputData: +) -> CommandResult: """ Share a streamlit app with a role. """ cursor = StreamlitManager().share(streamlit_name=name, to_role=to_role) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command("drop") @@ -107,12 +115,12 @@ def streamlit_share( def streamlit_drop( name: str = typer.Argument(..., help="Name of streamlit to be deleted."), **options, -) -> OutputData: +) -> CommandResult: """ Drop a streamlit app. """ cursor = StreamlitManager().drop(streamlit_name=name) - return OutputData.from_cursor(cursor) + return SingleQueryResult(cursor) @app.command("deploy") @@ -154,7 +162,7 @@ def streamlit_deploy( + "environment.yml (noting the risk of runtime errors).", ), **options, -) -> OutputData: +) -> CommandResult: """ Deploy a streamlit app. """ @@ -170,5 +178,5 @@ def streamlit_deploy( excluded_anaconda_deps=excluded_anaconda_deps, ) if result is not None: - return OutputData.from_string(result) - return OutputData.from_string("Done") + return MessageResult(result) + return MessageResult("Done") diff --git a/src/snowcli/cli/warehouse/commands.py b/src/snowcli/cli/warehouse/commands.py index fd21ff407..c71a579a2 100644 --- a/src/snowcli/cli/warehouse/commands.py +++ b/src/snowcli/cli/warehouse/commands.py @@ -6,7 +6,7 @@ from snowcli.cli.common.flags import DEFAULT_CONTEXT_SETTINGS from snowcli.cli.warehouse.manager import WarehouseManager from snowcli.output.decorators import with_output -from snowcli.output.printing import OutputData +from snowcli.output.types import QueryResult app = typer.Typer( name="warehouse", @@ -23,4 +23,4 @@ def warehouse_status(**options): Show the status of each warehouse in the configured environment. """ cursor = WarehouseManager().show() - return OutputData.from_cursor(cursor) + return QueryResult(cursor) diff --git a/src/snowcli/output/decorators.py b/src/snowcli/output/decorators.py index 57227b4e5..a5ec740ab 100644 --- a/src/snowcli/output/decorators.py +++ b/src/snowcli/output/decorators.py @@ -3,18 +3,18 @@ from functools import wraps from snowcli.exception import CommandReturnTypeError -from snowcli.output.printing import OutputData +from snowcli.output.printing import print_result from snowflake.connector.cursor import SnowflakeCursor +from snowcli.output.types import CommandResult def with_output(func): @wraps(func) def wrapper(*args, **kwargs): output_data = func(*args, **kwargs) - - if not isinstance(output_data, OutputData): + if not isinstance(output_data, CommandResult): raise CommandReturnTypeError(type(output_data)) - output_data.print() + print_result(output_data) return wrapper diff --git a/src/snowcli/output/printing.py b/src/snowcli/output/printing.py index 211819044..e40b93923 100644 --- a/src/snowcli/output/printing.py +++ b/src/snowcli/output/printing.py @@ -1,25 +1,34 @@ from __future__ import annotations +import sys from datetime import datetime from json import JSONEncoder from pathlib import Path -from rich import box, print, print_json +from rich import box, print from rich.live import Live from rich.table import Table -from snowflake.connector.cursor import SnowflakeCursor -from typing import List, Optional, Dict, Union, Iterator +from typing import Union from snowcli.cli.common.snow_cli_global_context import snow_cli_global_context_manager -from snowcli.exception import OutputDataTypeError from snowcli.output.formats import OutputFormat +from snowcli.output.types import ( + MessageResult, + ObjectResult, + CollectionResult, + CommandResult, + MultipleResults, + QueryResult, +) class CustomJSONEncoder(JSONEncoder): """Custom JSON encoder handling serialization of non-standard types""" def default(self, o): - if isinstance(o, OutputData): - return o.as_json() + if isinstance(o, (ObjectResult, MessageResult)): + return o.result + if isinstance(o, (CollectionResult, MultipleResults)): + return list(o.result) if isinstance(o, datetime): return o.isoformat() if isinstance(o, Path): @@ -27,110 +36,6 @@ def default(self, o): return super().default(o) -class OutputData: - """ - This class constitutes base for returning output of commands. Every command wishing to return some - information to end users should return `OutputData` and use `@with_output` decorator. - - This implementation can handle streams of outputs. This helps with automated iteration through snowflake - cursors as well with cases when you want to stream constant output (for example logs). - """ - - def __init__( - self, - stream: Optional[Iterator[Union[Dict, OutputData]]] = None, - format_: Optional[OutputFormat] = None, - ) -> None: - self._stream = stream - self._format = format_ - - @classmethod - def from_cursor( - cls, cursor: SnowflakeCursor, format_: Optional[OutputFormat] = None - ) -> OutputData: - """Converts Snowflake cursor to stream of data""" - if not isinstance(cursor, SnowflakeCursor): - raise OutputDataTypeError(type(cursor), SnowflakeCursor) - return OutputData(stream=_get_data_from_cursor(cursor), format_=format_) - - @classmethod - def from_string( - cls, message: str, format_: Optional[OutputFormat] = None - ) -> OutputData: - """Coverts string to stream of data""" - if not isinstance(message, str): - raise OutputDataTypeError(type(message), str) - return cls(stream=({"result": message} for _ in range(1)), format_=format_) - - @classmethod - def from_list( - cls, data: List[Union[Dict, OutputData]], format_: Optional[OutputFormat] = None - ) -> OutputData: - """Converts list to stream of data.""" - if not isinstance(data, list) or ( - len(data) > 0 and not isinstance(data[0], (dict, OutputData)) - ): - raise OutputDataTypeError(type(data), List[Union[Dict, OutputData]]) - return cls(stream=(item for item in data), format_=format_) - - @property - def format(self) -> OutputFormat: - if not self._format: - self._format = _get_format_type() - return self._format - - def is_empty(self) -> bool: - return self._stream is None - - def get_data(self) -> Iterator[Union[Dict, OutputData]]: - """Returns iterator over output data""" - if not self._stream: - return None - - yield from self._stream - - def print(self): - _print_output(self) - - def as_json(self): - return list(self._stream) - - -def _print_output(output_data: Optional[OutputData] = None) -> None: - if output_data is None: - print("Done") - return - - if output_data.is_empty(): - print("No data") - return - - if output_data.format == OutputFormat.TABLE: - _render_table_output(output_data) - elif output_data.format == OutputFormat.JSON: - _print_json(output_data) - else: - raise Exception(f"Unknown {output_data.format} format option") - - -def _print_json(output_data: OutputData) -> None: - import json - - print_json(json.dumps(output_data.as_json(), cls=CustomJSONEncoder)) - - -def _get_data_from_cursor( - cursor: SnowflakeCursor, columns: Optional[List[str]] = None -) -> Iterator[Dict]: - column_names = [col.name for col in cursor.description] - columns_to_include = columns or column_names - - return ( - {k: v for k, v in zip(column_names, row) if k in columns_to_include} - for row in cursor - ) - - def _get_format_type() -> OutputFormat: output_format = ( snow_cli_global_context_manager.get_global_context_copy().output_format @@ -140,22 +45,74 @@ def _get_format_type() -> OutputFormat: return OutputFormat.TABLE -def _render_table_output(data: OutputData) -> None: - stream = data.get_data() - for item in stream: - if isinstance(item, OutputData): - _render_table_output(item) - else: - _print_table(item, stream) +def _get_table(): + return Table(show_header=True, box=box.ASCII) -def _print_table(item, stream): - table = Table(show_header=True, box=box.ASCII) - for column in item.keys(): +def _print_multiple_table_results(obj: CollectionResult): + if isinstance(obj, QueryResult): + print(obj.query) + items = obj.result + first_item = next(items) + table = _get_table() + for column in first_item.keys(): table.add_column(column) with Live(table, refresh_per_second=4): - table.add_row(*[str(i) for i in item.values()]) - for item in stream: + table.add_row(*[str(i) for i in first_item.values()]) + for item in items: table.add_row(*[str(i) for i in item.values()]) # Add separator between tables print() + + +def is_structured_format(output_format): + return output_format == OutputFormat.JSON + + +def print_structured(result: CommandResult): + """Handles outputs like json, yml and other structured and parsable formats.""" + import json + + return json.dump(result, sys.stdout, cls=CustomJSONEncoder) + + +def print_unstructured(obj: CommandResult | None): + """Handles outputs like table, plain text and other unstructured types.""" + if not obj: + print("Done") + elif not obj.result: + print("No data") + elif isinstance(obj, MessageResult): + print(obj.message) + else: + if isinstance(obj, ObjectResult): + _print_single_table(obj) + elif isinstance(obj, CollectionResult): + _print_multiple_table_results(obj) + else: + raise TypeError(f"No print strategy for type: {type(obj)}") + + +def _print_single_table(obj): + table = _get_table() + table.add_column("key") + table.add_column("value") + for key, value in obj.result.items(): + table.add_row(str(key), str(value)) + print(table) + + +def print_result(cmd_result: CommandResult, output_format: OutputFormat | None = None): + output_format = output_format or _get_format_type() + if is_structured_format(output_format): + print_structured(cmd_result) + elif isinstance(cmd_result, MultipleResults): + for res in cmd_result.result: + print_result(res) + elif ( + isinstance(cmd_result, (MessageResult, ObjectResult, CollectionResult)) + or cmd_result is None + ): + print_unstructured(cmd_result) + else: + raise ValueError(f"Unexpected type {type(cmd_result)}") diff --git a/src/snowcli/output/types.py b/src/snowcli/output/types.py new file mode 100644 index 000000000..84938af1b --- /dev/null +++ b/src/snowcli/output/types.py @@ -0,0 +1,84 @@ +from __future__ import annotations +import typing as t + +from snowflake.connector.cursor import SnowflakeCursor + + +class CommandResult: + @property + def result(self): + raise NotImplemented() + + +class ObjectResult(CommandResult): + def __init__(self, element: t.Dict): + self._element = element + + @property + def result(self): + return self._element + + +class CollectionResult(CommandResult): + def __init__(self, elements: t.Iterable[t.Dict]): + self._elements = elements + + @property + def result(self): + yield from self._elements + + +class MultipleResults(CommandResult): + def __init__(self, elements: t.List[CommandResult] | None = None): + self._elements = elements or [] + + def add(self, element: CommandResult): + self._elements.append(element) + + @property + def result(self): + return self._elements + + +class QueryResult(CollectionResult): + def __init__(self, cursor: SnowflakeCursor): + self.column_names = [col.name for col in cursor.description] + super().__init__(elements=self._prepare_payload(cursor)) + self._query = cursor.query + + def _prepare_payload(self, cursor): + return ({k: v for k, v in zip(self.column_names, row)} for row in cursor) + + @property + def query(self): + return self._query + + +class SingleQueryResult(ObjectResult): + def __init__(self, cursor: SnowflakeCursor): + super().__init__(element=self._prepare_payload(cursor)) + + def _prepare_payload(self, cursor): + results = list(QueryResult(cursor).result) + if results: + return results[0] + return None + + +class StreamResult(CommandResult): + def __init__(self, generator_func, mapper): + self._generator_func = generator_func + self._mapper = mapper + + +class MessageResult(CommandResult): + def __init__(self, message: str): + self._message = message + + @property + def message(self): + return self._message + + @property + def result(self): + return {"message": self._message} diff --git a/tests/__snapshots__/test_package.ambr b/tests/__snapshots__/test_package.ambr index ef285584a..24e004df6 100644 --- a/tests/__snapshots__/test_package.ambr +++ b/tests/__snapshots__/test_package.ambr @@ -1,36 +1,24 @@ # serializer version: 1 # name: TestPackage.test_package_lookup[argument0] ''' - +------------------------------------------------------------------------------+ - | result | - |------------------------------------------------------------------------------| - | Package snowflake-connector-python is available on the Snowflake anaconda | - | channel. | - +------------------------------------------------------------------------------+ + Package snowflake-connector-python is available on the Snowflake anaconda + channel. ''' # --- # name: TestPackage.test_package_lookup[argument1] ''' - +------------------------------------------------------------------------------+ - | result | - |------------------------------------------------------------------------------| - | Lookup for package some-weird-package-we-dont-know resulted in some error. | - | Please check the package name or try again with -y option | - +------------------------------------------------------------------------------+ + Lookup for package some-weird-package-we-dont-know resulted in some error. + Please check the package name or try again with -y option ''' # --- # name: TestPackage.test_package_lookup_with_install_packages ''' - +------------------------------------------------------------------------------+ - | result | - |------------------------------------------------------------------------------| - | The package some-other-package is supported, but does depend on the | - | following Snowflake supported native libraries. You should | - | include the following in your packages: [] | - +------------------------------------------------------------------------------+ + The package some-other-package is supported, but does depend on the + following Snowflake supported native libraries. You should + include the following in your packages: [] ''' # --- diff --git a/tests/__snapshots__/test_warehouse.ambr b/tests/__snapshots__/test_warehouse.ambr index 1e0db3766..fa606994e 100644 --- a/tests/__snapshots__/test_warehouse.ambr +++ b/tests/__snapshots__/test_warehouse.ambr @@ -1,6 +1,7 @@ # serializer version: 1 # name: test_show_warehouses ''' + SELECT A MOCK QUERY +------------------+ | name | state | |------+-----------| diff --git a/tests/output/test_printing.py b/tests/output/test_printing.py index 56d3d1924..8787542c6 100644 --- a/tests/output/test_printing.py +++ b/tests/output/test_printing.py @@ -4,9 +4,16 @@ from click import Context, Command -from snowcli.exception import OutputDataTypeError from snowcli.output.formats import OutputFormat -from snowcli.output.printing import OutputData +from snowcli.output.printing import print_result +from snowcli.output.types import ( + MultipleResults, + QueryResult, + MessageResult, + CollectionResult, + SingleQueryResult, + ObjectResult, +) from tests.testing_utils.fixtures import * @@ -15,25 +22,90 @@ class MockResultMetadata(NamedTuple): name: str -def test_print_multi_cursors_table(capsys, _create_mock_cursor): - output_data = OutputData.from_list( +def test_single_value_from_query(capsys, mock_cursor): + output_data = SingleQueryResult( + mock_cursor( + columns=["array", "object", "date"], + rows=[ + (["array"], {"k": "object"}, datetime(2022, 3, 21)), + ], + ) + ) + + print_result(output_data, output_format=OutputFormat.TABLE) + assert _get_output(capsys) == dedent( + """\ + +------------------------------+ + | key | value | + |--------+---------------------| + | array | ['array'] | + | object | {'k': 'object'} | + | date | 2022-03-21 00:00:00 | + +------------------------------+ + """ + ) + + +def test_single_object_result(capsys, mock_cursor): + output_data = ObjectResult( + {"array": ["array"], "object": {"k": "object"}, "date": datetime(2022, 3, 21)} + ) + + print_result(output_data, output_format=OutputFormat.TABLE) + assert _get_output(capsys) == dedent( + """\ + +------------------------------+ + | key | value | + |--------+---------------------| + | array | ['array'] | + | object | {'k': 'object'} | + | date | 2022-03-21 00:00:00 | + +------------------------------+ + """ + ) + + +def test_single_collection_result(capsys, mock_cursor): + output_data = { + "array": ["array"], + "object": {"k": "object"}, + "date": datetime(2022, 3, 21), + } + collection = CollectionResult([output_data, output_data]) + + print_result(collection, output_format=OutputFormat.TABLE) + assert _get_output(capsys) == dedent( + """\ + +---------------------------------------------------+ + | array | object | date | + |-----------+-----------------+---------------------| + | ['array'] | {'k': 'object'} | 2022-03-21 00:00:00 | + | ['array'] | {'k': 'object'} | 2022-03-21 00:00:00 | + +---------------------------------------------------+ + """ + ) + + +def test_print_multi_results_table(capsys, _create_mock_cursor): + output_data = MultipleResults( [ - OutputData.from_cursor(_create_mock_cursor()), - OutputData.from_cursor(_create_mock_cursor()), + QueryResult(_create_mock_cursor()), + QueryResult(_create_mock_cursor()), ], - format_=OutputFormat.TABLE, ) - output_data.print() + print_result(output_data, output_format=OutputFormat.TABLE) assert _get_output(capsys) == dedent( """\ + SELECT A MOCK QUERY +---------------------------------------------------------------------+ | string | number | array | object | date | |--------+--------+-----------+-----------------+---------------------| | string | 42 | ['array'] | {'k': 'object'} | 2022-03-21 00:00:00 | | string | 43 | ['array'] | {'k': 'object'} | 2022-03-21 00:00:00 | +---------------------------------------------------------------------+ + SELECT A MOCK QUERY +---------------------------------------------------------------------+ | string | number | array | object | date | |--------+--------+-----------+-----------------+---------------------| @@ -44,10 +116,10 @@ def test_print_multi_cursors_table(capsys, _create_mock_cursor): ) -def test_print_different_multi_cursors_table(capsys, mock_cursor): - output_data = OutputData.from_list( +def test_print_different_multi_results_table(capsys, mock_cursor): + output_data = MultipleResults( [ - OutputData.from_cursor( + QueryResult( mock_cursor( columns=["string", "number"], rows=[ @@ -68,7 +140,7 @@ def test_print_different_multi_cursors_table(capsys, mock_cursor): ], ) ), - OutputData.from_cursor( + QueryResult( mock_cursor( columns=["array", "object", "date"], rows=[ @@ -78,19 +150,20 @@ def test_print_different_multi_cursors_table(capsys, mock_cursor): ) ), ], - format_=OutputFormat.TABLE, ) - output_data.print() + print_result(output_data, output_format=OutputFormat.TABLE) assert _get_output(capsys) == dedent( """\ + SELECT A MOCK QUERY +-----------------+ | string | number | |--------+--------| | string | 42 | | string | 43 | +-----------------+ + SELECT A MOCK QUERY +---------------------------------------------------+ | array | object | date | |-----------+-----------------+---------------------| @@ -102,30 +175,26 @@ def test_print_different_multi_cursors_table(capsys, mock_cursor): def test_print_different_data_sources_table(capsys, _create_mock_cursor): - output_data = OutputData.from_list( + output_data = MultipleResults( [ - OutputData.from_cursor(_create_mock_cursor()), - OutputData.from_string("Command done"), - OutputData.from_list([{"key": "value"}]), + QueryResult(_create_mock_cursor()), + MessageResult("Command done"), + CollectionResult(({"key": "value"} for _ in range(1))), ], - format_=OutputFormat.TABLE, ) - output_data.print() + print_result(output_data, output_format=OutputFormat.TABLE) assert _get_output(capsys) == dedent( """\ + SELECT A MOCK QUERY +---------------------------------------------------------------------+ | string | number | array | object | date | |--------+--------+-----------+-----------------+---------------------| | string | 42 | ['array'] | {'k': 'object'} | 2022-03-21 00:00:00 | | string | 43 | ['array'] | {'k': 'object'} | 2022-03-21 00:00:00 | +---------------------------------------------------------------------+ - +--------------+ - | result | - |--------------| - | Command done | - +--------------+ + Command done +-------+ | key | |-------| @@ -136,14 +205,13 @@ def test_print_different_data_sources_table(capsys, _create_mock_cursor): def test_print_multi_db_cursor_json(capsys, _create_mock_cursor): - output_data = OutputData.from_list( + output_data = MultipleResults( [ - OutputData.from_cursor(_create_mock_cursor()), - OutputData.from_cursor(_create_mock_cursor()), + QueryResult(_create_mock_cursor()), + QueryResult(_create_mock_cursor()), ], - format_=OutputFormat.JSON, ) - output_data.print() + print_result(output_data, output_format=OutputFormat.JSON) assert _get_output_as_json(capsys) == [ [ @@ -182,16 +250,15 @@ def test_print_multi_db_cursor_json(capsys, _create_mock_cursor): def test_print_different_data_sources_json(capsys, _create_mock_cursor): - output_data = OutputData.from_list( + output_data = MultipleResults( [ - OutputData.from_cursor(_create_mock_cursor()), - OutputData.from_string("Command done"), - OutputData.from_list([{"key": "value"}]), + QueryResult(_create_mock_cursor()), + MessageResult("Command done"), + CollectionResult(({"key": f"value_{i}"} for i in range(2))), ], - format_=OutputFormat.JSON, ) - output_data.print() + print_result(output_data, output_format=OutputFormat.JSON) assert _get_output_as_json(capsys) == [ [ @@ -210,55 +277,30 @@ def test_print_different_data_sources_json(capsys, _create_mock_cursor): "date": "2022-03-21T00:00:00", }, ], - [{"result": "Command done"}], - [{"key": "value"}], + {"message": "Command done"}, + [{"key": "value_0"}, {"key": "value_1"}], ] def test_print_with_no_data_table(capsys): - output_data = OutputData(format_=OutputFormat.TABLE) + print_result(None) + assert _get_output(capsys) == "Done\n" - output_data.print() - assert _get_output(capsys) == "No data\n" +def test_print_with_no_data_in_query_json(capsys, _empty_cursor): + print_result(QueryResult(_empty_cursor()), output_format=OutputFormat.JSON) + assert _get_output(capsys) == "[]" -def test_print_with_no_data_json(capsys): - output_data = OutputData(format_=OutputFormat.JSON) +def test_print_with_no_data_in_single_value_query_json(capsys, _empty_cursor): + print_result(SingleQueryResult(_empty_cursor()), output_format=OutputFormat.JSON) + assert _get_output(capsys) == "null" - output_data.print() - assert _get_output(capsys) == "No data\n" +def test_print_with_no_response_json(capsys): + print_result(None, output_format=OutputFormat.JSON) - -def test_raise_error_when_try_add_wrong_data_type_to_from_cursor(): - with pytest.raises(OutputDataTypeError) as exception: - OutputData.from_cursor("") - - assert ( - exception.value.args[0] - == "Got type but expected " - ) - - -def test_raise_error_when_try_add_wrong_data_type_to_from_string(): - with pytest.raises(OutputDataTypeError) as exception: - OutputData.from_string(0) - - assert ( - exception.value.args[0] == "Got type but expected " - ) - - -def test_raise_error_when_try_add_wrong_data_type_to_from_list(): - with pytest.raises(OutputDataTypeError) as exception: - OutputData.from_list("") - - assert ( - exception.value.args[0] - == "Got type but expected typing.List[typing.Union[typing.Dict, " - "snowcli.output.printing.OutputData]]" - ) + assert _get_output(capsys) == "null" def _mock_output_format(mock_context, format): @@ -285,3 +327,11 @@ def _create_mock_cursor(mock_cursor): ("string", 43, ["array"], {"k": "object"}, datetime(2022, 3, 21)), ], ) + + +@pytest.fixture +def _empty_cursor(mock_cursor): + return lambda: mock_cursor( + columns=["string", "number", "array", "object", "date"], + rows=[], + ) diff --git a/tests/snowpark/__snapshots__/test_procedure_coverage.ambr b/tests/snowpark/__snapshots__/test_procedure_coverage.ambr index befb289ee..e875a6c9f 100644 --- a/tests/snowpark/__snapshots__/test_procedure_coverage.ambr +++ b/tests/snowpark/__snapshots__/test_procedure_coverage.ambr @@ -1,28 +1,10 @@ # serializer version: 1 # name: test_procedure_coverage_report_create_html_report - ''' - +--------------------------------------------------------------------------+ - | result | - |--------------------------------------------------------------------------| - | Your HTML code coverage report is now available in 'htmlcov/index.html'. | - +--------------------------------------------------------------------------+ - ''' + "Your HTML code coverage report is now available in 'htmlcov/index.html'." # --- # name: test_procedure_coverage_report_create_json_report - ''' - +---------------------------------------------------------------------+ - | result | - |---------------------------------------------------------------------| - | Your JSON code coverage report is now available in 'coverage.json'. | - +---------------------------------------------------------------------+ - ''' + "Your JSON code coverage report is now available in 'coverage.json'." # --- # name: test_procedure_coverage_report_create_lcov_report - ''' - +---------------------------------------------------------------------+ - | result | - |---------------------------------------------------------------------| - | Your lcov code coverage report is now available in 'coverage.lcov'. | - +---------------------------------------------------------------------+ - ''' + "Your lcov code coverage report is now available in 'coverage.lcov'." # --- diff --git a/tests/test_main.py b/tests/test_main.py index 8ee9fe493..124d19ef6 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,10 +2,6 @@ from __future__ import annotations import json -import os -from pathlib import Path -from textwrap import dedent -from unittest import mock import typing as t import click diff --git a/tests/test_snow_connector.py b/tests/test_snow_connector.py index c9ec41f80..35ca4ac77 100644 --- a/tests/test_snow_connector.py +++ b/tests/test_snow_connector.py @@ -53,7 +53,7 @@ def test_registry_get_token(mock_conn, runner): } result = runner.invoke(["snowpark", "registry", "token", "--format", "JSON"]) assert result.exit_code == 0, result.output - assert json.loads(result.stdout) == [{"token": "token1234", "expires_in": 42}] + assert json.loads(result.stdout) == {"token": "token1234", "expires_in": 42} @mock.patch.dict(os.environ, {}, clear=True) diff --git a/tests/test_warehouse.py b/tests/test_warehouse.py index bd26ec822..1365b886d 100644 --- a/tests/test_warehouse.py +++ b/tests/test_warehouse.py @@ -1,5 +1,3 @@ -from unittest import mock - from tests.testing_utils.fixtures import * diff --git a/tests/testing_utils/fixtures.py b/tests/testing_utils/fixtures.py index dd3adef9c..a11392858 100644 --- a/tests/testing_utils/fixtures.py +++ b/tests/testing_utils/fixtures.py @@ -115,6 +115,7 @@ def __init__(self, rows: List[tuple], columns: List[str]): super().__init__(mock.Mock()) self._rows = rows self._columns = [MockResultMetadata(c) for c in columns] + self.query = "SELECT A MOCK QUERY" def fetchone(self): if self._rows: diff --git a/tests_integration/conftest.py b/tests_integration/conftest.py index dd62236d4..82943661d 100644 --- a/tests_integration/conftest.py +++ b/tests_integration/conftest.py @@ -17,7 +17,7 @@ @dataclass class CommandResult: exit_code: int - json: Optional[List[Dict[str, Any]]] = None + json: Optional[List[Dict[str, Any]] | Dict[str, Any]] = None output: Optional[str] = None diff --git a/tests_integration/test_package.py b/tests_integration/test_package.py index 58c8284f8..d0321a6de 100644 --- a/tests_integration/test_package.py +++ b/tests_integration/test_package.py @@ -64,11 +64,9 @@ def test_package_create_with_non_anaconda_package_without_install( result = runner.invoke_integration(["snowpark", "package", "create", "PyRTF3"]) assert_that_result_is_successful(result) - assert result.json == [ - { - "result": "Lookup for package PyRTF3 resulted in some error. Please check the package name or try again with -y option" - } - ] + assert result.json == { + "message": "Lookup for package PyRTF3 resulted in some error. Please check the package name or try again with -y option" + } assert not os.path.exists("PyRTF3.zip") @pytest.fixture diff --git a/tests_integration/testing_utils/assertions/test_result_assertions.py b/tests_integration/testing_utils/assertions/test_result_assertions.py index 006f59657..e89b19a1c 100644 --- a/tests_integration/testing_utils/assertions/test_result_assertions.py +++ b/tests_integration/testing_utils/assertions/test_result_assertions.py @@ -27,13 +27,12 @@ def assert_that_result_is_successful_and_output_json_equals( def assert_that_result_contains_row_with(result: CommandResult, expect: Dict) -> None: assert result.json is not None - assert contains_row_with(result.json, expect) + assert contains_row_with(result.json, expect) # type: ignore def assert_that_result_is_successful_and_done_is_on_output( result: CommandResult, ) -> None: assert_that_result_is_successful(result) - assert result.output is not None and json.loads(result.output) == [ - {"result": "Done"} - ] + assert result.output is not None + assert json.loads(result.output) == {"message": "Done"} diff --git a/tests_integration/testing_utils/snowpark_utils.py b/tests_integration/testing_utils/snowpark_utils.py index 682fbdd14..b1989d3ac 100644 --- a/tests_integration/testing_utils/snowpark_utils.py +++ b/tests_integration/testing_utils/snowpark_utils.py @@ -201,8 +201,6 @@ def snowpark_describe_should_return_entity_description( result, {"property": "signature", "value": arguments} ) assert result.json is not None - imports = [i for i in result.json if i["property"] == "imports"][0] - assert imports["value"].__contains__(entity_name) def snowpark_init_should_initialize_files_with_default_content( self, @@ -302,7 +300,7 @@ def snowpark_update_should_finish_successfully( ) assert_that_result_is_successful_and_output_json_equals( result, - [{"status": f"Function {entity_name.upper()} successfully created."}], + {"status": f"Function {entity_name.upper()} successfully created."}, ) def snowpark_drop_should_finish_successfully( @@ -388,7 +386,8 @@ def coverage_clear_should_execute_succesfully(self, procedure_name, arguments): ) assert result.exit_code == 0 - assert result.json[0]["result"] == "removed" + print(result.json) + assert result.json["result"] == "removed" def assert_that_no_entities_are_in_snowflake(self) -> None: self.assert_that_only_these_entities_are_in_snowflake()