Skip to content

Commit

Permalink
Add SHA256 hash (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkaMaul authored Dec 18, 2024
1 parent 1442c2b commit 68a5770
Showing 8 changed files with 104 additions and 15 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## Added

- `TimestampRequest` now accepts setting the hash algorithm to `SHA256` (in addition to `SHA512`)
([93](https://github.com/trailofbits/rfc3161-client/pull/93))

## [0.1.2] - 2024-12-11

### Changed
50 changes: 46 additions & 4 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
@@ -623,21 +623,63 @@ pub(crate) fn parse_timestamp_request(
Ok(TimeStampReq { raw: raw.into() })
}

struct HashInfo<'a> {
params: cryptography_x509::common::AlgorithmParameters<'a>,
hash_fn: fn(&[u8]) -> Vec<u8>,
}

fn detect_hash_algorithm<'a>(
py: Python<'a>,
hash_algorithm: Option<pyo3::Bound<'a, pyo3::PyAny>>,
) -> PyResult<HashInfo<'a>> {
let name = if hash_algorithm.is_none() {
"SHA512".to_string()
} else {
let algorithm = hash_algorithm.unwrap();
if !algorithm.is_instance(&crate::util::HASH_ALGORITHM.get(py)?)? {
return Err(pyo3::exceptions::PyValueError::new_err(
"invalid hash algorithm",
));
}
let name_str = algorithm
.getattr(pyo3::intern!(py, "name"))?
.extract::<pyo3::pybacked::PyBackedStr>()?;
name_str.to_string()
};

match name.as_str() {
"SHA256" => Ok(HashInfo {
params: cryptography_x509::common::AlgorithmParameters::Sha256(Some(())),
hash_fn: |data| sha2::Sha256::digest(data).to_vec(),
}),
"SHA512" => Ok(HashInfo {
params: cryptography_x509::common::AlgorithmParameters::Sha512(Some(())),
hash_fn: |data| sha2::Sha512::digest(data).to_vec(),
}),
_ => Err(pyo3::exceptions::PyValueError::new_err(format!(
"unsupported hash algorithm {:?}",
name
))),
}
}

