Skip to content

Commit

Permalink
tests: improve test coverage to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
narenaryan committed Nov 10, 2024
1 parent e1ad83c commit e4b31fc
Show file tree
Hide file tree
Showing 8 changed files with 353 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[run]
source = whispr
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ lib64/
.env
*.creds
.coverage*
!.coveragerc
20 changes: 20 additions & 0 deletions src/whispr/utils/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

import yaml

from whispr.logging import logger

def write_to_yaml_file(config: dict, file_path: str):
"""Writes a given config object to a file in YAML format"""
if not os.path.exists(file_path):
with open(file_path, "w", encoding="utf-8") as file:
yaml.dump(config, file)
logger.info(f"{file_path} has been created.")

def load_config(file_path: str) -> dict:
"""Loads a given config file"""
try:
with open(file_path, "r", encoding="utf-8") as file:
return yaml.safe_load(file)
except Exception as e:
raise e
31 changes: 31 additions & 0 deletions src/whispr/utils/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
import subprocess
import shlex

from whispr.logging import logger

def execute_command(command: tuple, no_env: bool, secrets: dict):
"""Executes a Unix/Windows command.
Arg: `no_env` decides whether secrets are passed vai environment or K:V pairs in command arguments.
"""
if not secrets:
secrets = {}

try:
usr_command = shlex.split(command[0])

if no_env:
# Pass as --env K=V format (secure)
usr_command.extend([
f"{k}={v}" for k,v in secrets.items()
])
else:
# Pass via environment (slightly insecure)
os.environ.update(secrets)

sp = subprocess.run(usr_command, env=os.environ, shell=False, check=True)
except subprocess.CalledProcessError as e:
logger.error(
f"Encountered a problem while running command: '{command[0]}'. Aborting."
)
raise e
72 changes: 72 additions & 0 deletions src/whispr/utils/vault.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json

from dotenv import dotenv_values

from whispr.factory import VaultFactory
from whispr.logging import logger
from whispr.enums import VaultType


def fetch_secrets(config: dict) -> dict:
"""Fetch secret from relevant vault"""
kwargs = config
kwargs["logger"] = logger

vault = config.get("vault")
secret_name = config.get("secret_name")

if not vault or not secret_name:
logger.error(
"Vault type or secret name not specified in the configuration file."
)
return {}

try:
vault_instance = VaultFactory.get_vault(**kwargs)
except ValueError as e:
logger.error(f"Error creating vault instance: {str(e)}")
return {}

secret_string = vault_instance.fetch_secrets(secret_name)
if not secret_string:
return {}

return json.loads(secret_string)


def get_filled_secrets(env_file: str, vault_secrets: dict) -> dict:
"""Inject vault secret values into local empty secrets"""

filled_secrets = {}
env_vars = dotenv_values(dotenv_path=env_file)

# Iterate over .env variables and check if they exist in the fetched secrets
for key in env_vars:
if key in vault_secrets:
filled_secrets[key] = vault_secrets[key] # Collect the matching secrets
else:
logger.warning(
f"The given key: '{key}' is not found in vault. So ignoring it."
)

# Return the dictionary of matched secrets for further use if needed
return filled_secrets


def prepare_vault_config(vault_type: str) -> dict:
"""Prepares in-memory configuration for a given vault"""
config = {
"env_file": ".env",
"secret_name": "<your_secret_name>",
"vault": VaultType.AWS.value,
}

# Add more configuration fields as needed for other secret managers.
if vault_type == VaultType.GCP.value:
config["project_id"] = "<gcp_project_id>"
config["vault"] = VaultType.GCP.value
elif vault_type == VaultType.AZURE.value:
config["vault_url"] = "<azure_vault_url>"
config["vault"] = VaultType.AZURE.value

return config
61 changes: 61 additions & 0 deletions tests/test_io_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

import os
import yaml
import unittest

from unittest.mock import MagicMock, patch, mock_open
from whispr.utils.io import write_to_yaml_file, load_config

class IOUtilsTestCase(unittest.TestCase):
"""Unit tests for the file utilities: write_to_yaml_file and load_config."""

def setUp(self):
"""Set up mocks for logger and os.path methods."""
self.mock_logger = MagicMock()
self.config = {"key": "value"}
self.file_path = "test_config.yaml"

@patch("whispr.utils.io.logger", new_callable=lambda: MagicMock())
@patch("builtins.open", new_callable=mock_open)
@patch("os.path.exists", return_value=False)
def test_write_to_yaml_file_creates_file(self, mock_exists, mock_open_file, mock_logger):
"""Test that write_to_yaml_file creates a new file and writes config data as YAML."""
write_to_yaml_file(self.config, self.file_path)

