From bf946d28c5843cf6c0b1d91fb023aa6811897139 Mon Sep 17 00:00:00 2001 From: Juliusz Sosinowicz Date: Fri, 20 Sep 2024 11:54:09 +0200 Subject: [PATCH] Address code review --- src/dtls.c | 96 +++++++++++++++++++++++++++++----------------- src/dtls13.c | 29 ++------------ src/internal.c | 71 +++++++++++++++------------------- src/tls.c | 13 +++---- wolfssl/internal.h | 4 +- 5 files changed, 103 insertions(+), 110 deletions(-) diff --git a/src/dtls.c b/src/dtls.c index c30066be23..5b2356a922 100644 --- a/src/dtls.c +++ b/src/dtls.c @@ -1063,7 +1063,7 @@ static int DtlsCidGetSize(WOLFSSL* ssl, unsigned int* size, int rx) ConnectionID* id; CIDInfo* info; - if (ssl == NULL) + if (ssl == NULL || size == NULL) return BAD_FUNC_ARG; info = DtlsCidGetInfo(ssl); @@ -1071,14 +1071,12 @@ static int DtlsCidGetSize(WOLFSSL* ssl, unsigned int* size, int rx) return WOLFSSL_FAILURE; id = rx ? info->rx : info->tx; - if (id == NULL || id->length == 0) { - if (size != NULL) - *size = 0; - return WOLFSSL_FAILURE; + if (id == NULL) { + *size = 0; + return WOLFSSL_SUCCESS; } - if (size != NULL) - *size = id->length; + *size = id->length; return WOLFSSL_SUCCESS; } @@ -1234,46 +1232,42 @@ int TLSX_ConnectionID_Parse(WOLFSSL* ssl, const byte* input, word16 length, } } + if (length < OPAQUE8_LEN) + return BUFFER_ERROR; + + cidSz = *input; + if (cidSz + OPAQUE8_LEN > length) + return BUFFER_ERROR; + info = DtlsCidGetInfo(ssl); if (info == NULL) return BAD_STATE_E; /* it may happen if we process two ClientHello because the server sent an * HRR/HVR request */ - if (info->tx != NULL) { + if (info->tx != NULL || info->negotiated) { if (ssl->options.side != WOLFSSL_SERVER_END && ssl->options.serverState != SERVER_HELLO_RETRY_REQUEST_COMPLETE && !IsSCR(ssl)) return BAD_STATE_E; - if (!info->negotiated) { - XFREE(info->tx, ssl->heap, DYNAMIC_TYPE_TLSX); - info->tx = NULL; - } - } - - if (length < OPAQUE8_LEN) - return BUFFER_ERROR; - - cidSz = *input; - if (cidSz + OPAQUE8_LEN > length) - return BUFFER_ERROR; + /* Should not be null if negotiated */ + if (info->tx == NULL) + return BAD_STATE_E; - if (cidSz > 0) { - if (!info->negotiated) { - ConnectionID* id = (ConnectionID*)XMALLOC(sizeof(*id) + cidSz, - ssl->heap, DYNAMIC_TYPE_TLSX); - if (id == NULL) - return MEMORY_ERROR; - XMEMCPY(id->id, input + OPAQUE8_LEN, cidSz); - id->length = cidSz; - info->tx = id; - } - else { - /* For now we don't support changing the CID on a rehandshake */ - if (XMEMCMP(info->tx->id, input + OPAQUE8_LEN, cidSz) != 0) - return DTLS_CID_ERROR; - } + /* For now we don't support changing the CID on a rehandshake */ + if (cidSz != info->tx->length || + XMEMCMP(info->tx->id, input + OPAQUE8_LEN, cidSz) != 0) + return DTLS_CID_ERROR; + } + else if (cidSz > 0) { + ConnectionID* id = (ConnectionID*)XMALLOC(sizeof(*id) + cidSz, + ssl->heap, DYNAMIC_TYPE_TLSX); + if (id == NULL) + return MEMORY_ERROR; + XMEMCPY(id->id, input + OPAQUE8_LEN, cidSz); + id->length = cidSz; + info->tx = id; } info->negotiated = 1; @@ -1382,8 +1376,38 @@ int wolfSSL_dtls_cid_max_size(void) { return DTLS_CID_MAX_SIZE; } - #endif /* WOLFSSL_DTLS_CID */ + +byte DtlsGetCidTxSize(WOLFSSL* ssl) +{ +#ifdef WOLFSSL_DTLS_CID + unsigned int cidSz; + int ret; + ret = wolfSSL_dtls_cid_get_tx_size(ssl, &cidSz); + if (ret != WOLFSSL_SUCCESS) + return 0; + return (byte)cidSz; +#else + (void)ssl; + return 0; +#endif +} + +byte DtlsGetCidRxSize(WOLFSSL* ssl) +{ +#ifdef WOLFSSL_DTLS_CID + unsigned int cidSz; + int ret; + ret = wolfSSL_dtls_cid_get_rx_size(ssl, &cidSz); + if (ret != WOLFSSL_SUCCESS) + return 0; + return (byte)cidSz; +#else + (void)ssl; + return 0; +#endif +} + #endif /* WOLFSSL_DTLS */ #endif /* WOLFCRYPT_ONLY */ diff --git a/src/dtls13.c b/src/dtls13.c index 31b3e53740..aa630d3d57 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -1054,25 +1054,6 @@ static WC_INLINE word8 Dtls13GetEpochBits(w64wrapper epoch) } #ifdef WOLFSSL_DTLS_CID -static byte Dtls13GetCidTxSize(WOLFSSL* ssl) -{ - unsigned int cidSz; - int ret; - ret = wolfSSL_dtls_cid_get_tx_size(ssl, &cidSz); - if (ret != WOLFSSL_SUCCESS) - return 0; - return (byte)cidSz; -} - -static byte Dtls13GetCidRxSize(WOLFSSL* ssl) -{ - unsigned int cidSz; - int ret; - ret = wolfSSL_dtls_cid_get_rx_size(ssl, &cidSz); - if (ret != WOLFSSL_SUCCESS) - return 0; - return (byte)cidSz; -} static int Dtls13AddCID(WOLFSSL* ssl, byte* flags, byte* out, word16* idx) { @@ -1082,7 +1063,7 @@ static int Dtls13AddCID(WOLFSSL* ssl, byte* flags, byte* out, word16* idx) if (!wolfSSL_dtls_cid_is_enabled(ssl)) return 0; - cidSz = Dtls13GetCidTxSize(ssl); + cidSz = DtlsGetCidTxSize(ssl); /* no cid */ if (cidSz == 0) @@ -1138,8 +1119,6 @@ static int Dtls13UnifiedHeaderParseCID(WOLFSSL* ssl, byte flags, #else #define Dtls13AddCID(a, b, c, d) 0 -#define Dtls13GetCidRxSize(a) 0 -#define Dtls13GetCidTxSize(a) 0 #define Dtls13UnifiedHeaderParseCID(a, b, c, d, e) 0 #endif /* WOLFSSL_DTLS_CID */ @@ -1245,7 +1224,7 @@ int Dtls13EncryptRecordNumber(WOLFSSL* ssl, byte* hdr, word16 recordLength) seqLength = (*hdr & DTLS13_LEN_BIT) ? DTLS13_SEQ_16_LEN : DTLS13_SEQ_8_LEN; - cidSz = Dtls13GetCidTxSize(ssl); + cidSz = DtlsGetCidTxSize(ssl); /* header flags + seq number + CID size*/ hdrLength = OPAQUE8_LEN + seqLength + cidSz; @@ -1276,7 +1255,7 @@ word16 Dtls13GetRlHeaderLength(WOLFSSL* ssl, byte isEncrypted) if (!isEncrypted) return DTLS_RECORD_HEADER_SZ; - return DTLS13_UNIFIED_HEADER_SIZE + Dtls13GetCidTxSize(ssl); + return DTLS13_UNIFIED_HEADER_SIZE + DtlsGetCidTxSize(ssl); } /** @@ -1403,7 +1382,7 @@ int Dtls13GetUnifiedHeaderSize(WOLFSSL* ssl, const byte input, word16* size) return BAD_FUNC_ARG; /* flags (1) + CID + seq 8bit (1) */ - *size = OPAQUE8_LEN + Dtls13GetCidRxSize(ssl) + OPAQUE8_LEN; + *size = OPAQUE8_LEN + DtlsGetCidRxSize(ssl) + OPAQUE8_LEN; if (input & DTLS13_SEQ_LEN_BIT) *size += OPAQUE8_LEN; if (input & DTLS13_LEN_BIT) diff --git a/src/internal.c b/src/internal.c index d099876578..51729eaf32 100644 --- a/src/internal.c +++ b/src/internal.c @@ -10135,9 +10135,8 @@ int HashOutput(WOLFSSL* ssl, const byte* output, int sz, int ivSz) #endif /* WOLFSSL_DTLS13 */ } else { #ifdef WOLFSSL_DTLS_CID - unsigned int cidSz = 0; - if (IsEncryptionOn(ssl, 1) && - wolfSSL_dtls_cid_get_tx_size(ssl, &cidSz) == WOLFSSL_SUCCESS) { + byte cidSz = DtlsGetCidTxSize(ssl); + if (IsEncryptionOn(ssl, 1) && cidSz > 0) { adj += cidSz; sz -= cidSz + 1; /* +1 to not hash the real content type */ } @@ -10225,9 +10224,8 @@ static void AddRecordHeader(byte* output, word32 length, byte type, /* dtls record layer header extensions */ DtlsRecordLayerHeader* dtls = (DtlsRecordLayerHeader*)output; #ifdef WOLFSSL_DTLS_CID - unsigned int cidSz = 0; - if (type == dtls12_cid && - wolfSSL_dtls_cid_get_tx_size(ssl, &cidSz) == WOLFSSL_SUCCESS) { + byte cidSz = 0; + if (type == dtls12_cid && (cidSz = DtlsGetCidTxSize(ssl)) > 0) { wolfSSL_dtls_cid_get_tx(ssl, output + DTLS12_CID_OFFSET, cidSz); c16toa((word16)length, output + DTLS12_CID_OFFSET + cidSz); } @@ -11343,8 +11341,8 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, word32* inOutIdx, static int GetDtlsRecordHeader(WOLFSSL* ssl, word32* inOutIdx, RecordLayerHeader* rh, word16* size) { -#if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) - unsigned int cidSz = 0; +#ifdef WOLFSSL_DTLS_CID + byte cidSz = 0; #endif #ifdef HAVE_FUZZER @@ -11399,10 +11397,8 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, word32* inOutIdx, *inOutIdx += ENUM_LEN + VERSION_SZ; ato16(ssl->buffers.inputBuffer.buffer + *inOutIdx, &ssl->keys.curEpoch); -#if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) - if (rh->type == dtls12_cid && - (wolfSSL_dtls_cid_get_rx_size(ssl, &cidSz) != WOLFSSL_SUCCESS || - cidSz == 0)) +#ifdef WOLFSSL_DTLS_CID + if (rh->type == dtls12_cid && (cidSz = DtlsGetCidRxSize(ssl)) == 0) return DTLS_CID_ERROR; #endif @@ -11437,10 +11433,11 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, word32* inOutIdx, ssl->keys.curSeq = w64From32(ssl->keys.curSeq_hi, ssl->keys.curSeq_lo); #endif /* WOLFSSL_DTLS13 */ -#if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) +#ifdef WOLFSSL_DTLS_CID if (rh->type == dtls12_cid) { byte cid[DTLS_CID_MAX_SIZE]; - if (ssl->buffers.inputBuffer.length - *inOutIdx < cidSz + LENGTH_SZ) + if (ssl->buffers.inputBuffer.length - *inOutIdx < + (word32)cidSz + LENGTH_SZ) return LENGTH_ERROR; if (cidSz > DTLS_CID_MAX_SIZE || wolfSSL_dtls_cid_get_rx(ssl, cid, cidSz) != WOLFSSL_SUCCESS) @@ -18927,9 +18924,9 @@ typedef int (*Sm4AuthDecryptFunc)(wc_Sm4* sm4, byte* out, const byte* in, #endif #if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) -#define TLS_AEAD_CID_SZ(s, dec, c) \ - ((dec) ? wolfSSL_dtls_cid_get_rx_size((s), (c)) \ - : wolfSSL_dtls_cid_get_tx_size((s), (c))) +#define TLS_AEAD_CID_SZ(s, dec) \ + ((dec) ? DtlsGetCidRxSize((s)) \ + : DtlsGetCidTxSize((s))) #define TLS_AEAD_CID(s, dec, b, c) \ ((dec) ? wolfSSL_dtls_cid_get_rx((s), (b), (c)) \ : wolfSSL_dtls_cid_get_tx((s), (b), (c))) @@ -18941,17 +18938,16 @@ typedef int (*Sm4AuthDecryptFunc)(wc_Sm4* sm4, byte* out, const byte* in, * @param type Record content type * @param additional AAD output buffer. Assumed AEAD_AUTH_DATA_SZ length. * @param dec Are we decrypting - * @return > 0 length of auth data - * <=0 error + * @return >= 0 length of auth data + * < 0 error */ int writeAeadAuthData(WOLFSSL* ssl, word16 sz, byte type, byte* additional, byte dec, byte** seq, int verifyOrder) { word32 idx = 0; #if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) - unsigned int cidSz = 0; - if (ssl->options.dtls && - TLS_AEAD_CID_SZ(ssl, dec, &cidSz) == WOLFSSL_SUCCESS) { + byte cidSz = 0; + if (ssl->options.dtls && (cidSz = TLS_AEAD_CID_SZ(ssl, dec)) > 0) { if (cidSz > DTLS_CID_MAX_SIZE) { WOLFSSL_MSG("DTLS CID too large"); return DTLS_CID_ERROR; @@ -18960,7 +18956,7 @@ int writeAeadAuthData(WOLFSSL* ssl, word16 sz, byte type, XMEMSET(additional + idx, 0xFF, SEQ_SZ); idx += SEQ_SZ; additional[idx++] = dtls12_cid; - additional[idx++] = (byte)cidSz; + additional[idx++] = cidSz; additional[idx++] = dtls12_cid; additional[idx++] = dec ? ssl->curRL.pvMajor : ssl->version.major; additional[idx++] = dec ? ssl->curRL.pvMinor : ssl->version.minor; @@ -18968,7 +18964,7 @@ int writeAeadAuthData(WOLFSSL* ssl, word16 sz, byte type, if (seq != NULL) *seq = additional + idx; idx += SEQ_SZ; - if (TLS_AEAD_CID(ssl, dec, additional + idx, cidSz) + if (TLS_AEAD_CID(ssl, dec, additional + idx, (unsigned int)cidSz) == WC_NO_ERR_TRACE(WOLFSSL_FAILURE)) { WOLFSSL_MSG("DTLS CID write failed"); return DTLS_CID_ERROR; @@ -21785,8 +21781,6 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr) } #if defined(HAVE_ENCRYPT_THEN_MAC) && !defined(WOLFSSL_AEAD_ONLY) if (IsEncryptionOn(ssl, 0) && ssl->options.startedETMRead) { - /* For TLS v1.1 the block size and explicit IV are added to idx, - * so it needs to be included in this limit check */ if ((ssl->curSize - ssl->keys.padSz > MAX_PLAINTEXT_SZ) #ifdef WOLFSSL_ASYNC_CRYPT && ssl->buffers.inputBuffer.length != @@ -21804,8 +21798,6 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr) else #endif /* TLS13 plaintext limit is checked earlier before decryption */ - /* For TLS v1.1 the block size and explicit IV are added to idx, - * so it needs to be included in this limit check */ if (!IsAtLeastTLSv1_3(ssl->version) && ssl->curSize - ssl->keys.padSz > MAX_PLAINTEXT_SZ #ifdef WOLFSSL_ASYNC_CRYPT @@ -22816,9 +22808,8 @@ int BuildMessage(WOLFSSL* ssl, byte* output, int outSz, const byte* input, args->headerSz += DTLS_RECORD_EXTRA; #ifdef WOLFSSL_DTLS_CID if (ssl->options.dtls) { - unsigned int cidSz = 0; - if (wolfSSL_dtls_cid_get_tx_size(ssl, &cidSz) - == WOLFSSL_SUCCESS) { + byte cidSz = 0; + if ((cidSz = DtlsGetCidTxSize(ssl)) > 0) { args->sz += cidSz; args->idx += cidSz; args->headerSz += cidSz; @@ -22909,8 +22900,7 @@ int BuildMessage(WOLFSSL* ssl, byte* output, int outSz, const byte* input, args->size = (word16)(args->sz - args->headerSz); /* include mac and digest */ #if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) - if (ssl->options.dtls && - wolfSSL_dtls_cid_get_tx_size(ssl, NULL) == WOLFSSL_SUCCESS) + if (ssl->options.dtls && DtlsGetCidTxSize(ssl) > 0) args->type = dtls12_cid; #endif AddRecordHeader(output, args->size, args->type, ssl, epochOrder); @@ -22924,8 +22914,7 @@ int BuildMessage(WOLFSSL* ssl, byte* output, int outSz, const byte* input, XMEMCPY(output + args->idx, input, inSz); args->idx += (word32)inSz; #if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) - if (ssl->options.dtls && - wolfSSL_dtls_cid_get_tx_size(ssl, NULL) == WOLFSSL_SUCCESS) { + if (ssl->options.dtls && DtlsGetCidTxSize(ssl) > 0) { output[args->idx++] = (byte)type; /* type goes after input */ inSz++; } @@ -23238,8 +23227,8 @@ int SendFinished(WOLFSSL* ssl) outputSz = sizeof(input) + MAX_MSG_EXTRA; #if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) if (ssl->options.dtls) { - unsigned int cidSz = 0; - if (wolfSSL_dtls_cid_get_tx_size(ssl, &cidSz) == WOLFSSL_SUCCESS) + byte cidSz = 0; + if ((cidSz = DtlsGetCidTxSize(ssl)) > 0) outputSz += cidSz + 1; /* +1 for inner content type */ } #endif @@ -23549,8 +23538,8 @@ int cipherExtraData(WOLFSSL* ssl) /* Add space needed for the CID */ #if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) if (ssl->options.dtls) { - unsigned int cidSz = 0; - if (wolfSSL_dtls_cid_get_tx_size(ssl, &cidSz) == WOLFSSL_SUCCESS) + byte cidSz = 0; + if ((cidSz = DtlsGetCidTxSize(ssl)) > 0) cipherExtra += cidSz + 1; /* +1 for inner content type */ } #endif @@ -24757,8 +24746,8 @@ int SendData(WOLFSSL* ssl, const void* data, int sz) #if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) if (ssl->options.dtls) { - unsigned int cidSz = 0; - if (wolfSSL_dtls_cid_get_tx_size(ssl, &cidSz) == WOLFSSL_SUCCESS) + byte cidSz = 0; + if ((cidSz = DtlsGetCidTxSize(ssl)) > 0) outputSz += cidSz + 1; /* +1 for inner content type */ } #endif diff --git a/src/tls.c b/src/tls.c index 3c1e0e7fe9..71f1c3e817 100644 --- a/src/tls.c +++ b/src/tls.c @@ -762,8 +762,7 @@ int wolfSSL_SetTlsHmacInner(WOLFSSL* ssl, byte* inner, word32 sz, int content, if (content == dtls12_cid #if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) - || (ssl->options.dtls && - wolfSSL_dtls_cid_get_tx_size(ssl, NULL) == WOLFSSL_SUCCESS) + || (ssl->options.dtls && DtlsGetCidTxSize(ssl) > 0) #endif ) { WOLFSSL_MSG("wolfSSL_SetTlsHmacInner doesn't support CID"); @@ -915,6 +914,7 @@ static int Hmac_OuterHash(Hmac* hmac, unsigned char* mac) if (ret == 0) ret = wc_HashFinal(&hash, hashType, mac); } + wc_HashFree(&hash, hashType); return ret; } @@ -1221,9 +1221,9 @@ static int Hmac_UpdateFinal(Hmac* hmac, byte* digest, const byte* in, #endif #if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) -#define TLS_HMAC_CID_SZ(s, v, c) \ - ((v) ? wolfSSL_dtls_cid_get_rx_size((s), (c)) \ - : wolfSSL_dtls_cid_get_tx_size((s), (c))) +#define TLS_HMAC_CID_SZ(s, v) \ + ((v) ? DtlsGetCidRxSize((s)) \ + : DtlsGetCidTxSize((s))) #define TLS_HMAC_CID(s, v, b, c) \ ((v) ? wolfSSL_dtls_cid_get_rx((s), (b), (c)) \ : wolfSSL_dtls_cid_get_tx((s), (b), (c))) @@ -1234,8 +1234,7 @@ static int TLS_hmac_SetInner(WOLFSSL* ssl, byte* inner, word32* innerSz, { #if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID) unsigned int cidSz = 0; - if (ssl->options.dtls && - TLS_HMAC_CID_SZ(ssl, verify, &cidSz) == WOLFSSL_SUCCESS) { + if (ssl->options.dtls && (cidSz = TLS_HMAC_CID_SZ(ssl, verify)) > 0) { word32 idx = 0; if (cidSz > DTLS_CID_MAX_SIZE) { WOLFSSL_MSG("DTLS CID too large"); diff --git a/wolfssl/internal.h b/wolfssl/internal.h index e5e486366f..d3a03e1d4b 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -3694,6 +3694,8 @@ WOLFSSL_LOCAL void DtlsCIDOnExtensionsParsed(WOLFSSL* ssl); WOLFSSL_LOCAL byte DtlsCIDCheck(WOLFSSL* ssl, const byte* input, word16 inputSize); #endif /* WOLFSSL_DTLS_CID */ +WOLFSSL_LOCAL byte DtlsGetCidTxSize(WOLFSSL* ssl); +WOLFSSL_LOCAL byte DtlsGetCidRxSize(WOLFSSL* ssl); #ifdef OPENSSL_EXTRA enum SetCBIO { @@ -7013,7 +7015,7 @@ WOLFSSL_LOCAL int tlsShowSecrets(WOLFSSL* ssl, void* secret, /* Optional Pre-Master-Secret logging for Wireshark */ #if !defined(NO_FILESYSTEM) && defined(WOLFSSL_SSLKEYLOGFILE) #ifndef WOLFSSL_SSLKEYLOGFILE_OUTPUT - #define WOLFSSL_SSLKEYLOGFILE_OUTPUT "/tmp/secrets" + #define WOLFSSL_SSLKEYLOGFILE_OUTPUT "sslkeylog.log" #endif #endif