Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/integrate connections toml #298

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a3d18eb
feat: add connection file path and connection name cli arguments
Oct 18, 2024
be0e37a
feat: add connections.toml options to DeployConfig and parse_cli_args.py
Oct 19, 2024
41a4995
feat: fetch snowflake_password in config object
Oct 19, 2024
660bb0a
feat: prioritize SNOWFLAKE_PRIVATE_KEY_PATH env var in DeployConfig
Oct 19, 2024
aadf7a1
feat: introduce env var and connections.toml layering into get_merged…
Oct 23, 2024
f2635a2
feat: remove get_session_from_config, opt for config.get_session_kwargs
Oct 24, 2024
7dda5de
fix: fix verbose handling in get_yaml_config_kwargs
Oct 24, 2024
fdb7554
fix: return account as session attribute
Oct 24, 2024
8046f5a
fix: account for token path newline
Oct 24, 2024
cec87a5
fix: connections.toml requires token_file_path instead of token-file-…
Oct 25, 2024
8b52000
Merge branch 'master' into feat/integrate-connections-toml
Oct 31, 2024
f9e9917
feat: swap tomllib with tomlkit
Oct 31, 2024
95fbe72
feat: avoid DeployConfig argument repetition
Oct 31, 2024
c724850
fix: clear environment variables to allow GitHub Actions tests to pass
Oct 31, 2024
addf642
fix: remove unused default_config_file_name class attribute
Oct 31, 2024
10b867c
fix: make ownership grant idempotent
Oct 31, 2024
6ee659f
feat: make tests runnable on Snowflake standard edition
Oct 31, 2024
6eaeaa9
feat: respect SNOWFLAKE_DEFAULT_CONNECTION_NAME environment variable
Oct 31, 2024
6e53452
feat: log snowflake.connector.connect arguments
Oct 31, 2024
713ae2d
feat: log get_merged_config steps
Oct 31, 2024
9fec758
fix: fix get_merged_config tests
Oct 31, 2024
19319b6
fix: fix get_merged_config tests
Oct 31, 2024
e82b3b9
fix: fix missing database reference
Oct 31, 2024
d5bf05a
docs: reference new cli arguments
Oct 31, 2024
d218e52
docs: reference new cli arguments
Oct 31, 2024
81aaeae
docs: adding to the change log
Oct 31, 2024
f659acb
feat: prioritize cli arguments over environment variables
Nov 8, 2024
db86db4
feat: prioritize cli arguments over environment variables
Nov 8, 2024
1beb3a5
feat: remove oauth-config support
Nov 8, 2024
eb8882b
feat: deprecate cli connection arguments
Nov 14, 2024
f29925a
feat: drive SnowflakeSession attributes from connection instead of ar…
Nov 14, 2024
f35f1a6
feat: defer to snowflake python connector for connection arguments
Nov 14, 2024
fb2f2a4
feat: restructure integration test
Nov 15, 2024
c07f97c
fix: fix integration tests
Nov 15, 2024
53efca1
docs: clarify integration test seup
Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions schemachange/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from schemachange.config.get_merged_config import get_merged_config
from schemachange.deploy import deploy
from schemachange.redact_config_secrets import redact_config_secrets
from schemachange.session.SnowflakeSession import get_session_from_config
from schemachange.session.SnowflakeSession import SnowflakeSession

