From 928b93c500b847a07416c4a93dd1d19b549ce64f Mon Sep 17 00:00:00 2001 From: Huiba Li Date: Wed, 18 Sep 2024 21:24:23 +0800 Subject: [PATCH] add an emulation layer for state threads, passing its UT (#561) --- include/photon/thread/st.h | 1 + thread/st.cpp | 415 ++++++++++++++++++++++++++++ thread/st.h | 141 ++++++++++ thread/test/CMakeLists.txt | 6 +- thread/test/st_utest.cpp | 43 +++ thread/test/st_utest.hpp | 127 +++++++++ thread/test/st_utest_coroutines.cpp | 120 ++++++++ thread/test/st_utest_tcp.cpp | 92 ++++++ thread/thread.cpp | 27 +- thread/thread.h | 5 +- 10 files changed, 962 insertions(+), 15 deletions(-) create mode 120000 include/photon/thread/st.h create mode 100644 thread/st.cpp create mode 100644 thread/st.h create mode 100644 thread/test/st_utest.cpp create mode 100644 thread/test/st_utest.hpp create mode 100644 thread/test/st_utest_coroutines.cpp create mode 100644 thread/test/st_utest_tcp.cpp diff --git a/include/photon/thread/st.h b/include/photon/thread/st.h new file mode 120000 index 00000000..22fd5bdf --- /dev/null +++ b/include/photon/thread/st.h @@ -0,0 +1 @@ +../../../thread/st.h \ No newline at end of file diff --git a/thread/st.cpp b/thread/st.cpp new file mode 100644 index 00000000..de8510b0 --- /dev/null +++ b/thread/st.cpp @@ -0,0 +1,415 @@ +#include "st.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int _eventsys = 0; +int st_get_eventsys(void) { + return _eventsys; +} + +int st_set_eventsys(int eventsys) { + auto es = (uint32_t)eventsys; + if (es > ST_EVENTSYS_IOURING) + LOG_ERROR_RETURN(EINVAL, -1, "unknown eventsys ", eventsys); + _eventsys = es; + return 0; +} + +int st_init(void) { + auto engine = photon::INIT_EVENT_DEFAULT; +#if defined(__linux__) + if (_eventsys == ST_EVENTSYS_IOURING) { + engine ^= photon::INIT_EVENT_EPOLL; + } +#endif + return photon::init(engine, 0, { + .libaio_queue_depth = 0, + .use_pooled_stack_allocator = true, + .bypass_threadpool = true, + }); +} + +int st_getfdlimit(void) { + struct rlimit rlim; + if (getrlimit(RLIMIT_NOFILE, &rlim) < 0) + LOG_ERRNO_RETURN(0, -1, "failed to getrlimit()"); + return rlim.rlim_max; +} + +const char *st_get_eventsys_name(void) { + return "event_sys_name"; +} + +// st_switch_cb_t st_set_switch_in_cb(st_switch_cb_t cb); +// st_switch_cb_t st_set_switch_out_cb(st_switch_cb_t cb); + +st_thread_t st_thread_create(void *(*start)(void *arg), void *arg, + int joinable, int stack_size) { + if (stack_size == 0) + stack_size = photon::DEFAULT_STACK_SIZE; + auto th = photon::thread_create(start, arg, stack_size); + if (joinable) photon::thread_enable_join(th); + return th; +} + +void st_thread_exit(void *retval) { + photon::thread_exit(retval); +} + +int st_thread_join(st_thread_t thread, void **retvalp) { + auto retval = photon::thread_join((photon::join_handle*)thread); + if (retvalp) *retvalp = retval; + return 0; +} + +st_thread_t st_thread_self(void) { + return photon::CURRENT; +} + +void st_thread_interrupt(st_thread_t th) { + thread_interrupt((photon::thread*)th); +} + +int st_sleep(int secs) { + return photon::thread_sleep(secs); +} + +int st_usleep(st_utime_t usecs) { + return photon::thread_usleep(usecs); +} + +int st_randomize_stacks(int on) { + return 0; +} + +int st_key_create(int *keyp, void (*destructor)(void *)) { + photon::thread_key_t key; + auto ret = photon::thread_key_create(&key, destructor); + if (ret < 0) + LOG_ERRNO_RETURN(0, -1, "failed to thread_key_create()"); + if (key > INT_MAX) { + photon::thread_key_delete(key); + LOG_ERROR_RETURN(EOVERFLOW, -1, "thread key space overflow"); + } + *keyp = key; + return 0; +} + +int st_key_getlimit(void) { + return photon::THREAD_KEYS_MAX; +} + +int st_thread_setspecific(int key, void *value) { + photon::thread_key_t k = key; + return photon::thread_setspecific(k, value); +} + +void *st_thread_getspecific(int key) { + photon::thread_key_t k = key; + return photon::thread_getspecific(k); +} + +// Synchronization +st_cond_t st_cond_new(void) { + return new photon::condition_variable; +} + +int st_cond_destroy(st_cond_t cvar) { + delete (photon::condition_variable*)cvar; + return 0; +} + +int st_cond_wait(st_cond_t cvar) { + return st_cond_timedwait(cvar, -1UL); +} + +int st_cond_timedwait(st_cond_t cvar, st_utime_t timeout) { + auto cv = (photon::condition_variable*)cvar; + return cv->wait_no_lock(timeout); +} + +int st_cond_signal(st_cond_t cvar) { + auto cv = (photon::condition_variable*)cvar; + return cv->signal(), 0; +} + +int st_cond_broadcast(st_cond_t cvar) { + auto cv = (photon::condition_variable*)cvar; + return cv->broadcast(), 0; +} + +st_mutex_t st_mutex_new(void) { + return new photon::mutex; +} + +int st_mutex_destroy(st_mutex_t lock) { + delete (photon::mutex*) lock; + return 0; +} + +int st_mutex_lock(st_mutex_t lock) { + auto m = (photon::mutex*) lock; + return m->lock(); +} + +int st_mutex_trylock(st_mutex_t lock) { + auto m = (photon::mutex*) lock; + return m->try_lock(); +} + +int st_mutex_unlock(st_mutex_t lock) { + auto m = (photon::mutex*) lock; + return m->unlock(), 0; +} + +time_t st_time(void) { + return photon::now / 1000 /1000; +} + +st_utime_t st_utime(void) { + return photon::now; +} + +int st_set_utime_function(st_utime_t (*func)(void)) { + return 0; +} + +int st_timecache_set(int on) { + return 0; +} + +struct netfd { + int fd; + void* specific = nullptr; + void (*destructor)(void *); + netfd(int fd) : fd(fd) { } + netfd(int fd, bool) : fd(fd) { + photon::net::set_fd_nonblocking(fd); + } + ~netfd() { + if (specific && destructor) + destructor(specific); + if (fd >= 0) + ::close(fd); + } +}; + +inline int getfd(st_netfd_t fd) { + return static_cast(fd)->fd; +} + +// I/O Functions +st_netfd_t st_netfd_open(int osfd) { + return new netfd(osfd); +} + +st_netfd_t st_netfd_open_socket(int osfd) { + return new netfd(osfd, true); +} + +void st_netfd_free(st_netfd_t fd) { + delete (netfd*)fd; +} + +int st_netfd_close(st_netfd_t fd) { + return ::close(getfd(fd)); +} + +int st_netfd_fileno(st_netfd_t fd) { + return getfd(fd); +} + +void st_netfd_setspecific(st_netfd_t fd, void *value, void (*destructor)(void *)) { + auto _fd = (netfd*)fd; + _fd->specific = value; + _fd->destructor = destructor; +} + +void *st_netfd_getspecific(st_netfd_t fd) { + auto _fd = (netfd*)fd; + return _fd->specific; +} + +// On some platforms (e.g., Solaris 2.5 and possibly other SVR4 implementations) +// accept(3) calls from different processes on the same listening socket (see +// bind(3), listen(3)) must be serialized. This function causes all subsequent +// accept(3) calls made by st_accept() on the specified file descriptor object +// to be serialized. +int st_netfd_serialize_accept(st_netfd_t fd) { + return 0; // we do not support thoses platforms +} + +inline uint32_t to_photon_events(int poll_event) { + uint32_t events = 0; + if (poll_event & POLLIN) events |= photon::EVENT_READ; + if (poll_event & POLLOUT) events |= photon::EVENT_WRITE; + if (poll_event & POLLPRI) events |= photon::EVENT_ERROR; + return events; +} + +int st_netfd_poll(st_netfd_t fd, int how, st_utime_t timeout) { + return photon::get_vcpu()->master_event_engine-> + wait_for_fd(getfd(fd), to_photon_events(how), timeout); +} + +st_netfd_t st_accept(st_netfd_t fd, struct sockaddr *addr, int *addrlen, st_utime_t timeout) { + static_assert(sizeof(socklen_t) == sizeof(int), "..."); + auto connection = photon::net::accept( + getfd(fd), addr, (socklen_t*)addrlen, timeout); + if (connection < 0) + LOG_ERRNO_RETURN(0, nullptr, "failed to accept new connection"); + return st_netfd_open_socket(connection); +} + +int st_connect(st_netfd_t fd, const struct sockaddr *addr, int addrlen, st_utime_t timeout) { + return photon::net::connect(getfd(fd), addr, addrlen, timeout); +} + +ssize_t st_read(st_netfd_t fd, void *buf, size_t nbyte, st_utime_t timeout) { + return photon::net::read(getfd(fd), buf, nbyte, timeout); +} + +ssize_t st_read_fully(st_netfd_t fd, void *buf, size_t nbyte, st_utime_t timeout) { + return photon::net::read_n(getfd(fd), buf, nbyte, timeout); +} + +int st_read_resid(st_netfd_t fd, void *buf, size_t *resid, st_utime_t timeout) { + auto ret = photon::net::read_n(getfd(fd), buf, *resid, timeout); + if (ret > 0) *resid -= ret; + return ret; +} + +ssize_t st_readv(st_netfd_t fd, const struct iovec *iov, int iov_size, st_utime_t timeout) { + return photon::net::readv(getfd(fd), iov, iov_size, timeout); +} + +int st_readv_resid(st_netfd_t fd, struct iovec **iov, int *iov_size, st_utime_t timeout) { + if (unlikely(!iov || !*iov || !iov_size || *iov_size <= 0)) + LOG_ERROR_RETURN(EINVAL, -1, "invalid arguments"); + photon::Timeout tmo(timeout); + iovector_view v(*iov, *iov_size); + auto ret = DOIO_LOOP(photon::net::readv(getfd(fd), v.iov, v.iovcnt, tmo), + photon::net::BufStepV(v)); + *iov = v.iov; + *iov_size = v.iovcnt; + return ret; +} + +ssize_t st_write(st_netfd_t fd, const void *buf, size_t nbyte, st_utime_t timeout) { + return photon::net::write(getfd(fd), buf, nbyte, timeout); +} + +int st_write_resid(st_netfd_t fd, const void *buf, size_t *resid, st_utime_t timeout) { + auto ret = photon::net::write_n(getfd(fd), buf, *resid, timeout); + if (ret > 0) *resid -= ret; + return ret; +} + +ssize_t st_writev(st_netfd_t fd, const struct iovec *iov, int iov_size, st_utime_t timeout) { + return photon::net::writev(getfd(fd), iov, iov_size, timeout); +} + +int st_writev_resid(st_netfd_t fd, struct iovec **iov, int *iov_size, st_utime_t timeout) { + if (unlikely(!iov || !*iov || !iov_size || *iov_size <= 0)) + LOG_ERROR_RETURN(EINVAL, -1, "invalid arguments"); + photon::Timeout tmo(timeout); + iovector_view v(*iov, *iov_size); + // TODO:: this implementation of DOIO_LOOP incurs an extra wait_for_fd() + // in every iteration, should fix it. + auto ret = DOIO_LOOP(photon::net::writev(getfd(fd), v.iov, v.iovcnt, tmo), + photon::net::BufStepV(v)); + *iov = v.iov; + *iov_size = v.iovcnt; + return ret; +} + +using photon::net::doio_once; +using photon::net::doio_loop; + +int st_recvfrom(st_netfd_t fd, void *buf, int len, struct sockaddr *addr, int *addrlen, st_utime_t timeout) { + iovec iov{buf, (size_t)len}; + struct msghdr hdr { + .msg_name = (void*)addr, + .msg_namelen = addrlen ? (socklen_t)*addrlen : 0, + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = nullptr, + .msg_controllen = 0, + .msg_flags = 0, + }; + auto ret = st_recvmsg(fd, &hdr, 0, timeout); + if (addrlen) *addrlen = hdr.msg_namelen; + return ret; +} + +int st_sendto(st_netfd_t fd, const void *buf, int len, struct sockaddr *addr, int addrlen, st_utime_t timeout) { + iovec iov{(void*)buf, (size_t)len}; + struct msghdr hdr { + .msg_name = (void*)addr, + .msg_namelen = (socklen_t)addrlen, + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = nullptr, + .msg_controllen = 0, + .msg_flags = 0, + }; + return st_sendmsg(fd, &hdr, 0, timeout); +} + +int st_recvmsg(st_netfd_t fd, struct msghdr *msg, int flags, st_utime_t timeout) { + return DOIO_ONCE(::recvmsg(getfd(fd), msg, flags | MSG_DONTWAIT), + photon::wait_for_fd_readable(getfd(fd), timeout)); +} + +int st_sendmsg(st_netfd_t fd, const struct msghdr *msg, int flags, st_utime_t timeout) { + return DOIO_ONCE(::sendmsg(getfd(fd), msg, flags | MSG_DONTWAIT | MSG_NOSIGNAL), + photon::wait_for_fd_writable(getfd(fd), timeout)); +} + +st_netfd_t st_open(const char *path, int oflags, mode_t mode) { + int fd = ::open(path, oflags, mode); + if (fd < 0) + LOG_ERRNO_RETURN(0, nullptr, "failed to open(`, `, `) file", path, oflags, mode); + return new netfd(fd); +} + +int st_poll(struct pollfd *pds, int npds, st_utime_t timeout) { + if (!pds || !npds) return 0; + auto eng = photon::new_default_cascading_engine(); + DEFER(delete eng); + for (int i = 0; i < npds; ++i) { + auto& p = pds[i]; + if (p.fd < 0) + LOG_ERROR_RETURN(EINVAL, -1, "invalid fd ", p.fd); + auto events = to_photon_events(p.events); + eng->add_interest({p.fd, events, (void*)(int64_t)i}); + } + constexpr int MAX = 32; + void* data[MAX]; + int n = 0; +again: + auto ret = eng->wait_for_events(data, MAX, timeout); + if (ret < 0) + LOG_ERRNO_RETURN(0, -1, "failed to wait_for_events() via default cascading engine"); + n += ret; + for (ssize_t i = 0; i < ret; ++i) { + auto j = (uint64_t)data[i]; + if (j >= (uint64_t)npds) + LOG_ERROR_RETURN(EOVERFLOW, -1, "reap event data overflow"); + pds[j].revents = pds[j].events; + } + if (ret == MAX) + goto again; + return n; +} diff --git a/thread/st.h b/thread/st.h new file mode 100644 index 00000000..b26ac209 --- /dev/null +++ b/thread/st.h @@ -0,0 +1,141 @@ +/* +Copyright 2022 The Photon Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef __STATE_THREADS__ +#define __STATE_THREADS__ +#include +#include +#include + +// This file provides an emulation layer for state threads (https://state-threads.sourceforge.net) +// #include +// #include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void * st_thread_t; +typedef void * st_cond_t; +typedef void * st_mutex_t; +typedef void * st_netfd_t; +typedef unsigned long long st_utime_t; +// typedef void (*st_switch_cb_t)(void); + + + +#define ST_EVENTSYS_DEFAULT 0 // epoll in Linux, kqueue in MacOSX +#define ST_EVENTSYS_SELECT 1 // epoll in Linux, kqueue in MacOSX +#define ST_EVENTSYS_POLL 2 // epoll in Linux, kqueue in MacOSX +#define ST_EVENTSYS_ALT 3 // epoll in Linux, kqueue in MacOSX +#define ST_EVENTSYS_IOURING 4 // io_uring in Linux, kqueue in MacOSX + +int st_set_eventsys(int eventsys); +int st_get_eventsys(void); +int st_init(void); +int st_getfdlimit(void); +const char *st_get_eventsys_name(void); +// st_switch_cb_t st_set_switch_in_cb(st_switch_cb_t cb); +// st_switch_cb_t st_set_switch_out_cb(st_switch_cb_t cb); + +st_thread_t st_thread_create(void *(*start)(void *arg), void *arg, + int joinable, int stack_size); +void st_thread_exit(void *retval); +int st_thread_join(st_thread_t thread, void **retvalp); +st_thread_t st_thread_self(void); +void st_thread_interrupt(st_thread_t thread); +int st_sleep(int secs); +int st_usleep(st_utime_t usecs); +int st_randomize_stacks(int on); +int st_key_create(int *keyp, void (*destructor)(void *)); +int st_key_getlimit(void); +int st_thread_setspecific(int key, void *value); +void *st_thread_getspecific(int key); + +// Synchronization +st_cond_t st_cond_new(void); +int st_cond_destroy(st_cond_t cvar); +int st_cond_wait(st_cond_t cvar); +int st_cond_timedwait(st_cond_t cvar, st_utime_t timeout); +int st_cond_signal(st_cond_t cvar); +int st_cond_broadcast(st_cond_t cvar); + +st_mutex_t st_mutex_new(void); +int st_mutex_destroy(st_mutex_t lock); +int st_mutex_lock(st_mutex_t lock); +int st_mutex_trylock(st_mutex_t lock); +int st_mutex_unlock(st_mutex_t lock); + +time_t st_time(void); +st_utime_t st_utime(void); +int st_set_utime_function(st_utime_t (*func)(void)); +int st_timecache_set(int on); + + +// I/O Functions +st_netfd_t st_netfd_open(int osfd); +st_netfd_t st_netfd_open_socket(int osfd); +void st_netfd_free(st_netfd_t fd); +int st_netfd_close(st_netfd_t fd); +int st_netfd_fileno(st_netfd_t fd); +void st_netfd_setspecific(st_netfd_t fd, void *value, + void (*destructor)(void *)); +void *st_netfd_getspecific(st_netfd_t fd); +int st_netfd_serialize_accept(st_netfd_t fd); +int st_netfd_poll(st_netfd_t fd, int how, st_utime_t timeout); +st_netfd_t st_accept(st_netfd_t fd, struct sockaddr *addr, int *addrlen, + st_utime_t timeout); +int st_connect(st_netfd_t fd, const struct sockaddr *addr, int addrlen, + st_utime_t timeout); +ssize_t st_read(st_netfd_t fd, void *buf, size_t nbyte, st_utime_t timeout); +ssize_t st_read_fully(st_netfd_t fd, void *buf, size_t nbyte, + st_utime_t timeout); +int st_read_resid(st_netfd_t fd, void *buf, size_t *resid, + st_utime_t timeout); +ssize_t st_readv(st_netfd_t fd, const struct iovec *iov, int iov_size, + st_utime_t timeout); +int st_readv_resid(st_netfd_t fd, struct iovec **iov, int *iov_size, + st_utime_t timeout); +ssize_t st_write(st_netfd_t fd, const void *buf, size_t nbyte, + st_utime_t timeout); +int st_write_resid(st_netfd_t fd, const void *buf, size_t *resid, + st_utime_t timeout); +ssize_t st_writev(st_netfd_t fd, const struct iovec *iov, int iov_size, + st_utime_t timeout); +ssize_t st_writev(st_netfd_t fd, const struct iovec *iov, int iov_size, + st_utime_t timeout); +int st_writev_resid(st_netfd_t fd, struct iovec **iov, int *iov_size, + st_utime_t timeout); +int st_recvfrom(st_netfd_t fd, void *buf, int len, struct sockaddr *from, + int *fromlen, st_utime_t timeout); +int st_sendto(st_netfd_t fd, const void *msg, int len, struct sockaddr *to, + int tolen, st_utime_t timeout); +int st_recvmsg(st_netfd_t fd, struct msghdr *msg, int flags, + st_utime_t timeout); +int st_sendmsg(st_netfd_t fd, const struct msghdr *msg, int flags, + st_utime_t timeout); +st_netfd_t st_open(const char *path, int oflags, mode_t mode); +int st_poll(struct pollfd *pds, int npds, st_utime_t timeout); + + + + +#ifdef __cplusplus +} +#endif + + +#endif \ No newline at end of file diff --git a/thread/test/CMakeLists.txt b/thread/test/CMakeLists.txt index 6e6a5d48..c31f1c27 100644 --- a/thread/test/CMakeLists.txt +++ b/thread/test/CMakeLists.txt @@ -44,4 +44,8 @@ add_test(NAME test-multi-vcpu-locking COMMAND $) \ No newline at end of file +add_test(NAME test-pooled-stack-allocator COMMAND $) + +add_executable(test-st-utest st_utest.cpp st_utest_tcp.cpp st_utest_coroutines.cpp) +target_link_libraries(test-st-utest PRIVATE photon_shared) +add_test(NAME test-st-utest COMMAND $) diff --git a/thread/test/st_utest.cpp b/thread/test/st_utest.cpp new file mode 100644 index 00000000..8562edb4 --- /dev/null +++ b/thread/test/st_utest.cpp @@ -0,0 +1,43 @@ +/* SPDX-License-Identifier: MIT */ +/* Copyright (c) 2013-2024 The SRS Authors */ + +#include "st_utest.hpp" + +// #include +#include + +std::ostream& operator<<(std::ostream& out, const ErrorObject* err) { + if (!err) return out; + if (err->r0_) out << "r0=" << err->r0_; + if (err->errno_) out << ", errno=" << err->errno_; + if (!err->message_.empty()) out << ", msg=" << err->message_; + return out; +} + +// We could do something in the main of utest. +// Copy from gtest-1.6.0/src/gtest_main.cc +GTEST_API_ int main(int argc, char **argv) { + // Select the best event system available on the OS. In Linux this is + // epoll(). On BSD it will be kqueue. On Cygwin it will be select. +#if __CYGWIN__ + assert(st_set_eventsys(ST_EVENTSYS_SELECT) != -1); +#else + assert(st_set_eventsys(ST_EVENTSYS_ALT) != -1); +#endif + + // Initialize state-threads, create idle coroutine. + assert(st_init() == 0); + + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +// basic test and samples. +VOID TEST(SampleTest, ExampleIntSizeTest) +{ + EXPECT_EQ(1, (int)sizeof(int8_t)); + EXPECT_EQ(2, (int)sizeof(int16_t)); + EXPECT_EQ(4, (int)sizeof(int32_t)); + EXPECT_EQ(8, (int)sizeof(int64_t)); +} + diff --git a/thread/test/st_utest.hpp b/thread/test/st_utest.hpp new file mode 100644 index 00000000..7e832e0f --- /dev/null +++ b/thread/test/st_utest.hpp @@ -0,0 +1,127 @@ +/* SPDX-License-Identifier: MIT */ +/* Copyright (c) 2013-2024 The SRS Authors */ + +#ifndef ST_UTEST_PUBLIC_HPP +#define ST_UTEST_PUBLIC_HPP + +// Before define the private/protected, we must include some system header files. +// Or it may fail with: +// redeclared with different access struct __xfer_bufptrs +// @see https://stackoverflow.com/questions/47839718/sstream-redeclared-with-public-access-compiler-error +#include + +#include +#include + +#define VOID + +// Close the fd automatically. +#define StFdCleanup(fd, stfd) impl__StFdCleanup _ST_free_##fd(&fd, &stfd) +#define StStfdCleanup(stfd) impl__StFdCleanup _ST_free_##stfd(NULL, &stfd) +class impl__StFdCleanup { + int* fd_; + st_netfd_t* stfd_; +public: + impl__StFdCleanup(int* fd, st_netfd_t* stfd) : fd_(fd), stfd_(stfd) { + } + virtual ~impl__StFdCleanup() { + if (stfd_ && *stfd_) { + st_netfd_close(*stfd_); + } else if (fd_ && *fd_ > 0) { + ::close(*fd_); + } + } +}; + +// For coroutine function to return with error object. +struct ErrorObject { + int r0_; + int errno_; + std::string message_; + + ErrorObject(int r0, std::string message) : r0_(r0), errno_(errno), message_(message) { + } +}; +extern std::ostream& operator<<(std::ostream& out, const ErrorObject* err); +#define ST_ASSERT_ERROR(error, r0, message) if (error) return new ErrorObject(r0, message) +#define ST_COROUTINE_JOIN(trd, r0) ErrorObject* r0 = NULL; SrsAutoFree(ErrorObject, r0); if (trd) st_thread_join(trd, (void**)&r0) +#define ST_EXPECT_SUCCESS(r0) EXPECT_TRUE(!r0) << r0 +#define ST_EXPECT_FAILED(r0) EXPECT_TRUE(r0) << r0 + +#include + +// To free the instance in the current scope, for instance, MyClass* ptr, +// which is a ptr and this class will: +// 1. free the ptr. +// 2. set ptr to NULL. +// +// Usage: +// MyClass* po = new MyClass(); +// // ...... use po +// SrsAutoFree(MyClass, po); +// +// Usage for array: +// MyClass** pa = new MyClass*[size]; +// // ....... use pa +// SrsAutoFreeA(MyClass*, pa); +// +// @remark the MyClass can be basic type, for instance, SrsAutoFreeA(char, pstr), +// where the char* pstr = new char[size]. +// To delete object. +#define SrsAutoFree(className, instance) \ + impl_SrsAutoFree _auto_free_##instance(&instance, false, false, NULL) +// To delete array. +#define SrsAutoFreeA(className, instance) \ + impl_SrsAutoFree _auto_free_array_##instance(&instance, true, false, NULL) +// Use free instead of delete. +#define SrsAutoFreeF(className, instance) \ + impl_SrsAutoFree _auto_free_##instance(&instance, false, true, NULL) +// Use hook instead of delete. +#define SrsAutoFreeH(className, instance, hook) \ + impl_SrsAutoFree _auto_free_##instance(&instance, false, false, hook) +// The template implementation. +template +class impl_SrsAutoFree +{ +private: + T** ptr; + bool is_array; + bool _use_free; + void (*_hook)(T*); +public: + // If use_free, use free(void*) to release the p. + // If specified hook, use hook(p) to release it. + // Use delete to release p, or delete[] if p is an array. + impl_SrsAutoFree(T** p, bool array, bool use_free, void (*hook)(T*)) { + ptr = p; + is_array = array; + _use_free = use_free; + _hook = hook; + } + + virtual ~impl_SrsAutoFree() { + if (ptr == NULL || *ptr == NULL) { + return; + } + + if (_use_free) { + free(*ptr); + } else if (_hook) { + _hook(*ptr); + } else { + if (is_array) { + delete[] *ptr; + } else { + delete *ptr; + } + } + + *ptr = NULL; + } +}; + +// The time unit in ms, for example 100 * SRS_UTIME_MILLISECONDS means 100ms. +#define SRS_UTIME_MILLISECONDS 1000 + +#endif + diff --git a/thread/test/st_utest_coroutines.cpp b/thread/test/st_utest_coroutines.cpp new file mode 100644 index 00000000..f9f2d24d --- /dev/null +++ b/thread/test/st_utest_coroutines.cpp @@ -0,0 +1,120 @@ +/* SPDX-License-Identifier: MIT */ +/* Copyright (c) 2013-2024 The SRS Authors */ + +#include "st_utest.hpp" + +// #include + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// The utest for empty coroutine. +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void* coroutine(void* /*arg*/) +{ + st_usleep(0); + return NULL; +} + +VOID TEST(CoroutineTest, StartCoroutine) +{ + st_thread_t trd = st_thread_create(coroutine, NULL, 1, 0); + EXPECT_TRUE(trd != NULL); + + // Wait for joinable coroutine to quit. + st_thread_join(trd, NULL); +} + +VOID TEST(CoroutineTest, StartCoroutineX3) +{ + st_thread_t trd0 = st_thread_create(coroutine, NULL, 1, 0); + st_thread_t trd1 = st_thread_create(coroutine, NULL, 1, 0); + st_thread_t trd2 = st_thread_create(coroutine, NULL, 1, 0); + EXPECT_TRUE(trd0 != NULL && trd1 != NULL && trd2 != NULL); + + // Wait for joinable coroutine to quit. + st_thread_join(trd1, NULL); + st_thread_join(trd2, NULL); + st_thread_join(trd0, NULL); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// The utest for adding coroutine. +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void* coroutine_add(void* arg) +{ + int v = 0; + int* pi = (int*)arg; + + // Load the change of arg. + while (v != *pi) { + v = *pi; + st_usleep(0); + } + + // Add with const. + v += 100; + *pi = v; + + return NULL; +} + +VOID TEST(CoroutineTest, StartCoroutineAdd) +{ + int v = 0; + st_thread_t trd = st_thread_create(coroutine_add, &v, 1, 0); + EXPECT_TRUE(trd != NULL); + + // Wait for joinable coroutine to quit. + st_thread_join(trd, NULL); + + EXPECT_EQ(100, v); +} + +VOID TEST(CoroutineTest, StartCoroutineAddX3) +{ + int v = 0; + st_thread_t trd0 = st_thread_create(coroutine_add, &v, 1, 0); + st_thread_t trd1 = st_thread_create(coroutine_add, &v, 1, 0); + st_thread_t trd2 = st_thread_create(coroutine_add, &v, 1, 0); + EXPECT_TRUE(trd0 != NULL && trd1 != NULL && trd2 != NULL); + + // Wait for joinable coroutine to quit. + st_thread_join(trd0, NULL); + st_thread_join(trd1, NULL); + st_thread_join(trd2, NULL); + + EXPECT_EQ(300, v); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// The utest for output params coroutine. +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +int coroutine_params_x4(int a, int b, int c, int d) +{ + int e = 0; + + st_usleep(0); + + e += a + b + c + d; + e += 100; + return e; +} + +void* coroutine_params(void* arg) +{ + int r0 = coroutine_params_x4(1, 2, 3, 4); + *(int*)arg = r0; + return NULL; +} + +VOID TEST(CoroutineTest, StartCoroutineParams) +{ + int r0 = 0; + st_thread_t trd = st_thread_create(coroutine_params, &r0, 1, 0); + EXPECT_TRUE(trd != NULL); + + // Wait for joinable coroutine to quit. + st_thread_join(trd, NULL); + + EXPECT_EQ(110, r0); +} + diff --git a/thread/test/st_utest_tcp.cpp b/thread/test/st_utest_tcp.cpp new file mode 100644 index 00000000..52fe38a3 --- /dev/null +++ b/thread/test/st_utest_tcp.cpp @@ -0,0 +1,92 @@ +/* SPDX-License-Identifier: MIT */ +/* Copyright (c) 2013-2024 The SRS Authors */ + +#include "st_utest.hpp" + +// #include +#include + +#include +#include +#include + +#define ST_UTEST_PORT 26878 +#define ST_UTEST_TIMEOUT (100 * SRS_UTIME_MILLISECONDS) + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// The utest for ping-pong TCP server coroutine. +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void* tcp_server(void* /*arg*/) +{ + int fd = -1; + st_netfd_t stfd = NULL; + StFdCleanup(fd, stfd); + + fd = socket(AF_INET, SOCK_STREAM, 0); + ST_ASSERT_ERROR(fd == -1, fd, "Create socket"); + + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(INADDR_ANY); + addr.sin_port = htons(ST_UTEST_PORT); + + int v = 1; + int r0 = setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &v, sizeof(int)); + ST_ASSERT_ERROR(r0, r0, "Set SO_REUSEADDR"); + + r0 = ::bind(fd, (const sockaddr*)&addr, sizeof(addr)); + ST_ASSERT_ERROR(r0, r0, "Bind socket"); + + r0 = ::listen(fd, 10); + ST_ASSERT_ERROR(r0, r0, "Listen socket"); + + stfd = st_netfd_open_socket(fd); + ST_ASSERT_ERROR(!stfd, fd, "Open ST socket"); + + st_netfd_t client = NULL; + StStfdCleanup(client); + + client = st_accept(stfd, NULL, NULL, ST_UTEST_TIMEOUT); + ST_ASSERT_ERROR(!client, fd, "Accept client"); + + return NULL; +} + +void* tcp_client(void* /*arg*/) +{ + int fd = -1; + st_netfd_t stfd = NULL; + StFdCleanup(fd, stfd); + + fd = socket(AF_INET, SOCK_STREAM, 0); + ST_ASSERT_ERROR(fd == -1, fd, "Create socket"); + + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = inet_addr("127.0.0.1"); + addr.sin_port = htons(ST_UTEST_PORT); + + stfd = st_netfd_open_socket(fd); + ST_ASSERT_ERROR(!stfd, fd, "Open ST socket"); + + int r0 = st_connect(stfd, (const sockaddr*)&addr, sizeof(addr), ST_UTEST_TIMEOUT); + ST_ASSERT_ERROR(r0, r0, "Connect to server"); + + return NULL; +} + +VOID TEST(TcpTest, TcpConnection) +{ + st_thread_t svr = st_thread_create(tcp_server, NULL, 1, 0); + EXPECT_TRUE(svr != NULL); + + st_thread_t client = st_thread_create(tcp_client, NULL, 1, 0); + EXPECT_TRUE(client != NULL); + + ST_COROUTINE_JOIN(svr, r0); + ST_COROUTINE_JOIN(client, r1); + + ST_EXPECT_SUCCESS(r0); + ST_EXPECT_SUCCESS(r1); +} + diff --git a/thread/thread.cpp b/thread/thread.cpp index 4518502a..55b6d51b 100644 --- a/thread/thread.cpp +++ b/thread/thread.cpp @@ -268,13 +268,6 @@ namespace photon #endif } - void go() { - assert(this == CURRENT); - auto _arg = arg; - arg = nullptr; - retval = start(_arg); - die(); - } void die() __attribute__((always_inline)); void dequeue_ready_atomic(states newstat = states::READY); vcpu_t* get_vcpu() { @@ -874,9 +867,11 @@ R"( #endif // x86 or arm - extern "C" void _photon_switch_context_defer_die(void* arg,uint64_t defer_func_addr, void** to) - asm ("_photon_switch_context_defer_die"); + extern "C" __attribute__((noreturn)) + void _photon_switch_context_defer_die(void* arg, uint64_t defer_func_addr, + void** to) asm ("_photon_switch_context_defer_die"); + __attribute__((noreturn)) inline void thread::die() { deallocate_tls(&tls); // if CURRENT is idle stub and during vcpu_fini @@ -902,7 +897,7 @@ R"( _photon_switch_context_defer_die( arg, func, sw.to->stack.pointer_ref()); } - __attribute__((used)) static + static __attribute__((used, noreturn)) void _photon_thread_die(thread* th) { assert(th == CURRENT); th->die(); @@ -1389,22 +1384,28 @@ R"( return (join_handle*)th; } - void thread_join(join_handle* jh) - { + void* thread_join(join_handle* jh) { auto th = (thread*)jh; + assert(th->is_joinable()); if (!th->is_joinable()) - LOG_ERROR_RETURN(ENOSYS, , "join is not enabled for thread ", th); + LOG_ERROR_RETURN(ENOSYS, nullptr, "join is not enabled for thread ", th); th->lock.lock(); while (th->state != states::DONE) { th->cond.wait(th->lock); } + auto retval = th->retval; th->dispose(); + return retval; } inline void thread_join(thread* th) { thread_join((join_handle*)th); } + void thread_exit(void* retval) { + CURRENT->retval = retval; + _photon_thread_die(CURRENT); + } int thread_shutdown(thread* th, bool flag) { diff --git a/thread/thread.h b/thread/thread.h index 93b78d6e..e99999eb 100644 --- a/thread/thread.h +++ b/thread/thread.h @@ -75,7 +75,10 @@ namespace photon // Failing to do so will cause resource leak. struct join_handle; join_handle* thread_enable_join(thread* th, bool flag = true); - void thread_join(join_handle* jh); + void* thread_join(join_handle* jh); + + // terminates CURRENT with return value `retval` + void thread_exit(void* retval) __attribute__((noreturn)); // switching to other threads (without going into sleep queue) // return error_number if interrupted during the rolling