-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tests: improve test coverage to utils
- Loading branch information
1 parent
e1ad83c
commit e4b31fc
Showing
8 changed files
with
353 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[run] | ||
source = whispr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,3 +29,4 @@ lib64/ | |
.env | ||
*.creds | ||
.coverage* | ||
!.coveragerc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import os | ||
|
||
import yaml | ||
|
||
from whispr.logging import logger | ||
|
||
def write_to_yaml_file(config: dict, file_path: str): | ||
"""Writes a given config object to a file in YAML format""" | ||
if not os.path.exists(file_path): | ||
with open(file_path, "w", encoding="utf-8") as file: | ||
yaml.dump(config, file) | ||
logger.info(f"{file_path} has been created.") | ||
|
||
def load_config(file_path: str) -> dict: | ||
"""Loads a given config file""" | ||
try: | ||
with open(file_path, "r", encoding="utf-8") as file: | ||
return yaml.safe_load(file) | ||
except Exception as e: | ||
raise e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import os | ||
import subprocess | ||
import shlex | ||
|
||
from whispr.logging import logger | ||
|
||
def execute_command(command: tuple, no_env: bool, secrets: dict): | ||
"""Executes a Unix/Windows command. | ||
Arg: `no_env` decides whether secrets are passed vai environment or K:V pairs in command arguments. | ||
""" | ||
if not secrets: | ||
secrets = {} | ||
|
||
try: | ||
usr_command = shlex.split(command[0]) | ||
|
||
if no_env: | ||
# Pass as --env K=V format (secure) | ||
usr_command.extend([ | ||
f"{k}={v}" for k,v in secrets.items() | ||
]) | ||
else: | ||
# Pass via environment (slightly insecure) | ||
os.environ.update(secrets) | ||
|
||
sp = subprocess.run(usr_command, env=os.environ, shell=False, check=True) | ||
except subprocess.CalledProcessError as e: | ||
logger.error( | ||
f"Encountered a problem while running command: '{command[0]}'. Aborting." | ||
) | ||
raise e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import json | ||
|
||
from dotenv import dotenv_values | ||
|
||
from whispr.factory import VaultFactory | ||
from whispr.logging import logger | ||
from whispr.enums import VaultType | ||
|
||
|
||
def fetch_secrets(config: dict) -> dict: | ||
"""Fetch secret from relevant vault""" | ||
kwargs = config | ||
kwargs["logger"] = logger | ||
|
||
vault = config.get("vault") | ||
secret_name = config.get("secret_name") | ||
|
||
if not vault or not secret_name: | ||
logger.error( | ||
"Vault type or secret name not specified in the configuration file." | ||
) | ||
return {} | ||
|
||
try: | ||
vault_instance = VaultFactory.get_vault(**kwargs) | ||
except ValueError as e: | ||
logger.error(f"Error creating vault instance: {str(e)}") | ||
return {} | ||
|
||
secret_string = vault_instance.fetch_secrets(secret_name) | ||
if not secret_string: | ||
return {} | ||
|
||
return json.loads(secret_string) | ||
|
||
|
||
def get_filled_secrets(env_file: str, vault_secrets: dict) -> dict: | ||
"""Inject vault secret values into local empty secrets""" | ||
|
||
filled_secrets = {} | ||
env_vars = dotenv_values(dotenv_path=env_file) | ||
|
||
# Iterate over .env variables and check if they exist in the fetched secrets | ||
for key in env_vars: | ||
if key in vault_secrets: | ||
filled_secrets[key] = vault_secrets[key] # Collect the matching secrets | ||
else: | ||
logger.warning( | ||
f"The given key: '{key}' is not found in vault. So ignoring it." | ||
) | ||
|
||
# Return the dictionary of matched secrets for further use if needed | ||
return filled_secrets | ||
|
||
|
||
def prepare_vault_config(vault_type: str) -> dict: | ||
"""Prepares in-memory configuration for a given vault""" | ||
config = { | ||
"env_file": ".env", | ||
"secret_name": "<your_secret_name>", | ||
"vault": VaultType.AWS.value, | ||
} | ||
|
||
# Add more configuration fields as needed for other secret managers. | ||
if vault_type == VaultType.GCP.value: | ||
config["project_id"] = "<gcp_project_id>" | ||
config["vault"] = VaultType.GCP.value | ||
elif vault_type == VaultType.AZURE.value: | ||
config["vault_url"] = "<azure_vault_url>" | ||
config["vault"] = VaultType.AZURE.value | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
|
||
import os | ||
import yaml | ||
import unittest | ||
|
||
from unittest.mock import MagicMock, patch, mock_open | ||
from whispr.utils.io import write_to_yaml_file, load_config | ||
|
||
class IOUtilsTestCase(unittest.TestCase): | ||
"""Unit tests for the file utilities: write_to_yaml_file and load_config.""" | ||
|
||
def setUp(self): | ||
"""Set up mocks for logger and os.path methods.""" | ||
self.mock_logger = MagicMock() | ||
self.config = {"key": "value"} | ||
self.file_path = "test_config.yaml" | ||
|
||
@patch("whispr.utils.io.logger", new_callable=lambda: MagicMock()) | ||
@patch("builtins.open", new_callable=mock_open) | ||
@patch("os.path.exists", return_value=False) | ||
def test_write_to_yaml_file_creates_file(self, mock_exists, mock_open_file, mock_logger): | ||
"""Test that write_to_yaml_file creates a new file and writes config data as YAML.""" | ||
write_to_yaml_file(self.config, self.file_path) | ||
|
||
mock_open_file.assert_called_once_with(self.file_path, "w", encoding="utf-8") | ||
mock_open_file().write.assert_called() # Ensures that some content was written | ||
mock_logger.info.assert_called_once_with(f"{self.file_path} has been created.") | ||
|
||
@patch("whispr.utils.io.logger", new_callable=lambda: MagicMock()) | ||
@patch("builtins.open", new_callable=mock_open) | ||
@patch("os.path.exists", return_value=True) | ||
def test_write_to_yaml_file_does_not_overwrite_existing_file(self, mock_exists, mock_open_file, mock_logger): | ||
"""Test that write_to_yaml_file does not overwrite an existing file.""" | ||
write_to_yaml_file(self.config, self.file_path) | ||
|
||
mock_open_file.assert_not_called() | ||
mock_logger.info.assert_not_called() | ||
|
||
@patch("builtins.open", new_callable=mock_open, read_data="key: value") | ||
def test_load_config_success(self, mock_open_file): | ||
"""Test that load_config loads a YAML file and returns a config dictionary.""" | ||
result = load_config(self.file_path) | ||
|
||
mock_open_file.assert_called_once_with(self.file_path, "r", encoding="utf-8") | ||
self.assertEqual(result, {"key": "value"}) | ||
|
||
@patch("builtins.open", new_callable=mock_open) | ||
def test_load_config_file_not_found(self, mock_open_file): | ||
"""Test load_config raises an error if the file does not exist.""" | ||
mock_open_file.side_effect = FileNotFoundError | ||
|
||
with self.assertRaises(FileNotFoundError): | ||
load_config("non_existent.yaml") | ||
|
||
@patch("builtins.open", new_callable=mock_open) | ||
def test_load_config_yaml_error(self, mock_open_file): | ||
"""Test load_config raises an error for an invalid YAML file.""" | ||
mock_open_file.side_effect = yaml.YAMLError | ||
|
||
with self.assertRaises(yaml.YAMLError): | ||
load_config(self.file_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import unittest | ||
from unittest.mock import patch, MagicMock | ||
import subprocess | ||
import os | ||
|
||
from whispr.utils.process import execute_command | ||
|
||
|
||
class ProcessUtilsTestCase(unittest.TestCase): | ||
"""Unit tests for the execute_command function, which executes commands with optional secrets.""" | ||
|
||
def setUp(self): | ||
"""Set up test data and mocks for logger and os environment.""" | ||
self.command = ("echo Hello",) | ||
self.secrets = {"API_KEY": "123456"} | ||
self.no_env = True | ||
self.mock_logger = MagicMock() | ||
|
||
@patch("whispr.utils.process.logger", new_callable=lambda: MagicMock()) | ||
@patch("subprocess.run") | ||
def test_execute_command_with_no_env(self, mock_subprocess_run, mock_logger): | ||
"""Test execute_command with `no_env=True`, passing secrets as command arguments.""" | ||
execute_command(self.command, self.no_env, self.secrets) | ||
|
||
expected_command = ["echo", "Hello", "API_KEY=123456"] | ||
mock_subprocess_run.assert_called_once_with(expected_command, env=os.environ, shell=False, check=True) | ||
|
||
@patch("whispr.utils.process.logger", new_callable=lambda: MagicMock()) | ||
@patch("subprocess.run") | ||
@patch("os.environ.update") | ||
def test_execute_command_with_env(self, mock_env_update, mock_subprocess_run, mock_logger): | ||
"""Test execute_command with `no_env=False`, passing secrets via environment variables.""" | ||
execute_command(self.command, no_env=False, secrets=self.secrets) | ||
|
||
mock_env_update.assert_called_once_with(self.secrets) | ||
expected_command = ["echo", "Hello"] | ||
mock_subprocess_run.assert_called_once_with(expected_command, env=os.environ, shell=False, check=True) | ||
|
||
@patch("whispr.utils.process.logger", new_callable=lambda: MagicMock()) | ||
@patch("subprocess.run", side_effect=subprocess.CalledProcessError(1, "test")) | ||
def test_execute_command_called_process_error(self, mock_subprocess_run, mock_logger): | ||
"""Test execute_command handles CalledProcessError and logs an error message.""" | ||
with self.assertRaises(subprocess.CalledProcessError): | ||
execute_command(self.command, no_env=True, secrets=self.secrets) | ||
|
||
mock_logger.error.assert_called_once_with( | ||
f"Encountered a problem while running command: '{self.command[0]}'. Aborting." | ||
) | ||
|
||
@patch("whispr.utils.process.logger", new_callable=lambda: MagicMock()) | ||
@patch("subprocess.run") | ||
def test_execute_command_without_secrets(self, mock_subprocess_run, mock_logger): | ||
"""Test execute_command without any secrets.""" | ||
execute_command(self.command, no_env=True, secrets={}) | ||
|
||
expected_command = ["echo", "Hello"] | ||
mock_subprocess_run.assert_called_once_with(expected_command, env=os.environ, shell=False, check=True) | ||
mock_logger.error.assert_not_called() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import unittest | ||
from unittest.mock import patch, MagicMock | ||
import json | ||
from dotenv import dotenv_values | ||
|
||
from whispr.utils.vault import fetch_secrets, get_filled_secrets, prepare_vault_config | ||
from whispr.enums import VaultType | ||
|
||
|
||
class SecretUtilsTestCase(unittest.TestCase): | ||
"""Unit tests for the secret utilities: fetch_secrets, get_filled_secrets, and prepare_vault_config.""" | ||
|
||
def setUp(self): | ||
"""Set up test configuration and mock logger.""" | ||
self.config = { | ||
"vault": VaultType.AWS.value, | ||
"secret_name": "test_secret", | ||
} | ||
self.vault_secrets = {"API_KEY": "123456"} | ||
self.env_file = ".env" | ||
self.mock_logger = MagicMock() | ||
|
||
@patch("whispr.utils.vault.logger", new_callable=lambda: MagicMock()) | ||
@patch("whispr.utils.vault.VaultFactory.get_vault") | ||
def test_fetch_secrets_success(self, mock_get_vault, mock_logger): | ||
"""Test fetch_secrets successfully retrieves and parses a secret.""" | ||
mock_vault_instance = MagicMock() | ||
mock_vault_instance.fetch_secrets.return_value = json.dumps(self.vault_secrets) | ||
mock_get_vault.return_value = mock_vault_instance | ||
|
||
result = fetch_secrets(self.config) | ||
self.assertEqual(result, self.vault_secrets) | ||
|
||
@patch("whispr.utils.vault.logger", new_callable=lambda: MagicMock()) | ||
def test_fetch_secrets_missing_config(self, mock_logger): | ||
"""Test fetch_secrets logs an error if the vault type or secret name is missing.""" | ||
config = {"vault": None, "secret_name": None} | ||
|
||
result = fetch_secrets(config) | ||
self.assertEqual(result, {}) | ||
mock_logger.error.assert_called_once_with( | ||
"Vault type or secret name not specified in the configuration file." | ||
) | ||
|
||
@patch("whispr.utils.vault.logger", new_callable=lambda: MagicMock()) | ||
@patch("whispr.utils.vault.VaultFactory.get_vault", side_effect=ValueError("Invalid vault type")) | ||
def test_fetch_secrets_invalid_vault(self, mock_get_vault, mock_logger): | ||
"""Test fetch_secrets logs an error if the vault factory raises a ValueError.""" | ||
result = fetch_secrets({ | ||
"vault": "UNKOWN", | ||
"secret_name": "test_secret", | ||
}) | ||
|
||
self.assertEqual(result, {}) | ||
mock_logger.error.assert_called_once_with("Error creating vault instance: Invalid vault type") | ||
|
||
@patch("whispr.utils.vault.dotenv_values", return_value={"API_KEY": "", "OTHER_KEY": ""}) | ||
@patch("whispr.utils.vault.logger", new_callable=lambda: MagicMock()) | ||
def test_get_filled_secrets_partial_match(self, mock_logger, mock_dotenv_values): | ||
"""Test get_filled_secrets fills only matching secrets from vault_secrets.""" | ||
filled_secrets = get_filled_secrets(self.env_file, self.vault_secrets) | ||
|
||
self.assertEqual(filled_secrets, {"API_KEY": "123456"}) | ||
mock_logger.warning.assert_called_once_with( | ||
"The given key: 'OTHER_KEY' is not found in vault. So ignoring it." | ||
) | ||
|
||
@patch("whispr.utils.vault.dotenv_values", return_value={"NON_MATCHING_KEY": ""}) | ||
@patch("whispr.utils.vault.logger", new_callable=lambda: MagicMock()) | ||
def test_get_filled_secrets_no_match(self, mock_logger, mock_dotenv_values): | ||
"""Test get_filled_secrets returns an empty dictionary if no env variables match vault secrets.""" | ||
filled_secrets = get_filled_secrets(self.env_file, self.vault_secrets) | ||
self.assertEqual(filled_secrets, {}) | ||
mock_logger.warning.assert_called_once_with( | ||
"The given key: 'NON_MATCHING_KEY' is not found in vault. So ignoring it." | ||
) | ||
|
||
def test_prepare_vault_config_aws(self): | ||
"""Test prepare_vault_config generates AWS configuration.""" | ||
config = prepare_vault_config(VaultType.AWS.value) | ||
expected_config = { | ||
"env_file": ".env", | ||
"secret_name": "<your_secret_name>", | ||
"vault": VaultType.AWS.value, | ||
} | ||
self.assertEqual(config, expected_config) | ||
|
||
def test_prepare_vault_config_gcp(self): | ||
"""Test prepare_vault_config generates GCP configuration.""" | ||
config = prepare_vault_config(VaultType.GCP.value) | ||
expected_config = { | ||
"env_file": ".env", | ||
"secret_name": "<your_secret_name>", | ||
"vault": VaultType.GCP.value, | ||
"project_id": "<gcp_project_id>", | ||
} | ||
self.assertEqual(config, expected_config) | ||
|
||
def test_prepare_vault_config_azure(self): | ||
"""Test prepare_vault_config generates Azure configuration.""" | ||
config = prepare_vault_config(VaultType.AZURE.value) | ||
expected_config = { | ||
"env_file": ".env", | ||
"secret_name": "<your_secret_name>", | ||
"vault": VaultType.AZURE.value, | ||
"vault_url": "<azure_vault_url>", | ||
} | ||
self.assertEqual(config, expected_config) |