Skip to content

Commit

Permalink
Merge pull request #14078 from rgacogne/ddist-harvest-quic
Browse files Browse the repository at this point in the history
dnsdist: Use the correct source IP for outgoing QUIC datagrams
  • Loading branch information
rgacogne authored Apr 25, 2024
2 parents 2af044b + 77c1af6 commit 5e38664
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 96 deletions.
52 changes: 17 additions & 35 deletions pdns/dnsdistdist/dnsdist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,34 +157,25 @@ static constexpr size_t s_maxUDPResponsePacketSize{4096U};
static size_t const s_initialUDPPacketBufferSize = s_maxUDPResponsePacketSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
static_assert(s_initialUDPPacketBufferSize <= UINT16_MAX, "Packet size should fit in a uint16_t");

static ssize_t sendfromto(int sock, const void* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& dest)
static void sendfromto(int sock, const PacketBuffer& buffer, const ComboAddress& from, const ComboAddress& dest)
{
const int flags = 0;
if (from.sin4.sin_family == 0) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
return sendto(sock, data, len, flags, reinterpret_cast<const struct sockaddr*>(&dest), dest.getSocklen());
}
msghdr msgh{};
iovec iov{};
cmsgbuf_aligned cbuf;

/* Set up iov and msgh structures. */
memset(&msgh, 0, sizeof(struct msghdr));
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast): it's the API
iov.iov_base = const_cast<void*>(data);
iov.iov_len = len;
msgh.msg_iov = &iov;
msgh.msg_iovlen = 1;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-type-const-cast)
msgh.msg_name = const_cast<sockaddr*>(reinterpret_cast<const sockaddr*>(&dest));
msgh.msg_namelen = dest.getSocklen();

if (from.sin4.sin_family != 0) {
addCMsgSrcAddr(&msgh, &cbuf, &from, 0);
auto ret = sendto(sock, buffer.data(), buffer.size(), flags, reinterpret_cast<const struct sockaddr*>(&dest), dest.getSocklen());
if (ret == -1) {
int error = errno;
vinfolog("Error sending UDP response to %s: %s", dest.toStringWithPort(), stringerror(error));
}
return;
}
else {
msgh.msg_control = nullptr;

try {
sendMsgWithOptions(sock, buffer.data(), buffer.size(), &dest, &from, 0, 0);
}
catch (const std::exception& exp) {
vinfolog("Error sending UDP response from %s to %s: %s", from.toStringWithPort(), dest.toStringWithPort(), exp.what());
}
return sendmsg(sock, &msgh, flags);
}

static void truncateTC(PacketBuffer& packet, size_t maximumSize, unsigned int qnameWireLength)
Expand Down Expand Up @@ -223,13 +214,9 @@ struct DelayedPacket
PacketBuffer packet;
ComboAddress destination;
ComboAddress origDest;
void operator()()
void operator()() const
{
ssize_t res = sendfromto(fd, packet.data(), packet.size(), 0, origDest, destination);
if (res == -1) {
int err = errno;
vinfolog("Error sending delayed response to %s: %s", destination.toStringWithPort(), stringerror(err));
}
sendfromto(fd, packet, origDest, destination);
}
};

Expand Down Expand Up @@ -667,12 +654,7 @@ bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMs
}
#endif /* DISABLE_DELAY_PIPE */
// NOLINTNEXTLINE(readability-suspicious-call-argument)
ssize_t res = sendfromto(origFD, response.data(), response.size(), 0, origDest, origRemote);
if (res == -1) {
int err = errno;
vinfolog("Error sending response to %s: %s", origRemote.toStringWithPort(), stringerror(err));
}

sendfromto(origFD, response, origDest, origRemote);
return true;
}

Expand Down
44 changes: 28 additions & 16 deletions pdns/dnsdistdist/doh3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ using h3_headers_t = std::map<std::string, std::string>;
class H3Connection
{
public:
H3Connection(const ComboAddress& peer, QuicheConfig config, QuicheConnection&& conn) :
d_peer(peer), d_conn(std::move(conn)), d_config(std::move(config))
H3Connection(const ComboAddress& peer, const ComboAddress& localAddr, QuicheConfig config, QuicheConnection&& conn) :
d_peer(peer), d_localAddr(localAddr), d_conn(std::move(conn)), d_config(std::move(config))
{
}
H3Connection(const H3Connection&) = delete;
Expand All @@ -65,6 +65,7 @@ class H3Connection
~H3Connection() = default;

ComboAddress d_peer;
ComboAddress d_localAddr;
QuicheConnection d_conn;
QuicheConfig d_config;
QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free};
Expand Down Expand Up @@ -421,14 +422,14 @@ static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description)
}
}

