Skip to content

Commit

Permalink
Added class to hide variables in tracebacks (#1599)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-astus committed Sep 20, 2024
1 parent 394c493 commit dd325d9
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 32 deletions.
9 changes: 9 additions & 0 deletions src/snowflake/cli/_app/secret.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class SecretType:
def __init__(self, value):
self.value = value

def __repr__(self):
return "SecretType(***)"

def __str___(self):
return "***"
66 changes: 39 additions & 27 deletions src/snowflake/cli/_app/snow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from snowflake.cli._app.constants import (
PARAM_APPLICATION_NAME,
)
from snowflake.cli._app.secret import SecretType
from snowflake.cli._app.telemetry import command_info
from snowflake.cli.api.config import (
get_connection_dict,
Expand Down Expand Up @@ -205,7 +206,7 @@ def _load_private_key(connection_parameters: Dict, private_key_var_name: str) ->
connection_parameters[private_key_var_name]
)
private_key = _load_pem_to_der(private_key_pem)
connection_parameters["private_key"] = private_key
connection_parameters["private_key"] = private_key.value
del connection_parameters[private_key_var_name]
else:
raise ClickException(
Expand All @@ -217,10 +218,11 @@ def _load_private_key_from_parameters(
connection_parameters: Dict, private_key_var_name: str
) -> None:
if connection_parameters.get("authenticator") == "SNOWFLAKE_JWT":
private_key_pem = connection_parameters[private_key_var_name]
private_key_pem = private_key_pem.encode("utf-8")
private_key_pem = _load_pem_from_parameters(
connection_parameters[private_key_var_name]
)
private_key = _load_pem_to_der(private_key_pem)
connection_parameters["private_key"] = private_key
connection_parameters["private_key"] = private_key.value
del connection_parameters[private_key_var_name]
else:
raise ClickException(
Expand All @@ -236,43 +238,49 @@ def _update_connection_application_name(connection_parameters: Dict):
connection_parameters.update(connection_application_params)


def _load_pem_from_file(private_key_file: str) -> bytes:
def _load_pem_from_file(private_key_file: str) -> SecretType:
with SecurePath(private_key_file).open(
"rb", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB
) as f:
private_key_pem = f.read()
private_key_pem = SecretType(f.read())
return private_key_pem


def _load_pem_to_der(private_key_pem: bytes) -> bytes:
def _load_pem_from_parameters(private_key_raw: str) -> SecretType:
return SecretType(private_key_raw.encode("utf-8"))


def _load_pem_to_der(private_key_pem: SecretType) -> SecretType:
"""
Given a private key file path (in PEM format), decode key data into DER
format
"""
private_key_passphrase = os.getenv("PRIVATE_KEY_PASSPHRASE", None)
private_key_passphrase = SecretType(os.getenv("PRIVATE_KEY_PASSPHRASE", None))
if (
private_key_pem.startswith(ENCRYPTED_PKCS8_PK_HEADER)
and private_key_passphrase is None
private_key_pem.value.startswith(ENCRYPTED_PKCS8_PK_HEADER)
and private_key_passphrase.value is None
):
raise ClickException(
"Encrypted private key, you must provide the"
"passphrase in the environment variable PRIVATE_KEY_PASSPHRASE"
)

if not private_key_pem.startswith(
if not private_key_pem.value.startswith(
ENCRYPTED_PKCS8_PK_HEADER
) and not private_key_pem.startswith(UNENCRYPTED_PKCS8_PK_HEADER):
) and not private_key_pem.value.startswith(UNENCRYPTED_PKCS8_PK_HEADER):
raise ClickException(
"Private key provided is not in PKCS#8 format. Please use correct format."
)

if private_key_pem.startswith(UNENCRYPTED_PKCS8_PK_HEADER):
private_key_passphrase = None
if private_key_pem.value.startswith(UNENCRYPTED_PKCS8_PK_HEADER):
private_key_passphrase = SecretType(None)

return prepare_private_key(private_key_pem, private_key_passphrase)


def prepare_private_key(private_key_pem, private_key_passphrase=None):
def prepare_private_key(
private_key_pem: SecretType, private_key_passphrase: SecretType = SecretType(None)
):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import (
Encoding,
Expand All @@ -281,17 +289,21 @@ def prepare_private_key(private_key_pem, private_key_passphrase=None):
load_pem_private_key,
)

private_key = load_pem_private_key(
private_key_pem,
(
str.encode(private_key_passphrase)
if private_key_passphrase is not None
else private_key_passphrase
),
default_backend(),
private_key = SecretType(
load_pem_private_key(
private_key_pem.value,
(
str.encode(private_key_passphrase.value)
if private_key_passphrase.value is not None
else private_key_passphrase.value
),
default_backend(),
)
)
return private_key.private_bytes(
encoding=Encoding.DER,
format=PrivateFormat.PKCS8,
encryption_algorithm=NoEncryption(),
return SecretType(
private_key.value.private_bytes(
encoding=Encoding.DER,
format=PrivateFormat.PKCS8,
encryption_algorithm=NoEncryption(),
)
)
3 changes: 2 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import pytest
import tomlkit
from snowflake.cli._app.secret import SecretType
from snowflake.cli.api.constants import ObjectType

from tests_common import IS_WINDOWS
Expand Down Expand Up @@ -705,7 +706,7 @@ def test_key_pair_authentication_from_config(
):
ctx = mock_ctx()
mock_connector.return_value = ctx
mock_convert.return_value = "secret value"
mock_convert.return_value = SecretType("secret value")

with NamedTemporaryFile("w+", suffix="toml") as tmp_file:
tmp_file.write(
Expand Down
10 changes: 6 additions & 4 deletions tests/test_snow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest import mock

import pytest
from snowflake.cli._app.secret import SecretType


# Used as a solution to syrupy having some problems with comparing multilines string
Expand Down Expand Up @@ -118,9 +119,10 @@ def test_private_key_loading_and_aliases(
else:
overrides[user_input] = override_value

key = SecretType(b"bytes")
mock_command_info.return_value = "SNOWCLI.SQL"
mock_load_pem_from_file.return_value = b"bytes"
mock_load_pem_to_der.return_value = b"bytes"
mock_load_pem_from_file.return_value = key
mock_load_pem_to_der.return_value = key

conn_dict = get_connection_dict(connection_name)
default_value = conn_dict.get("private_key_file", None) or conn_dict.get(
Expand All @@ -135,7 +137,7 @@ def test_private_key_loading_and_aliases(
expected_private_key_args = (
{}
if expected_private_key_file_value is None
else dict(private_key=mock_load_pem_to_der.return_value)
else dict(private_key=b"bytes")
)
mock_connect.assert_called_once_with(
application=mock_command_info.return_value,
Expand All @@ -145,7 +147,7 @@ def test_private_key_loading_and_aliases(
)
if expected_private_key_file_value is not None:
mock_load_pem_from_file.assert_called_with(expected_private_key_file_value)
mock_load_pem_to_der.assert_called_with(b"bytes")
mock_load_pem_to_der.assert_called_with(key)


@mock.patch.dict(os.environ, {}, clear=True)
Expand Down
1 change: 1 addition & 0 deletions tests_integration/snowflake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytest
from snowflake import connector
from snowflake.cli._app.secret import SecretType
from snowflake.cli.api.exceptions import EnvironmentVariableNotFoundError
from snowflake.cli._app.snow_connector import update_connection_details_with_private_key

Expand Down

0 comments on commit dd325d9

Please sign in to comment.