From dd229669cccb87c30c282b927de5ad436c00e64e Mon Sep 17 00:00:00 2001 From: JacobBarthelmeh Date: Fri, 23 Aug 2024 11:19:26 -0700 Subject: [PATCH] adjust Keys for DTLS 1.3 and fix warning of unused argument --- mplabx/small-psk-build/example-client-psk.c | 3 + mplabx/small-psk-build/psk-ssl.c | 5 + src/dtls13.c | 120 ++++++++++---------- src/internal.c | 1 + 4 files changed, 69 insertions(+), 60 deletions(-) diff --git a/mplabx/small-psk-build/example-client-psk.c b/mplabx/small-psk-build/example-client-psk.c index 7c84217f83..07301063ef 100644 --- a/mplabx/small-psk-build/example-client-psk.c +++ b/mplabx/small-psk-build/example-client-psk.c @@ -279,6 +279,9 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t sz) #ifndef USE_LIBFUZZER printf("ret of connect = %d\n", ret); #endif + if (ret < 0) { + goto exit; + } /* write string to the server */ if (wolfSSL_write_inline(ssl, recvline, strlen(recvline), MAXLINE) < 0) { diff --git a/mplabx/small-psk-build/psk-ssl.c b/mplabx/small-psk-build/psk-ssl.c index 60f8aefe28..deb80ecda6 100644 --- a/mplabx/small-psk-build/psk-ssl.c +++ b/mplabx/small-psk-build/psk-ssl.c @@ -2452,6 +2452,11 @@ static int DoHandShakeMsg(WOLFSSL* ssl, byte* input, word32* inOutIdx, if (inputLength - HANDSHAKE_HEADER_SZ < size) { ssl->arrays->pendingMsgType = type; ssl->arrays->pendingMsgSz = size + HANDSHAKE_HEADER_SZ; + + if (ssl->arrays->pendingMsg != NULL) { + XFREE(ssl->arrays->pendingMsg, ssl->heap, DYNAMIC_TYPE_ARRAYS); + } + ssl->arrays->pendingMsg = (byte*)XMALLOC(size + HANDSHAKE_HEADER_SZ, ssl->heap, DYNAMIC_TYPE_ARRAYS); diff --git a/src/dtls13.c b/src/dtls13.c index 4d2365f38a..ae7433de06 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -201,7 +201,7 @@ static int Dtls13HandshakeAddHeaderFrag(WOLFSSL* ssl, byte* output, hdr->msg_type = msg_type; c32to24((word32)msg_length, hdr->length); - c16toa(ssl->keys.dtls_handshake_number, hdr->messageSeq); + c16toa(ssl->keys->dtls_handshake_number, hdr->messageSeq); c32to24(frag_offset, hdr->fragmentOffset); c32to24(frag_length, hdr->fragmentLength); @@ -337,7 +337,7 @@ static byte Dtls13RtxMsgNeedsAck(WOLFSSL* ssl, enum HandShakeType hs) static void Dtls13MsgWasProcessed(WOLFSSL* ssl, enum HandShakeType hs) { if (ssl->options.dtlsStateful) - ssl->keys.dtls_expected_peer_handshake_number++; + ssl->keys->dtls_expected_peer_handshake_number++; /* we need to send ACKs on the last message of a flight that needs explicit acknowledgment */ @@ -357,7 +357,7 @@ int Dtls13ProcessBufferedMessages(WOLFSSL* ssl) idx = 0; /* message not in order */ - if (ssl->keys.dtls_expected_peer_handshake_number != msg->seq) + if (ssl->keys->dtls_expected_peer_handshake_number != msg->seq) break; /* message not complete */ @@ -404,7 +404,7 @@ int Dtls13ProcessBufferedMessages(WOLFSSL* ssl) /* DoHandShakeMsgType normally handles the hs number but if * DoTls13HandShakeMsgType processed 1.2 msgs then this wasn't * incremented. */ - ssl->keys.dtls_expected_peer_handshake_number++; + ssl->keys->dtls_expected_peer_handshake_number++; ssl->dtls_rx_msg_list = msg->next; DtlsMsgDelete(msg, ssl->heap); @@ -426,7 +426,7 @@ static int Dtls13NextMessageComplete(WOLFSSL* ssl) return ssl->dtls_rx_msg_list != NULL && ssl->dtls_rx_msg_list->ready && ssl->dtls_rx_msg_list->seq == - ssl->keys.dtls_expected_peer_handshake_number; + ssl->keys->dtls_expected_peer_handshake_number; } static WC_INLINE int FragIsInOutputBuffer(WOLFSSL* ssl, const byte* frag) @@ -742,13 +742,13 @@ static int Dtls13DetectDisruption(WOLFSSL* ssl, word32 fragOffset) { /* retransmission. The other peer may have lost our flight or our ACKs. We don't account this as a disruption */ - if (ssl->keys.dtls_peer_handshake_number < - ssl->keys.dtls_expected_peer_handshake_number) + if (ssl->keys->dtls_peer_handshake_number < + ssl->keys->dtls_expected_peer_handshake_number) return 0; /* out of order message */ - if (ssl->keys.dtls_peer_handshake_number > - ssl->keys.dtls_expected_peer_handshake_number) { + if (ssl->keys->dtls_peer_handshake_number > + ssl->keys->dtls_expected_peer_handshake_number) { return 1; } @@ -785,8 +785,8 @@ static void Dtls13RtxRemoveCurAck(WOLFSSL* ssl) rn = ssl->dtls13Rtx.seenRecords; while (rn != NULL) { - if (w64Equal(rn->epoch, ssl->keys.curEpoch64) && - w64Equal(rn->seq, ssl->keys.curSeq)) { + if (w64Equal(rn->epoch, ssl->keys->curEpoch64) && + w64Equal(rn->seq, ssl->keys->curSeq)) { *prevNext = rn->next; XFREE(rn, ssl->heap, DYNAMIC_TYPE_DTLS_MSG); return; @@ -830,8 +830,8 @@ static int Dtls13RtxMsgRecvd(WOLFSSL* ssl, enum HandShakeType hs, WOLFSSL_ENTER("Dtls13RtxMsgRecvd"); if (!ssl->options.handShakeDone && - ssl->keys.dtls_peer_handshake_number >= - ssl->keys.dtls_expected_peer_handshake_number) { + ssl->keys->dtls_peer_handshake_number >= + ssl->keys->dtls_expected_peer_handshake_number) { if (hs == server_hello) Dtls13MaybeSaveClientHello(ssl); @@ -849,8 +849,8 @@ static int Dtls13RtxMsgRecvd(WOLFSSL* ssl, enum HandShakeType hs, DtlsMsgPoolReset(ssl); } - if (ssl->keys.dtls_peer_handshake_number < - ssl->keys.dtls_expected_peer_handshake_number) { + if (ssl->keys->dtls_peer_handshake_number < + ssl->keys->dtls_expected_peer_handshake_number) { /* retransmission detected. */ ssl->dtls13Rtx.retransmit = 1; @@ -861,8 +861,8 @@ static int Dtls13RtxMsgRecvd(WOLFSSL* ssl, enum HandShakeType hs, ssl->dtls13Rtx.sendAcks = (byte)ssl->options.dtls13SendMoreAcks; } - if (ssl->keys.dtls_peer_handshake_number == - ssl->keys.dtls_expected_peer_handshake_number && + if (ssl->keys->dtls_peer_handshake_number == + ssl->keys->dtls_expected_peer_handshake_number && ssl->options.handShakeDone && hs == certificate_request) { /* the current record, containing a post-handshake certificate request, @@ -1199,7 +1199,7 @@ int Dtls13HandshakeAddHeader(WOLFSSL* ssl, byte* output, hdr->msg_type = msg_type; c32to24((word32)length, hdr->length); - c16toa(ssl->keys.dtls_handshake_number, hdr->messageSeq); + c16toa(ssl->keys->dtls_handshake_number, hdr->messageSeq); /* send unfragmented first */ c32to24(0, hdr->fragmentOffset); @@ -1487,7 +1487,7 @@ int Dtls13RecordRecvd(WOLFSSL* ssl) if (!ssl->options.dtls13SendMoreAcks) ssl->dtls13FastTimeout = 1; - ret = Dtls13RtxAddAck(ssl, ssl->keys.curEpoch64, ssl->keys.curSeq); + ret = Dtls13RtxAddAck(ssl, ssl->keys->curEpoch64, ssl->keys->curSeq); if (ret != 0) WOLFSSL_MSG("can't save ack fragment"); @@ -1647,10 +1647,10 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size, /* To be able to operate in stateless mode, we assume the ClientHello * is in order and we use its Handshake Message number and Sequence * Number for our Tx. */ - ssl->keys.dtls_expected_peer_handshake_number = - ssl->keys.dtls_handshake_number = - ssl->keys.dtls_peer_handshake_number; - ssl->dtls13Epochs[0].nextSeqNumber = ssl->keys.curSeq; + ssl->keys->dtls_expected_peer_handshake_number = + ssl->keys->dtls_handshake_number = + ssl->keys->dtls_peer_handshake_number; + ssl->dtls13Epochs[0].nextSeqNumber = ssl->keys->curSeq; } if (idx + fragLength > size) { @@ -1665,8 +1665,8 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size, if (ret != 0) return ret; - if (ssl->keys.dtls_peer_handshake_number < - ssl->keys.dtls_expected_peer_handshake_number) { + if (ssl->keys->dtls_peer_handshake_number < + ssl->keys->dtls_expected_peer_handshake_number) { #ifdef WOLFSSL_DEBUG_TLS WOLFSSL_MSG( @@ -1674,7 +1674,7 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size, #endif /* WOLFSSL_DEBUG_TLS */ /* ignore the message */ - *processedSize = idx + fragLength + ssl->keys.padSz; + *processedSize = idx + fragLength + ssl->keys->padSz; return 0; } @@ -1708,7 +1708,7 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size, WOLFSSL_MSG("DTLS1.3 not accepting fragmented plaintext message"); #endif /* WOLFSSL_DEBUG_TLS */ /* ignore the message */ - *processedSize = idx + fragLength + ssl->keys.padSz; + *processedSize = idx + fragLength + ssl->keys->padSz; return 0; } } @@ -1721,12 +1721,12 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size, * if the message is stored in the buffer. */ if (!isComplete || - ssl->keys.dtls_peer_handshake_number > - ssl->keys.dtls_expected_peer_handshake_number || + ssl->keys->dtls_peer_handshake_number > + ssl->keys->dtls_expected_peer_handshake_number || usingAsyncCrypto) { if (ssl->dtls_rx_msg_list_sz < DTLS_POOL_SZ) { - DtlsMsgStore(ssl, (word16)w64GetLow32(ssl->keys.curEpoch64), - ssl->keys.dtls_peer_handshake_number, + DtlsMsgStore(ssl, (word16)w64GetLow32(ssl->keys->curEpoch64), + ssl->keys->dtls_peer_handshake_number, input + DTLS_HANDSHAKE_HEADER_SZ, messageLength, handshakeType, fragOff, fragLength, ssl->heap); } @@ -1736,7 +1736,7 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size, return DTLS_TOO_MANY_FRAGMENTS_E; } - *processedSize = idx + fragLength + ssl->keys.padSz; + *processedSize = idx + fragLength + ssl->keys->padSz; if (Dtls13NextMessageComplete(ssl)) return Dtls13ProcessBufferedMessages(ssl); @@ -1785,7 +1785,7 @@ int Dtls13FragmentsContinue(WOLFSSL* ssl) ret = Dtls13SendFragmentedInternal(ssl); if (ret == 0) - ssl->keys.dtls_handshake_number++; + ssl->keys->dtls_handshake_number++; return ret; } @@ -1875,13 +1875,13 @@ int Dtls13HandshakeSend(WOLFSSL* ssl, byte* message, word16 outputSize, ret = Dtls13SendOneFragmentRtx(ssl, handshakeType, outputSize, message, length, hashOutput); if (ret == 0 || ret == WANT_WRITE) - ssl->keys.dtls_handshake_number++; + ssl->keys->dtls_handshake_number++; } else { ret = Dtls13SendFragmented(ssl, message, length, handshakeType, hashOutput); if (ret == 0) - ssl->keys.dtls_handshake_number++; + ssl->keys->dtls_handshake_number++; } return ret; @@ -1908,7 +1908,7 @@ int Dtls13DeriveSnKeys(WOLFSSL* ssl, int provision) if (ret != 0) goto end; - XMEMCPY(ssl->keys.client_sn_key, key_dig, ssl->specs.key_size); + XMEMCPY(ssl->keys->client_sn_key, key_dig, ssl->specs.key_size); } if (provision & PROVISION_SERVER) { @@ -1919,7 +1919,7 @@ int Dtls13DeriveSnKeys(WOLFSSL* ssl, int provision) if (ret != 0) goto end; - XMEMCPY(ssl->keys.server_sn_key, key_dig, ssl->specs.key_size); + XMEMCPY(ssl->keys->server_sn_key, key_dig, ssl->specs.key_size); } end: @@ -2062,7 +2062,7 @@ int Dtls13GetSeq(WOLFSSL* ssl, int order, word32* seq, byte increment) w64wrapper* nativeSeq; if (order == PEER_ORDER) { - nativeSeq = &ssl->keys.curSeq; + nativeSeq = &ssl->keys->curSeq; /* never increment seq number for current record. In DTLS seq number are explicit */ increment = 0; @@ -2147,7 +2147,7 @@ int Dtls13NewEpoch(WOLFSSL* ssl, w64wrapper epochNumber, int side) return BAD_STATE_E; } - Dtls13EpochCopyKeys(ssl, e, &ssl->keys, side); + Dtls13EpochCopyKeys(ssl, e, ssl->keys, side); if (!e->isValid) { /* fresh epoch, initialize fields */ @@ -2224,33 +2224,33 @@ int Dtls13SetEpochKeys(WOLFSSL* ssl, w64wrapper epochNumber, return 0; if (clientWrite) { - XMEMCPY(ssl->keys.client_write_key, e->client_write_key, - sizeof(ssl->keys.client_write_key)); + XMEMCPY(ssl->keys->client_write_key, e->client_write_key, + sizeof(ssl->keys->client_write_key)); - XMEMCPY(ssl->keys.client_write_IV, e->client_write_IV, - sizeof(ssl->keys.client_write_IV)); + XMEMCPY(ssl->keys->client_write_IV, e->client_write_IV, + sizeof(ssl->keys->client_write_IV)); - XMEMCPY(ssl->keys.client_sn_key, e->client_sn_key, - sizeof(ssl->keys.client_sn_key)); + XMEMCPY(ssl->keys->client_sn_key, e->client_sn_key, + sizeof(ssl->keys->client_sn_key)); } if (serverWrite) { - XMEMCPY(ssl->keys.server_write_key, e->server_write_key, - sizeof(ssl->keys.server_write_key)); + XMEMCPY(ssl->keys->server_write_key, e->server_write_key, + sizeof(ssl->keys->server_write_key)); - XMEMCPY(ssl->keys.server_write_IV, e->server_write_IV, - sizeof(ssl->keys.server_write_IV)); + XMEMCPY(ssl->keys->server_write_IV, e->server_write_IV, + sizeof(ssl->keys->server_write_IV)); - XMEMCPY(ssl->keys.server_sn_key, e->server_sn_key, - sizeof(ssl->keys.server_sn_key)); + XMEMCPY(ssl->keys->server_sn_key, e->server_sn_key, + sizeof(ssl->keys->server_sn_key)); } if (enc) - XMEMCPY(ssl->keys.aead_enc_imp_IV, e->aead_enc_imp_IV, - sizeof(ssl->keys.aead_enc_imp_IV)); + XMEMCPY(ssl->keys->aead_enc_imp_IV, e->aead_enc_imp_IV, + sizeof(ssl->keys->aead_enc_imp_IV)); if (dec) - XMEMCPY(ssl->keys.aead_dec_imp_IV, e->aead_dec_imp_IV, - sizeof(ssl->keys.aead_dec_imp_IV)); + XMEMCPY(ssl->keys->aead_dec_imp_IV, e->aead_dec_imp_IV, + sizeof(ssl->keys->aead_dec_imp_IV)); return SetKeysSide(ssl, side); } @@ -2281,16 +2281,16 @@ int Dtls13SetRecordNumberKeys(WOLFSSL* ssl, enum encrypt_side side) if (enc) { if (ssl->options.side == WOLFSSL_CLIENT_END) - encKey = ssl->keys.client_sn_key; + encKey = ssl->keys->client_sn_key; else - encKey = ssl->keys.server_sn_key; + encKey = ssl->keys->server_sn_key; } if (dec) { if (ssl->options.side == WOLFSSL_CLIENT_END) - decKey = ssl->keys.server_sn_key; + decKey = ssl->keys->server_sn_key; else - decKey = ssl->keys.client_sn_key; + decKey = ssl->keys->client_sn_key; } /* DTLSv1.3 supports only AEAD algorithm. */ @@ -2844,7 +2844,7 @@ int Dtls13CheckAEADFailLimit(WOLFSSL* ssl) else if (w64GT(ssl->dtls13DecryptEpoch->dropCount, keyUpdateLimit)) { WOLFSSL_MSG("Connection exceeded key update limit. Issuing key update"); /* If not waiting for a response then request a key update. */ - if (!ssl->keys.updateResponseReq) { + if (!ssl->keys->updateResponseReq) { ssl->dtls13DoKeyUpdate = 1; ssl->dtls13InvalidateBefore = ssl->dtls13PeerEpoch; w64Increment(&ssl->dtls13InvalidateBefore); diff --git a/src/internal.c b/src/internal.c index f0621a3fcd..fbe8d37252 100644 --- a/src/internal.c +++ b/src/internal.c @@ -6463,6 +6463,7 @@ static void InitSuites_EitherSide(Suites* suites, ProtocolVersion pv, int keySz, haveECC, TRUE, haveStaticECC, haveFalconSig, haveDilithiumSig, haveAnon, TRUE, side); } + (void)haveDH; /* not used when no server support is compiled in */ } void InitSSL_CTX_Suites(WOLFSSL_CTX* ctx)