mock_open_file.assert_called_once_with(self.file_path, "w", encoding="utf-8")
mock_open_file().write.assert_called() # Ensures that some content was written
mock_logger.info.assert_called_once_with(f"{self.file_path} has been created.")

@patch("whispr.utils.io.logger", new_callable=lambda: MagicMock())
@patch("builtins.open", new_callable=mock_open)
@patch("os.path.exists", return_value=True)
def test_write_to_yaml_file_does_not_overwrite_existing_file(self, mock_exists, mock_open_file, mock_logger):
"""Test that write_to_yaml_file does not overwrite an existing file."""
write_to_yaml_file(self.config, self.file_path)

mock_open_file.assert_not_called()
mock_logger.info.assert_not_called()

@patch("builtins.open", new_callable=mock_open, read_data="key: value")
def test_load_config_success(self, mock_open_file):
"""Test that load_config loads a YAML file and returns a config dictionary."""
result = load_config(self.file_path)

mock_open_file.assert_called_once_with(self.file_path, "r", encoding="utf-8")
self.assertEqual(result, {"key": "value"})

@patch("builtins.open", new_callable=mock_open)
def test_load_config_file_not_found(self, mock_open_file):
"""Test load_config raises an error if the file does not exist."""
mock_open_file.side_effect = FileNotFoundError

with self.assertRaises(FileNotFoundError):
load_config("non_existent.yaml")

@patch("builtins.open", new_callable=mock_open)
def test_load_config_yaml_error(self, mock_open_file):
"""Test load_config raises an error for an invalid YAML file."""
mock_open_file.side_effect = yaml.YAMLError

with self.assertRaises(yaml.YAMLError):
load_config(self.file_path)
58 changes: 58 additions & 0 deletions tests/test_process_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import unittest
from unittest.mock import patch, MagicMock
import subprocess
import os

from whispr.utils.process import execute_command


class ProcessUtilsTestCase(unittest.TestCase):
"""Unit tests for the execute_command function, which executes commands with optional secrets."""

def setUp(self):
"""Set up test data and mocks for logger and os environment."""
self.command = ("echo Hello",)
self.secrets = {"API_KEY": "123456"}
self.no_env = True
self.mock_logger = MagicMock()

@patch("whispr.utils.process.logger", new_callable=lambda: MagicMock())
@patch("subprocess.run")
def test_execute_command_with_no_env(self, mock_subprocess_run, mock_logger):
"""Test execute_command with `no_env=True`, passing secrets as command arguments."""
execute_command(self.command, self.no_env, self.secrets)

expected_command = ["echo", "Hello", "API_KEY=123456"]
mock_subprocess_run.assert_called_once_with(expected_command, env=os.environ, shell=False, check=True)

@patch("whispr.utils.process.logger", new_callable=lambda: MagicMock())
@patch("subprocess.run")
@patch("os.environ.update")
def test_execute_command_with_env(self, mock_env_update, mock_subprocess_run, mock_logger):
"""Test execute_command with `no_env=False`, passing secrets via environment variables."""
execute_command(self.command, no_env=False, secrets=self.secrets)

mock_env_update.assert_called_once_with(self.secrets)
expected_command = ["echo", "Hello"]
mock_subprocess_run.assert_called_once_with(expected_command, env=os.environ, shell=False, check=True)

@patch("whispr.utils.process.logger", new_callable=lambda: MagicMock())
@patch("subprocess.run", side_effect=subprocess.CalledProcessError(1, "test"))
def test_execute_command_called_process_error(self, mock_subprocess_run, mock_logger):
"""Test execute_command handles CalledProcessError and logs an error message."""
with self.assertRaises(subprocess.CalledProcessError):
execute_command(self.command, no_env=True, secrets=self.secrets)

mock_logger.error.assert_called_once_with(
f"Encountered a problem while running command: '{self.command[0]}'. Aborting."
)

@patch("whispr.utils.process.logger", new_callable=lambda: MagicMock())
@patch("subprocess.run")
def test_execute_command_without_secrets(self, mock_subprocess_run, mock_logger):
"""Test execute_command without any secrets."""
execute_command(self.command, no_env=True, secrets={})

expected_command = ["echo", "Hello"]
mock_subprocess_run.assert_called_once_with(expected_command, env=os.environ, shell=False, check=True)
mock_logger.error.assert_not_called()
108 changes: 108 additions & 0 deletions tests/test_vault_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import unittest
from unittest.mock import patch, MagicMock
import json
from dotenv import dotenv_values

