Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(resolver): race UDP and TCP when connecting upstream #1302

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions resolver/upstream_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,24 +188,71 @@ func (r *dnsUpstreamClient) fmtURL(ip net.IP, port uint16, _ string) string {
func (r *dnsUpstreamClient) callExternal(
ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
) (response *dns.Msg, rtt time.Duration, err error) {
if protocol == model.RequestProtocolTCP {
response, rtt, err = r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
if err != nil && r.udpClient != nil {
// try UDP as fallback
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "dial" {
return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
}
if r.udpClient == nil {
return r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
}

return r.raceClients(ctx, msg, upstreamURL, protocol)
}

func (r *dnsUpstreamClient) raceClients(
ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
) (response *dns.Msg, rtt time.Duration, err error) {
type result struct {
proto model.RequestProtocol
msg *dns.Msg
rtt time.Duration
err error
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

// We don't explicitly close the channel, but since the buffer is big enough for all goroutines,
// it will be GC'ed and closed automatically.
ch := make(chan result, 2) //nolint:gomnd // TCP and UDP

exchange := func(client *dns.Client, proto model.RequestProtocol) {
msg, rtt, err := client.ExchangeContext(ctx, msg, upstreamURL)

ch <- result{proto, msg, rtt, err}
}

go exchange(r.tcpClient, model.RequestProtocolTCP)
go exchange(r.udpClient, model.RequestProtocolUDP)

// We don't care about a response too big for the downstream protocol: that's handled by `Server`,
// and returning a larger request from here might allow us to cache it.

res1 := <-ch
if res1.err == nil && !res1.msg.Truncated {
return res1.msg, res1.rtt, nil
}

res2 := <-ch
if res2.err == nil && !res2.msg.Truncated {
return res2.msg, res2.rtt, nil
}

resWhere := func(pred func(*result) bool) *result {
if pred(&res1) {
return &res1
}

return response, rtt, err
return &res2
}

if r.udpClient != nil {
return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
// When both failed, return the result that used the same protocol as the downstream request
if res1.err != nil && res2.err != nil {
sameProto := resWhere(func(r *result) bool { return r.proto == protocol })

return sameProto.msg, sameProto.rtt, sameProto.err
}

return r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
// Only a single one failed, use the one that succeeded
successful := resWhere(func(r *result) bool { return r.err == nil })

return successful.msg, successful.rtt, nil
}

// NewUpstreamResolver creates new resolver instance
Expand Down
2 changes: 1 addition & 1 deletion resolver/upstream_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
sutConfig.Upstream = mockUpstream.Start()
})

It("should retry with UDP", func() {
It("should also try with UDP", func() {
req := newRequest("example.com.", A)
req.Protocol = RequestProtocolTCP

Expand Down
8 changes: 4 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *d
response.Res.MsgHdr.RecursionAvailable = request.MsgHdr.RecursionDesired

// truncate if necessary
response.Res.Truncate(getMaxResponseSize(w.LocalAddr().Network(), request))
response.Res.Truncate(getMaxResponseSize(r))

// enable compression
response.Res.Compress = true
Expand All @@ -624,13 +624,13 @@ func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *d
}

// returns EDNS UDP size or if not present, 512 for UDP and 64K for TCP
func getMaxResponseSize(network string, request *dns.Msg) int {
edns := request.IsEdns0()
func getMaxResponseSize(req *model.Request) int {
edns := req.Req.IsEdns0()
if edns != nil && edns.UDPSize() > 0 {
return int(edns.UDPSize())
}

if network == "tcp" {
if req.Protocol == model.RequestProtocolTCP {
return dns.MaxMsgSize
}

Expand Down