From 4081c6be48ea51fa658bcd2875ea482e77e474b3 Mon Sep 17 00:00:00 2001 From: pennam Date: Mon, 4 Nov 2024 09:56:02 +0100 Subject: [PATCH] Fix and simplify sni setting --- libraries/SocketWrapper/src/AClient.cpp | 4 ++-- libraries/SocketWrapper/src/AClient.h | 2 +- libraries/SocketWrapper/src/MbedClient.cpp | 10 +--------- libraries/SocketWrapper/src/MbedClient.h | 2 +- libraries/SocketWrapper/src/MbedSSLClient.h | 8 +++++++- 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/libraries/SocketWrapper/src/AClient.cpp b/libraries/SocketWrapper/src/AClient.cpp index 272b18946..9ffa9137a 100644 --- a/libraries/SocketWrapper/src/AClient.cpp +++ b/libraries/SocketWrapper/src/AClient.cpp @@ -46,11 +46,11 @@ int arduino::AClient::connectSSL(IPAddress ip, uint16_t port) { return client->connectSSL(ip, port); } -int arduino::AClient::connectSSL(const char *host, uint16_t port, bool disableSNI) { +int arduino::AClient::connectSSL(const char *host, uint16_t port) { if (!client) { newMbedClient(); } - return client->connectSSL(host, port, disableSNI); + return client->connectSSL(host, port); } void arduino::AClient::stop() { diff --git a/libraries/SocketWrapper/src/AClient.h b/libraries/SocketWrapper/src/AClient.h index 195f7a1f2..4f72020ee 100644 --- a/libraries/SocketWrapper/src/AClient.h +++ b/libraries/SocketWrapper/src/AClient.h @@ -32,7 +32,7 @@ class AClient : public Client { virtual int connect(IPAddress ip, uint16_t port); virtual int connect(const char *host, uint16_t port); int connectSSL(IPAddress ip, uint16_t port); - int connectSSL(const char* host, uint16_t port, bool disableSNI = false); + int connectSSL(const char* host, uint16_t port); virtual void stop(); virtual explicit operator bool(); diff --git a/libraries/SocketWrapper/src/MbedClient.cpp b/libraries/SocketWrapper/src/MbedClient.cpp index f50ff18b7..296b64943 100644 --- a/libraries/SocketWrapper/src/MbedClient.cpp +++ b/libraries/SocketWrapper/src/MbedClient.cpp @@ -186,15 +186,7 @@ int arduino::MbedClient::connectSSL(IPAddress ip, uint16_t port) { return connectSSL(SocketHelpers::socketAddressFromIpAddress(ip, port)); } -int arduino::MbedClient::connectSSL(const char *host, uint16_t port, bool disableSNI) { - if (!disableSNI) { - if (sock == nullptr) { - sock = new TLSSocket(); - _own_socket = true; - } - static_cast(sock)->set_hostname(host); - } - +int arduino::MbedClient::connectSSL(const char *host, uint16_t port) { SocketAddress socketAddress = SocketAddress(); socketAddress.set_port(port); SocketHelpers::gethostbyname(getNetwork(), host, &socketAddress); diff --git a/libraries/SocketWrapper/src/MbedClient.h b/libraries/SocketWrapper/src/MbedClient.h index 2a6777af4..573c4d5b1 100644 --- a/libraries/SocketWrapper/src/MbedClient.h +++ b/libraries/SocketWrapper/src/MbedClient.h @@ -56,7 +56,7 @@ class MbedClient { virtual int connect(const char* host, uint16_t port); int connectSSL(SocketAddress socketAddress); int connectSSL(IPAddress ip, uint16_t port); - int connectSSL(const char* host, uint16_t port, bool disableSNI = false); + int connectSSL(const char* host, uint16_t port); size_t write(uint8_t); size_t write(const uint8_t* buf, size_t size); int available(); diff --git a/libraries/SocketWrapper/src/MbedSSLClient.h b/libraries/SocketWrapper/src/MbedSSLClient.h index 34f4d583d..67c496f5c 100644 --- a/libraries/SocketWrapper/src/MbedSSLClient.h +++ b/libraries/SocketWrapper/src/MbedSSLClient.h @@ -41,7 +41,8 @@ class MbedSSLClient : public arduino::MbedClient { return connectSSL(ip, port); } int connect(const char* host, uint16_t port) { - return connectSSL(host, port, _disableSNI); + _hostname = host; + return connectSSL(host, port); } void disableSNI(bool statusSNI) { _disableSNI = statusSNI; @@ -53,6 +54,7 @@ class MbedSSLClient : public arduino::MbedClient { protected: const char* _ca_cert_custom = NULL; + const char* _hostname = NULL; private: int setRootCA() { @@ -79,6 +81,10 @@ class MbedSSLClient : public arduino::MbedClient { } #endif + if(_hostname && !_disableSNI) { + ((TLSSocket*)sock)->set_hostname(_hostname); + } + if(_ca_cert_custom != NULL) { err = ((TLSSocket*)sock)->append_root_ca_cert(_ca_cert_custom); }