diff --git a/src/connection.h b/src/connection.h index 0762441732..74bd23e5a5 100644 --- a/src/connection.h +++ b/src/connection.h @@ -56,6 +56,7 @@ typedef enum { #define CONN_FLAG_CLOSE_SCHEDULED (1 << 0) /* Closed scheduled by a handler */ #define CONN_FLAG_WRITE_BARRIER (1 << 1) /* Write barrier requested */ +#define CONN_FLAG_CLIENT (1 << 2) /* Connection is of a client - not a cluster link. */ #define CONN_TYPE_SOCKET "tcp" #define CONN_TYPE_UNIX "unix" diff --git a/src/io_threads.c b/src/io_threads.c index f4471b96d0..b08faa13db 100644 --- a/src/io_threads.c +++ b/src/io_threads.c @@ -9,6 +9,7 @@ static __thread int thread_id = 0; /* Thread local var */ static pthread_t io_threads[IO_THREADS_MAX_NUM] = {0}; static pthread_mutex_t io_threads_mutex[IO_THREADS_MAX_NUM]; +void (*tls_negotiation_cb)(void *); /* IO jobs queue functions - Used to send jobs from the main-thread to the IO thread. */ typedef void (*job_handler)(void *); @@ -554,3 +555,56 @@ void trySendPollJobToIOThreads(void) { aeSetPollProtect(server.el, 1); IOJobQueue_push(jq, IOThreadPoll, server.el); } + +void setTLSNegotiationCallback(void (*cb)(void *)) { + tls_negotiation_cb = cb; +} + +static void ioThreadTLSNegotiation(void *data) { + client *c = (client *)data; + tls_negotiation_cb(c->conn); + c->io_read_state = CLIENT_COMPLETED_IO; +} + +/* + * This function attempts to offload TLS negotiation for a client connection to an I/O thread. + * Returns C_OK if the TLS negotiation was successfully queued for processing by an I/O thread, + * or C_ERR if the client is not eligible for offloading. + * Parameters: + * conn: The connection object for which TLS negotiation should be performed + */ +int trySendTLSNegotiationToIOThreads(connection *conn) { + if (server.io_threads_num <= 1) { + return C_ERR; + } + + if (!(conn->flags & CONN_FLAG_CLIENT)) { + return C_ERR; + } + + client *c = connGetPrivateData(conn); + if (c->io_read_state != CLIENT_IDLE) { + return C_OK; + } + + if (server.active_io_threads_num <= 1) { + return C_ERR; + } + + size_t thread_id = (c->id % (server.active_io_threads_num - 1)) + 1; + IOJobQueue *job_queue = &io_jobs[thread_id]; + + if (IOJobQueue_isFull(job_queue)) { + return C_ERR; + } + + c->read_flags = READ_FLAGS_TLS_NEGOTIATION; + c->io_read_state = CLIENT_PENDING_IO; + c->flag.pending_read = 1; + listLinkNodeTail(server.clients_pending_io_read, &c->pending_read_list_node); + connSetPostponeUpdateState(c->conn, 1); + server.stat_io_tls_negotiation_offloaded++; + IOJobQueue_push(job_queue, ioThreadTLSNegotiation, c); + + return C_OK; +} diff --git a/src/io_threads.h b/src/io_threads.h index f9a9cf762f..ddb8c51ccf 100644 --- a/src/io_threads.h +++ b/src/io_threads.h @@ -13,5 +13,6 @@ int tryOffloadFreeArgvToIOThreads(client *c); void adjustIOThreadsByEventLoad(int numevents, int increase_only); void drainIOThreadsQueue(void); void trySendPollJobToIOThreads(void); - +int trySendTLSNegotiationToIOThreads(connection *conn); +void setTLSNegotiationCallback(void (*cb)(void *)); #endif /* IO_THREADS_H */ diff --git a/src/networking.c b/src/networking.c index 4791055b5a..03d3d8d65f 100644 --- a/src/networking.c +++ b/src/networking.c @@ -134,6 +134,7 @@ client *createClient(connection *conn) { if (server.tcpkeepalive) connKeepAlive(conn, server.tcpkeepalive); connSetReadHandler(conn, readQueryFromClient); connSetPrivateData(conn, c); + conn->flags |= CONN_FLAG_CLIENT; } c->buf = zmalloc_usable(PROTO_REPLY_CHUNK_BYTES, &c->buf_usable_size); selectDb(c, 0); @@ -4725,6 +4726,9 @@ int processIOThreadsReadDone(void) { connSetPostponeUpdateState(c->conn, 0); connUpdateState(c->conn); + /* No client's data was read only TLS handshake. */ + if (c->read_flags & READ_FLAGS_TLS_NEGOTIATION) continue; + /* On read error - stop here. */ if (handleReadResult(c) == C_ERR) { continue; diff --git a/src/server.c b/src/server.c index 12691df8ee..ea625e6715 100644 --- a/src/server.c +++ b/src/server.c @@ -2604,6 +2604,7 @@ void resetServerStats(void) { server.stat_total_reads_processed = 0; server.stat_io_writes_processed = 0; server.stat_io_freed_objects = 0; + server.stat_io_tls_negotiation_offloaded = 0; server.stat_poll_processed_by_io_threads = 0; server.stat_total_writes_processed = 0; server.stat_client_qbuf_limit_disconnections = 0; @@ -5862,6 +5863,7 @@ sds genValkeyInfoString(dict *section_dict, int all_sections, int everything) { "io_threaded_reads_processed:%lld\r\n", server.stat_io_reads_processed, "io_threaded_writes_processed:%lld\r\n", server.stat_io_writes_processed, "io_threaded_freed_objects:%lld\r\n", server.stat_io_freed_objects, + "io_threaded_tls_negotiations:%lld\r\n", server.stat_io_tls_negotiation_offloaded, "io_threaded_poll_processed:%lld\r\n", server.stat_poll_processed_by_io_threads, "io_threaded_total_prefetch_batches:%lld\r\n", server.stat_total_prefetch_batches, "io_threaded_total_prefetch_entries:%lld\r\n", server.stat_total_prefetch_entries, diff --git a/src/server.h b/src/server.h index 5ef04a9080..358297b2f3 100644 --- a/src/server.h +++ b/src/server.h @@ -1841,6 +1841,7 @@ struct valkeyServer { long long stat_io_reads_processed; /* Number of read events processed by IO threads */ long long stat_io_writes_processed; /* Number of write events processed by IO threads */ long long stat_io_freed_objects; /* Number of objects freed by IO threads */ + long long stat_io_tls_negotiation_offloaded; /* Number of TLS negotiation offloads */ long long stat_poll_processed_by_io_threads; /* Total number of poll jobs processed by IO */ long long stat_total_reads_processed; /* Total number of read events processed */ long long stat_total_writes_processed; /* Total number of write events processed */ @@ -2767,6 +2768,7 @@ void dictVanillaFree(void *val); #define READ_FLAGS_PRIMARY (1 << 14) #define READ_FLAGS_DONT_PARSE (1 << 15) #define READ_FLAGS_AUTH_REQUIRED (1 << 16) +#define READ_FLAGS_TLS_NEGOTIATION (1 << 17) /* Write flags for various write errors and states */ #define WRITE_FLAGS_WRITE_ERROR (1 << 0) diff --git a/src/tls.c b/src/tls.c index a1fda2a7ae..1ec7d98096 100644 --- a/src/tls.c +++ b/src/tls.c @@ -32,6 +32,7 @@ #include "server.h" #include "connhelpers.h" #include "adlist.h" +#include "io_threads.h" #if (USE_OPENSSL == 1 /* BUILD_YES */) || ((USE_OPENSSL == 2 /* BUILD_MODULE */) && (BUILD_TLS_MODULE == 2)) @@ -61,6 +62,8 @@ SSL_CTX *valkey_tls_ctx = NULL; SSL_CTX *valkey_tls_client_ctx = NULL; +static void TLSNegotiate(void *conn); + static int parseProtocolsConfig(const char *str) { int i, count = 0; int protocols = 0; @@ -116,6 +119,8 @@ static list *pending_list = NULL; static pthread_mutex_t *openssl_locks; +static void TLSNegotiate(void *conn); + static void sslLockingCallback(int mode, int lock_id, const char *f, int line) { pthread_mutex_t *mt = openssl_locks + lock_id; @@ -170,6 +175,7 @@ static void tlsInit(void) { } pending_list = listCreate(); + setTLSNegotiationCallback(TLSNegotiate); } static void tlsCleanup(void) { @@ -518,16 +524,16 @@ static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, in static void updateSSLEvent(tls_connection *conn); /* Process the return code received from OpenSSL> - * Update the want parameter with expected I/O. + * Update the conn flags with the WANT_READ/WANT_WRITE flags. * Update the connection's error state if a real error has occurred. * Returns an SSL error code, or 0 if no further handling is required. */ -static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType *want) { +static int handleSSLReturnCode(tls_connection *conn, int ret_value) { if (ret_value <= 0) { int ssl_err = SSL_get_error(conn->ssl, ret_value); switch (ssl_err) { - case SSL_ERROR_WANT_WRITE: *want = WANT_WRITE; return 0; - case SSL_ERROR_WANT_READ: *want = WANT_READ; return 0; + case SSL_ERROR_WANT_WRITE: conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE; return 0; + case SSL_ERROR_WANT_READ: conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ; return 0; case SSL_ERROR_SYSCALL: conn->c.last_errno = errno; if (conn->ssl_error) zfree(conn->ssl_error); @@ -563,11 +569,8 @@ static int updateStateAfterSSLIO(tls_connection *conn, int ret_value, int update } if (ret_value <= 0) { - WantIOType want = 0; int ssl_err; - if (!(ssl_err = handleSSLReturnCode(conn, ret_value, &want))) { - if (want == WANT_READ) conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ; - if (want == WANT_WRITE) conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE; + if (!(ssl_err = handleSSLReturnCode(conn, ret_value))) { if (update_event) updateSSLEvent(conn); errno = EAGAIN; return -1; @@ -585,19 +588,34 @@ static int updateStateAfterSSLIO(tls_connection *conn, int ret_value, int update return ret_value; } -static void registerSSLEvent(tls_connection *conn, WantIOType want) { +static void updateSSLEvent(tls_connection *conn) { + if (conn->flags & TLS_CONN_FLAG_POSTPONE_UPDATE_STATE) return; + int mask = aeGetFileEvents(server.el, conn->c.fd); + int need_read, need_write; - switch (want) { - case WANT_READ: - if (mask & AE_WRITABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); - if (!(mask & AE_READABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, tlsEventHandler, conn); - break; - case WANT_WRITE: - if (mask & AE_READABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); - if (!(mask & AE_WRITABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn); - break; - default: serverAssert(0); break; + if (conn->c.state == CONN_STATE_CONNECTED) { + /* When connected, check both flags and handlers. */ + need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ); + need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE); + } else { + /* When not connected, only check flags. */ + need_read = conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ; + need_write = conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE; + } + + /* Add events that are needed */ + if (need_read && !(mask & AE_READABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, tlsEventHandler, conn); + if (need_write && !(mask & AE_WRITABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn); + + if (conn->c.state == CONN_STATE_CONNECTED) { + /* Remove events that are no longer needed */ + if (!need_read && (mask & AE_READABLE)) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); + if (!need_write && (mask & AE_WRITABLE)) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); + } else { + /* When not connected, read and write events are mutually exclusive we need to remove the opposite event. */ + if (need_read && (mask & AE_WRITABLE)) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); + if (need_write && (mask & AE_READABLE)) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); } } @@ -634,28 +652,37 @@ void updateSSLPendingFlag(tls_connection *conn) { } } -static void updateSSLEvent(tls_connection *conn) { - if (conn->flags & TLS_CONN_FLAG_POSTPONE_UPDATE_STATE) return; - - int mask = aeGetFileEvents(server.el, conn->c.fd); - int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ); - int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE); - - if (need_read && !(mask & AE_READABLE)) - aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, tlsEventHandler, conn); - if (!need_read && (mask & AE_READABLE)) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); - - if (need_write && !(mask & AE_WRITABLE)) - aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn); - if (!need_write && (mask & AE_WRITABLE)) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); -} - static void updateSSLState(connection *conn_) { tls_connection *conn = (tls_connection *)conn_; updateSSLEvent(conn); + if (conn->c.conn_handler && conn->c.state != CONN_STATE_ACCEPTING) { + /* If the conn handler is set, we need to call it to notify it that the connection is ready. */ + callHandler((connection *)conn, conn->c.conn_handler); + conn->c.conn_handler = NULL; + return; + } updatePendingData(conn); } +static void clearTLSWantFlags(tls_connection *conn) { + conn->flags &= ~(TLS_CONN_FLAG_WRITE_WANT_READ | TLS_CONN_FLAG_READ_WANT_WRITE); +} + +static void TLSNegotiate(void *_conn) { + tls_connection *conn = (tls_connection *)_conn; + ERR_clear_error(); + clearTLSWantFlags(conn); + + int ret = SSL_accept(conn->ssl); + if (ret > 0) { + conn->c.state = CONN_STATE_CONNECTED; + } else if (!handleSSLReturnCode(conn, ret)) { + updateSSLEvent(conn); + } else { + conn->c.state = CONN_STATE_ERROR; + } +} + static void tlsHandleEvent(tls_connection *conn, int mask) { int ret, conn_error; @@ -669,6 +696,7 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { conn->c.last_errno = conn_error; conn->c.state = CONN_STATE_ERROR; } else { + clearTLSWantFlags(conn); ERR_clear_error(); if (!(conn->flags & TLS_CONN_FLAG_FD_SET)) { SSL_set_fd(conn->ssl, conn->c.fd); @@ -676,14 +704,8 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { } ret = SSL_connect(conn->ssl); if (ret <= 0) { - WantIOType want = 0; - if (!handleSSLReturnCode(conn, ret, &want)) { - registerSSLEvent(conn, want); - - /* Avoid hitting UpdateSSLEvent, which knows nothing - * of what SSL_connect() wants and instead looks at our - * R/W handlers. - */ + if (!handleSSLReturnCode(conn, ret)) { + updateSSLEvent(conn); return; } @@ -698,25 +720,9 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { conn->c.conn_handler = NULL; break; case CONN_STATE_ACCEPTING: - ERR_clear_error(); - ret = SSL_accept(conn->ssl); - if (ret <= 0) { - WantIOType want = 0; - if (!handleSSLReturnCode(conn, ret, &want)) { - /* Avoid hitting UpdateSSLEvent, which knows nothing - * of what SSL_connect() wants and instead looks at our - * R/W handlers. - */ - registerSSLEvent(conn, want); - return; - } - - /* If not handled, it's an error */ - conn->c.state = CONN_STATE_ERROR; - } else { - conn->c.state = CONN_STATE_CONNECTED; - } - + if (trySendTLSNegotiationToIOThreads((connection *)conn) == C_OK) return; + TLSNegotiate((connection *)conn); + if (conn->c.state == CONN_STATE_ACCEPTING) return; /* Still pending negotiation */ if (!callHandler((connection *)conn, conn->c.conn_handler)) return; conn->c.conn_handler = NULL; break; @@ -740,20 +746,17 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { int invert = conn->c.flags & CONN_FLAG_WRITE_BARRIER; if (!invert && call_read) { - conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; if (!callHandler((connection *)conn, conn->c.read_handler)) return; } /* Fire the writable event. */ if (call_write) { - conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ; if (!callHandler((connection *)conn, conn->c.write_handler)) return; } /* If we have to invert the call, fire the readable event now * after the writable one. */ if (invert && call_read) { - conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; if (!callHandler((connection *)conn, conn->c.read_handler)) return; } updatePendingData(conn); @@ -841,27 +844,17 @@ static void connTLSClose(connection *conn_) { static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) { tls_connection *conn = (tls_connection *)_conn; - int ret; - if (conn->c.state != CONN_STATE_ACCEPTING) return C_ERR; - ERR_clear_error(); /* Try to accept */ conn->c.conn_handler = accept_handler; - ret = SSL_accept(conn->ssl); + if (trySendTLSNegotiationToIOThreads(_conn) == C_OK) return C_OK; - if (ret <= 0) { - WantIOType want = 0; - if (!handleSSLReturnCode(conn, ret, &want)) { - registerSSLEvent(conn, want); /* We'll fire back */ - return C_OK; - } else { - conn->c.state = CONN_STATE_ERROR; - return C_ERR; - } - } + TLSNegotiate(_conn); + + if (conn->c.state == CONN_STATE_ERROR) return C_ERR; + if (conn->c.state == CONN_STATE_ACCEPTING) return C_OK; /* Still pending negotiation. */ - conn->c.state = CONN_STATE_CONNECTED; if (!callHandler((connection *)conn, conn->c.conn_handler)) return C_OK; conn->c.conn_handler = NULL; @@ -898,6 +891,7 @@ static int connTLSWrite(connection *conn_, const void *data, size_t data_len) { int ret; if (conn->c.state != CONN_STATE_CONNECTED) return -1; + clearTLSWantFlags(conn); ERR_clear_error(); ret = SSL_write(conn->ssl, data, data_len); return updateStateAfterSSLIO(conn, ret, 1); @@ -946,6 +940,7 @@ static int connTLSRead(connection *conn_, void *buf, size_t buf_len) { if (conn->c.state != CONN_STATE_CONNECTED) return -1; ERR_clear_error(); + clearTLSWantFlags(conn); ret = SSL_read(conn->ssl, buf, buf_len); updateSSLPendingFlag(conn); return updateStateAfterSSLIO(conn, ret, 1); @@ -1014,6 +1009,7 @@ static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long tls_connection *conn = (tls_connection *)conn_; setBlockingTimeout(conn, timeout); + clearTLSWantFlags(conn); SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE); ERR_clear_error(); int ret = SSL_write(conn->ssl, ptr, size); @@ -1028,6 +1024,7 @@ static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long tls_connection *conn = (tls_connection *)conn_; setBlockingTimeout(conn, timeout); + clearTLSWantFlags(conn); ERR_clear_error(); int ret = SSL_read(conn->ssl, ptr, size); updateSSLPendingFlag(conn); @@ -1042,6 +1039,7 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l ssize_t nread = 0; setBlockingTimeout(conn, timeout); + clearTLSWantFlags(conn); size--; while (size) {