Skip to content

Commit

Permalink
src, test: remove I/O from sign/verify APIs
Browse files Browse the repository at this point in the history
Signed-off-by: William Woodruff <william@trailofbits.com>
  • Loading branch information
woodruffw committed Jul 15, 2024
1 parent 9a09d22 commit a307b3c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 53 deletions.
20 changes: 15 additions & 5 deletions src/pypi_attestations/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sigstore.oidc
from cryptography import x509
from pydantic import ValidationError
from sigstore._utils import _sha256_streaming
from sigstore.oidc import IdentityError, IdentityToken, Issuer
from sigstore.sign import SigningContext
from sigstore.verify import Verifier, policy
Expand Down Expand Up @@ -183,15 +184,19 @@ def _sign(args: argparse.Namespace) -> None:
for file_path in args.files:
_logger.debug(f"Signing {file_path}")

signature_path = Path(f"{file_path}.publish.attestation")
with file_path.open(mode="rb", buffering=0) as io:
# Replace this with `hashlib.file_digest()` once
# our minimum supported Python is >=3.11
digest = _sha256_streaming(io).hex()

try:
attestation = Attestation.sign(signer, file_path)
attestation = Attestation.sign(signer, file_path.name, digest)
except AttestationError as e:
_die(f"Failed to sign: {e}")

_logger.debug("Attestation saved for %s saved in %s", file_path, signature_path)

signature_path = Path(f"{file_path}.publish.attestation")
signature_path.write_text(attestation.model_dump_json())
_logger.debug("Attestation saved for %s saved in %s", file_path, signature_path)


def _inspect(args: argparse.Namespace) -> None:
Expand Down Expand Up @@ -266,8 +271,13 @@ def _verify(args: argparse.Namespace) -> None:
except ValidationError as validation_error:
_die(f"Invalid attestation ({file_path}): {validation_error}")

with file_path.open(mode="rb", buffering=0) as io:
# Replace this with `hashlib.file_digest()` once
# our minimum supported Python is >=3.11
digest = _sha256_streaming(io).hex()

try:
attestation.verify(verifier, pol, file_path)
attestation.verify(verifier, pol, file_path.name, digest)
except VerificationError as verification_error:
_logger.error("Verification failed for %s: %s", file_path, verification_error)
continue
Expand Down
36 changes: 13 additions & 23 deletions src/pypi_attestations/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
)
from pydantic import Base64Bytes, BaseModel
from pydantic_core import ValidationError
from sigstore._utils import _sha256_streaming
from sigstore.dsse import Envelope as DsseEnvelope
from sigstore.dsse import Error as DsseError
from sigstore.dsse import _DigestSet, _Statement, _StatementBuilder, _Subject
Expand All @@ -31,8 +30,6 @@
from sigstore_protobuf_specs.io.intoto import Signature as _Signature

if TYPE_CHECKING:
from pathlib import Path # pragma: no cover

from sigstore.sign import Signer # pragma: no cover
from sigstore.verify import Verifier # pragma: no cover
from sigstore.verify.policy import VerificationPolicy # pragma: no cover
Expand Down Expand Up @@ -98,21 +95,13 @@ class Attestation(BaseModel):
"""

@classmethod
def sign(cls, signer: Signer, dist: Path) -> Attestation:
"""Create an envelope, with signature, from a distribution file.
def sign(cls, signer: Signer, dist_filename: str, dist_digest: str) -> Attestation:
"""Create an envelope, with signature, from the given filename and digest.
On failure, raises `AttestationError`.
"""
try:
with dist.open(mode="rb", buffering=0) as io:
# Replace this with `hashlib.file_digest()` once
# our minimum supported Python is >=3.11
digest = _sha256_streaming(io).hex()
except OSError as e:
raise AttestationError(str(e))

try:
name = _ultranormalize_dist_filename(dist.name)
name = _ultranormalize_dist_filename(dist_filename)
except (ValueError, InvalidWheelFilename, InvalidSdistFilename) as e:
raise AttestationError(str(e))

