Skip to content

Commit

Permalink
Preserve original fd blocking state in TLS I/O operations
Browse files Browse the repository at this point in the history
This change prevents unintended side effects on connection state
and improves consistency with non-TLS sync operations.

Signed-off-by: xbasel <103044017+xbasel@users.noreply.github.com>
  • Loading branch information
xbasel committed Nov 21, 2024
1 parent 2df56d8 commit aa0db5c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
30 changes: 26 additions & 4 deletions src/anet.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,24 @@ int anetGetError(int fd) {
return sockerr;
}

int anetSetBlock(char *err, int fd, int non_block) {
static int anetGetSocketFlags(char *err, int fd) {
int flags;

/* Set the socket blocking (if non_block is zero) or non-blocking.
* Note that fcntl(2) for F_GETFL and F_SETFL can't be
* interrupted by a signal. */
if ((flags = fcntl(fd, F_GETFL)) == -1) {
anetSetError(err, "fcntl(F_GETFL): %s", strerror(errno));
return ANET_ERR;
}

return flags;
}

int anetSetBlock(char *err, int fd, int non_block) {
int flags = anetGetSocketFlags(err, fd);

if (flags == ANET_ERR) {
return ANET_ERR;
}

/* Check if this flag has been set or unset, if so,
* then there is no need to call fcntl to set/unset it again. */
if (!!(flags & O_NONBLOCK) == !!non_block) return ANET_OK;
Expand All @@ -105,6 +112,21 @@ int anetBlock(char *err, int fd) {
return anetSetBlock(err, fd, 0);
}

int anetIsBlock(char *err, int fd) {
int flags = anetGetSocketFlags(err, fd);

if (flags == ANET_ERR) {
return ANET_ERR;
}

/* Check if the O_NONBLOCK flag is set */
if (flags & O_NONBLOCK) {
return 0; /* Socket is non-blocking */
} else {
return 1; /* Socket is blocking */
}
}

/* Enable the FD_CLOEXEC on the given fd to avoid fd leaks.
* This function should be invoked for fd's on specific places
* where fork + execve system calls are called. */
Expand Down
1 change: 1 addition & 0 deletions src/anet.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ int anetTcpAccept(char *err, int serversock, char *ip, size_t ip_len, int *port)
int anetUnixAccept(char *err, int serversock);
int anetNonBlock(char *err, int fd);
int anetBlock(char *err, int fd);
int anetIsBlock(char *err, int fd);
int anetCloexec(int fd);
int anetEnableTcpNoDelay(char *err, int fd);
int anetDisableTcpNoDelay(char *err, int fd);
Expand Down
21 changes: 16 additions & 5 deletions src/tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,10 @@ static int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func)
return C_OK;
}

static int isBlocking(tls_connection *conn) {
return anetIsBlock(NULL, conn->c.fd);
}

static void setBlockingTimeout(tls_connection *conn, long long timeout) {
anetBlock(NULL, conn->c.fd);
anetSendTimeout(NULL, conn->c.fd, timeout);
Expand Down Expand Up @@ -1005,26 +1009,30 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port,

static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long long timeout) {
tls_connection *conn = (tls_connection *)conn_;

int blocking = isBlocking(conn);
setBlockingTimeout(conn, timeout);
SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
ERR_clear_error();
int ret = SSL_write(conn->ssl, ptr, size);
ret = updateStateAfterSSLIO(conn, ret, 0);
SSL_set_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
unsetBlockingTimeout(conn);
if (!blocking) {
unsetBlockingTimeout(conn);
}

return ret;
}

static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long long timeout) {
tls_connection *conn = (tls_connection *)conn_;

int blocking = isBlocking(conn);
setBlockingTimeout(conn, timeout);
ERR_clear_error();
int ret = SSL_read(conn->ssl, ptr, size);
ret = updateStateAfterSSLIO(conn, ret, 0);
unsetBlockingTimeout(conn);
if (!blocking) {
unsetBlockingTimeout(conn);
}

return ret;
}
Expand All @@ -1033,6 +1041,7 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l
tls_connection *conn = (tls_connection *)conn_;
ssize_t nread = 0;

int blocking = isBlocking(conn);
setBlockingTimeout(conn, timeout);

size--;
Expand All @@ -1058,7 +1067,9 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l
size--;
}
exit:
unsetBlockingTimeout(conn);
if (!blocking) {
unsetBlockingTimeout(conn);
}
return nread;
}

Expand Down

0 comments on commit aa0db5c

Please sign in to comment.