diff --git a/ports/esp32/modsocket.c b/ports/esp32/modsocket.c index fe76b5f7601c..cacc52dec568 100644 --- a/ports/esp32/modsocket.c +++ b/ports/esp32/modsocket.c @@ -82,6 +82,8 @@ typedef struct _socket_obj_t { #endif } socket_obj_t; +STATIC const char *TAG = "modsocket"; + void _socket_settimeout(socket_obj_t *sock, uint64_t timeout_ms); #if MICROPY_PY_SOCKET_EVENTS @@ -364,16 +366,96 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_1(socket_accept_obj, socket_accept); STATIC mp_obj_t socket_connect(const mp_obj_t arg0, const mp_obj_t arg1) { socket_obj_t *self = MP_OBJ_TO_PTR(arg0); struct addrinfo *res; + bool blocking = false; + int flags; + int raise_err = 0; + _socket_getaddrinfo(arg1, &res); MP_THREAD_GIL_EXIT(); self->state = SOCKET_STATE_CONNECTED; - int r = lwip_connect(self->fd, res->ai_addr, res->ai_addrlen); - MP_THREAD_GIL_ENTER(); + + flags = fcntl(self->fd, F_GETFL); + + blocking = (flags & O_NONBLOCK) == 0; + + if (blocking) { + // For blocking sockets, make the socket temporarily non-blocking and emulate + // blocking using select. + // + // This has two benefits: + // + // - Allows handling external exceptions while waiting for connect. + // + // - Allows emulating a connect timeout, which is not supported by LWIP or + // required by POSIX but is normal behaviour for CPython. + if (fcntl(self->fd, F_SETFL, flags | O_NONBLOCK) < 0) { + ESP_LOGE(TAG, "fcntl set failed %d", errno); // Unexpected internal failure + raise_err = errno; + } + } + + if (raise_err == 0) { + // Try performing the actual connect. Expected to always return immediately. + int r = lwip_connect(self->fd, res->ai_addr, res->ai_addrlen); + if (r != 0) { + raise_err = errno; + } + } + + if (blocking) { + // Set the socket back to blocking. We can still pass it to select() in this state. + int r = fcntl(self->fd, F_SETFL, flags); + if (r != 0 && (raise_err == 0 || raise_err == EINPROGRESS)) { + ESP_LOGE(TAG, "fcntl restore failed %d", errno); // Unexpected internal failure + raise_err = errno; + } + } + lwip_freeaddrinfo(res); - if (r != 0) { - mp_raise_OSError(errno); + + if (blocking && raise_err == EINPROGRESS) { + // Keep calling select() until the socket is marked writable (i.e. connected), + // or an error or a timeout occurs + + // Note: _socket_settimeout() always sets self->retries != 0 on blocking sockets. + + for (unsigned int i = 0; i <= self->retries; i++) { + struct timeval timeout = { + .tv_sec = 0, + .tv_usec = SOCKET_POLL_US, + }; + fd_set wfds; + FD_ZERO(&wfds); + FD_SET(self->fd, &wfds); + + int r = select(self->fd + 1, NULL, &wfds, NULL, &timeout); + if (r < 0) { + // Error condition + raise_err = errno; + break; + } else if (r > 0) { + // Select indicated the socket is writable. Check for any error. + socklen_t socklen = sizeof(raise_err); + r = getsockopt(self->fd, SOL_SOCKET, SO_ERROR, &raise_err, &socklen); + if (r < 0) { + raise_err = errno; + } + break; + } else { + // Select timed out + raise_err = ETIMEDOUT; + + MP_THREAD_GIL_ENTER(); + check_for_exceptions(); + MP_THREAD_GIL_EXIT(); + } + } } + MP_THREAD_GIL_ENTER(); + if (raise_err) { + mp_raise_OSError(raise_err); + } return mp_const_none; } STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_connect_obj, socket_connect); @@ -847,7 +929,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(esp_socket_getaddrinfo_obj, 2, 6, esp STATIC mp_obj_t esp_socket_initialize() { static int initialized = 0; if (!initialized) { - ESP_LOGI("modsocket", "Initializing"); + ESP_LOGI(TAG, "Initializing"); esp_netif_init(); initialized = 1; }