Skip to content

Commit

Permalink
Clean up output of the setup script generation step (#1124)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bdufour authored May 27, 2024
1 parent 73930ee commit b81c6cd
Show file tree
Hide file tree
Showing 4 changed files with 360 additions and 524 deletions.
32 changes: 21 additions & 11 deletions src/snowflake/cli/plugins/nativeapp/codegen/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Dict, Optional

from snowflake.cli.api.console import cli_console as cc
from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp
from snowflake.cli.api.project.schemas.native_app.path_mapping import (
PathMapping,
Expand Down Expand Up @@ -54,19 +55,28 @@ def compile_artifacts(self):
Go through every artifact object in the project definition of a native app, and execute processors in order of specification for each of the artifact object.
May have side-effects on the filesystem by either directly editing source files or the deploy root.
"""
should_proceed = False
for artifact in self.artifacts:
for processor in artifact.processors:
artifact_processor = self._try_create_processor(
processor_mapping=processor,
)
if artifact_processor is None:
raise UnsupportedArtifactProcessorError(
processor_name=processor.name
)
else:
artifact_processor.process(
artifact_to_process=artifact, processor_mapping=processor
if artifact.processors:
should_proceed = True
break
if not should_proceed:
return

with cc.phase("Invoking artifact processors"):
for artifact in self.artifacts:
for processor in artifact.processors:
artifact_processor = self._try_create_processor(
processor_mapping=processor,
)
if artifact_processor is None:
raise UnsupportedArtifactProcessorError(
processor_name=processor.name
)
else:
artifact_processor.process(
artifact_to_process=artifact, processor_mapping=processor
)

def _try_create_processor(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import json
import pprint
import re
from pathlib import Path
from textwrap import dedent
Expand Down Expand Up @@ -173,7 +172,7 @@ def process(
artifact_to_process: PathMapping,
processor_mapping: Optional[ProcessorMapping],
**kwargs,
) -> str: # String output is temporary until we have better e2e testing mechanism
) -> None:
"""
Collects code annotations from Snowpark python files containing extension functions and augments the existing
setup script with generated SQL that registers these functions.
Expand Down Expand Up @@ -201,22 +200,16 @@ def process(
continue

relative_py_file = py_file.relative_to(bundle_map.deploy_root())
cc.message(
"-- Generating Snowpark annotation SQL code for {}".format(
relative_py_file
)
)
cc.message(create_stmt)
collected_output.append(f"-- {relative_py_file}")
collected_output.append(create_stmt)

grant_statements = generate_grant_sql_ddl_statements(extension_fn)
if grant_statements is not None:
cc.message(grant_statements)
collected_output.append(grant_statements)

with open(sql_file, "a") as file:
file.write("\n")
file.write(
f"-- Generated by the Snowflake CLI from {relative_py_file}\n"
)
file.write(f"-- DO NOT EDIT\n")
file.write(create_stmt)
if grant_statements is not None:
file.write("\n")
Expand All @@ -231,8 +224,6 @@ def process(
generated_root=self.generated_root,
)

return "\n".join(collected_output)

def _normalize_imports(
self,
extension_fn: NativeAppExtensionFunction,
Expand Down Expand Up @@ -308,6 +299,11 @@ def collect_extension_functions(
for src_file, dest_file in bundle_map.all_mappings(
absolute=True, expand_directories=True, predicate=_is_python_file_artifact
):
cc.step(
"Processing Snowpark annotations from {}".format(
dest_file.relative_to(bundle_map.deploy_root())
)
)
collected_extension_function_json = _execute_in_sandbox(
py_file=str(dest_file.resolve()),
deploy_root=self.deploy_root,
Expand All @@ -333,10 +329,6 @@ def collect_extension_functions(
cc.warning("Invalid extension function definition")

if collected_extension_functions:
cc.message(f"This is the file path in deploy root: {dest_file}\n")
cc.message("This is the list of collected extension functions:")
cc.message(pprint.pformat(collected_extension_functions))

collected_extension_fns_by_path[
dest_file
] = collected_extension_functions
Expand Down Expand Up @@ -479,7 +471,7 @@ def edit_setup_script_with_exec_imm_sql(
sql_file_relative_path = sql_file.relative_to(
deploy_root
) # Path on stage, without the leading slash
file.write(f"\nEXECUTE IMMEDIATE FROM '/{sql_file_relative_path}';")
file.write(f"EXECUTE IMMEDIATE FROM '/{sql_file_relative_path}';\n")

# Find the setup script in the deploy root.
setup_file_path = find_setup_script_file(deploy_root=deploy_root)
Expand All @@ -493,4 +485,5 @@ def edit_setup_script_with_exec_imm_sql(
generated_file_relative_path = generated_file_path.relative_to(deploy_root)
with open(setup_file_path, "w", encoding="utf-8") as file:
file.write(code)
file.write(f"\nEXECUTE IMMEDIATE FROM '/{generated_file_relative_path}';\n")
file.write(f"\nEXECUTE IMMEDIATE FROM '/{generated_file_relative_path}';")
file.write(f"\n")
Loading

0 comments on commit b81c6cd

Please sign in to comment.