Expand All @@ -123,7 +112,7 @@ def sign(cls, signer: Signer, dist: Path) -> Attestation:
[
_Subject(
name=name,
digest=_DigestSet(root={"sha256": digest}),
digest=_DigestSet(root={"sha256": dist_digest}),
)
]
)
Expand All @@ -144,19 +133,20 @@ def sign(cls, signer: Signer, dist: Path) -> Attestation:
raise AttestationError(str(e))

def verify(
self, verifier: Verifier, policy: VerificationPolicy, dist: Path
self,
verifier: Verifier,
policy: VerificationPolicy,
dist_filename: str,
dist_digest: str,
) -> tuple[str, dict[str, Any] | None]:
"""Verify against an existing Python artifact.
The artifact is identified by its distribution filename (sdist or wheel)
and its SHA-256 digest, as a hex string.
Returns a tuple of the in-toto predicate type and optional deserialized JSON predicate.
On failure, raises an appropriate subclass of `AttestationError`.
"""
with dist.open(mode="rb", buffering=0) as io:
# Replace this with `hashlib.file_digest()` once
# our minimum supported Python is >=3.11
expected_digest = _sha256_streaming(io).hex()

bundle = self.to_bundle()
try:
type_, payload = verifier.verify_dsse(bundle, policy)
Expand Down Expand Up @@ -185,7 +175,7 @@ def verify(
raise VerificationError(f"invalid subject: {str(e)}")

try:
normalized = _ultranormalize_dist_filename(dist.name)
normalized = _ultranormalize_dist_filename(dist_filename)
except ValueError as e:
raise VerificationError(f"invalid distribution name: {str(e)}")

Expand All @@ -195,7 +185,7 @@ def verify(
)

digest = subject.digest.root.get("sha256")
if digest is None or digest != expected_digest:
if digest is None or digest != dist_digest:
raise VerificationError("subject does not match distribution digest")

try:
Expand Down
60 changes: 35 additions & 25 deletions test/test_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Internal implementation tests."""

import os
from hashlib import sha256
from pathlib import Path

import pretend
Expand All @@ -21,11 +22,13 @@
_ASSETS = _HERE / "assets"

artifact_path = _ASSETS / "rfc8785-0.1.2-py3-none-any.whl"
artifact_digest = sha256(artifact_path.read_bytes()).hexdigest()
bundle_path = _ASSETS / "rfc8785-0.1.2-py3-none-any.whl.sigstore"
attestation_path = _ASSETS / "rfc8785-0.1.2-py3-none-any.whl.attestation"

# produced by actions/attest@v1
gh_signed_artifact_path = _ASSETS / "pypi_attestation_models-0.0.4a2.tar.gz"
gh_signed_artifact_digest = sha256(gh_signed_artifact_path.read_bytes()).hexdigest()
gh_signed_bundle_path = _ASSETS / "pypi_attestation_models-0.0.4a2.tar.gz.sigstore"


Expand All @@ -36,17 +39,19 @@ def test_roundtrip(self, id_token: IdentityToken) -> None:
verifier = Verifier.staging()

with sign_ctx.signer(id_token) as signer:
attestation = impl.Attestation.sign(signer, artifact_path)
attestation = impl.Attestation.sign(signer, artifact_path.name, artifact_digest)

attestation.verify(verifier, policy.UnsafeNoOp(), artifact_path)
attestation.verify(verifier, policy.UnsafeNoOp(), artifact_path.name, artifact_digest)

# converting to a bundle and verifying as a bundle also works
bundle = attestation.to_bundle()
verifier.verify_dsse(bundle, policy.UnsafeNoOp())

# converting back also works
roundtripped_attestation = impl.Attestation.from_bundle(bundle)
roundtripped_attestation.verify(verifier, policy.UnsafeNoOp(), artifact_path)
roundtripped_attestation.verify(
verifier, policy.UnsafeNoOp(), artifact_path.name, artifact_digest
)

def test_sign_invalid_dist_filename(self, tmp_path: Path) -> None:
bad_dist = tmp_path / "invalid-name.tar.gz"
Expand All @@ -56,24 +61,20 @@ def test_sign_invalid_dist_filename(self, tmp_path: Path) -> None:
impl.AttestationError,
match=r"Invalid sdist filename \(invalid version\): invalid-name\.tar\.gz",
):
impl.Attestation.sign(pretend.stub(), bad_dist)
impl.Attestation.sign(pretend.stub(), bad_dist.name, "abcd")

def test_sign_raises_attestation_exception(self, tmp_path: Path) -> None:
non_existing_file = tmp_path / "invalid-name.tar.gz"
with pytest.raises(impl.AttestationError, match="No such file"):
impl.Attestation.sign(pretend.stub(), non_existing_file)

bad_wheel_filename = tmp_path / "invalid-name.whl"
bad_wheel_filename.write_bytes(b"junk")

with pytest.raises(impl.AttestationError, match="Invalid wheel filename"):
impl.Attestation.sign(pretend.stub(), bad_wheel_filename)
impl.Attestation.sign(pretend.stub(), bad_wheel_filename.name, "abcd")

bad_sdist_filename = tmp_path / "invalid_name.tar.gz"
bad_sdist_filename.write_bytes(b"junk")

with pytest.raises(impl.AttestationError, match="Invalid sdist filename"):
impl.Attestation.sign(pretend.stub(), bad_sdist_filename)
impl.Attestation.sign(pretend.stub(), bad_sdist_filename.name, "abcd")

def test_wrong_predicate_raises_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
def dummy_predicate(self_: _StatementBuilder, _: str) -> _StatementBuilder:
Expand All @@ -83,7 +84,7 @@ def dummy_predicate(self_: _StatementBuilder, _: str) -> _StatementBuilder:

monkeypatch.setattr(sigstore.dsse._StatementBuilder, "predicate_type", dummy_predicate)
with pytest.raises(impl.AttestationError, match="invalid statement"):
impl.Attestation.sign(pretend.stub(), artifact_path)
impl.Attestation.sign(pretend.stub(), artifact_path.name, artifact_digest)

@online
def test_expired_certificate(
Expand All @@ -97,7 +98,7 @@ def in_validity_period(_: IdentityToken) -> bool:
sign_ctx = SigningContext.staging()
with sign_ctx.signer(id_token, cache=False) as signer:
with pytest.raises(impl.AttestationError):
impl.Attestation.sign(signer, artifact_path)
impl.Attestation.sign(signer, artifact_path.name, artifact_digest)

@online
def test_multiple_signatures(
Expand All @@ -115,7 +116,7 @@ def get_bundle(*_) -> Bundle: # noqa: ANN002

with pytest.raises(impl.AttestationError):
with sign_ctx.signer(id_token) as signer:
impl.Attestation.sign(signer, artifact_path)
impl.Attestation.sign(signer, artifact_path.name, artifact_digest)

def test_verify_github_attested(self) -> None:
verifier = Verifier.production()
Expand All @@ -131,7 +132,9 @@ def test_verify_github_attested(self) -> None:
bundle = Bundle.from_json(gh_signed_bundle_path.read_bytes())
attestation = impl.Attestation.from_bundle(bundle)

predicate_type, predicate = attestation.verify(verifier, pol, gh_signed_artifact_path)
predicate_type, predicate = attestation.verify(
verifier, pol, gh_signed_artifact_path.name, gh_signed_artifact_digest
)
assert predicate_type == "https://docs.pypi.org/attestations/publish/v1"
assert predicate == {}

Expand All @@ -143,7 +146,9 @@ def test_verify(self) -> None:
)

attestation = impl.Attestation.model_validate_json(attestation_path.read_text())
predicate_type, predicate = attestation.verify(verifier, pol, artifact_path)
predicate_type, predicate = attestation.verify(
verifier, pol, artifact_path.name, artifact_digest
)

assert predicate_type == "https://docs.pypi.org/attestations/publish/v1"
assert predicate is None
Expand All @@ -168,7 +173,12 @@ def test_verify_digest_mismatch(self, tmp_path: Path) -> None:
with pytest.raises(
impl.VerificationError, match="subject does not match distribution digest"
):
attestation.verify(verifier, pol, modified_artifact_path)
attestation.verify(
verifier,
pol,
modified_artifact_path.name,
sha256(modified_artifact_path.read_bytes()).hexdigest(),
)

def test_verify_filename_mismatch(self, tmp_path: Path) -> None:
verifier = Verifier.staging()
Expand All @@ -186,7 +196,7 @@ def test_verify_filename_mismatch(self, tmp_path: Path) -> None:
with pytest.raises(
impl.VerificationError, match="subject does not match distribution name"
):
attestation.verify(verifier, pol, modified_artifact_path)
attestation.verify(verifier, pol, modified_artifact_path.name, artifact_digest)

def test_verify_policy_mismatch(self) -> None:
verifier = Verifier.staging()
Expand All @@ -196,7 +206,7 @@ def test_verify_policy_mismatch(self) -> None:
attestation = impl.Attestation.model_validate_json(attestation_path.read_text())

with pytest.raises(impl.VerificationError, match=r"Certificate's SANs do not match"):
attestation.verify(verifier, pol, artifact_path)
attestation.verify(verifier, pol, artifact_path.name, artifact_digest)

def test_verify_wrong_envelope(self) -> None:
verifier = pretend.stub(
Expand All @@ -207,7 +217,7 @@ def test_verify_wrong_envelope(self) -> None:
attestation = impl.Attestation.model_validate_json(attestation_path.read_text())

with pytest.raises(impl.VerificationError, match="expected JSON envelope, got fake-type"):
attestation.verify(verifier, pol, artifact_path)
attestation.verify(verifier, pol, artifact_path.name, artifact_digest)

def test_verify_bad_payload(self) -> None:
verifier = pretend.stub(
Expand All @@ -220,7 +230,7 @@ def test_verify_bad_payload(self) -> None:
attestation = impl.Attestation.model_validate_json(attestation_path.read_text())

with pytest.raises(impl.VerificationError, match="invalid statement"):
attestation.verify(verifier, pol, artifact_path)
attestation.verify(verifier, pol, artifact_path.name, artifact_digest)

def test_verify_too_many_subjects(self) -> None:
statement = (
Expand Down Expand Up @@ -249,7 +259,7 @@ def test_verify_too_many_subjects(self) -> None:
attestation = impl.Attestation.model_validate_json(attestation_path.read_text())

with pytest.raises(impl.VerificationError, match="too many subjects in statement"):
attestation.verify(verifier, pol, artifact_path)
attestation.verify(verifier, pol, artifact_path.name, artifact_digest)

def test_verify_subject_missing_name(self) -> None:
statement = (
Expand Down Expand Up @@ -277,7 +287,7 @@ def test_verify_subject_missing_name(self) -> None:
attestation = impl.Attestation.model_validate_json(attestation_path.read_text())

with pytest.raises(impl.VerificationError, match="invalid subject: missing name"):
attestation.verify(verifier, pol, artifact_path)
attestation.verify(verifier, pol, artifact_path.name, artifact_digest)

def test_verify_subject_invalid_name(self) -> None:
statement = (
Expand Down Expand Up @@ -308,7 +318,7 @@ def test_verify_subject_invalid_name(self) -> None:
attestation = impl.Attestation.model_validate_json(attestation_path.read_text())

with pytest.raises(impl.VerificationError, match="invalid subject: Invalid wheel filename"):
attestation.verify(verifier, pol, artifact_path)
attestation.verify(verifier, pol, artifact_path.name, artifact_digest)

def test_verify_distribution_invalid_name(self, tmp_path: Path) -> None:
statement = (
Expand Down Expand Up @@ -343,7 +353,7 @@ def test_verify_distribution_invalid_name(self, tmp_path: Path) -> None:
with pytest.raises(
impl.VerificationError, match="invalid distribution name: Invalid wheel filename"
):
attestation.verify(verifier, pol, bad_artifact)
attestation.verify(verifier, pol, bad_artifact.name, artifact_digest)

def test_verify_unknown_attestation_type(self) -> None:
statement = (
Expand Down Expand Up @@ -380,7 +390,7 @@ def test_verify_unknown_attestation_type(self) -> None:
attestation = impl.Attestation.model_validate_json(attestation_path.read_text())

with pytest.raises(impl.VerificationError, match="unknown attestation type: foo"):
attestation.verify(verifier, pol, artifact_path)
attestation.verify(verifier, pol, artifact_path.name, artifact_digest)


def test_from_bundle_missing_signatures() -> None:
Expand Down

0 comments on commit a307b3c

Please sign in to comment.