From fe7b1f92345b18e7082f9c1e88c8cb554b9b1d2e Mon Sep 17 00:00:00 2001 From: Jukka Rissanen Date: Tue, 22 Oct 2024 16:18:00 +0300 Subject: [PATCH] net: sockets: Cleanup socket properly if POSIX API is enabled MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The sock_obj_core_dealloc() was not called if close() is called instead of zsock_close(). This happens if POSIX API is enabled. Fix this by calling zvfs_close() from zsock_close() and then pass the socket number to zsock_close_ctx() so that the cleanup can be done properly. Reported-by: Andreas Ålgård Signed-off-by: Jukka Rissanen --- include/zephyr/sys/fdtable.h | 5 +++- lib/os/fdtable.c | 9 ++++++- subsys/net/lib/sockets/sockets.c | 32 ++++++----------------- subsys/net/lib/sockets/sockets_inet.c | 22 +++++++++++----- subsys/net/lib/sockets/sockets_internal.h | 2 +- subsys/net/lib/sockets/sockets_packet.c | 6 ++--- subsys/net/lib/sockets/sockets_tls.c | 12 ++++++--- 7 files changed, 48 insertions(+), 40 deletions(-) diff --git a/include/zephyr/sys/fdtable.h b/include/zephyr/sys/fdtable.h index dbe134e70f6d37..3d0b71494e0d02 100644 --- a/include/zephyr/sys/fdtable.h +++ b/include/zephyr/sys/fdtable.h @@ -64,7 +64,10 @@ struct fd_op_vtable { ssize_t (*write)(void *obj, const void *buf, size_t sz); ssize_t (*write_offs)(void *obj, const void *buf, size_t sz, size_t offset); }; - int (*close)(void *obj); + union { + int (*close)(void *obj); + int (*close2)(void *obj, int fd); + }; int (*ioctl)(void *obj, unsigned int request, va_list args); }; diff --git a/lib/os/fdtable.c b/lib/os/fdtable.c index 5baa8412e93216..a2514b1c296bee 100644 --- a/lib/os/fdtable.c +++ b/lib/os/fdtable.c @@ -389,7 +389,14 @@ int zvfs_close(int fd) (void)k_mutex_lock(&fdtable[fd].lock, K_FOREVER); if (fdtable[fd].vtable->close != NULL) { /* close() is optional - e.g. stdinout_fd_op_vtable */ - res = fdtable[fd].vtable->close(fdtable[fd].obj); + if (fdtable[fd].mode & ZVFS_MODE_IFSOCK) { + /* Network socket needs to know socket number so pass + * it via close2() call. + */ + res = fdtable[fd].vtable->close2(fdtable[fd].obj, fd); + } else { + res = fdtable[fd].vtable->close(fdtable[fd].obj); + } } k_mutex_unlock(&fdtable[fd].lock); diff --git a/subsys/net/lib/sockets/sockets.c b/subsys/net/lib/sockets/sockets.c index 007606ac84f8f3..441ed438472b44 100644 --- a/subsys/net/lib/sockets/sockets.c +++ b/subsys/net/lib/sockets/sockets.c @@ -147,42 +147,26 @@ static inline int z_vrfy_zsock_socket(int family, int type, int proto) #include #endif /* CONFIG_USERSPACE */ +extern int zvfs_close(int fd); + int z_impl_zsock_close(int sock) +{ + return zvfs_close(sock); +} + +#ifdef CONFIG_USERSPACE +static inline int z_vrfy_zsock_close(int sock) { const struct socket_op_vtable *vtable; struct k_mutex *lock; void *ctx; - int ret; - - SYS_PORT_TRACING_OBJ_FUNC_ENTER(socket, close, sock); ctx = get_sock_vtable(sock, &vtable, &lock); if (ctx == NULL) { errno = EBADF; - SYS_PORT_TRACING_OBJ_FUNC_EXIT(socket, close, sock, -errno); return -1; } - (void)k_mutex_lock(lock, K_FOREVER); - - NET_DBG("close: ctx=%p, fd=%d", ctx, sock); - - ret = vtable->fd_vtable.close(ctx); - - k_mutex_unlock(lock); - - SYS_PORT_TRACING_OBJ_FUNC_EXIT(socket, close, sock, ret < 0 ? -errno : ret); - - zvfs_free_fd(sock); - - (void)sock_obj_core_dealloc(sock); - - return ret; -} - -#ifdef CONFIG_USERSPACE -static inline int z_vrfy_zsock_close(int sock) -{ return z_impl_zsock_close(sock); } #include diff --git a/subsys/net/lib/sockets/sockets_inet.c b/subsys/net/lib/sockets/sockets_inet.c index 88e4e47bfe4a6d..acb9d470a5ea95 100644 --- a/subsys/net/lib/sockets/sockets_inet.c +++ b/subsys/net/lib/sockets/sockets_inet.c @@ -137,10 +137,14 @@ static int zsock_socket_internal(int family, int type, int proto) return fd; } -int zsock_close_ctx(struct net_context *ctx) +int zsock_close_ctx(struct net_context *ctx, int sock) { int ret; + SYS_PORT_TRACING_OBJ_FUNC_ENTER(socket, close, sock); + + NET_DBG("close: ctx=%p, fd=%d", ctx, sock); + /* Reset callbacks to avoid any race conditions while * flushing queues. No need to check return values here, * as these are fail-free operations and we're closing @@ -160,10 +164,16 @@ int zsock_close_ctx(struct net_context *ctx) ret = net_context_put(ctx); if (ret < 0) { errno = -ret; - return -1; + ret = -1; } - return 0; + SYS_PORT_TRACING_OBJ_FUNC_EXIT(socket, close, sock, ret < 0 ? -errno : ret); + + if (ret == 0) { + (void)sock_obj_core_dealloc(sock); + } + + return ret; } static void zsock_accepted_cb(struct net_context *new_ctx, @@ -2771,9 +2781,9 @@ static int sock_setsockopt_vmeth(void *obj, int level, int optname, return zsock_setsockopt_ctx(obj, level, optname, optval, optlen); } -static int sock_close_vmeth(void *obj) +static int sock_close2_vmeth(void *obj, int fd) { - return zsock_close_ctx(obj); + return zsock_close_ctx(obj, fd); } static int sock_getpeername_vmeth(void *obj, struct sockaddr *addr, socklen_t *addrlen) @@ -2791,7 +2801,7 @@ const struct socket_op_vtable sock_fd_op_vtable = { .fd_vtable = { .read = sock_read_vmeth, .write = sock_write_vmeth, - .close = sock_close_vmeth, + .close2 = sock_close2_vmeth, .ioctl = sock_ioctl_vmeth, }, .shutdown = sock_shutdown_vmeth, diff --git a/subsys/net/lib/sockets/sockets_internal.h b/subsys/net/lib/sockets/sockets_internal.h index 7efd6745ed387a..647d45834439e6 100644 --- a/subsys/net/lib/sockets/sockets_internal.h +++ b/subsys/net/lib/sockets/sockets_internal.h @@ -15,7 +15,7 @@ #define SOCK_NONBLOCK 2 #define SOCK_ERROR 4 -int zsock_close_ctx(struct net_context *ctx); +int zsock_close_ctx(struct net_context *ctx, int sock); int zsock_poll_internal(struct zsock_pollfd *fds, int nfds, k_timeout_t timeout); int zsock_wait_data(struct net_context *ctx, k_timeout_t *timeout); diff --git a/subsys/net/lib/sockets/sockets_packet.c b/subsys/net/lib/sockets/sockets_packet.c index 3fe9258c7996b7..ec2d7e18539323 100644 --- a/subsys/net/lib/sockets/sockets_packet.c +++ b/subsys/net/lib/sockets/sockets_packet.c @@ -462,16 +462,16 @@ static int packet_sock_setsockopt_vmeth(void *obj, int level, int optname, return zpacket_setsockopt_ctx(obj, level, optname, optval, optlen); } -static int packet_sock_close_vmeth(void *obj) +static int packet_sock_close2_vmeth(void *obj, int fd) { - return zsock_close_ctx(obj); + return zsock_close_ctx(obj, fd); } static const struct socket_op_vtable packet_sock_fd_op_vtable = { .fd_vtable = { .read = packet_sock_read_vmeth, .write = packet_sock_write_vmeth, - .close = packet_sock_close_vmeth, + .close2 = packet_sock_close2_vmeth, .ioctl = packet_sock_ioctl_vmeth, }, .bind = packet_sock_bind_vmeth, diff --git a/subsys/net/lib/sockets/sockets_tls.c b/subsys/net/lib/sockets/sockets_tls.c index 57120321b3a2bc..bb7b44097d09ff 100644 --- a/subsys/net/lib/sockets/sockets_tls.c +++ b/subsys/net/lib/sockets/sockets_tls.c @@ -2108,7 +2108,7 @@ static int ztls_socket(int family, int type, int proto) return -1; } -int ztls_close_ctx(struct tls_context *ctx) +int ztls_close_ctx(struct tls_context *ctx, int sock) { int ret, err = 0; @@ -2120,6 +2120,10 @@ int ztls_close_ctx(struct tls_context *ctx) err = tls_release(ctx); ret = zsock_close(ctx->sock); + if (ret == 0) { + (void)sock_obj_core_dealloc(sock); + } + /* In case close fails, we propagate errno value set by close. * In case close succeeds, but tls_release fails, set errno * according to tls_release return value. @@ -3826,9 +3830,9 @@ static int tls_sock_setsockopt_vmeth(void *obj, int level, int optname, return ztls_setsockopt_ctx(obj, level, optname, optval, optlen); } -static int tls_sock_close_vmeth(void *obj) +static int tls_sock_close2_vmeth(void *obj, int sock) { - return ztls_close_ctx(obj); + return ztls_close_ctx(obj, sock); } static int tls_sock_getpeername_vmeth(void *obj, struct sockaddr *addr, @@ -3851,7 +3855,7 @@ static const struct socket_op_vtable tls_sock_fd_op_vtable = { .fd_vtable = { .read = tls_sock_read_vmeth, .write = tls_sock_write_vmeth, - .close = tls_sock_close_vmeth, + .close2 = tls_sock_close2_vmeth, .ioctl = tls_sock_ioctl_vmeth, }, .shutdown = tls_sock_shutdown_vmeth,