diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 1a2d02b..c7be13c 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -623,21 +623,63 @@ pub(crate) fn parse_timestamp_request( Ok(TimeStampReq { raw: raw.into() }) } +struct HashInfo<'a> { + params: cryptography_x509::common::AlgorithmParameters<'a>, + hash_fn: fn(&[u8]) -> Vec, +} + +fn detect_hash_algorithm<'a>( + py: Python<'a>, + hash_algorithm: Option>, +) -> PyResult> { + let name = if hash_algorithm.is_none() { + "SHA512".to_string() + } else { + let algorithm = hash_algorithm.unwrap(); + if !algorithm.is_instance(&crate::util::HASH_ALGORITHM.get(py)?)? { + return Err(pyo3::exceptions::PyValueError::new_err( + "invalid hash algorithm", + )); + } + let name_str = algorithm + .getattr(pyo3::intern!(py, "name"))? + .extract::()?; + name_str.to_string() + }; + + match name.as_str() { + "SHA256" => Ok(HashInfo { + params: cryptography_x509::common::AlgorithmParameters::Sha256(Some(())), + hash_fn: |data| sha2::Sha256::digest(data).to_vec(), + }), + "SHA512" => Ok(HashInfo { + params: cryptography_x509::common::AlgorithmParameters::Sha512(Some(())), + hash_fn: |data| sha2::Sha512::digest(data).to_vec(), + }), + _ => Err(pyo3::exceptions::PyValueError::new_err(format!( + "unsupported hash algorithm {:?}", + name + ))), + } +} + #[pyo3::pyfunction] -#[pyo3(signature = (data, nonce, cert))] +#[pyo3(signature = (data, nonce, cert, hash_algorithm=None))] pub(crate) fn create_timestamp_request( py: pyo3::Python<'_>, data: pyo3::Py, nonce: bool, cert: bool, + hash_algorithm: Option>, ) -> PyResult { - let data_bytes = data.as_bytes(py); - let hash = sha2::Sha512::digest(data_bytes); + let hash_info = detect_hash_algorithm(py, hash_algorithm)?; + let data_bytes = data.as_bytes(py); + let hash = (hash_info.hash_fn)(data_bytes); let message_imprint = tsp_asn1::tsp::MessageImprint { hash_algorithm: cryptography_x509::common::AlgorithmIdentifier { oid: asn1::DefinedByMarker::marker(), - params: cryptography_x509::common::AlgorithmParameters::Sha512(Some(())), + params: hash_info.params, }, hashed_message: hash.as_slice(), }; diff --git a/rust/src/util.rs b/rust/src/util.rs index f9fbb47..98644f6 100644 --- a/rust/src/util.rs +++ b/rust/src/util.rs @@ -94,6 +94,9 @@ pub static NAME: LazyPyImport = LazyPyImport::new("cryptography.x509", &["Name"] pub static DIRECTORY_NAME: LazyPyImport = LazyPyImport::new("cryptography.x509", &["DirectoryName"]); +pub static HASH_ALGORITHM: LazyPyImport = + LazyPyImport::new("rfc3161_client.base", &["HashAlgorithm"]); + pub fn generate_random_bytes_for_asn1_biguint() -> Vec { let mut rng = rand::thread_rng(); let nonce_random: u64 = rng.gen_range(0..u64::MAX); diff --git a/src/rfc3161_client/_rust/__init__.pyi b/src/rfc3161_client/_rust/__init__.pyi index 9e1fa9e..45f485d 100644 --- a/src/rfc3161_client/_rust/__init__.pyi +++ b/src/rfc3161_client/_rust/__init__.pyi @@ -1,4 +1,5 @@ from rfc3161_client.tsp import TimeStampRequest, TimeStampResponse +from rfc3161_client.base import HashAlgorithm class PyMessageImprint: ... @@ -18,6 +19,7 @@ def create_timestamp_request( data: bytes, nonce: bool, cert: bool, + hash_algorithm: HashAlgorithm | None = None, ) -> TimeStampRequest: ... diff --git a/src/rfc3161_client/base.py b/src/rfc3161_client/base.py index 9809276..9be006d 100644 --- a/src/rfc3161_client/base.py +++ b/src/rfc3161_client/base.py @@ -14,6 +14,7 @@ class HashAlgorithm(enum.Enum): """Hash algorithms.""" + SHA256 = "SHA256" SHA512 = "SHA512" @@ -83,6 +84,7 @@ def build(self) -> TimeStampRequest: data=self._data, nonce=self._nonce, cert=self._cert_req, + hash_algorithm=self._algorithm, ) diff --git a/test/common.py b/test/common.py new file mode 100644 index 0000000..f8f7a91 --- /dev/null +++ b/test/common.py @@ -0,0 +1,4 @@ +import cryptography.x509 + +SHA256_OID = cryptography.x509.ObjectIdentifier("2.16.840.1.101.3.4.2.1") +SHA512_OID = cryptography.x509.ObjectIdentifier("2.16.840.1.101.3.4.2.3") diff --git a/test/test_base.py b/test/test_base.py index 4b9e9b6..612d38e 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -1,10 +1,9 @@ -import cryptography.x509 import pytest from cryptography.hazmat.primitives import hashes from rfc3161_client.base import HashAlgorithm, TimestampRequestBuilder -SHA512_OID = cryptography.x509.ObjectIdentifier("2.16.840.1.101.3.4.2.3") +from .common import SHA256_OID, SHA512_OID class TestRequestBuilder: @@ -18,13 +17,6 @@ def test_succeeds(self): assert request.nonce is not None assert request.policy is None - message_imprint = request.message_imprint - assert message_imprint.hash_algorithm == SHA512_OID - - digest = hashes.Hash(hashes.SHA512()) - digest.update(message) - assert digest.finalize() == message_imprint.message - def test_data(self): with pytest.raises(ValueError): TimestampRequestBuilder().build() @@ -35,15 +27,33 @@ def test_data(self): with pytest.raises(ValueError, match="once"): TimestampRequestBuilder().data(b"hello").data(b"world") - def test_set_algorithm(self): + def test_algorithm_sha256(self): + message = b"random-message" + request = ( + TimestampRequestBuilder().data(message).hash_algorithm(HashAlgorithm.SHA256).build() + ) + assert request.message_imprint.hash_algorithm == SHA256_OID + + digest = hashes.Hash(hashes.SHA256()) + digest.update(message) + assert digest.finalize() == request.message_imprint.message + + def test_algorithm_sha512(self): + message = b"random-message" request = ( - TimestampRequestBuilder().hash_algorithm(HashAlgorithm.SHA512).data(b"hello").build() + TimestampRequestBuilder().data(message).hash_algorithm(HashAlgorithm.SHA512).build() ) assert request.message_imprint.hash_algorithm == SHA512_OID + digest = hashes.Hash(hashes.SHA512()) + digest.update(message) + assert digest.finalize() == request.message_imprint.message + + def test_set_algorithm(self): with pytest.raises(TypeError): TimestampRequestBuilder().hash_algorithm("invalid hash algorihtm") + # Default hash algorithm request = TimestampRequestBuilder().data(b"hello").build() assert request.message_imprint.hash_algorithm == SHA512_OID diff --git a/test/test_rust.py b/test/test_rust.py new file mode 100644 index 0000000..8faed25 --- /dev/null +++ b/test/test_rust.py @@ -0,0 +1,21 @@ +from rfc3161_client._rust import create_timestamp_request +from rfc3161_client.base import HashAlgorithm + +from .common import SHA256_OID, SHA512_OID + + +def test_create_timestamp_request(): + request = create_timestamp_request( + data=b"hello", nonce=True, cert=False, hash_algorithm=HashAlgorithm.SHA512 + ) + + assert request.message_imprint.hash_algorithm == SHA512_OID + + # Optional parameter + request = create_timestamp_request(data=b"hello", nonce=True, cert=True) + assert request.message_imprint.hash_algorithm == SHA512_OID + + request = create_timestamp_request( + data=b"hello", nonce=True, cert=True, hash_algorithm=HashAlgorithm.SHA256 + ) + assert request.message_imprint.hash_algorithm == SHA256_OID