diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ecb6f6..6ba95b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ This project adheres to [Semantic Versioning](http://semver.org/) and [Keep a Ch - Update session manager to pass toolchain role region to sts ### Fixes +- correct archive extraction when there is only one nested path (issue 749) ## v5.0.0 (2024-08-16) diff --git a/seedfarmer/mgmt/archive_support.py b/seedfarmer/mgmt/archive_support.py index a36a966..8b6c475 100644 --- a/seedfarmer/mgmt/archive_support.py +++ b/seedfarmer/mgmt/archive_support.py @@ -72,13 +72,31 @@ def _download_archive(archive_url: str, secret_name: Optional[str]) -> Response: def _extract_archive(archive_name: str, extracted_dir_path: str) -> str: if archive_name.endswith(".tar.gz"): with tarfile.open(archive_name, "r:gz") as tar_file: - embedded_dir = os.path.commonprefix(tar_file.getnames()) - tar_file.extractall(extracted_dir_path) + all_members = tar_file.getmembers() + top_level_dirs = set(member.name.split("/")[0] for member in all_members if "/" in member.name) + if len(top_level_dirs) > 1: + raise InvalidConfigurationError( + f"the archive {archive_name} can only have one directory at the root and no files" + ) + elif len(top_level_dirs) == 1: + embedded_dir = top_level_dirs.pop() + else: + embedded_dir = "" + tar_file.extractall(extracted_dir_path, members=all_members) else: with ZipFile(archive_name, "r") as zip_file: - embedded_dir = os.path.commonprefix(zip_file.namelist()) - zip_file.extractall(extracted_dir_path) - + all_files = zip_file.namelist() + top_level_dirs = set(name.split("/")[0] for name in all_files if "/" in name) + if len(top_level_dirs) > 1: + raise InvalidConfigurationError( + f"the archive {archive_name} can only have one directory at the root and no files" + ) + elif len(top_level_dirs) == 1: + embedded_dir = top_level_dirs.pop() + else: + embedded_dir = "" + for file in all_files: + zip_file.extract(file, path=extracted_dir_path) return embedded_dir diff --git a/test/unit-test/test_mgmt_archive_support.py b/test/unit-test/test_mgmt_archive_support.py index e18507d..e4acb21 100644 --- a/test/unit-test/test_mgmt_archive_support.py +++ b/test/unit-test/test_mgmt_archive_support.py @@ -79,11 +79,45 @@ def parent_dir_prepare(): shutil.rmtree(parent_dir) +# This represents a proper archive with the project (test-project) at the root +# are are ignoring the hidden file at the root, and the other files +# under test-project are not used example_archive_files = [ - ("modules/test-module/modulestack.yaml", io.BytesIO(b"111")), - ("modules/test-module/pyproject.toml", io.BytesIO(b"222")), - ("README.md", io.BytesIO(b"333")), - ("LICENSE", io.BytesIO(b"444")), + ("test-project/modules/test-module/modulestack.yaml", io.BytesIO(b"111")), + ("test-project/modules/test-module/pyproject.toml", io.BytesIO(b"222")), + ("test-project/README.md", io.BytesIO(b"333")), + ("test-project/LICENSE", io.BytesIO(b"444")), + (".DS_Store_local", io.BytesIO(b"555")), +] + +# This represents a proper archive with the project (test-project) at the root +# are are ignoring the hidden file at the root, and the other file is not used +example_archive_files_single_module = [ + ("test-project/modules/test-module/deployspec.yaml", io.BytesIO(b"666")), + ("test-project/modules/test-module/something.yaml", io.BytesIO(b"777")), + (".DS_Store", io.BytesIO(b"888")), + ("README.md", io.BytesIO(b"999")), +] + +# This represents a an archive without a nested dir (project) +# This is not recommended as it is not a proper SF project, +# but can support due to customer ask +example_archive_files_no_nesting = [ + ("deployspec.yaml", io.BytesIO(b"1010")), + ("app.py", io.BytesIO(b"1011")), + ("stack.py", io.BytesIO(b"1012")), + (".DS_Store", io.BytesIO(b"1013")), + ("README.md", io.BytesIO(b"1014")), +] + +# This is a bad package as there are two dirs at the +# root level...that is not allowed +example_archive_files_bad_structure = [ + ("test-project/deployspec.yaml", io.BytesIO(b"1200")), + ("test-project_additional/deployspec.yaml", io.BytesIO(b"1201")), + ("deployspec.yaml", io.BytesIO(b"1202")), + ("app.py", io.BytesIO(b"1203")), + ("stack.py", io.BytesIO(b"1204")), ] @@ -98,6 +132,39 @@ def zip_file_data() -> Tuple[bytes, str]: return zip_buffer.getvalue(), "zip" +@pytest.fixture(scope="function") +def zip_file_data_not_nested() -> Tuple[bytes, str]: + zip_buffer = io.BytesIO() + + with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file: + for file_name, data in example_archive_files_no_nesting: + zip_file.writestr(file_name, data.getvalue()) + + return zip_buffer.getvalue(), "zip" + + +@pytest.fixture(scope="function") +def zip_file_data_single_module() -> Tuple[bytes, str]: + zip_buffer = io.BytesIO() + + with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file: + for file_name, data in example_archive_files_single_module: + zip_file.writestr(file_name, data.getvalue()) + + return zip_buffer.getvalue(), "zip" + + +@pytest.fixture(scope="function") +def zip_file_data_bad_structure() -> Tuple[bytes, str]: + zip_buffer = io.BytesIO() + + with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file: + for file_name, data in example_archive_files_bad_structure: + zip_file.writestr(file_name, data.getvalue()) + + return zip_buffer.getvalue(), "zip" + + @pytest.fixture(scope="function") def tar_file_data() -> Tuple[bytes, str]: tar_buffer = io.BytesIO() @@ -109,6 +176,39 @@ def tar_file_data() -> Tuple[bytes, str]: return tar_buffer.getvalue(), "tar.gz" +@pytest.fixture(scope="function") +def tar_file_data_not_nested() -> Tuple[bytes, str]: + tar_buffer = io.BytesIO() + + with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar_file: + for file_name, data in example_archive_files_no_nesting: + tar_file.addfile(tarfile.TarInfo(name=file_name), fileobj=data) + + return tar_buffer.getvalue(), "tar.gz" + + +@pytest.fixture(scope="function") +def tar_file_data_single_module() -> Tuple[bytes, str]: + tar_buffer = io.BytesIO() + + with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar_file: + for file_name, data in example_archive_files_single_module: + tar_file.addfile(tarfile.TarInfo(name=file_name), fileobj=data) + + return tar_buffer.getvalue(), "tar.gz" + + +@pytest.fixture(scope="function") +def tar_file_data_bad_structure() -> Tuple[bytes, str]: + tar_buffer = io.BytesIO() + + with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar_file: + for file_name, data in example_archive_files_bad_structure: + tar_file.addfile(tarfile.TarInfo(name=file_name), fileobj=data) + + return tar_buffer.getvalue(), "tar.gz" + + @pytest.fixture(params=["zip_file_data", "tar_file_data"]) def archive_file_data( request: pytest.FixtureRequest, zip_file_data: Tuple[bytes, str], tar_file_data: Tuple[bytes, str] @@ -116,6 +216,33 @@ def archive_file_data( return request.getfixturevalue(request.param) +@pytest.fixture(params=["zip_file_data_not_nested", "tar_file_data_not_nested"]) +def archive_file_data_not_nested( + request: pytest.FixtureRequest, + zip_file_data_not_nested: Tuple[bytes, str], + tar_file_data_not_nested: Tuple[bytes, str], +) -> Tuple[bytes, str]: + return request.getfixturevalue(request.param) + + +@pytest.fixture(params=["zip_file_data_single_module", "tar_file_data_single_module"]) +def archive_file_data_single_module( + request: pytest.FixtureRequest, + zip_file_data_single_module: Tuple[bytes, str], + tar_file_data_single_module: Tuple[bytes, str], +) -> Tuple[bytes, str]: + return request.getfixturevalue(request.param) + + +@pytest.fixture(params=["zip_file_data_bad_structure", "tar_file_data_bad_structure"]) +def archive_file_data_bad_structure( + request: pytest.FixtureRequest, + zip_file_data_bad_structure: Tuple[bytes, str], + tar_file_data_bad_structure: Tuple[bytes, str], +) -> Tuple[bytes, str]: + return request.getfixturevalue(request.param) + + @pytest.mark.mgmt @pytest.mark.mgmt_archive_support def test_fetch_module_repo_dns_path(): @@ -193,9 +320,7 @@ def test_fetch_module_repo_from_s3( with patch("requests.get", return_value=response_mock) as mock_requests_get: archive_path_test = f"archive::{s3_http_url}?module={module_name}" - archive_path, _ = archive.fetch_archived_module(release_path=archive_path_test) - archive.fetch_archived_module(release_path=archive_path_test) - + archive_path, subdir = archive.fetch_archived_module(release_path=archive_path_test) mock_requests_get.assert_called_once() assert mock_requests_get.call_args.kwargs["url"] == s3_http_url @@ -203,9 +328,94 @@ def test_fetch_module_repo_from_s3( # check that the sha256 header was added and that the request was signed assert "x-amz-content-sha256" in mock_requests_get.call_args.kwargs["headers"] assert "Authorization" in mock_requests_get.call_args.kwargs["headers"] - # Check that the module was extracted to the correct location assert os.path.exists(os.path.join(archive_path, module_name, "modulestack.yaml")) + assert subdir == "modules/test-module/" + + +@pytest.mark.mgmt +@pytest.mark.mgmt_archive_support +@pytest.mark.parametrize( + "s3_bucket_http_url", ["testing-bucket.s3.amazonaws.com", "testing-bucket.s3.us-west-2.amazonaws.com"] +) +def test_fetch_module_repo_from_s3_non_nested( + session_manager: None, archive_file_data_not_nested: Tuple[bytes, str], s3_bucket_http_url: str +) -> None: + archive_bytes, archive_extension = archive_file_data_not_nested + + response_mock = MagicMock() + response_mock.status_code = 200 + response_mock.content = archive_bytes + + s3_http_url = f"https://{s3_bucket_http_url}/testing-modules.{archive_extension}" + module_name = "./" + + with patch("requests.get", return_value=response_mock) as mock_requests_get: + archive_path_test = f"archive::{s3_http_url}?module={module_name}" + archive_path, subdir = archive.fetch_archived_module(release_path=archive_path_test) + mock_requests_get.assert_called_once() + + assert mock_requests_get.call_args.kwargs["url"] == s3_http_url + # check that the sha256 header was added and that the request was signed + assert "x-amz-content-sha256" in mock_requests_get.call_args.kwargs["headers"] + assert "Authorization" in mock_requests_get.call_args.kwargs["headers"] + + assert os.path.exists(os.path.join(archive_path, "deployspec.yaml")) + assert subdir == "./" + + +@pytest.mark.mgmt +@pytest.mark.mgmt_archive_support +@pytest.mark.parametrize( + "s3_bucket_http_url", ["testing-bucket.s3.amazonaws.com", "testing-bucket.s3.us-west-2.amazonaws.com"] +) +def test_fetch_module_repo_from_s3_single_module( + session_manager: None, archive_file_data_single_module: Tuple[bytes, str], s3_bucket_http_url: str +) -> None: + archive_bytes, archive_extension = archive_file_data_single_module + + response_mock = MagicMock() + response_mock.status_code = 200 + response_mock.content = archive_bytes + + s3_http_url = f"https://{s3_bucket_http_url}/testing-modules.{archive_extension}" + module_name = "modules/test-module/" + + with patch("requests.get", return_value=response_mock) as mock_requests_get: + archive_path_test = f"archive::{s3_http_url}?module={module_name}" + archive_path, subdir = archive.fetch_archived_module(release_path=archive_path_test) + mock_requests_get.assert_called_once() + + assert mock_requests_get.call_args.kwargs["url"] == s3_http_url + # check that the sha256 header was added and that the request was signed + assert "x-amz-content-sha256" in mock_requests_get.call_args.kwargs["headers"] + assert "Authorization" in mock_requests_get.call_args.kwargs["headers"] + assert os.path.exists(os.path.join(archive_path, module_name, "deployspec.yaml")) + assert subdir == "modules/test-module/" + + +@pytest.mark.mgmt +@pytest.mark.mgmt_archive_support +@pytest.mark.parametrize( + "s3_bucket_http_url", ["testing-bucket.s3.amazonaws.com", "testing-bucket.s3.us-west-2.amazonaws.com"] +) +def test_fetch_module_repo_bad_archive_structure( + session_manager: None, archive_file_data_bad_structure: Tuple[bytes, str], s3_bucket_http_url: str +) -> None: + archive_bytes, archive_extension = archive_file_data_bad_structure + + response_mock = MagicMock() + response_mock.status_code = 200 + response_mock.content = archive_bytes + + s3_http_url = f"https://{s3_bucket_http_url}/testing-modules.{archive_extension}" + module_name = "modules/test-module/" + + with patch("requests.get", return_value=response_mock) as mock_requests_get: + archive_path_test = f"archive::{s3_http_url}?module={module_name}" + with pytest.raises(InvalidConfigurationError): + archive.fetch_archived_module(release_path=archive_path_test) + mock_requests_get.assert_called_once() @pytest.mark.mgmt