Skip to content

Commit

Permalink
adjust Keys for DTLS 1.3 and fix warning of unused argument
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobBarthelmeh committed Aug 23, 2024
1 parent e3c392d commit dd22966
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 60 deletions.
3 changes: 3 additions & 0 deletions mplabx/small-psk-build/example-client-psk.c
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t sz)
#ifndef USE_LIBFUZZER
printf("ret of connect = %d\n", ret);
#endif
if (ret < 0) {
goto exit;
}

/* write string to the server */
if (wolfSSL_write_inline(ssl, recvline, strlen(recvline), MAXLINE) < 0) {
Expand Down
5 changes: 5 additions & 0 deletions mplabx/small-psk-build/psk-ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -2452,6 +2452,11 @@ static int DoHandShakeMsg(WOLFSSL* ssl, byte* input, word32* inOutIdx,
if (inputLength - HANDSHAKE_HEADER_SZ < size) {
ssl->arrays->pendingMsgType = type;
ssl->arrays->pendingMsgSz = size + HANDSHAKE_HEADER_SZ;

if (ssl->arrays->pendingMsg != NULL) {
XFREE(ssl->arrays->pendingMsg, ssl->heap, DYNAMIC_TYPE_ARRAYS);
}

ssl->arrays->pendingMsg = (byte*)XMALLOC(size + HANDSHAKE_HEADER_SZ,
ssl->heap,
DYNAMIC_TYPE_ARRAYS);
Expand Down
120 changes: 60 additions & 60 deletions src/dtls13.c
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ static int Dtls13HandshakeAddHeaderFrag(WOLFSSL* ssl, byte* output,

hdr->msg_type = msg_type;
c32to24((word32)msg_length, hdr->length);
c16toa(ssl->keys.dtls_handshake_number, hdr->messageSeq);
c16toa(ssl->keys->dtls_handshake_number, hdr->messageSeq);

c32to24(frag_offset, hdr->fragmentOffset);
c32to24(frag_length, hdr->fragmentLength);
Expand Down Expand Up @@ -337,7 +337,7 @@ static byte Dtls13RtxMsgNeedsAck(WOLFSSL* ssl, enum HandShakeType hs)
static void Dtls13MsgWasProcessed(WOLFSSL* ssl, enum HandShakeType hs)
{
if (ssl->options.dtlsStateful)
ssl->keys.dtls_expected_peer_handshake_number++;
ssl->keys->dtls_expected_peer_handshake_number++;

/* we need to send ACKs on the last message of a flight that needs explicit
acknowledgment */
Expand All @@ -357,7 +357,7 @@ int Dtls13ProcessBufferedMessages(WOLFSSL* ssl)
idx = 0;

/* message not in order */
if (ssl->keys.dtls_expected_peer_handshake_number != msg->seq)
if (ssl->keys->dtls_expected_peer_handshake_number != msg->seq)
break;

/* message not complete */
Expand Down Expand Up @@ -404,7 +404,7 @@ int Dtls13ProcessBufferedMessages(WOLFSSL* ssl)
/* DoHandShakeMsgType normally handles the hs number but if
* DoTls13HandShakeMsgType processed 1.2 msgs then this wasn't
* incremented. */
ssl->keys.dtls_expected_peer_handshake_number++;
ssl->keys->dtls_expected_peer_handshake_number++;

ssl->dtls_rx_msg_list = msg->next;
DtlsMsgDelete(msg, ssl->heap);
Expand All @@ -426,7 +426,7 @@ static int Dtls13NextMessageComplete(WOLFSSL* ssl)
return ssl->dtls_rx_msg_list != NULL &&
ssl->dtls_rx_msg_list->ready &&
ssl->dtls_rx_msg_list->seq ==
ssl->keys.dtls_expected_peer_handshake_number;
ssl->keys->dtls_expected_peer_handshake_number;
}

static WC_INLINE int FragIsInOutputBuffer(WOLFSSL* ssl, const byte* frag)
Expand Down Expand Up @@ -742,13 +742,13 @@ static int Dtls13DetectDisruption(WOLFSSL* ssl, word32 fragOffset)
{
/* retransmission. The other peer may have lost our flight or our ACKs. We
don't account this as a disruption */
if (ssl->keys.dtls_peer_handshake_number <
ssl->keys.dtls_expected_peer_handshake_number)
if (ssl->keys->dtls_peer_handshake_number <
ssl->keys->dtls_expected_peer_handshake_number)
return 0;

/* out of order message */
if (ssl->keys.dtls_peer_handshake_number >
ssl->keys.dtls_expected_peer_handshake_number) {
if (ssl->keys->dtls_peer_handshake_number >
ssl->keys->dtls_expected_peer_handshake_number) {
return 1;
}

Expand Down Expand Up @@ -785,8 +785,8 @@ static void Dtls13RtxRemoveCurAck(WOLFSSL* ssl)
rn = ssl->dtls13Rtx.seenRecords;

while (rn != NULL) {
if (w64Equal(rn->epoch, ssl->keys.curEpoch64) &&
w64Equal(rn->seq, ssl->keys.curSeq)) {
if (w64Equal(rn->epoch, ssl->keys->curEpoch64) &&
w64Equal(rn->seq, ssl->keys->curSeq)) {
*prevNext = rn->next;
XFREE(rn, ssl->heap, DYNAMIC_TYPE_DTLS_MSG);
return;
Expand Down Expand Up @@ -830,8 +830,8 @@ static int Dtls13RtxMsgRecvd(WOLFSSL* ssl, enum HandShakeType hs,
WOLFSSL_ENTER("Dtls13RtxMsgRecvd");

if (!ssl->options.handShakeDone &&
ssl->keys.dtls_peer_handshake_number >=
ssl->keys.dtls_expected_peer_handshake_number) {
ssl->keys->dtls_peer_handshake_number >=
ssl->keys->dtls_expected_peer_handshake_number) {

if (hs == server_hello)
Dtls13MaybeSaveClientHello(ssl);
Expand All @@ -849,8 +849,8 @@ static int Dtls13RtxMsgRecvd(WOLFSSL* ssl, enum HandShakeType hs,
DtlsMsgPoolReset(ssl);
}

if (ssl->keys.dtls_peer_handshake_number <
ssl->keys.dtls_expected_peer_handshake_number) {
if (ssl->keys->dtls_peer_handshake_number <
ssl->keys->dtls_expected_peer_handshake_number) {

/* retransmission detected. */
ssl->dtls13Rtx.retransmit = 1;
Expand All @@ -861,8 +861,8 @@ static int Dtls13RtxMsgRecvd(WOLFSSL* ssl, enum HandShakeType hs,
ssl->dtls13Rtx.sendAcks = (byte)ssl->options.dtls13SendMoreAcks;
}

if (ssl->keys.dtls_peer_handshake_number ==
ssl->keys.dtls_expected_peer_handshake_number &&
if (ssl->keys->dtls_peer_handshake_number ==
ssl->keys->dtls_expected_peer_handshake_number &&
ssl->options.handShakeDone && hs == certificate_request) {

/* the current record, containing a post-handshake certificate request,
Expand Down Expand Up @@ -1199,7 +1199,7 @@ int Dtls13HandshakeAddHeader(WOLFSSL* ssl, byte* output,

hdr->msg_type = msg_type;
c32to24((word32)length, hdr->length);
c16toa(ssl->keys.dtls_handshake_number, hdr->messageSeq);
c16toa(ssl->keys->dtls_handshake_number, hdr->messageSeq);

/* send unfragmented first */
c32to24(0, hdr->fragmentOffset);
Expand Down Expand Up @@ -1487,7 +1487,7 @@ int Dtls13RecordRecvd(WOLFSSL* ssl)
if (!ssl->options.dtls13SendMoreAcks)
ssl->dtls13FastTimeout = 1;

ret = Dtls13RtxAddAck(ssl, ssl->keys.curEpoch64, ssl->keys.curSeq);
ret = Dtls13RtxAddAck(ssl, ssl->keys->curEpoch64, ssl->keys->curSeq);
if (ret != 0)
WOLFSSL_MSG("can't save ack fragment");

Expand Down Expand Up @@ -1647,10 +1647,10 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size,
/* To be able to operate in stateless mode, we assume the ClientHello
* is in order and we use its Handshake Message number and Sequence
* Number for our Tx. */
ssl->keys.dtls_expected_peer_handshake_number =
ssl->keys.dtls_handshake_number =
ssl->keys.dtls_peer_handshake_number;
ssl->dtls13Epochs[0].nextSeqNumber = ssl->keys.curSeq;
ssl->keys->dtls_expected_peer_handshake_number =
ssl->keys->dtls_handshake_number =
ssl->keys->dtls_peer_handshake_number;
ssl->dtls13Epochs[0].nextSeqNumber = ssl->keys->curSeq;
}

if (idx + fragLength > size) {
Expand All @@ -1665,16 +1665,16 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size,
if (ret != 0)
return ret;

if (ssl->keys.dtls_peer_handshake_number <
ssl->keys.dtls_expected_peer_handshake_number) {
if (ssl->keys->dtls_peer_handshake_number <
ssl->keys->dtls_expected_peer_handshake_number) {

#ifdef WOLFSSL_DEBUG_TLS
WOLFSSL_MSG(
"DTLS1.3 retransmission detected - discard and schedule a rtx");
#endif /* WOLFSSL_DEBUG_TLS */

/* ignore the message */
*processedSize = idx + fragLength + ssl->keys.padSz;
*processedSize = idx + fragLength + ssl->keys->padSz;

return 0;
}
Expand Down Expand Up @@ -1708,7 +1708,7 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size,
WOLFSSL_MSG("DTLS1.3 not accepting fragmented plaintext message");
#endif /* WOLFSSL_DEBUG_TLS */
/* ignore the message */
*processedSize = idx + fragLength + ssl->keys.padSz;
*processedSize = idx + fragLength + ssl->keys->padSz;
return 0;
}
}
Expand All @@ -1721,12 +1721,12 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size,
* if the message is stored in the buffer.
*/
if (!isComplete ||
ssl->keys.dtls_peer_handshake_number >
ssl->keys.dtls_expected_peer_handshake_number ||
ssl->keys->dtls_peer_handshake_number >
ssl->keys->dtls_expected_peer_handshake_number ||
usingAsyncCrypto) {
if (ssl->dtls_rx_msg_list_sz < DTLS_POOL_SZ) {
DtlsMsgStore(ssl, (word16)w64GetLow32(ssl->keys.curEpoch64),
ssl->keys.dtls_peer_handshake_number,
DtlsMsgStore(ssl, (word16)w64GetLow32(ssl->keys->curEpoch64),
ssl->keys->dtls_peer_handshake_number,
input + DTLS_HANDSHAKE_HEADER_SZ, messageLength, handshakeType,
fragOff, fragLength, ssl->heap);
}
Expand All @@ -1736,7 +1736,7 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size,
return DTLS_TOO_MANY_FRAGMENTS_E;
}

*processedSize = idx + fragLength + ssl->keys.padSz;
*processedSize = idx + fragLength + ssl->keys->padSz;
if (Dtls13NextMessageComplete(ssl))
return Dtls13ProcessBufferedMessages(ssl);

Expand Down Expand Up @@ -1785,7 +1785,7 @@ int Dtls13FragmentsContinue(WOLFSSL* ssl)

ret = Dtls13SendFragmentedInternal(ssl);
if (ret == 0)
ssl->keys.dtls_handshake_number++;
ssl->keys->dtls_handshake_number++;

return ret;
}
Expand Down Expand Up @@ -1875,13 +1875,13 @@ int Dtls13HandshakeSend(WOLFSSL* ssl, byte* message, word16 outputSize,
ret = Dtls13SendOneFragmentRtx(ssl, handshakeType, outputSize, message,
length, hashOutput);
if (ret == 0 || ret == WANT_WRITE)
ssl->keys.dtls_handshake_number++;
ssl->keys->dtls_handshake_number++;
}
else {
ret = Dtls13SendFragmented(ssl, message, length, handshakeType,
hashOutput);
if (ret == 0)
ssl->keys.dtls_handshake_number++;
ssl->keys->dtls_handshake_number++;
}

return ret;
Expand All @@ -1908,7 +1908,7 @@ int Dtls13DeriveSnKeys(WOLFSSL* ssl, int provision)
if (ret != 0)
goto end;

XMEMCPY(ssl->keys.client_sn_key, key_dig, ssl->specs.key_size);
XMEMCPY(ssl->keys->client_sn_key, key_dig, ssl->specs.key_size);
}

if (provision & PROVISION_SERVER) {
Expand All @@ -1919,7 +1919,7 @@ int Dtls13DeriveSnKeys(WOLFSSL* ssl, int provision)
if (ret != 0)
goto end;

XMEMCPY(ssl->keys.server_sn_key, key_dig, ssl->specs.key_size);
XMEMCPY(ssl->keys->server_sn_key, key_dig, ssl->specs.key_size);
}

end:
Expand Down Expand Up @@ -2062,7 +2062,7 @@ int Dtls13GetSeq(WOLFSSL* ssl, int order, word32* seq, byte increment)
w64wrapper* nativeSeq;

if (order == PEER_ORDER) {
nativeSeq = &ssl->keys.curSeq;
nativeSeq = &ssl->keys->curSeq;
/* never increment seq number for current record. In DTLS seq number are
explicit */
increment = 0;
Expand Down Expand Up @@ -2147,7 +2147,7 @@ int Dtls13NewEpoch(WOLFSSL* ssl, w64wrapper epochNumber, int side)
return BAD_STATE_E;
}

Dtls13EpochCopyKeys(ssl, e, &ssl->keys, side);
Dtls13EpochCopyKeys(ssl, e, ssl->keys, side);

if (!e->isValid) {
/* fresh epoch, initialize fields */
Expand Down Expand Up @@ -2224,33 +2224,33 @@ int Dtls13SetEpochKeys(WOLFSSL* ssl, w64wrapper epochNumber,
return 0;

if (clientWrite) {
XMEMCPY(ssl->keys.client_write_key, e->client_write_key,
sizeof(ssl->keys.client_write_key));
XMEMCPY(ssl->keys->client_write_key, e->client_write_key,
sizeof(ssl->keys->client_write_key));

XMEMCPY(ssl->keys.client_write_IV, e->client_write_IV,
sizeof(ssl->keys.client_write_IV));
XMEMCPY(ssl->keys->client_write_IV, e->client_write_IV,
sizeof(ssl->keys->client_write_IV));

XMEMCPY(ssl->keys.client_sn_key, e->client_sn_key,
sizeof(ssl->keys.client_sn_key));
XMEMCPY(ssl->keys->client_sn_key, e->client_sn_key,
sizeof(ssl->keys->client_sn_key));
}

if (serverWrite) {
XMEMCPY(ssl->keys.server_write_key, e->server_write_key,
sizeof(ssl->keys.server_write_key));
XMEMCPY(ssl->keys->server_write_key, e->server_write_key,
sizeof(ssl->keys->server_write_key));

XMEMCPY(ssl->keys.server_write_IV, e->server_write_IV,
sizeof(ssl->keys.server_write_IV));
XMEMCPY(ssl->keys->server_write_IV, e->server_write_IV,
sizeof(ssl->keys->server_write_IV));

XMEMCPY(ssl->keys.server_sn_key, e->server_sn_key,
sizeof(ssl->keys.server_sn_key));
XMEMCPY(ssl->keys->server_sn_key, e->server_sn_key,
sizeof(ssl->keys->server_sn_key));
}

if (enc)
XMEMCPY(ssl->keys.aead_enc_imp_IV, e->aead_enc_imp_IV,
sizeof(ssl->keys.aead_enc_imp_IV));
XMEMCPY(ssl->keys->aead_enc_imp_IV, e->aead_enc_imp_IV,
sizeof(ssl->keys->aead_enc_imp_IV));
if (dec)
XMEMCPY(ssl->keys.aead_dec_imp_IV, e->aead_dec_imp_IV,
sizeof(ssl->keys.aead_dec_imp_IV));
XMEMCPY(ssl->keys->aead_dec_imp_IV, e->aead_dec_imp_IV,
sizeof(ssl->keys->aead_dec_imp_IV));

return SetKeysSide(ssl, side);
}
Expand Down Expand Up @@ -2281,16 +2281,16 @@ int Dtls13SetRecordNumberKeys(WOLFSSL* ssl, enum encrypt_side side)

if (enc) {
if (ssl->options.side == WOLFSSL_CLIENT_END)
encKey = ssl->keys.client_sn_key;
encKey = ssl->keys->client_sn_key;
else
encKey = ssl->keys.server_sn_key;
encKey = ssl->keys->server_sn_key;
}

if (dec) {
if (ssl->options.side == WOLFSSL_CLIENT_END)
decKey = ssl->keys.server_sn_key;
decKey = ssl->keys->server_sn_key;
else
decKey = ssl->keys.client_sn_key;
decKey = ssl->keys->client_sn_key;
}

/* DTLSv1.3 supports only AEAD algorithm. */
Expand Down Expand Up @@ -2844,7 +2844,7 @@ int Dtls13CheckAEADFailLimit(WOLFSSL* ssl)
else if (w64GT(ssl->dtls13DecryptEpoch->dropCount, keyUpdateLimit)) {
WOLFSSL_MSG("Connection exceeded key update limit. Issuing key update");
/* If not waiting for a response then request a key update. */
if (!ssl->keys.updateResponseReq) {
if (!ssl->keys->updateResponseReq) {
ssl->dtls13DoKeyUpdate = 1;
ssl->dtls13InvalidateBefore = ssl->dtls13PeerEpoch;
w64Increment(&ssl->dtls13InvalidateBefore);
Expand Down
1 change: 1 addition & 0 deletions src/internal.c
Original file line number Diff line number Diff line change
Expand Up @@ -6463,6 +6463,7 @@ static void InitSuites_EitherSide(Suites* suites, ProtocolVersion pv, int keySz,
haveECC, TRUE, haveStaticECC, haveFalconSig,
haveDilithiumSig, haveAnon, TRUE, side);
}
(void)haveDH; /* not used when no server support is compiled in */
}

void InitSSL_CTX_Suites(WOLFSSL_CTX* ctx)
Expand Down

0 comments on commit dd22966

Please sign in to comment.