diff --git a/src/snowflake/cli/_app/secret.py b/src/snowflake/cli/_app/secret.py new file mode 100644 index 0000000000..d833988aee --- /dev/null +++ b/src/snowflake/cli/_app/secret.py @@ -0,0 +1,9 @@ +class SecretType: + def __init__(self, value): + self.value = value + + def __repr__(self): + return "SecretType(***)" + + def __str___(self): + return "***" diff --git a/src/snowflake/cli/_app/snow_connector.py b/src/snowflake/cli/_app/snow_connector.py index d0f3a729c4..ccd48c75c4 100644 --- a/src/snowflake/cli/_app/snow_connector.py +++ b/src/snowflake/cli/_app/snow_connector.py @@ -24,6 +24,7 @@ from snowflake.cli._app.constants import ( PARAM_APPLICATION_NAME, ) +from snowflake.cli._app.secret import SecretType from snowflake.cli._app.telemetry import command_info from snowflake.cli.api.config import ( get_connection_dict, @@ -205,7 +206,7 @@ def _load_private_key(connection_parameters: Dict, private_key_var_name: str) -> connection_parameters[private_key_var_name] ) private_key = _load_pem_to_der(private_key_pem) - connection_parameters["private_key"] = private_key + connection_parameters["private_key"] = private_key.value del connection_parameters[private_key_var_name] else: raise ClickException( @@ -217,10 +218,11 @@ def _load_private_key_from_parameters( connection_parameters: Dict, private_key_var_name: str ) -> None: if connection_parameters.get("authenticator") == "SNOWFLAKE_JWT": - private_key_pem = connection_parameters[private_key_var_name] - private_key_pem = private_key_pem.encode("utf-8") + private_key_pem = _load_pem_from_parameters( + connection_parameters[private_key_var_name] + ) private_key = _load_pem_to_der(private_key_pem) - connection_parameters["private_key"] = private_key + connection_parameters["private_key"] = private_key.value del connection_parameters[private_key_var_name] else: raise ClickException( @@ -236,43 +238,49 @@ def _update_connection_application_name(connection_parameters: Dict): connection_parameters.update(connection_application_params) -def _load_pem_from_file(private_key_file: str) -> bytes: +def _load_pem_from_file(private_key_file: str) -> SecretType: with SecurePath(private_key_file).open( "rb", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB ) as f: - private_key_pem = f.read() + private_key_pem = SecretType(f.read()) return private_key_pem -def _load_pem_to_der(private_key_pem: bytes) -> bytes: +def _load_pem_from_parameters(private_key_raw: str) -> SecretType: + return SecretType(private_key_raw.encode("utf-8")) + + +def _load_pem_to_der(private_key_pem: SecretType) -> SecretType: """ Given a private key file path (in PEM format), decode key data into DER format """ - private_key_passphrase = os.getenv("PRIVATE_KEY_PASSPHRASE", None) + private_key_passphrase = SecretType(os.getenv("PRIVATE_KEY_PASSPHRASE", None)) if ( - private_key_pem.startswith(ENCRYPTED_PKCS8_PK_HEADER) - and private_key_passphrase is None + private_key_pem.value.startswith(ENCRYPTED_PKCS8_PK_HEADER) + and private_key_passphrase.value is None ): raise ClickException( "Encrypted private key, you must provide the" "passphrase in the environment variable PRIVATE_KEY_PASSPHRASE" ) - if not private_key_pem.startswith( + if not private_key_pem.value.startswith( ENCRYPTED_PKCS8_PK_HEADER - ) and not private_key_pem.startswith(UNENCRYPTED_PKCS8_PK_HEADER): + ) and not private_key_pem.value.startswith(UNENCRYPTED_PKCS8_PK_HEADER): raise ClickException( "Private key provided is not in PKCS#8 format. Please use correct format." ) - if private_key_pem.startswith(UNENCRYPTED_PKCS8_PK_HEADER): - private_key_passphrase = None + if private_key_pem.value.startswith(UNENCRYPTED_PKCS8_PK_HEADER): + private_key_passphrase = SecretType(None) return prepare_private_key(private_key_pem, private_key_passphrase) -def prepare_private_key(private_key_pem, private_key_passphrase=None): +def prepare_private_key( + private_key_pem: SecretType, private_key_passphrase: SecretType = SecretType(None) +): from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import ( Encoding, @@ -281,17 +289,21 @@ def prepare_private_key(private_key_pem, private_key_passphrase=None): load_pem_private_key, ) - private_key = load_pem_private_key( - private_key_pem, - ( - str.encode(private_key_passphrase) - if private_key_passphrase is not None - else private_key_passphrase - ), - default_backend(), + private_key = SecretType( + load_pem_private_key( + private_key_pem.value, + ( + str.encode(private_key_passphrase.value) + if private_key_passphrase.value is not None + else private_key_passphrase.value + ), + default_backend(), + ) ) - return private_key.private_bytes( - encoding=Encoding.DER, - format=PrivateFormat.PKCS8, - encryption_algorithm=NoEncryption(), + return SecretType( + private_key.value.private_bytes( + encoding=Encoding.DER, + format=PrivateFormat.PKCS8, + encryption_algorithm=NoEncryption(), + ) ) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6e7bb52a29..f3a9a966fa 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -21,6 +21,7 @@ import pytest import tomlkit +from snowflake.cli._app.secret import SecretType from snowflake.cli.api.constants import ObjectType from tests_common import IS_WINDOWS @@ -705,7 +706,7 @@ def test_key_pair_authentication_from_config( ): ctx = mock_ctx() mock_connector.return_value = ctx - mock_convert.return_value = "secret value" + mock_convert.return_value = SecretType("secret value") with NamedTemporaryFile("w+", suffix="toml") as tmp_file: tmp_file.write( diff --git a/tests/test_snow_connector.py b/tests/test_snow_connector.py index 81b755e64d..302bf684bd 100644 --- a/tests/test_snow_connector.py +++ b/tests/test_snow_connector.py @@ -16,6 +16,7 @@ from unittest import mock import pytest +from snowflake.cli._app.secret import SecretType # Used as a solution to syrupy having some problems with comparing multilines string @@ -118,9 +119,10 @@ def test_private_key_loading_and_aliases( else: overrides[user_input] = override_value + key = SecretType(b"bytes") mock_command_info.return_value = "SNOWCLI.SQL" - mock_load_pem_from_file.return_value = b"bytes" - mock_load_pem_to_der.return_value = b"bytes" + mock_load_pem_from_file.return_value = key + mock_load_pem_to_der.return_value = key conn_dict = get_connection_dict(connection_name) default_value = conn_dict.get("private_key_file", None) or conn_dict.get( @@ -135,7 +137,7 @@ def test_private_key_loading_and_aliases( expected_private_key_args = ( {} if expected_private_key_file_value is None - else dict(private_key=mock_load_pem_to_der.return_value) + else dict(private_key=b"bytes") ) mock_connect.assert_called_once_with( application=mock_command_info.return_value, @@ -145,7 +147,7 @@ def test_private_key_loading_and_aliases( ) if expected_private_key_file_value is not None: mock_load_pem_from_file.assert_called_with(expected_private_key_file_value) - mock_load_pem_to_der.assert_called_with(b"bytes") + mock_load_pem_to_der.assert_called_with(key) @mock.patch.dict(os.environ, {}, clear=True) diff --git a/tests_integration/snowflake_connector.py b/tests_integration/snowflake_connector.py index 8dc7bcbcf5..ee1a6d6f0a 100644 --- a/tests_integration/snowflake_connector.py +++ b/tests_integration/snowflake_connector.py @@ -22,6 +22,7 @@ import pytest from snowflake import connector +from snowflake.cli._app.secret import SecretType from snowflake.cli.api.exceptions import EnvironmentVariableNotFoundError from snowflake.cli._app.snow_connector import update_connection_details_with_private_key