# region Global Variables
# metadata
Expand Down Expand Up @@ -63,11 +63,11 @@ def main():
)
else:
config.check_for_deploy_args()
session = get_session_from_config(
config=config,
session = SnowflakeSession(
schemachange_version=SCHEMACHANGE_VERSION,
snowflake_application_name=SNOWFLAKE_APPLICATION_NAME,
application=SNOWFLAKE_APPLICATION_NAME,
logger=logger,
**config.get_session_kwargs(),
)
deploy(config=config, session=session)

Expand Down
112 changes: 94 additions & 18 deletions schemachange/config/DeployConfig.py
zanebclark marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from schemachange.config.BaseConfig import BaseConfig
from schemachange.config.ChangeHistoryTable import ChangeHistoryTable
from schemachange.config.utils import get_snowflake_identifier_string
from schemachange.config.utils import (
get_snowflake_identifier_string,
validate_file_path,
get_oauth_token,
)


@dataclasses.dataclass(frozen=True, kw_only=True)
Expand All @@ -18,6 +22,12 @@ class DeployConfig(BaseConfig):
snowflake_warehouse: str | None = None
snowflake_database: str | None = None
snowflake_schema: str | None = None
snowflake_authenticator: str | None = "snowflake"
snowflake_password: str | None = None
snowflake_oauth_token: str | None = None
snowflake_private_key_path: Path | None = None
connections_file_path: Path | None = None
connection_name: str | None = None
# TODO: Turn change_history_table into three arguments. There's no need to parse it from a string
change_history_table: ChangeHistoryTable | None = dataclasses.field(
default_factory=ChangeHistoryTable
Expand All @@ -26,41 +36,62 @@ class DeployConfig(BaseConfig):
autocommit: bool = False
dry_run: bool = False
query_tag: str | None = None
oauth_config: dict | None = None

@classmethod
def factory(
cls,
config_file_path: Path,
snowflake_role: str | None = None,
snowflake_warehouse: str | None = None,
snowflake_database: str | None = None,
snowflake_schema: str | None = None,
change_history_table: str | None = None,
**kwargs,
):
if "subcommand" in kwargs:
kwargs.pop("subcommand")

for sf_input in [
"snowflake_role",
"snowflake_warehouse",
"snowflake_database",
"snowflake_schema",
]:
if sf_input in kwargs and kwargs[sf_input] is not None:
kwargs[sf_input] = get_snowflake_identifier_string(
kwargs[sf_input], sf_input
)

for sf_path_input in [
"snowflake_private_key_path",
"snowflake_token_path",
]:
if sf_path_input in kwargs and kwargs[sf_path_input] is not None:
kwargs[sf_path_input] = validate_file_path(
file_path=kwargs[sf_path_input]
)

# If set by an environment variable, pop snowflake_token_path from kwargs
if "snowflake_oauth_token" in kwargs:
kwargs.pop("snowflake_token_path", None)
kwargs.pop("oauth_config", None)
# Load it from a file, if provided
elif "snowflake_token_path" in kwargs:
kwargs.pop("oauth_config", None)
oauth_token_path = kwargs.pop("snowflake_token_path")
with open(oauth_token_path) as f:
kwargs["snowflake_oauth_token"] = f.read()
# Make the oauth call if authenticator == "oauth"

elif "oauth_config" in kwargs:
oauth_config = kwargs.pop("oauth_config")
authenticator = kwargs.get("snowflake_authenticator")
if authenticator is not None and authenticator.lower() == "oauth":
kwargs["snowflake_oauth_token"] = get_oauth_token(oauth_config)

change_history_table = ChangeHistoryTable.from_str(
table_str=change_history_table
)

return super().factory(
subcommand="deploy",
config_file_path=config_file_path,
snowflake_role=get_snowflake_identifier_string(
snowflake_role, "snowflake_role"
),
snowflake_warehouse=get_snowflake_identifier_string(
snowflake_warehouse, "snowflake_warehouse"
),
snowflake_database=get_snowflake_identifier_string(
snowflake_database, "snowflake_database"
),
snowflake_schema=get_snowflake_identifier_string(
snowflake_schema, "snowflake_schema"
),
change_history_table=change_history_table,
**kwargs,
)
Expand All @@ -74,6 +105,31 @@ def check_for_deploy_args(self) -> None:
"snowflake_role": self.snowflake_role,
"snowflake_warehouse": self.snowflake_warehouse,
}

# OAuth based authentication
if self.snowflake_authenticator.lower() == "oauth":
req_args["snowflake_oauth_token"] = self.snowflake_oauth_token

# External Browser based SSO
elif self.snowflake_authenticator.lower() == "externalbrowser":
pass

# IDP based Authentication, limited to Okta
elif self.snowflake_authenticator.lower()[:8] == "https://":
req_args["snowflake_password"] = self.snowflake_password

elif self.snowflake_authenticator.lower() == "snowflake_jwt":
req_args["snowflake_private_key_path"] = self.snowflake_private_key_path

elif self.snowflake_authenticator.lower() == "snowflake":
req_args["snowflake_password"] = self.snowflake_password

else:
raise ValueError(
f"{self.snowflake_authenticator} is not supported authenticator option. "
"Choose from snowflake, snowflake_jwt, externalbrowser, oauth, https://<subdomain>.okta.com."
)

missing_args = [key for key, value in req_args.items() if value is None]

if len(missing_args) == 0:
Expand All @@ -83,3 +139,23 @@ def check_for_deploy_args(self) -> None:
raise ValueError(
f"Missing config values. The following config values are required: {missing_args}"
)

def get_session_kwargs(self) -> dict:
session_kwargs = {
"account": self.snowflake_account,
"user": self.snowflake_user,
"role": self.snowflake_role,
"warehouse": self.snowflake_warehouse,
"database": self.snowflake_database,
"schema": self.snowflake_schema,
"authenticator": self.snowflake_authenticator,
"password": self.snowflake_password,
"oauth_token": self.snowflake_oauth_token,
"private_key_path": self.snowflake_private_key_path,
"connections_file_path": self.connections_file_path,
"connection_name": self.connection_name,
"change_history_table": self.change_history_table,
"autocommit": self.autocommit,
"query_tag": self.query_tag,
}
return {k: v for k, v in session_kwargs.items() if v is not None}
51 changes: 40 additions & 11 deletions schemachange/config/get_merged_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from schemachange.config.DeployConfig import DeployConfig
from schemachange.config.RenderConfig import RenderConfig
from schemachange.config.parse_cli_args import parse_cli_args
from schemachange.config.utils import load_yaml_config, validate_directory
from schemachange.config.utils import (
load_yaml_config,
validate_directory,
get_env_kwargs,
get_connection_kwargs,
validate_file_path,
)


def get_yaml_config_kwargs(config_file_path: Optional[Path]) -> dict:
Expand All @@ -20,28 +26,30 @@ def get_yaml_config_kwargs(config_file_path: Optional[Path]) -> dict:
}

