Skip to content

Commit

Permalink
feat: global recursion upstream handler
Browse files Browse the repository at this point in the history
  • Loading branch information
wintbiit committed Dec 12, 2023
1 parent 94bde68 commit 375a4ca
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 108 deletions.
15 changes: 2 additions & 13 deletions resolver/A.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package resolver

import (
"fmt"

"github.com/miekg/dns"
"github.com/wintbiit/ninedns/model"
)
Expand All @@ -13,19 +11,10 @@ func init() {
resolvers[dns.TypeA] = &A{}
}

func (_ *A) Resolve(s model.RecordProvider, r *dns.Msg, name string) ([]dns.RR, error) {
func (_ *A) Resolve(s model.RecordProvider, name string) ([]dns.RR, error) {
record := s.FindRecord(name, dns.TypeA)
if record == nil {
if !s.Recursion() {
return nil, fmt.Errorf("no record found for question: %+v", name)
}

resp, err := s.Exchange(r)
if err != nil {
return nil, err
}

return resp.Answer, nil
return nil, nil
}

rr := &dns.A{
Expand Down
15 changes: 2 additions & 13 deletions resolver/AAAA.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package resolver

import (
"fmt"

"github.com/miekg/dns"
"github.com/wintbiit/ninedns/model"
)
Expand All @@ -13,19 +11,10 @@ func init() {
resolvers[dns.TypeAAAA] = &AAAA{}
}

func (_ *AAAA) Resolve(s model.RecordProvider, r *dns.Msg, name string) ([]dns.RR, error) {
func (_ *AAAA) Resolve(s model.RecordProvider, name string) ([]dns.RR, error) {
record := s.FindRecord(name, dns.TypeAAAA)
if record == nil {
if !s.Recursion() {
return nil, fmt.Errorf("no record found for question: %+v", name)
}

resp, err := s.Exchange(r)
if err != nil {
return nil, err
}

return resp.Answer, nil
return nil, nil
}

rr := &dns.AAAA{
Expand Down
14 changes: 2 additions & 12 deletions resolver/CNAME.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package resolver

import (
"fmt"

"github.com/miekg/dns"
"github.com/wintbiit/ninedns/model"
)
Expand All @@ -13,18 +11,10 @@ func init() {
resolvers[dns.TypeCNAME] = &CNAME{}
}

func (_ *CNAME) Resolve(s model.RecordProvider, r *dns.Msg, name string) ([]dns.RR, error) {
func (_ *CNAME) Resolve(s model.RecordProvider, name string) ([]dns.RR, error) {
record := s.FindRecord(name, dns.TypeCNAME)
if record == nil {
if !s.Recursion() {
return nil, fmt.Errorf("no record found for question: %+v", name)
}
resp, err := s.Exchange(r)
if err != nil {
return nil, err
}

return resp.Answer, nil
return nil, nil
}

cname := record.Value.String()
Expand Down
15 changes: 2 additions & 13 deletions resolver/MX.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package resolver

import (
"fmt"

"github.com/miekg/dns"
"github.com/wintbiit/ninedns/model"
)
Expand All @@ -13,19 +11,10 @@ func init() {
resolvers[dns.TypeMX] = &MX{}
}

func (_ *MX) Resolve(s model.RecordProvider, r *dns.Msg, name string) ([]dns.RR, error) {
func (_ *MX) Resolve(s model.RecordProvider, name string) ([]dns.RR, error) {
record := s.FindRecord(name, dns.TypeMX)
if record == nil {
if !s.Recursion() {
return nil, fmt.Errorf("no record found for question: %+v", name)
}

resp, err := s.Exchange(r)
if err != nil {
return nil, err
}

return resp.Answer, nil
return nil, nil
}

mx, err := record.Value.MX()
Expand Down
15 changes: 2 additions & 13 deletions resolver/NS.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package resolver

import (
"fmt"

"github.com/miekg/dns"
"github.com/wintbiit/ninedns/model"
)
Expand All @@ -13,19 +11,10 @@ func init() {
resolvers[dns.TypeNS] = &NS{}
}

func (_ *NS) Resolve(s model.RecordProvider, r *dns.Msg, name string) ([]dns.RR, error) {
func (_ *NS) Resolve(s model.RecordProvider, name string) ([]dns.RR, error) {
record := s.FindRecord(name, dns.TypeNS)
if record == nil {
if !s.Recursion() {
return nil, fmt.Errorf("no record found for question: %+v", name)
}

resp, err := s.Exchange(r)
if err != nil {
return nil, err
}

return resp.Answer, nil
return nil, nil
}

rr := &dns.NS{
Expand Down
15 changes: 2 additions & 13 deletions resolver/SOA.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package resolver

import (
"fmt"

"github.com/miekg/dns"
"github.com/wintbiit/ninedns/model"
)
Expand All @@ -13,20 +11,11 @@ func init() {
resolvers[dns.TypeSOA] = &SOA{}
}

func (_ *SOA) Resolve(s model.RecordProvider, r *dns.Msg, name string) ([]dns.RR, error) {
func (_ *SOA) Resolve(s model.RecordProvider, name string) ([]dns.RR, error) {
record := s.FindRecord(name, dns.TypeSOA)

if record == nil {
if !s.Recursion() {
return nil, fmt.Errorf("no record found for question: %+v", name)
}

resp, err := s.Exchange(r)
if err != nil {
return nil, err
}

return resp.Answer, nil
return nil, nil
}

soa, err := record.Value.SOA()
Expand Down
15 changes: 2 additions & 13 deletions resolver/SRV.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package resolver

import (
"fmt"

"github.com/miekg/dns"
"github.com/wintbiit/ninedns/model"
)
Expand All @@ -13,20 +11,11 @@ func init() {
resolvers[dns.TypeSRV] = &SRV{}
}

func (_ *SRV) Resolve(s model.RecordProvider, r *dns.Msg, name string) ([]dns.RR, error) {
func (_ *SRV) Resolve(s model.RecordProvider, name string) ([]dns.RR, error) {
record := s.FindRecord(name, dns.TypeSRV)

if record == nil {
if !s.Recursion() {
return nil, fmt.Errorf("no record found for question: %+v", name)
}

resp, err := s.Exchange(r)
if err != nil {
return nil, err
}

return resp.Answer, nil
return nil, nil
}

srv, err := record.Value.SRV()
Expand Down
13 changes: 2 additions & 11 deletions resolver/TXT.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,11 @@ func init() {
resolvers[dns.TypeTXT] = &TXT{}
}

func (_ *TXT) Resolve(s model.RecordProvider, r *dns.Msg, name string) ([]dns.RR, error) {
func (_ *TXT) Resolve(s model.RecordProvider, name string) ([]dns.RR, error) {
records := s.FindRecords(name, dns.TypeTXT)

if records == nil {
if !s.Recursion() {
return nil, nil
}

resp, err := s.Exchange(r)
if err != nil {
return nil, err
}

return resp.Answer, nil
return nil, nil
}

rrs := make([]dns.RR, len(records))
Expand Down
6 changes: 3 additions & 3 deletions resolver/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ import (
)

type Resolver interface {
Resolve(model.RecordProvider, *dns.Msg, string) ([]dns.RR, error)
Resolve(model.RecordProvider, string) ([]dns.RR, error)
}

var resolvers = make(map[uint16]Resolver)

func Resolve(typ uint16, p model.RecordProvider, r *dns.Msg, name string) ([]dns.RR, error) {
func Resolve(typ uint16, p model.RecordProvider, name string) ([]dns.RR, error) {
resolver, ok := resolvers[typ]
if !ok {
return nil, nil
}

return resolver.Resolve(p, r, name)
return resolver.Resolve(p, name)
}
17 changes: 14 additions & 3 deletions server/ruleset.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (s *Server) newRuleSet(name string, rule model.Rule) *RuleSet {
return ruleSet
}

func (s *RuleSet) query(r, m *dns.Msg) {
func (s *RuleSet) resolve(r, m *dns.Msg) {
for _, q := range r.Question {
// 1. Try CNAME
name := q.Name
Expand All @@ -74,11 +74,22 @@ func (s *RuleSet) query(r, m *dns.Msg) {
}

func (s *RuleSet) question(q *dns.Question, r, m *dns.Msg, name string) error {
records, err := resolver.Resolve(q.Qtype, s, r, name)
records, err := resolver.Resolve(q.Qtype, s, name)
if err != nil {
return err
}

if len(records) == 0 && s.Recursion() {
s.l.Debugf("Question %s not found, try upstream", q.String())
resp, err := s.Exchange(r)
if err != nil {
return err
}

m.Answer = resp.Answer
return nil
}

for _, record := range records {
record.Header().Name = name
m.Answer = append(m.Answer, record)
Expand Down Expand Up @@ -132,7 +143,7 @@ func (s *RuleSet) FindRecords(name string, quesType uint16) []model.Record {
}
records, err := s.cacheClient.FindRecords(name, model.ReadRecordType(quesType).String(), s.Name)
if err != nil {
s.l.Errorf("Failed to query records: %s", err)
s.l.Errorf("Failed to resolve records: %s", err)
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (s *Server) handle(w dns.ResponseWriter, r *dns.Msg) {

s.l.Debugf("Found rule for %s: %+v", remoteAddr, handler)

handler.query(r, m)
handler.resolve(r, m)

if err := w.WriteMsg(m); err != nil {
s.l.Errorf("Failed to write response: %s", err)
Expand Down

0 comments on commit 375a4ca

Please sign in to comment.