Skip to content

Commit

Permalink
Merge pull request #7796 from SparkiDev/dtls_read_write_threaded
Browse files Browse the repository at this point in the history
SSL asynchronous read/write and encrypt
  • Loading branch information
douzzer authored Oct 17, 2024
2 parents 8803f3d + e4a661f commit abc6edf
Show file tree
Hide file tree
Showing 8 changed files with 664 additions and 158 deletions.
120 changes: 93 additions & 27 deletions src/dtls13.c
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,17 @@ static void Dtls13MsgWasProcessed(WOLFSSL* ssl, enum HandShakeType hs)
if (ssl->options.dtlsStateful)
ssl->keys.dtls_expected_peer_handshake_number++;

/* we need to send ACKs on the last message of a flight that needs explicit
acknowledgment */
ssl->dtls13Rtx.sendAcks = Dtls13RtxMsgNeedsAck(ssl, hs);
#ifdef WOLFSSL_RW_THREADED
if (wc_LockMutex(&ssl->dtls13Rtx.mutex) == 0)
#endif
{
/* we need to send ACKs on the last message of a flight that needs
* explicit acknowledgment */
ssl->dtls13Rtx.sendAcks = Dtls13RtxMsgNeedsAck(ssl, hs);
#ifdef WOLFSSL_RW_THREADED
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
#endif
}
}

int Dtls13ProcessBufferedMessages(WOLFSSL* ssl)
Expand Down Expand Up @@ -654,8 +662,17 @@ static void Dtls13RtxRecordUnlink(WOLFSSL* ssl, Dtls13RtxRecord** prevNext,
Dtls13RtxRecord* r)
{
/* if r was at the tail of the list, update the tail pointer */
if (r->next == NULL)
ssl->dtls13Rtx.rtxRecordTailPtr = prevNext;
if (r->next == NULL) {
#ifdef WOLFSSL_RW_THREADED
if (wc_LockMutex(&ssl->dtls13Rtx.mutex) == 0)
#endif
{
ssl->dtls13Rtx.rtxRecordTailPtr = prevNext;
#ifdef WOLFSSL_RW_THREADED
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
#endif
}
}

/* unlink */
*prevNext = r->next;
Expand Down Expand Up @@ -712,12 +729,20 @@ static int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq)

WOLFSSL_ENTER("Dtls13RtxAddAck");

rn = Dtls13NewRecordNumber(epoch, seq, ssl->heap);
if (rn == NULL)
return MEMORY_E;
#ifdef WOLFSSL_RW_THREADED
if (wc_LockMutex(&ssl->dtls13Rtx.mutex) == 0)
#endif
{
rn = Dtls13NewRecordNumber(epoch, seq, ssl->heap);
if (rn == NULL)
return MEMORY_E;

rn->next = ssl->dtls13Rtx.seenRecords;
ssl->dtls13Rtx.seenRecords = rn;
rn->next = ssl->dtls13Rtx.seenRecords;
ssl->dtls13Rtx.seenRecords = rn;
#ifdef WOLFSSL_RW_THREADED
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
#endif
}

return 0;
}
Expand All @@ -730,15 +755,23 @@ static void Dtls13RtxFlushAcks(WOLFSSL* ssl)

WOLFSSL_ENTER("Dtls13RtxFlushAcks");

list = ssl->dtls13Rtx.seenRecords;
#ifdef WOLFSSL_RW_THREADED
if (wc_LockMutex(&ssl->dtls13Rtx.mutex) == 0)
#endif
{
list = ssl->dtls13Rtx.seenRecords;

while (list != NULL) {
rn = list;
list = rn->next;
XFREE(rn, ssl->heap, DYNAMIC_TYPE_DTLS_MSG);
}
while (list != NULL) {
rn = list;
list = rn->next;
XFREE(rn, ssl->heap, DYNAMIC_TYPE_DTLS_MSG);
}

ssl->dtls13Rtx.seenRecords = NULL;
ssl->dtls13Rtx.seenRecords = NULL;
#ifdef WOLFSSL_RW_THREADED
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
#endif
}
}

static int Dtls13DetectDisruption(WOLFSSL* ssl, word32 fragOffset)
Expand Down Expand Up @@ -2519,13 +2552,25 @@ static void Dtls13RtxRemoveRecord(WOLFSSL* ssl, w64wrapper epoch,
int Dtls13DoScheduledWork(WOLFSSL* ssl)
{
int ret;
int sendAcks;

WOLFSSL_ENTER("Dtls13DoScheduledWork");

ssl->dtls13SendingAckOrRtx = 1;

if (ssl->dtls13Rtx.sendAcks) {
#ifdef WOLFSSL_RW_THREADED
ret = wc_LockMutex(&ssl->dtls13Rtx.mutex);
if (ret < 0)
return ret;
#endif
sendAcks = ssl->dtls13Rtx.sendAcks;
if (sendAcks) {
ssl->dtls13Rtx.sendAcks = 0;
}
#ifdef WOLFSSL_RW_THREADED
ret = wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
#endif
if (sendAcks) {
ret = SendDtls13Ack(ssl);
if (ret != 0)
return ret;
Expand Down Expand Up @@ -2601,13 +2646,28 @@ static int Dtls13RtxHasKeyUpdateBuffered(WOLFSSL* ssl)
return 0;
}

int DoDtls13KeyUpdateAck(WOLFSSL* ssl)
{
int ret = 0;

if (!Dtls13RtxHasKeyUpdateBuffered(ssl)) {
/* we removed the KeyUpdate message because it was ACKed */
ssl->dtls13WaitKeyUpdateAck = 0;
ret = Dtls13KeyUpdateAckReceived(ssl);
}

return ret;
}

int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize,
word32* processedSize)
{
const byte* ackMessage;
w64wrapper epoch, seq;
word16 length;
#ifndef WOLFSSL_RW_THREADED
int ret;
#endif
int i;

if (inputSize < OPAQUE16_LEN)
Expand Down Expand Up @@ -2639,15 +2699,13 @@ int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize,
ssl->options.serverState = SERVER_FINISHED_ACKED;
}

#ifndef WOLFSSL_RW_THREADED
if (ssl->dtls13WaitKeyUpdateAck) {
if (!Dtls13RtxHasKeyUpdateBuffered(ssl)) {
/* we removed the KeyUpdate message because it was ACKed */
ssl->dtls13WaitKeyUpdateAck = 0;
ret = Dtls13KeyUpdateAckReceived(ssl);
if (ret != 0)
return ret;
}
ret = DoDtls13KeyUpdateAck(ssl);
if (ret != 0)
return ret;
}
#endif

*processedSize = length + OPAQUE16_LEN;

Expand Down Expand Up @@ -2698,9 +2756,17 @@ int SendDtls13Ack(WOLFSSL* ssl)
if (ret != 0)
return ret;

ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords, &length);
if (ret != 0)
#ifdef WOLFSSL_RW_THREADED
ret = wc_LockMutex(&ssl->dtls13Rtx.mutex);
if (ret < 0)
return ret;
#endif
ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords, &length);
#ifdef WOLFSSL_RW_THREADED
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
#endif
if (ret != 0)
return ret;

output = GetOutputBuffer(ssl);

Expand Down
Loading

0 comments on commit abc6edf

Please sign in to comment.