if "verbose" in kwargs:
kwargs["log_level"] = logging.DEBUG
if kwargs["verbose"]:
kwargs["log_level"] = logging.DEBUG
kwargs.pop("verbose")

if "vars" in kwargs:
kwargs["config_vars"] = kwargs.pop("vars")

return kwargs
return {k: v for k, v in kwargs.items() if v is not None}


def get_merged_config() -> Union[DeployConfig, RenderConfig]:
cli_kwargs = parse_cli_args(sys.argv[1:])
env_kwargs: dict[str, str] = get_env_kwargs()

if "verbose" in cli_kwargs and cli_kwargs["verbose"]:
cli_kwargs["log_level"] = logging.DEBUG
cli_kwargs.pop("verbose")
cli_kwargs = parse_cli_args(sys.argv[1:])

cli_config_vars = cli_kwargs.pop("config_vars", None)
if cli_config_vars is None:
cli_config_vars = {}
cli_config_vars = cli_kwargs.pop("config_vars")

connections_file_path = validate_file_path(
file_path=cli_kwargs.pop("connections_file_path", None)
)
connection_name = cli_kwargs.pop("connection_name", None)
config_folder = validate_directory(path=cli_kwargs.pop("config_folder", "."))
config_file_path = Path(config_folder) / "schemachange-config.yml"
config_file_name = cli_kwargs.pop("config_file_name")
config_file_path = Path(config_folder) / config_file_name