from whispr.utils.vault import fetch_secrets, get_filled_secrets, prepare_vault_config
from whispr.enums import VaultType


class SecretUtilsTestCase(unittest.TestCase):
"""Unit tests for the secret utilities: fetch_secrets, get_filled_secrets, and prepare_vault_config."""

def setUp(self):
"""Set up test configuration and mock logger."""
self.config = {
"vault": VaultType.AWS.value,
"secret_name": "test_secret",
}
self.vault_secrets = {"API_KEY": "123456"}
self.env_file = ".env"
self.mock_logger = MagicMock()

@patch("whispr.utils.vault.logger", new_callable=lambda: MagicMock())
@patch("whispr.utils.vault.VaultFactory.get_vault")
def test_fetch_secrets_success(self, mock_get_vault, mock_logger):
"""Test fetch_secrets successfully retrieves and parses a secret."""
mock_vault_instance = MagicMock()
mock_vault_instance.fetch_secrets.return_value = json.dumps(self.vault_secrets)
mock_get_vault.return_value = mock_vault_instance

result = fetch_secrets(self.config)
self.assertEqual(result, self.vault_secrets)

@patch("whispr.utils.vault.logger", new_callable=lambda: MagicMock())
def test_fetch_secrets_missing_config(self, mock_logger):
"""Test fetch_secrets logs an error if the vault type or secret name is missing."""
config = {"vault": None, "secret_name": None}

result = fetch_secrets(config)
self.assertEqual(result, {})
mock_logger.error.assert_called_once_with(
"Vault type or secret name not specified in the configuration file."
)

@patch("whispr.utils.vault.logger", new_callable=lambda: MagicMock())
@patch("whispr.utils.vault.VaultFactory.get_vault", side_effect=ValueError("Invalid vault type"))
def test_fetch_secrets_invalid_vault(self, mock_get_vault, mock_logger):
"""Test fetch_secrets logs an error if the vault factory raises a ValueError."""
result = fetch_secrets({
"vault": "UNKOWN",
"secret_name": "test_secret",
})

self.assertEqual(result, {})
mock_logger.error.assert_called_once_with("Error creating vault instance: Invalid vault type")

@patch("whispr.utils.vault.dotenv_values", return_value={"API_KEY": "", "OTHER_KEY": ""})
@patch("whispr.utils.vault.logger", new_callable=lambda: MagicMock())
def test_get_filled_secrets_partial_match(self, mock_logger, mock_dotenv_values):
"""Test get_filled_secrets fills only matching secrets from vault_secrets."""
filled_secrets = get_filled_secrets(self.env_file, self.vault_secrets)

self.assertEqual(filled_secrets, {"API_KEY": "123456"})
mock_logger.warning.assert_called_once_with(
"The given key: 'OTHER_KEY' is not found in vault. So ignoring it."
)

@patch("whispr.utils.vault.dotenv_values", return_value={"NON_MATCHING_KEY": ""})
@patch("whispr.utils.vault.logger", new_callable=lambda: MagicMock())
def test_get_filled_secrets_no_match(self, mock_logger, mock_dotenv_values):
"""Test get_filled_secrets returns an empty dictionary if no env variables match vault secrets."""
filled_secrets = get_filled_secrets(self.env_file, self.vault_secrets)
self.assertEqual(filled_secrets, {})
mock_logger.warning.assert_called_once_with(
"The given key: 'NON_MATCHING_KEY' is not found in vault. So ignoring it."
)

def test_prepare_vault_config_aws(self):
"""Test prepare_vault_config generates AWS configuration."""
config = prepare_vault_config(VaultType.AWS.value)
expected_config = {
"env_file": ".env",
"secret_name": "<your_secret_name>",
"vault": VaultType.AWS.value,
}
self.assertEqual(config, expected_config)

def test_prepare_vault_config_gcp(self):
"""Test prepare_vault_config generates GCP configuration."""
config = prepare_vault_config(VaultType.GCP.value)
expected_config = {
"env_file": ".env",
"secret_name": "<your_secret_name>",
"vault": VaultType.GCP.value,
"project_id": "<gcp_project_id>",
}
self.assertEqual(config, expected_config)

def test_prepare_vault_config_azure(self):
"""Test prepare_vault_config generates Azure configuration."""
config = prepare_vault_config(VaultType.AZURE.value)
expected_config = {
"env_file": ".env",
"secret_name": "<your_secret_name>",
"vault": VaultType.AZURE.value,
"vault_url": "<azure_vault_url>",
}
self.assertEqual(config, expected_config)

0 comments on commit e4b31fc

Please sign in to comment.