#[pyo3::pyfunction]
#[pyo3(signature = (data, nonce, cert))]
#[pyo3(signature = (data, nonce, cert, hash_algorithm=None))]
pub(crate) fn create_timestamp_request(
py: pyo3::Python<'_>,
data: pyo3::Py<pyo3::types::PyBytes>,
nonce: bool,
cert: bool,
hash_algorithm: Option<pyo3::Bound<'_, pyo3::PyAny>>,
) -> PyResult<TimeStampReq> {
let data_bytes = data.as_bytes(py);
let hash = sha2::Sha512::digest(data_bytes);
let hash_info = detect_hash_algorithm(py, hash_algorithm)?;

let data_bytes = data.as_bytes(py);
let hash = (hash_info.hash_fn)(data_bytes);
let message_imprint = tsp_asn1::tsp::MessageImprint {
hash_algorithm: cryptography_x509::common::AlgorithmIdentifier {
oid: asn1::DefinedByMarker::marker(),
params: cryptography_x509::common::AlgorithmParameters::Sha512(Some(())),
params: hash_info.params,
},
hashed_message: hash.as_slice(),
};
3 changes: 3 additions & 0 deletions rust/src/util.rs
Original file line number Diff line number Diff line change
@@ -94,6 +94,9 @@ pub static NAME: LazyPyImport = LazyPyImport::new("cryptography.x509", &["Name"]
pub static DIRECTORY_NAME: LazyPyImport =
LazyPyImport::new("cryptography.x509", &["DirectoryName"]);

pub static HASH_ALGORITHM: LazyPyImport =
LazyPyImport::new("rfc3161_client.base", &["HashAlgorithm"]);

pub fn generate_random_bytes_for_asn1_biguint() -> Vec<u8> {
let mut rng = rand::thread_rng();
let nonce_random: u64 = rng.gen_range(0..u64::MAX);
2 changes: 2 additions & 0 deletions src/rfc3161_client/_rust/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from rfc3161_client.tsp import TimeStampRequest, TimeStampResponse
from rfc3161_client.base import HashAlgorithm

class PyMessageImprint: ...

@@ -18,6 +19,7 @@ def create_timestamp_request(
data: bytes,
nonce: bool,
cert: bool,
hash_algorithm: HashAlgorithm | None = None,
) -> TimeStampRequest: ...


2 changes: 2 additions & 0 deletions src/rfc3161_client/base.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
class HashAlgorithm(enum.Enum):
"""Hash algorithms."""

SHA256 = "SHA256"
SHA512 = "SHA512"


@@ -83,6 +84,7 @@ def build(self) -> TimeStampRequest:
data=self._data,
nonce=self._nonce,
cert=self._cert_req,
hash_algorithm=self._algorithm,
)


4 changes: 4 additions & 0 deletions test/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import cryptography.x509

SHA256_OID = cryptography.x509.ObjectIdentifier("2.16.840.1.101.3.4.2.1")
SHA512_OID = cryptography.x509.ObjectIdentifier("2.16.840.1.101.3.4.2.3")
32 changes: 21 additions & 11 deletions test/test_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import cryptography.x509
import pytest
from cryptography.hazmat.primitives import hashes

from rfc3161_client.base import HashAlgorithm, TimestampRequestBuilder

SHA512_OID = cryptography.x509.ObjectIdentifier("2.16.840.1.101.3.4.2.3")
from .common import SHA256_OID, SHA512_OID


class TestRequestBuilder:
@@ -18,13 +17,6 @@ def test_succeeds(self):
assert request.nonce is not None
assert request.policy is None

message_imprint = request.message_imprint
assert message_imprint.hash_algorithm == SHA512_OID

digest = hashes.Hash(hashes.SHA512())
digest.update(message)
assert digest.finalize() == message_imprint.message

def test_data(self):
with pytest.raises(ValueError):
TimestampRequestBuilder().build()
@@ -35,15 +27,33 @@ def test_data(self):
with pytest.raises(ValueError, match="once"):
TimestampRequestBuilder().data(b"hello").data(b"world")

def test_set_algorithm(self):
def test_algorithm_sha256(self):
message = b"random-message"
request = (
TimestampRequestBuilder().data(message).hash_algorithm(HashAlgorithm.SHA256).build()
)
assert request.message_imprint.hash_algorithm == SHA256_OID

digest = hashes.Hash(hashes.SHA256())
digest.update(message)
assert digest.finalize() == request.message_imprint.message

def test_algorithm_sha512(self):
message = b"random-message"
request = (
TimestampRequestBuilder().hash_algorithm(HashAlgorithm.SHA512).data(b"hello").build()
TimestampRequestBuilder().data(message).hash_algorithm(HashAlgorithm.SHA512).build()
)
assert request.message_imprint.hash_algorithm == SHA512_OID

digest = hashes.Hash(hashes.SHA512())
digest.update(message)
assert digest.finalize() == request.message_imprint.message

def test_set_algorithm(self):
with pytest.raises(TypeError):
TimestampRequestBuilder().hash_algorithm("invalid hash algorihtm")

# Default hash algorithm
request = TimestampRequestBuilder().data(b"hello").build()
assert request.message_imprint.hash_algorithm == SHA512_OID

21 changes: 21 additions & 0 deletions test/test_rust.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from rfc3161_client._rust import create_timestamp_request
from rfc3161_client.base import HashAlgorithm

from .common import SHA256_OID, SHA512_OID


def test_create_timestamp_request():
request = create_timestamp_request(
data=b"hello", nonce=True, cert=False, hash_algorithm=HashAlgorithm.SHA512
)

assert request.message_imprint.hash_algorithm == SHA512_OID

# Optional parameter
request = create_timestamp_request(data=b"hello", nonce=True, cert=True)
assert request.message_imprint.hash_algorithm == SHA512_OID

request = create_timestamp_request(
data=b"hello", nonce=True, cert=True, hash_algorithm=HashAlgorithm.SHA256
)
assert request.message_imprint.hash_algorithm == SHA256_OID

0 comments on commit 68a5770

Please sign in to comment.