yaml_kwargs = get_yaml_config_kwargs(
config_file_path=config_file_path,
Expand All @@ -50,6 +58,21 @@ def get_merged_config() -> Union[DeployConfig, RenderConfig]:
if yaml_config_vars is None:
yaml_config_vars = {}

if connections_file_path is None:
connections_file_path = yaml_kwargs.pop("connections_file_path", None)
if config_folder is not None and connections_file_path is not None:
connections_file_path = config_folder / connections_file_path

connections_file_path = validate_file_path(file_path=connections_file_path)

if connection_name is None:
connection_name = yaml_kwargs.pop("connection_name", None)

connection_kwargs: dict[str, str] = get_connection_kwargs(
connections_file_path=connections_file_path,
connection_name=connection_name,
)

config_vars = {
**yaml_config_vars,
**cli_config_vars,
Expand All @@ -59,9 +82,15 @@ def get_merged_config() -> Union[DeployConfig, RenderConfig]:
kwargs = {
"config_file_path": config_file_path,
"config_vars": config_vars,
**{k: v for k, v in connection_kwargs.items() if v is not None},
**{k: v for k, v in yaml_kwargs.items() if v is not None},
**{k: v for k, v in cli_kwargs.items() if v is not None},
**{k: v for k, v in env_kwargs.items() if v is not None},
}
if connections_file_path is not None:
kwargs["connections_file_path"] = connections_file_path
if connection_name is not None:
kwargs["connection_name"] = connection_name

if cli_kwargs["subcommand"] == "deploy":
return DeployConfig.factory(**kwargs)
Expand Down
55 changes: 53 additions & 2 deletions schemachange/config/parse_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import json
import logging
from enum import Enum

import structlog
Expand Down Expand Up @@ -58,6 +59,14 @@ def parse_cli_args(args) -> dict:
"(the default is the current working directory)",
required=False,
)
parent_parser.add_argument(
"--config-file-name",
type=str,
default="schemachange-config.yml",
help="The schemachange config YAML file name. Must be in the directory supplied as the config-folder "
"(Default: schemachange-config.yml)",
required=False,
)
parent_parser.add_argument(
"-f",
"--root-folder",
Expand Down Expand Up @@ -134,6 +143,39 @@ def parse_cli_args(args) -> dict:
help="The name of the default schema to use. Can be overridden in the change scripts.",
required=False,
)
parser_deploy.add_argument(
"-A",
"--snowflake-authenticator",
type=str,
help="The Snowflake Authenticator to use. One of snowflake, oauth, externalbrowser, or https://<okta_account_name>.okta.com",
required=False,
)
parser_deploy.add_argument(
"-k",
"--snowflake-private-key-path",
type=str,
help="Path to file containing private key.",
required=False,
)
parser_deploy.add_argument(
"-t",
"--snowflake-token-path",
type=str,
help="Path to the file containing the OAuth token to be used when authenticating with Snowflake.",
required=False,
)
parser_deploy.add_argument(
"--connections-file-path",
type=str,
help="Override the default connections file path at snowflake.connector.constants.CONNECTIONS_FILE (OS specific)",
required=False,
)
parser_deploy.add_argument(
"--connection-name",
type=str,
help="Override the default connection name. Other connection-related values will override these connection values.",
required=False,
)
parser_deploy.add_argument(
"-c",
"--change-history-table",
Expand Down Expand Up @@ -204,7 +246,16 @@ def parse_cli_args(args) -> dict:
if "log_level" in parsed_kwargs and isinstance(parsed_kwargs["log_level"], Enum):
parsed_kwargs["log_level"] = parsed_kwargs["log_level"].value

parsed_kwargs["config_vars"] = {}
if "vars" in parsed_kwargs:
parsed_kwargs["config_vars"] = parsed_kwargs.pop("vars")
config_vars = parsed_kwargs.pop("vars")
if config_vars is not None:
parsed_kwargs["config_vars"] = config_vars

if "verbose" in parsed_kwargs:
parsed_kwargs["log_level"] = (
logging.DEBUG if parsed_kwargs["verbose"] else logging.INFO
)
parsed_kwargs.pop("verbose")

return parsed_kwargs
return {k: v for k, v in parsed_kwargs.items() if v is not None}
Loading