static std::optional<std::reference_wrapper<H3Connection>> createConnection(DOH3ServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& local, const ComboAddress& peer)
static std::optional<std::reference_wrapper<H3Connection>> createConnection(DOH3ServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& localAddr, const ComboAddress& peer)
{
auto quicheConfig = std::atomic_load_explicit(&config.config, std::memory_order_acquire);
auto quicheConn = QuicheConnection(quiche_accept(serverSideID.data(), serverSideID.size(),
originalDestinationID.data(), originalDestinationID.size(),
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
reinterpret_cast<const struct sockaddr*>(&local),
local.getSocklen(),
reinterpret_cast<const struct sockaddr*>(&localAddr),
localAddr.getSocklen(),
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
reinterpret_cast<const struct sockaddr*>(&peer),
peer.getSocklen(),
Expand All @@ -439,7 +440,7 @@ static std::optional<std::reference_wrapper<H3Connection>> createConnection(DOH3
quiche_conn_set_keylog_path(quicheConn.get(), config.df->d_quicheParams.d_keyLogFile.c_str());
}

auto conn = H3Connection(peer, std::move(quicheConfig), std::move(quicheConn));
auto conn = H3Connection(peer, localAddr, std::move(quicheConfig), std::move(quicheConn));
auto pair = config.d_connections.emplace(serverSideID, std::move(conn));
return pair.first->second;
}
Expand Down Expand Up @@ -743,7 +744,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten
return;
}
DEBUGLOG("Dispatching GET query");
doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), clientState.local, client, serverConnID, streamID);
doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID);
conn.d_streamBuffers.erase(streamID);
conn.d_headersBuffers.erase(streamID);
return;
Expand Down Expand Up @@ -808,7 +809,7 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend,
}

DEBUGLOG("Dispatching POST query");
doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID);
doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID);
conn.d_headersBuffers.erase(streamID);
conn.d_streamBuffers.erase(streamID);
}
Expand Down Expand Up @@ -856,10 +857,21 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
PacketBuffer tokenBuf;
while (true) {
ComboAddress client;
ComboAddress localAddr;
client.sin4.sin_family = clientState.local.sin4.sin_family;
localAddr.sin4.sin_family = clientState.local.sin4.sin_family;
buffer.resize(4096);
if (!sock.recvFromAsync(buffer, client) || buffer.empty()) {
if (!dnsdist::doq::recvAsync(sock, buffer, client, localAddr)) {
return;
}
if (localAddr.sin4.sin_family == 0) {
localAddr = clientState.local;
}
else {
/* we don't get the port, only the address */
localAddr.sin4.sin_port = clientState.local.sin4.sin_port;
}

DEBUGLOG("Received DoH3 datagram of size " << buffer.size() << " from " << client.toStringWithPort());

uint32_t version{0};
Expand Down Expand Up @@ -896,14 +908,14 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
if (!quiche_version_is_supported(version)) {
DEBUGLOG("Unsupported version");
++frontend.d_doh3UnsupportedVersionErrors;
handleVersionNegociation(sock, clientConnID, serverConnID, client, buffer);
handleVersionNegociation(sock, clientConnID, serverConnID, client, localAddr, buffer);
continue;
}

if (token_len == 0) {
/* stateless retry */
DEBUGLOG("No token received");
handleStatelessRetry(sock, clientConnID, serverConnID, client, version, buffer);
handleStatelessRetry(sock, clientConnID, serverConnID, client, localAddr, version, buffer);
continue;
}

Expand All @@ -916,7 +928,7 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
}

DEBUGLOG("Creating a new connection");
conn = createConnection(*frontend.d_server_config, serverConnID, *originalDestinationID, clientState.local, client);
conn = createConnection(*frontend.d_server_config, serverConnID, *originalDestinationID, localAddr, client);
if (!conn) {
continue;
}
Expand All @@ -927,8 +939,8 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
reinterpret_cast<struct sockaddr*>(&client),
client.getSocklen(),
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
reinterpret_cast<struct sockaddr*>(&clientState.local),
clientState.local.getSocklen(),
reinterpret_cast<struct sockaddr*>(&localAddr),
localAddr.getSocklen(),
};

auto done = quiche_conn_recv(conn->get().d_conn.get(), buffer.data(), buffer.size(), &recv_info);
Expand All @@ -950,7 +962,7 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat

processH3Events(clientState, frontend, conn->get(), client, serverConnID, buffer);

