diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py index 9968895ef73b..e01a0439c9da 100644 --- a/src/py/flwr/cli/utils.py +++ b/src/py/flwr/cli/utils.py @@ -22,7 +22,7 @@ from contextlib import contextmanager from logging import DEBUG from pathlib import Path -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Optional, Union, cast import grpc import typer @@ -148,15 +148,18 @@ def sanitize_project_name(name: str) -> str: return sanitized_name -def get_sha256_hash(file_path: Path) -> str: +def get_sha256_hash(file_path_or_int: Union[Path, int]) -> str: """Calculate the SHA-256 hash of a file.""" sha256 = hashlib.sha256() - with open(file_path, "rb") as f: - while True: - data = f.read(65536) # Read in 64kB blocks - if not data: - break - sha256.update(data) + if isinstance(file_path_or_int, Path): + with open(file_path_or_int, "rb") as f: + while True: + data = f.read(65536) # Read in 64kB blocks + if not data: + break + sha256.update(data) + elif isinstance(file_path_or_int, int): + sha256.update(str(file_path_or_int).encode()) return sha256.hexdigest() diff --git a/src/py/flwr/cli/utils_test.py b/src/py/flwr/cli/utils_test.py new file mode 100644 index 000000000000..e722dee70c3c --- /dev/null +++ b/src/py/flwr/cli/utils_test.py @@ -0,0 +1,109 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Flower command line interface utils.""" + + +import hashlib +import os +import tempfile +import unittest +from pathlib import Path + +from flwr.cli.utils import get_sha256_hash + + +class TestGetSHA256Hash(unittest.TestCase): + """Unit tests for `get_sha256_hash` function.""" + + def test_hash_with_integer(self) -> None: + """Test the SHA-256 hash calculation when input is an integer.""" + # Prepare + test_int = 13413 + expected_hash = hashlib.sha256(str(test_int).encode()).hexdigest() + + # Execute + result = get_sha256_hash(test_int) + + # Assert + self.assertEqual(result, expected_hash) + + def test_hash_with_file(self) -> None: + """Test the SHA-256 hash calculation when input is a file path.""" + # Prepare - Create a temporary file with known content + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file.write(b"Test content for SHA-256 hashing.") + temp_file_path = Path(temp_file.name) + + try: + # Execute + sha256 = hashlib.sha256() + with open(temp_file_path, "rb") as f: + while True: + data = f.read(65536) + if not data: + break + sha256.update(data) + expected_hash = sha256.hexdigest() + + result = get_sha256_hash(temp_file_path) + + # Assert + self.assertEqual(result, expected_hash) + finally: + # Clean up the temporary file + os.remove(temp_file_path) + + def test_empty_file(self) -> None: + """Test the SHA-256 hash calculation for an empty file.""" + # Prepare + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file_path = Path(temp_file.name) + + try: + # Execute + expected_hash = hashlib.sha256(b"").hexdigest() + result = get_sha256_hash(temp_file_path) + + # Assert + self.assertEqual(result, expected_hash) + finally: + os.remove(temp_file_path) + + def test_large_file(self) -> None: + """Test the SHA-256 hash calculation for a large file.""" + # Prepare - Generate large data (e.g., 10 MB) + large_data = b"a" * (10 * 1024 * 1024) # 10 MB of 'a's + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file.write(large_data) + temp_file_path = Path(temp_file.name) + + try: + expected_hash = hashlib.sha256(large_data).hexdigest() + # Execute + result = get_sha256_hash(temp_file_path) + + # Assert + self.assertEqual(result, expected_hash) + finally: + os.remove(temp_file_path) + + def test_nonexistent_file(self) -> None: + """Test the SHA-256 hash calculation when the file does not exist.""" + # Prepare + nonexistent_path = Path("/path/to/nonexistent/file.txt") + + # Execute & assert + with self.assertRaises(FileNotFoundError): + get_sha256_hash(nonexistent_path)