Skip to content

Commit

Permalink
Load private key from env variable
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-astus committed Sep 11, 2024
1 parent 7783f39 commit d7613ff
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 30 deletions.
5 changes: 1 addition & 4 deletions src/snowflake/cli/_app/snow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,7 @@ def connect_to_snowflake(
k: v for k, v in connection_parameters.items() if v is not None
}

if "private_key_file" in connection_parameters:
update_connection_details_with_private_key(
connection_parameters, "private_key_file"
)
update_connection_details_with_private_key(connection_parameters)

if mfa_passcode:
connection_parameters["passcode"] = mfa_passcode
Expand Down
7 changes: 4 additions & 3 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,13 +698,14 @@ def test_token_file_path_tokens(mock_connector, mock_ctx, runner, temp_dir):
clear=True,
)
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._app.snow_connector._load_pem_from_file")
@mock.patch("snowflake.cli._app.snow_connector._load_pem_to_der")
def test_key_pair_authentication_from_config(
mock_load, mock_connector, mock_ctx, temp_dir, runner
mock_convert, mock_load_file, mock_connector, mock_ctx, temp_dir, runner
):
ctx = mock_ctx()
mock_connector.return_value = ctx
mock_load.return_value = "secret value"
mock_convert.return_value = "secret value"

with NamedTemporaryFile("w+", suffix="toml") as tmp_file:
tmp_file.write(
Expand All @@ -726,7 +727,7 @@ def test_key_pair_authentication_from_config(
)

assert result.exit_code == 0, result.output
mock_load.assert_called_once_with("~/sf_private_key.p8")
mock_load_file.assert_called_once_with("~/sf_private_key.p8")
mock_connector.assert_called_once_with(
application="SNOWCLI.OBJECT.LIST",
account="my_account",
Expand Down
6 changes: 5 additions & 1 deletion tests/test_snow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def test_command_context_is_passed_to_snowflake_connection(
@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._app.snow_connector.command_info")
@mock.patch("snowflake.cli._app.snow_connector._load_pem_to_der")
@mock.patch("snowflake.cli._app.snow_connector._load_pem_from_file")
def test_private_key_loading_and_aliases(
mock_load_pem_from_file,
mock_load_pem_to_der,
mock_command_info,
mock_connect,
Expand Down Expand Up @@ -117,6 +119,7 @@ def test_private_key_loading_and_aliases(
overrides[user_input] = override_value

mock_command_info.return_value = "SNOWCLI.SQL"
mock_load_pem_from_file.return_value = b"bytes"
mock_load_pem_to_der.return_value = b"bytes"

conn_dict = get_connection_dict(connection_name)
Expand All @@ -141,7 +144,8 @@ def test_private_key_loading_and_aliases(
**expected_private_key_args,
)
if expected_private_key_file_value is not None:
mock_load_pem_to_der.assert_called_with(expected_private_key_file_value)
mock_load_pem_from_file.assert_called_with(expected_private_key_file_value)
mock_load_pem_to_der.assert_called_with(b"bytes")


@mock.patch.dict(os.environ, {}, clear=True)
Expand Down
3 changes: 2 additions & 1 deletion tests_integration/snowflake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def snowflake_session():
"account": _get_from_env("ACCOUNT"),
"user": _get_from_env("USER"),
"private_key_file": _get_private_key_file(),
"private_key_raw": _get_from_env("PRIVATE_KEY_RAW", allow_none=True),
"host": _get_from_env("HOST", allow_none=True),
"warehouse": _get_from_env("WAREHOUSE", allow_none=True),
"role": _get_from_env("ROLE", allow_none=True),
Expand All @@ -121,4 +122,4 @@ def _get_private_key_file() -> Optional[str]:
private_key_file = _get_from_env("PRIVATE_KEY_PATH", allow_none=True)
if private_key_file is not None:
return private_key_file
return _get_from_env("PRIVATE_KEY_FILE")
return _get_from_env("PRIVATE_KEY_FILE", allow_none=True)
47 changes: 26 additions & 21 deletions tests_integration/test_temporary_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile

import pytest
import os
Expand All @@ -28,30 +29,34 @@
"SNOWFLAKE_CONNECTIONS_INTEGRATION_USER": os.environ.get(
"SNOWFLAKE_CONNECTIONS_INTEGRATION_USER", None
),
"SNOWFLAKE_CONNECTIONS_INTEGRATION_PRIVATE_KEY_FILE": os.environ.get(
"SNOWFLAKE_CONNECTIONS_INTEGRATION_PRIVATE_KEY_PATH",
os.environ.get("SNOWFLAKE_CONNECTIONS_INTEGRATION_PRIVATE_KEY_FILE"),
"SNOWFLAKE_CONNECTIONS_INTEGRATION_PRIVATE_KEY_RAW": os.environ.get(
"SNOWFLAKE_CONNECTIONS_INTEGRATION_PRIVATE_KEY_RAW",
),
},
clear=True,
)
def test_temporary_connection(runner, snapshot):

result = runner.invoke(
[
"sql",
"-q",
"select 1",
"--temporary-connection",
"--authenticator",
"SNOWFLAKE_JWT",
"--account",
os.environ["SNOWFLAKE_CONNECTIONS_INTEGRATION_ACCOUNT"],
"--user",
os.environ["SNOWFLAKE_CONNECTIONS_INTEGRATION_USER"],
"--private-key-file",
os.environ["SNOWFLAKE_CONNECTIONS_INTEGRATION_PRIVATE_KEY_FILE"],
]
)
assert result.exit_code == 0
assert result.output == snapshot
with tempfile.TemporaryDirectory() as tmp_dir:
private_key_path = os.path.join(tmp_dir, "private_key.p8")
with open(private_key_path, "w") as f:
f.write(os.environ["SNOWFLAKE_CONNECTIONS_INTEGRATION_PRIVATE_KEY_RAW"])

result = runner.invoke(
[
"sql",
"-q",
"select 1",
"--temporary-connection",
"--authenticator",
"SNOWFLAKE_JWT",
"--account",
os.environ["SNOWFLAKE_CONNECTIONS_INTEGRATION_ACCOUNT"],
"--user",
os.environ["SNOWFLAKE_CONNECTIONS_INTEGRATION_USER"],
"--private-key-file",
str(private_key_path),
]
)
assert result.exit_code == 0
assert result.output == snapshot

0 comments on commit d7613ff

Please sign in to comment.