diff --git a/dheater/__main__.py b/dheater/__main__.py index 37bed85..e677e1f 100755 --- a/dheater/__main__.py +++ b/dheater/__main__.py @@ -17,12 +17,16 @@ from cryptodatahub.common.algorithm import Authentication from cryptodatahub.common.parameter import DHParamWellKnown +from cryptodatahub.tls.algorithm import TlsCipherSuite, TlsNamedCurve, TlsSignatureAndHashAlgorithm from cryptoparser.common.exception import InvalidType, NotEnoughData -from cryptoparser.tls.algorithm import TlsSignatureAndHashAlgorithm -from cryptoparser.tls.ciphersuite import TlsCipherSuite -from cryptoparser.tls.extension import TlsNamedCurve, TlsExtensionEllipticCurves +from cryptoparser.tls.extension import ( + TlsExtensionsClient, + TlsExtensionKeyShareClient, + TlsExtensionKeyShareReservedClient, + TlsExtensionType, +) from cryptoparser.tls.record import TlsRecord from cryptoparser.tls.subprotocol import TlsHandshakeType from cryptoparser.tls.version import TlsProtocolVersion, TlsVersion @@ -49,6 +53,7 @@ L7ClientTlsBase, TlsHandshakeClientHelloKeyExchangeDHE, TlsHandshakeClientHelloSpecalization, + key_share_entry_from_named_curve, ) from cryptolyzer.tls.exception import TlsAlert import cryptolyzer.tls.dhparams @@ -457,16 +462,24 @@ def _prepare_packets(self): client_hello_class = TlsHandshakeClientHelloSpecalization if protocol_version > TlsProtocolVersion(TlsVersion.TLS1_2): - signature_algorithms = None - extensions = client_hello_class._get_tls1_3_extensions( # pylint: disable=protected-access - [protocol_version, ], [self.pre_check_result.dh_public_key, ], signature_algorithms - ) - extensions.append(TlsExtensionEllipticCurves([self.pre_check_result.dh_public_key, ])) client_hello = TlsHandshakeClientHelloKeyExchangeDHE( protocol_version=protocol_version, hostname=self.uri.host, named_curves=[self.pre_check_result.dh_public_key, ] ) + + extensions = [ + extension + for extension in client_hello.extensions + if extension.extension_type not in (TlsExtensionType.KEY_SHARE, TlsExtensionType.KEY_SHARE_RESERVED) + ] + + key_share_entry = key_share_entry_from_named_curve(self.pre_check_result.dh_public_key) + + client_hello.extensions = TlsExtensionsClient(extensions + [ + TlsExtensionKeyShareClient([key_share_entry]), + TlsExtensionKeyShareReservedClient([key_share_entry]), + ]) else: client_hello = client_hello_class( protocol_versions=[protocol_version, ], diff --git a/requirements.txt b/requirements.txt index 9883b36..c11de12 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ attrs>=19.2.0 -cryptolyzer>=0.11.0 +cryptolyzer==0.12.1 urllib3