diff --git a/pdns/dnsdistdist/dnsdist.cc b/pdns/dnsdistdist/dnsdist.cc index 4de6f6eae5cb..80ee94299d8a 100644 --- a/pdns/dnsdistdist/dnsdist.cc +++ b/pdns/dnsdistdist/dnsdist.cc @@ -157,34 +157,25 @@ static constexpr size_t s_maxUDPResponsePacketSize{4096U}; static size_t const s_initialUDPPacketBufferSize = s_maxUDPResponsePacketSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE; static_assert(s_initialUDPPacketBufferSize <= UINT16_MAX, "Packet size should fit in a uint16_t"); -static ssize_t sendfromto(int sock, const void* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& dest) +static void sendfromto(int sock, const PacketBuffer& buffer, const ComboAddress& from, const ComboAddress& dest) { + const int flags = 0; if (from.sin4.sin_family == 0) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - return sendto(sock, data, len, flags, reinterpret_cast(&dest), dest.getSocklen()); - } - msghdr msgh{}; - iovec iov{}; - cmsgbuf_aligned cbuf; - - /* Set up iov and msgh structures. */ - memset(&msgh, 0, sizeof(struct msghdr)); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast): it's the API - iov.iov_base = const_cast(data); - iov.iov_len = len; - msgh.msg_iov = &iov; - msgh.msg_iovlen = 1; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-type-const-cast) - msgh.msg_name = const_cast(reinterpret_cast(&dest)); - msgh.msg_namelen = dest.getSocklen(); - - if (from.sin4.sin_family != 0) { - addCMsgSrcAddr(&msgh, &cbuf, &from, 0); + auto ret = sendto(sock, buffer.data(), buffer.size(), flags, reinterpret_cast(&dest), dest.getSocklen()); + if (ret == -1) { + int error = errno; + vinfolog("Error sending UDP response to %s: %s", dest.toStringWithPort(), stringerror(error)); + } + return; } - else { - msgh.msg_control = nullptr; + + try { + sendMsgWithOptions(sock, buffer.data(), buffer.size(), &dest, &from, 0, 0); + } + catch (const std::exception& exp) { + vinfolog("Error sending UDP response from %s to %s: %s", from.toStringWithPort(), dest.toStringWithPort(), exp.what()); } - return sendmsg(sock, &msgh, flags); } static void truncateTC(PacketBuffer& packet, size_t maximumSize, unsigned int qnameWireLength) @@ -223,13 +214,9 @@ struct DelayedPacket PacketBuffer packet; ComboAddress destination; ComboAddress origDest; - void operator()() + void operator()() const { - ssize_t res = sendfromto(fd, packet.data(), packet.size(), 0, origDest, destination); - if (res == -1) { - int err = errno; - vinfolog("Error sending delayed response to %s: %s", destination.toStringWithPort(), stringerror(err)); - } + sendfromto(fd, packet, origDest, destination); } }; @@ -667,12 +654,7 @@ bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMs } #endif /* DISABLE_DELAY_PIPE */ // NOLINTNEXTLINE(readability-suspicious-call-argument) - ssize_t res = sendfromto(origFD, response.data(), response.size(), 0, origDest, origRemote); - if (res == -1) { - int err = errno; - vinfolog("Error sending response to %s: %s", origRemote.toStringWithPort(), stringerror(err)); - } - + sendfromto(origFD, response, origDest, origRemote); return true; } diff --git a/pdns/dnsdistdist/doh3.cc b/pdns/dnsdistdist/doh3.cc index 26b3cf5686ae..7ff4748f1551 100644 --- a/pdns/dnsdistdist/doh3.cc +++ b/pdns/dnsdistdist/doh3.cc @@ -54,8 +54,8 @@ using h3_headers_t = std::map; class H3Connection { public: - H3Connection(const ComboAddress& peer, QuicheConfig config, QuicheConnection&& conn) : - d_peer(peer), d_conn(std::move(conn)), d_config(std::move(config)) + H3Connection(const ComboAddress& peer, const ComboAddress& localAddr, QuicheConfig config, QuicheConnection&& conn) : + d_peer(peer), d_localAddr(localAddr), d_conn(std::move(conn)), d_config(std::move(config)) { } H3Connection(const H3Connection&) = delete; @@ -65,6 +65,7 @@ class H3Connection ~H3Connection() = default; ComboAddress d_peer; + ComboAddress d_localAddr; QuicheConnection d_conn; QuicheConfig d_config; QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free}; @@ -421,14 +422,14 @@ static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description) } } -static std::optional> createConnection(DOH3ServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& local, const ComboAddress& peer) +static std::optional> createConnection(DOH3ServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& localAddr, const ComboAddress& peer) { auto quicheConfig = std::atomic_load_explicit(&config.config, std::memory_order_acquire); auto quicheConn = QuicheConnection(quiche_accept(serverSideID.data(), serverSideID.size(), originalDestinationID.data(), originalDestinationID.size(), // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - reinterpret_cast(&local), - local.getSocklen(), + reinterpret_cast(&localAddr), + localAddr.getSocklen(), // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) reinterpret_cast(&peer), peer.getSocklen(), @@ -439,7 +440,7 @@ static std::optional> createConnection(DOH3 quiche_conn_set_keylog_path(quicheConn.get(), config.df->d_quicheParams.d_keyLogFile.c_str()); } - auto conn = H3Connection(peer, std::move(quicheConfig), std::move(quicheConn)); + auto conn = H3Connection(peer, localAddr, std::move(quicheConfig), std::move(quicheConn)); auto pair = config.d_connections.emplace(serverSideID, std::move(conn)); return pair.first->second; } @@ -743,7 +744,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten return; } DEBUGLOG("Dispatching GET query"); - doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), clientState.local, client, serverConnID, streamID); + doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID); conn.d_streamBuffers.erase(streamID); conn.d_headersBuffers.erase(streamID); return; @@ -808,7 +809,7 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend, } DEBUGLOG("Dispatching POST query"); - doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID); + doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID); conn.d_headersBuffers.erase(streamID); conn.d_streamBuffers.erase(streamID); } @@ -856,10 +857,21 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat PacketBuffer tokenBuf; while (true) { ComboAddress client; + ComboAddress localAddr; + client.sin4.sin_family = clientState.local.sin4.sin_family; + localAddr.sin4.sin_family = clientState.local.sin4.sin_family; buffer.resize(4096); - if (!sock.recvFromAsync(buffer, client) || buffer.empty()) { + if (!dnsdist::doq::recvAsync(sock, buffer, client, localAddr)) { return; } + if (localAddr.sin4.sin_family == 0) { + localAddr = clientState.local; + } + else { + /* we don't get the port, only the address */ + localAddr.sin4.sin_port = clientState.local.sin4.sin_port; + } + DEBUGLOG("Received DoH3 datagram of size " << buffer.size() << " from " << client.toStringWithPort()); uint32_t version{0}; @@ -896,14 +908,14 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat if (!quiche_version_is_supported(version)) { DEBUGLOG("Unsupported version"); ++frontend.d_doh3UnsupportedVersionErrors; - handleVersionNegociation(sock, clientConnID, serverConnID, client, buffer); + handleVersionNegociation(sock, clientConnID, serverConnID, client, localAddr, buffer); continue; } if (token_len == 0) { /* stateless retry */ DEBUGLOG("No token received"); - handleStatelessRetry(sock, clientConnID, serverConnID, client, version, buffer); + handleStatelessRetry(sock, clientConnID, serverConnID, client, localAddr, version, buffer); continue; } @@ -916,7 +928,7 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat } DEBUGLOG("Creating a new connection"); - conn = createConnection(*frontend.d_server_config, serverConnID, *originalDestinationID, clientState.local, client); + conn = createConnection(*frontend.d_server_config, serverConnID, *originalDestinationID, localAddr, client); if (!conn) { continue; } @@ -927,8 +939,8 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat reinterpret_cast(&client), client.getSocklen(), // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - reinterpret_cast(&clientState.local), - clientState.local.getSocklen(), + reinterpret_cast(&localAddr), + localAddr.getSocklen(), }; auto done = quiche_conn_recv(conn->get().d_conn.get(), buffer.data(), buffer.size(), &recv_info); @@ -950,7 +962,7 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat processH3Events(clientState, frontend, conn->get(), client, serverConnID, buffer); - flushEgress(sock, conn->get().d_conn, client, buffer); + flushEgress(sock, conn->get().d_conn, client, localAddr, buffer); } else { DEBUGLOG("Connection not established"); @@ -995,7 +1007,7 @@ void doh3Thread(ClientState* clientState) for (auto conn = frontend->d_server_config->d_connections.begin(); conn != frontend->d_server_config->d_connections.end();) { quiche_conn_on_timeout(conn->second.d_conn.get()); - flushEgress(sock, conn->second.d_conn, conn->second.d_peer, buffer); + flushEgress(sock, conn->second.d_conn, conn->second.d_peer, conn->second.d_localAddr, buffer); if (quiche_conn_is_closed(conn->second.d_conn.get())) { #ifdef DEBUGLOG_ENABLED diff --git a/pdns/dnsdistdist/doq-common.cc b/pdns/dnsdistdist/doq-common.cc index e92ccffdea4e..bb79ddc21849 100644 --- a/pdns/dnsdistdist/doq-common.cc +++ b/pdns/dnsdistdist/doq-common.cc @@ -126,7 +126,28 @@ std::optional validateToken(const PacketBuffer& token, const Combo } } -void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version, PacketBuffer& buffer) +static void sendFromTo(Socket& sock, const ComboAddress& peer, const ComboAddress& local, PacketBuffer& buffer) +{ + const int flags = 0; + if (local.sin4.sin_family == 0) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + auto ret = sendto(sock.getHandle(), buffer.data(), buffer.size(), flags, reinterpret_cast(&peer), peer.getSocklen()); + if (ret < 0) { + auto error = errno; + vinfolog("Error while sending QUIC datagram of size %d to %s: %s", buffer.size(), peer.toStringWithPort(), stringerror(error)); + } + return; + } + + try { + sendMsgWithOptions(sock.getHandle(), buffer.data(), buffer.size(), &peer, &local, 0, 0); + } + catch (const std::exception& exp) { + vinfolog("Error while sending QUIC datagram of size %d from %s to %s: %s", buffer.size(), local.toStringWithPort(), peer.toStringWithPort(), exp.what()); + } +} + +void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, const ComboAddress& localAddr, uint32_t version, PacketBuffer& buffer) { auto newServerConnID = getCID(); if (!newServerConnID) { @@ -148,11 +169,11 @@ void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const return; } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - sock.sendTo(reinterpret_cast(buffer.data()), static_cast(written), peer); + buffer.resize(static_cast(written)); + sendFromTo(sock, peer, localAddr, buffer); } -void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, PacketBuffer& buffer) +void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer) { buffer.resize(MAX_DATAGRAM_SIZE); @@ -164,11 +185,12 @@ void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, co DEBUGLOG("failed to create vneg packet " << written); return; } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - sock.sendTo(reinterpret_cast(buffer.data()), static_cast(written), peer); + + buffer.resize(static_cast(written)); + sendFromTo(sock, peer, localAddr, buffer); } -void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, PacketBuffer& buffer) +void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer) { buffer.resize(MAX_DATAGRAM_SIZE); quiche_send_info send_info; @@ -183,8 +205,8 @@ void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, return; } // FIXME pacing (as send_info.at should tell us when to send the packet) ? - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - sock.sendTo(reinterpret_cast(buffer.data()), static_cast(written), peer); + buffer.resize(static_cast(written)); + sendFromTo(sock, peer, localAddr, buffer); } } @@ -258,6 +280,51 @@ void configureQuiche(QuicheConfig& config, const QuicheParams& params, bool isHT } } +bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, ComboAddress& localAddr) +{ + msghdr msgh{}; + iovec iov{}; + /* used by HarvestDestinationAddress */ + cmsgbuf_aligned cbuf; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), reinterpret_cast(&buffer.at(0)), buffer.size(), &clientAddr); + + ssize_t got = recvmsg(socket.getHandle(), &msgh, 0); + if (got < 0) { + int error = errno; + if (error != EAGAIN) { + throw NetworkError("Error in recvmsg: " + stringerror(error)); + } + return false; + } + + if ((msgh.msg_flags & MSG_TRUNC) != 0) { + return false; + } + + buffer.resize(static_cast(got)); + + if (HarvestDestinationAddress(&msgh, &localAddr)) { + /* so it turns out that sometimes the kernel lies to us: + the address is set to 0.0.0.0:0 which makes our sendfromto() use + the wrong address. In that case it's better to let the kernel + do the work by itself and use sendto() instead. + This is indicated by setting the family to 0 which is acted upon + in sendUDPResponse() and DelayedPacket::(). + */ + const ComboAddress bogusV4("0.0.0.0:0"); + const ComboAddress bogusV6("[::]:0"); + if ((localAddr.sin4.sin_family == AF_INET && localAddr == bogusV4) || (localAddr.sin4.sin_family == AF_INET6 && localAddr == bogusV6)) { + localAddr.sin4.sin_family = 0; + } + } + else { + localAddr.sin4.sin_family = 0; + } + + return !buffer.empty(); +} + }; #endif diff --git a/pdns/dnsdistdist/doq-common.hh b/pdns/dnsdistdist/doq-common.hh index d2222c683382..9b04e4c83581 100644 --- a/pdns/dnsdistdist/doq-common.hh +++ b/pdns/dnsdistdist/doq-common.hh @@ -92,10 +92,11 @@ void fillRandom(PacketBuffer& buffer, size_t size); std::optional getCID(); PacketBuffer mintToken(const PacketBuffer& dcid, const ComboAddress& peer); std::optional validateToken(const PacketBuffer& token, const ComboAddress& peer); -void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version, PacketBuffer& buffer); -void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, PacketBuffer& buffer); -void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, PacketBuffer& buffer); +void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, const ComboAddress& localAddr, uint32_t version, PacketBuffer& buffer); +void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer); +void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer); void configureQuiche(QuicheConfig& config, const QuicheParams& params, bool isHTTP); +bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, ComboAddress& localAddr); }; diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index b6a3da6cb7a0..e757b1a96883 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -51,8 +51,8 @@ using namespace dnsdist::doq; class Connection { public: - Connection(const ComboAddress& peer, QuicheConfig config, QuicheConnection conn) : - d_peer(peer), d_conn(std::move(conn)), d_config(std::move(config)) + Connection(const ComboAddress& peer, const ComboAddress& localAddr, QuicheConfig config, QuicheConnection conn) : + d_peer(peer), d_localAddr(localAddr), d_conn(std::move(conn)), d_config(std::move(config)) { } Connection(const Connection&) = delete; @@ -62,6 +62,7 @@ class Connection ~Connection() = default; ComboAddress d_peer; + ComboAddress d_localAddr; QuicheConnection d_conn; QuicheConfig d_config; @@ -338,14 +339,14 @@ static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description) } } -static std::optional> createConnection(DOQServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& local, const ComboAddress& peer) +static std::optional> createConnection(DOQServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& peer, const ComboAddress& localAddr) { auto quicheConfig = std::atomic_load_explicit(&config.config, std::memory_order_acquire); auto quicheConn = QuicheConnection(quiche_accept(serverSideID.data(), serverSideID.size(), originalDestinationID.data(), originalDestinationID.size(), // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - reinterpret_cast(&local), - local.getSocklen(), + reinterpret_cast(&localAddr), + localAddr.getSocklen(), // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) reinterpret_cast(&peer), peer.getSocklen(), @@ -356,7 +357,7 @@ static std::optional> createConnection(DOQSer quiche_conn_set_keylog_path(quicheConn.get(), config.df->d_quicheParams.d_keyLogFile.c_str()); } - auto conn = Connection(peer, std::move(quicheConfig), std::move(quicheConn)); + auto conn = Connection(peer, localAddr, std::move(quicheConfig), std::move(quicheConn)); auto pair = config.d_connections.emplace(serverSideID, std::move(conn)); return pair.first->second; } @@ -641,7 +642,7 @@ static void handleReadableStream(DOQFrontend& frontend, ClientState& clientState return; } DEBUGLOG("Dispatching query"); - doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID); + doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID); conn.d_streamBuffers.erase(streamID); } @@ -654,10 +655,21 @@ static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState PacketBuffer tokenBuf; while (true) { ComboAddress client; + ComboAddress localAddr; + client.sin4.sin_family = clientState.local.sin4.sin_family; + localAddr.sin4.sin_family = clientState.local.sin4.sin_family; buffer.resize(4096); - if (!sock.recvFromAsync(buffer, client) || buffer.empty()) { + if (!dnsdist::doq::recvAsync(sock, buffer, client, localAddr)) { return; } + if (localAddr.sin4.sin_family == 0) { + localAddr = clientState.local; + } + else { + /* we don't get the port, only the address */ + localAddr.sin4.sin_port = clientState.local.sin4.sin_port; + } + DEBUGLOG("Received DoQ datagram of size " << buffer.size() << " from " << client.toStringWithPort()); uint32_t version{0}; @@ -693,14 +705,14 @@ static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState if (!quiche_version_is_supported(version)) { DEBUGLOG("Unsupported version"); ++frontend.d_doqUnsupportedVersionErrors; - handleVersionNegociation(sock, clientConnID, serverConnID, client, buffer); + handleVersionNegociation(sock, clientConnID, serverConnID, client, localAddr, buffer); continue; } if (token_len == 0) { /* stateless retry */ DEBUGLOG("No token received"); - handleStatelessRetry(sock, clientConnID, serverConnID, client, version, buffer); + handleStatelessRetry(sock, clientConnID, serverConnID, client, localAddr, version, buffer); continue; } @@ -713,7 +725,7 @@ static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState } DEBUGLOG("Creating a new connection"); - conn = createConnection(*frontend.d_server_config, serverConnID, *originalDestinationID, clientState.local, client); + conn = createConnection(*frontend.d_server_config, serverConnID, *originalDestinationID, client, localAddr); if (!conn) { continue; } @@ -724,8 +736,8 @@ static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState reinterpret_cast(&client), client.getSocklen(), // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - reinterpret_cast(&clientState.local), - clientState.local.getSocklen(), + reinterpret_cast(&localAddr), + localAddr.getSocklen(), }; auto done = quiche_conn_recv(conn->get().d_conn.get(), buffer.data(), buffer.size(), &recv_info); @@ -741,7 +753,7 @@ static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState handleReadableStream(frontend, clientState, *conn, streamID, client, serverConnID); } - flushEgress(sock, conn->get().d_conn, client, buffer); + flushEgress(sock, conn->get().d_conn, client, localAddr, buffer); } else { DEBUGLOG("Connection not established"); @@ -786,7 +798,7 @@ void doqThread(ClientState* clientState) for (auto conn = frontend->d_server_config->d_connections.begin(); conn != frontend->d_server_config->d_connections.end();) { quiche_conn_on_timeout(conn->second.d_conn.get()); - flushEgress(sock, conn->second.d_conn, conn->second.d_peer, buffer); + flushEgress(sock, conn->second.d_conn, conn->second.d_peer, conn->second.d_localAddr, buffer); if (quiche_conn_is_closed(conn->second.d_conn.get())) { #ifdef DEBUGLOG_ENABLED diff --git a/pdns/iputils.cc b/pdns/iputils.cc index 4409997bded1..bd1204e3ac86 100644 --- a/pdns/iputils.cc +++ b/pdns/iputils.cc @@ -366,17 +366,18 @@ void ComboAddress::truncate(unsigned int bits) noexcept *place &= (~((1<(const_cast(dest)); msgh.msg_namelen = dest->getSocklen(); } @@ -387,11 +388,12 @@ size_t sendMsgWithOptions(int fd, const char* buffer, size_t len, const ComboAdd msgh.msg_flags = 0; - if (localItf != 0 && local) { - addCMsgSrcAddr(&msgh, &cbuf, local, localItf); + if (local != nullptr && local->sin4.sin_family != 0) { + addCMsgSrcAddr(&msgh, &cbuf, local, static_cast(localItf)); } - iov.iov_base = reinterpret_cast(const_cast(buffer)); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast): it's the API + iov.iov_base = const_cast(buffer); iov.iov_len = len; msgh.msg_iov = &iov; msgh.msg_iovlen = 1; @@ -405,15 +407,15 @@ size_t sendMsgWithOptions(int fd, const char* buffer, size_t len, const ComboAdd do { #ifdef MSG_FASTOPEN - if (flags & MSG_FASTOPEN && firstTry == false) { + if ((flags & MSG_FASTOPEN) != 0 && !firstTry) { flags &= ~MSG_FASTOPEN; } #endif /* MSG_FASTOPEN */ - ssize_t res = sendmsg(fd, &msgh, flags); + ssize_t res = sendmsg(socketDesc, &msgh, flags); if (res > 0) { - size_t written = static_cast(res); + auto written = static_cast(res); sent += written; if (sent == len) { @@ -425,6 +427,7 @@ size_t sendMsgWithOptions(int fd, const char* buffer, size_t len, const ComboAdd firstTry = false; #endif iov.iov_len -= written; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): it's the API iov.iov_base = reinterpret_cast(reinterpret_cast(iov.iov_base) + written); } else if (res == 0) { @@ -435,14 +438,12 @@ size_t sendMsgWithOptions(int fd, const char* buffer, size_t len, const ComboAdd if (err == EINTR) { continue; } - else if (err == EAGAIN || err == EWOULDBLOCK || err == EINPROGRESS || err == ENOTCONN) { + if (err == EAGAIN || err == EWOULDBLOCK || err == EINPROGRESS || err == ENOTCONN) { /* EINPROGRESS might happen with non blocking socket, especially with TCP Fast Open */ return sent; } - else { - unixDie("failed in sendMsgWithTimeout"); - } + unixDie("failed in sendMsgWithOptions"); } } while (true); diff --git a/pdns/iputils.hh b/pdns/iputils.hh index e5943c8e7dcf..2a318f732726 100644 --- a/pdns/iputils.hh +++ b/pdns/iputils.hh @@ -1736,7 +1736,7 @@ bool HarvestDestinationAddress(const struct msghdr* msgh, ComboAddress* destinat bool HarvestTimestamp(struct msghdr* msgh, struct timeval* tv); void fillMSGHdr(struct msghdr* msgh, struct iovec* iov, cmsgbuf_aligned* cbuf, size_t cbufsize, char* data, size_t datalen, ComboAddress* addr); int sendOnNBSocket(int fd, const struct msghdr *msgh); -size_t sendMsgWithOptions(int fd, const char* buffer, size_t len, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int flags); +size_t sendMsgWithOptions(int socketDesc, const void* buffer, size_t len, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int flags); /* requires a non-blocking, connected TCP socket */ bool isTCPSocketUsable(int sock); diff --git a/regression-tests.dnsdist/quictests.py b/regression-tests.dnsdist/quictests.py index 62cf24e757a4..743de28db561 100644 --- a/regression-tests.dnsdist/quictests.py +++ b/regression-tests.dnsdist/quictests.py @@ -169,3 +169,25 @@ def testCached(self): total += self._responsesCounter[key] self.assertEqual(total, 1) + +class QUICGetLocalAddressOnAnyBindTests(object): + + def testGetLocalAddressOnAnyBind(self): + """ + QUIC: Return CNAME containing the local address for an ANY bind + """ + name = 'local-address-any.quic.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.CNAME, + 'address-was-127-0-0-1.local-address-any.advanced.tests.powerdns.com.') + response.answer.append(rrset) + + (_, receivedResponse) = self.sendQUICQuery(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, response) diff --git a/regression-tests.dnsdist/test_DOH3.py b/regression-tests.dnsdist/test_DOH3.py index 4704c26901e2..4a91c433f25e 100644 --- a/regression-tests.dnsdist/test_DOH3.py +++ b/regression-tests.dnsdist/test_DOH3.py @@ -4,7 +4,7 @@ from dnsdisttests import DNSDistTest from dnsdisttests import pickAvailablePort -from quictests import QUICTests, QUICWithCacheTests, QUICACLTests +from quictests import QUICTests, QUICWithCacheTests, QUICACLTests, QUICGetLocalAddressOnAnyBindTests import doh3client class TestDOH3(QUICTests, DNSDistTest): @@ -92,3 +92,33 @@ def testDOH3Post(self): receivedQuery.id = expectedQuery.id self.assertEqual(expectedQuery, receivedQuery) self.assertEqual(receivedResponse, response) + +class TestDOH3GetLocalAddressOnAnyBind(QUICGetLocalAddressOnAnyBindTests, DNSDistTest): + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _doqServerPort = pickAvailablePort() + _dohBaseURL = ("https://%s:%d/" % (_serverName, _doqServerPort)) + _config_template = """ + function answerBasedOnLocalAddress(dq) + local dest = tostring(dq.localaddr) + local i, j = string.find(dest, "[0-9.]+") + local addr = string.sub(dest, i, j) + local dashAddr = string.gsub(addr, "[.]", "-") + return DNSAction.Spoof, "address-was-"..dashAddr..".local-address-any.advanced.tests.powerdns.com." + end + addAction("local-address-any.quic.tests.powerdns.com.", LuaAction(answerBasedOnLocalAddress)) + newServer{address="127.0.0.1:%s"} + addDOH3Local("0.0.0.0:%d", "%s", "%s") + addDOH3Local("[::]:%d", "%s", "%s") + """ + _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey', '_doqServerPort','_serverCert', '_serverKey'] + _acl = ['127.0.0.1/32', '::1/128'] + _skipListeningOnCL = True + + def getQUICConnection(self): + return self.getDOQConnection(self._doqServerPort, self._caCert) + + def sendQUICQuery(self, query, response=None, useQueue=True, connection=None): + return self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, connection=connection) diff --git a/regression-tests.dnsdist/test_DOQ.py b/regression-tests.dnsdist/test_DOQ.py index 9af5d8a9387b..657df001a993 100644 --- a/regression-tests.dnsdist/test_DOQ.py +++ b/regression-tests.dnsdist/test_DOQ.py @@ -6,7 +6,7 @@ from dnsdisttests import DNSDistTest from dnsdisttests import pickAvailablePort from doqclient import quic_bogus_query -from quictests import QUICTests, QUICWithCacheTests, QUICACLTests +from quictests import QUICTests, QUICWithCacheTests, QUICACLTests, QUICGetLocalAddressOnAnyBindTests import doqclient from doqclient import quic_query @@ -142,3 +142,32 @@ def testCertificateReloaded(self): (_, secondSerial) = quic_query(query, '127.0.0.1', 0.5, self._doqServerPort, verify=self._caCert, server_hostname=self._serverName) # check that the serial is different self.assertNotEqual(serial, secondSerial) + +class TestDOQGetLocalAddressOnAnyBind(QUICGetLocalAddressOnAnyBindTests, DNSDistTest): + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _doqServerPort = pickAvailablePort() + _config_template = """ + function answerBasedOnLocalAddress(dq) + local dest = tostring(dq.localaddr) + local i, j = string.find(dest, "[0-9.]+") + local addr = string.sub(dest, i, j) + local dashAddr = string.gsub(addr, "[.]", "-") + return DNSAction.Spoof, "address-was-"..dashAddr..".local-address-any.advanced.tests.powerdns.com." + end + addAction("local-address-any.quic.tests.powerdns.com.", LuaAction(answerBasedOnLocalAddress)) + newServer{address="127.0.0.1:%s"} + addDOQLocal("0.0.0.0:%d", "%s", "%s") + addDOQLocal("[::]:%d", "%s", "%s") + """ + _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey', '_doqServerPort','_serverCert', '_serverKey'] + _acl = ['127.0.0.1/32', '::1/128'] + _skipListeningOnCL = True + + def getQUICConnection(self): + return self.getDOQConnection(self._doqServerPort, self._caCert) + + def sendQUICQuery(self, query, response=None, useQueue=True, connection=None): + return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, connection=connection)