diff --git a/src/pypi_attestations/_impl.py b/src/pypi_attestations/_impl.py index 735494a..0db7e2d 100644 --- a/src/pypi_attestations/_impl.py +++ b/src/pypi_attestations/_impl.py @@ -6,6 +6,7 @@ from __future__ import annotations import base64 +from enum import Enum from typing import TYPE_CHECKING, Annotated, Any, Literal, NewType import sigstore.errors @@ -37,6 +38,13 @@ from sigstore.verify.policy import VerificationPolicy # pragma: no cover +class AttestationType(str, Enum): + """Attestation types known to PyPI.""" + + SLSA_PROVENANCE_V1 = "https://slsa.dev/provenance/v1" + PYPI_PUBLISH_V1 = "https://docs.pypi.org/attestations/publish/v1" + + class AttestationError(ValueError): """Base error for all APIs.""" @@ -119,7 +127,7 @@ def sign(cls, signer: Signer, dist: Path) -> Attestation: ) ] ) - .predicate_type("https://docs.pypi.org/attestations/publish/v1") + .predicate_type(AttestationType.PYPI_PUBLISH_V1) .build() ) except DsseError as e: @@ -186,6 +194,11 @@ def verify( if digest is None or digest != expected_digest: raise VerificationError("subject does not match distribution digest") + try: + AttestationType(statement.predicate_type) + except ValueError: + raise VerificationError(f"unknown attestation type: {statement.predicate_type}") + return statement.predicate_type, statement.predicate def to_bundle(self) -> Bundle: diff --git a/test/test_impl.py b/test/test_impl.py index 5a8c832..ec181b5 100644 --- a/test/test_impl.py +++ b/test/test_impl.py @@ -58,9 +58,7 @@ def test_sign_invalid_dist_filename(self, tmp_path: Path) -> None: ): impl.Attestation.sign(pretend.stub(), bad_dist) - def test_sign_raises_attestation_exception( - self, id_token: IdentityToken, tmp_path: Path - ) -> None: + def test_sign_raises_attestation_exception(self, tmp_path: Path) -> None: non_existing_file = tmp_path / "invalid-name.tar.gz" with pytest.raises(impl.AttestationError, match="No such file"): impl.Attestation.sign(pretend.stub(), non_existing_file) @@ -77,9 +75,7 @@ def test_sign_raises_attestation_exception( with pytest.raises(impl.AttestationError, match="Invalid sdist filename"): impl.Attestation.sign(pretend.stub(), bad_sdist_filename) - def test_wrong_predicate_raises_exception( - self, id_token: IdentityToken, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_wrong_predicate_raises_exception(self, monkeypatch: pytest.MonkeyPatch) -> None: def dummy_predicate(self_: _StatementBuilder, _: str) -> _StatementBuilder: # wrong type here to have a validation error self_._predicate_type = False @@ -89,6 +85,7 @@ def dummy_predicate(self_: _StatementBuilder, _: str) -> _StatementBuilder: with pytest.raises(impl.AttestationError, match="invalid statement"): impl.Attestation.sign(pretend.stub(), artifact_path) + @online def test_expired_certificate( self, id_token: IdentityToken, monkeypatch: pytest.MonkeyPatch ) -> None: @@ -102,6 +99,7 @@ def in_validity_period(_: IdentityToken) -> bool: with pytest.raises(impl.AttestationError): impl.Attestation.sign(signer, artifact_path) + @online def test_multiple_signatures( self, id_token: IdentityToken, monkeypatch: pytest.MonkeyPatch ) -> None: @@ -312,6 +310,43 @@ def test_verify_subject_invalid_name(self) -> None: with pytest.raises(impl.VerificationError, match="invalid subject: Invalid wheel filename"): attestation.verify(verifier, pol, artifact_path) + def test_verify_unknown_attestation_type(self) -> None: + statement = ( + _StatementBuilder() # noqa: SLF001 + .subjects( + [ + _Subject( + name="rfc8785-0.1.2-py3-none-any.whl", + digest=_DigestSet( + root={ + "sha256": ( + "c4e92e9ecc828bef2aa7dba1de8ac983511f7532a0df11c770d39099a25cf201" + ), + } + ), + ), + ] + ) + .predicate_type("foo") + .build() + ._inner.model_dump_json() + ) + + verifier = pretend.stub( + verify_dsse=pretend.call_recorder( + lambda bundle, policy: ( + "application/vnd.in-toto+json", + statement.encode(), + ) + ) + ) + pol = pretend.stub() + + attestation = impl.Attestation.model_validate_json(attestation_path.read_text()) + + with pytest.raises(impl.VerificationError, match="unknown attestation type: foo"): + attestation.verify(verifier, pol, artifact_path) + def test_from_bundle_missing_signatures() -> None: bundle = Bundle.from_json(bundle_path.read_bytes())