Skip to content

Commit

Permalink
Added class to hide variables in tracebacks
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-astus committed Sep 19, 2024
1 parent c17fde7 commit 6fed55b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 25 deletions.
9 changes: 9 additions & 0 deletions src/snowflake/cli/_app/secret.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class SecretType:
def __init__(self, value):
self.value = value

def __repr__(self):
return "SecretType(***)"

def __str___(self):
return "***"
67 changes: 42 additions & 25 deletions src/snowflake/cli/_app/snow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from snowflake.cli._app.constants import (
PARAM_APPLICATION_NAME,
)
from snowflake.cli._app.secret import SecretType
from snowflake.cli._app.telemetry import command_info
from snowflake.cli.api.config import (
get_connection_dict,
Expand Down Expand Up @@ -155,6 +156,11 @@ def connect_to_snowflake(
# for cases when external browser and json format are used.
# Redirecting both stdout and stderr for offline usage.
with contextlib.redirect_stdout(None), contextlib.redirect_stderr(None):
# Unpack SecretType values
connection_parameters = {
k: v.value if isinstance(v, SecretType) else v
for k, v in connection_parameters.items()
}
return snowflake.connector.connect(
application=command_info(),
**connection_parameters,
Expand Down Expand Up @@ -217,8 +223,9 @@ def _load_private_key_from_parameters(
connection_parameters: Dict, private_key_var_name: str
) -> None:
if connection_parameters.get("authenticator") == "SNOWFLAKE_JWT":
private_key_pem = connection_parameters[private_key_var_name]
private_key_pem = private_key_pem.encode("utf-8")
private_key_pem = _load_pem_from_parameters(
connection_parameters[private_key_var_name]
)
private_key = _load_pem_to_der(private_key_pem)
connection_parameters["private_key"] = private_key
del connection_parameters[private_key_var_name]
Expand All @@ -236,43 +243,49 @@ def _update_connection_application_name(connection_parameters: Dict):
connection_parameters.update(connection_application_params)


def _load_pem_from_file(private_key_file: str) -> bytes:
def _load_pem_from_file(private_key_file: str) -> SecretType:
with SecurePath(private_key_file).open(
"rb", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB
) as f:
private_key_pem = f.read()
private_key_pem = SecretType(f.read())
return private_key_pem


def _load_pem_to_der(private_key_pem: bytes) -> bytes:
def _load_pem_from_parameters(private_key_raw: str) -> SecretType:
return SecretType(private_key_raw.encode("utf-8"))


def _load_pem_to_der(private_key_pem: SecretType) -> SecretType:
"""
Given a private key file path (in PEM format), decode key data into DER
format
"""
private_key_passphrase = os.getenv("PRIVATE_KEY_PASSPHRASE", None)
private_key_passphrase = SecretType(os.getenv("PRIVATE_KEY_PASSPHRASE", None))
if (
private_key_pem.startswith(ENCRYPTED_PKCS8_PK_HEADER)
and private_key_passphrase is None
private_key_pem.value.startswith(ENCRYPTED_PKCS8_PK_HEADER)
and private_key_passphrase.value is None
):
raise ClickException(
"Encrypted private key, you must provide the"
"passphrase in the environment variable PRIVATE_KEY_PASSPHRASE"
)

if not private_key_pem.startswith(
if not private_key_pem.value.startswith(
ENCRYPTED_PKCS8_PK_HEADER
) and not private_key_pem.startswith(UNENCRYPTED_PKCS8_PK_HEADER):
) and not private_key_pem.value.startswith(UNENCRYPTED_PKCS8_PK_HEADER):
raise ClickException(
"Private key provided is not in PKCS#8 format. Please use correct format."
)

if private_key_pem.startswith(UNENCRYPTED_PKCS8_PK_HEADER):
private_key_passphrase = None
if private_key_pem.value.startswith(UNENCRYPTED_PKCS8_PK_HEADER):
private_key_passphrase = SecretType(None)

return prepare_private_key(private_key_pem, private_key_passphrase)


def prepare_private_key(private_key_pem, private_key_passphrase=None):
def prepare_private_key(
private_key_pem: SecretType, private_key_passphrase: SecretType = SecretType(None)
):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import (
Encoding,
Expand All @@ -281,17 +294,21 @@ def prepare_private_key(private_key_pem, private_key_passphrase=None):
load_pem_private_key,
)

private_key = load_pem_private_key(
private_key_pem,
(
str.encode(private_key_passphrase)
if private_key_passphrase is not None
else private_key_passphrase
),
default_backend(),
private_key = SecretType(
load_pem_private_key(
private_key_pem.value,
(
str.encode(private_key_passphrase.value)
if private_key_passphrase.value is not None
else private_key_passphrase.value
),
default_backend(),
)
)
return private_key.private_bytes(
encoding=Encoding.DER,
format=PrivateFormat.PKCS8,
encryption_algorithm=NoEncryption(),
return SecretType(
private_key.value.private_bytes(
encoding=Encoding.DER,
format=PrivateFormat.PKCS8,
encryption_algorithm=NoEncryption(),
)
)

0 comments on commit 6fed55b

Please sign in to comment.