Skip to content

Commit

Permalink
feat: multi upstream support
Browse files Browse the repository at this point in the history
  • Loading branch information
wintbiit committed Dec 12, 2023
1 parent 375a4ca commit abb4f68
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
2 changes: 1 addition & 1 deletion model/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type Domain struct {
Rules map[string]Rule `json:"rules"`
Authoritative bool `json:"authoritative,default=true"`
Recursion bool `json:"recursion,default=false"`
Upstream string `json:"upstream,default=127.0.0.1:53"`
Upstreams []string `json:"upstreams"`
Providers map[string]string `json:"providers"`
TTL uint32 `json:"ttl,default=60"`
Tsig *TSIG `json:"tsig"`
Expand Down
41 changes: 39 additions & 2 deletions server/ruleset.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package server

import (
"fmt"
"math/rand"
"net"
"sort"
"strings"
"time"

"github.com/wintbiit/ninedns/log"

Expand All @@ -28,8 +30,43 @@ func (s *RuleSet) Recursion() bool {
}

func (s *RuleSet) Exchange(r *dns.Msg) (*dns.Msg, error) {
m, _, err := s.dnsClient.Exchange(r, s.Upstream)
return m, err
upstream, err := s.cacheClient.GetRuntimeCache("upstream:" + s.DomainName + s.Name)
if err != nil {
upstream, err = s.findUpstream(r)
if err != nil {
return nil, err
}

if err := s.cacheClient.AddRuntimeCache("upstream:"+s.DomainName+s.Name, upstream, time.Duration(s.TTL)*time.Second); err != nil {
s.l.Warnf("Failed to add runtime cache: %s", err)
}
}

m, _, err := s.dnsClient.Exchange(r, upstream)
if err != nil {
return nil, err
}

return m, nil
}

func (s *RuleSet) findUpstream(r *dns.Msg) (string, error) {
for _, upstream := range s.Upstreams {
m, _, err := s.dnsClient.Exchange(r, upstream)
if err != nil {
s.l.Debugf("Failed to exchange with upstream %s: %s", upstream, err)
continue
}

if m == nil || m.Rcode != dns.RcodeSuccess {
s.l.Debugf("Failed to exchange with upstream %s: %s", upstream, err)
continue
}

return upstream, nil
}

return "", fmt.Errorf("failed to find upstream")
}

func (s *Server) newRuleSet(name string, rule model.Rule) *RuleSet {
Expand Down
5 changes: 0 additions & 5 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@ func (s *Server) checkConfig() {
s.l.Warn("Server TTL is 0, automatically set it to 60.")
}

if s.Domain.Upstream == "" {
s.Domain.Upstream = "223.5.5.5:53"
s.l.Warn("Server upstream is empty, automatically set it to %s.", s.Domain.Upstream)
}

if s.Domain.Rules == nil {
s.l.Warn("Server rules is empty, automatically added general rule.")
s.Domain.Rules = map[string]model.Rule{"": {}}
Expand Down

0 comments on commit abb4f68

Please sign in to comment.