flushEgress(sock, conn->get().d_conn, client, buffer);
flushEgress(sock, conn->get().d_conn, client, localAddr, buffer);
}
else {
DEBUGLOG("Connection not established");
Expand Down Expand Up @@ -995,7 +1007,7 @@ void doh3Thread(ClientState* clientState)
for (auto conn = frontend->d_server_config->d_connections.begin(); conn != frontend->d_server_config->d_connections.end();) {
quiche_conn_on_timeout(conn->second.d_conn.get());

flushEgress(sock, conn->second.d_conn, conn->second.d_peer, buffer);
flushEgress(sock, conn->second.d_conn, conn->second.d_peer, conn->second.d_localAddr, buffer);

if (quiche_conn_is_closed(conn->second.d_conn.get())) {
#ifdef DEBUGLOG_ENABLED
Expand Down
85 changes: 76 additions & 9 deletions pdns/dnsdistdist/doq-common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,28 @@ std::optional<PacketBuffer> validateToken(const PacketBuffer& token, const Combo
}
}

void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version, PacketBuffer& buffer)
static void sendFromTo(Socket& sock, const ComboAddress& peer, const ComboAddress& local, PacketBuffer& buffer)
{
const int flags = 0;
if (local.sin4.sin_family == 0) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
auto ret = sendto(sock.getHandle(), buffer.data(), buffer.size(), flags, reinterpret_cast<const struct sockaddr*>(&peer), peer.getSocklen());
if (ret < 0) {
auto error = errno;
vinfolog("Error while sending QUIC datagram of size %d to %s: %s", buffer.size(), peer.toStringWithPort(), stringerror(error));
}
return;
}

try {
sendMsgWithOptions(sock.getHandle(), buffer.data(), buffer.size(), &peer, &local, 0, 0);
}
catch (const std::exception& exp) {
vinfolog("Error while sending QUIC datagram of size %d from %s to %s: %s", buffer.size(), local.toStringWithPort(), peer.toStringWithPort(), exp.what());
}
}

void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, const ComboAddress& localAddr, uint32_t version, PacketBuffer& buffer)
{
auto newServerConnID = getCID();
if (!newServerConnID) {
Expand All @@ -148,11 +169,11 @@ void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const
return;
}

// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
sock.sendTo(reinterpret_cast<const char*>(buffer.data()), static_cast<size_t>(written), peer);
buffer.resize(static_cast<size_t>(written));
sendFromTo(sock, peer, localAddr, buffer);
}

void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, PacketBuffer& buffer)
void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer)
{
buffer.resize(MAX_DATAGRAM_SIZE);

Expand All @@ -164,11 +185,12 @@ void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, co
DEBUGLOG("failed to create vneg packet " << written);
return;
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
sock.sendTo(reinterpret_cast<const char*>(buffer.data()), static_cast<size_t>(written), peer);

buffer.resize(static_cast<size_t>(written));
sendFromTo(sock, peer, localAddr, buffer);
}

void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, PacketBuffer& buffer)
void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer)
{
buffer.resize(MAX_DATAGRAM_SIZE);
quiche_send_info send_info;
Expand All @@ -183,8 +205,8 @@ void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer,
return;
}
// FIXME pacing (as send_info.at should tell us when to send the packet) ?
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
sock.sendTo(reinterpret_cast<const char*>(buffer.data()), static_cast<size_t>(written), peer);
buffer.resize(static_cast<size_t>(written));
sendFromTo(sock, peer, localAddr, buffer);
}
}

Expand Down Expand Up @@ -258,6 +280,51 @@ void configureQuiche(QuicheConfig& config, const QuicheParams& params, bool isHT
}
}

bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, ComboAddress& localAddr)
{
msghdr msgh{};
iovec iov{};
/* used by HarvestDestinationAddress */
cmsgbuf_aligned cbuf;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), reinterpret_cast<char*>(&buffer.at(0)), buffer.size(), &clientAddr);

ssize_t got = recvmsg(socket.getHandle(), &msgh, 0);
if (got < 0) {
int error = errno;
if (error != EAGAIN) {
throw NetworkError("Error in recvmsg: " + stringerror(error));
}
return false;
}

if ((msgh.msg_flags & MSG_TRUNC) != 0) {
return false;
}

buffer.resize(static_cast<size_t>(got));

if (HarvestDestinationAddress(&msgh, &localAddr)) {
/* so it turns out that sometimes the kernel lies to us:
the address is set to 0.0.0.0:0 which makes our sendfromto() use
the wrong address. In that case it's better to let the kernel
do the work by itself and use sendto() instead.
This is indicated by setting the family to 0 which is acted upon
in sendUDPResponse() and DelayedPacket::().
*/
const ComboAddress bogusV4("0.0.0.0:0");
const ComboAddress bogusV6("[::]:0");
if ((localAddr.sin4.sin_family == AF_INET && localAddr == bogusV4) || (localAddr.sin4.sin_family == AF_INET6 && localAddr == bogusV6)) {
localAddr.sin4.sin_family = 0;
}
}
else {
localAddr.sin4.sin_family = 0;
}

return !buffer.empty();
}

};

#endif
7 changes: 4 additions & 3 deletions pdns/dnsdistdist/doq-common.hh
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ void fillRandom(PacketBuffer& buffer, size_t size);
std::optional<PacketBuffer> getCID();
PacketBuffer mintToken(const PacketBuffer& dcid, const ComboAddress& peer);
std::optional<PacketBuffer> validateToken(const PacketBuffer& token, const ComboAddress& peer);
void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version, PacketBuffer& buffer);
void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, PacketBuffer& buffer);
void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, PacketBuffer& buffer);
void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, const ComboAddress& localAddr, uint32_t version, PacketBuffer& buffer);
void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer);
void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer);
void configureQuiche(QuicheConfig& config, const QuicheParams& params, bool isHTTP);
bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, ComboAddress& localAddr);

};

Expand Down
Loading

0 comments on commit 5e38664

Please sign in to comment.