diff --git a/helpertest/helper.go b/helpertest/helper.go index 29b0421af..cd5415bcc 100644 --- a/helpertest/helper.go +++ b/helpertest/helper.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" @@ -25,6 +26,7 @@ const ( HTTPS = dns.Type(dns.TypeHTTPS) MX = dns.Type(dns.TypeMX) PTR = dns.Type(dns.TypePTR) + SRV = dns.Type(dns.TypeSRV) TXT = dns.Type(dns.TypeTXT) DS = dns.Type(dns.TypeDS) ) @@ -216,6 +218,10 @@ func (matcher *dnsRecordMatcher) matchSingle(rr dns.RR) (success bool, err error return v.Target == matcher.answer, nil case *dns.PTR: return v.Ptr == matcher.answer, nil + case *dns.SRV: + return fmt.Sprintf("%d %d %d %s", v.Priority, v.Weight, v.Port, v.Target) == matcher.answer, nil + case *dns.TXT: + return strings.Join(v.Txt, " ") == matcher.answer, nil case *dns.MX: return v.Mx == matcher.answer, nil } diff --git a/resolver/custom_dns_resolver.go b/resolver/custom_dns_resolver.go index bf160a29f..8cc26dfb9 100644 --- a/resolver/custom_dns_resolver.go +++ b/resolver/custom_dns_resolver.go @@ -177,6 +177,10 @@ func (r *CustomDNSResolver) processDNSEntry( return r.processIP(v.A, question, v.Header().Ttl) case *dns.AAAA: return r.processIP(v.AAAA, question, v.Header().Ttl) + case *dns.TXT: + return r.processTXT(v.Txt, question, v.Header().Ttl) + case *dns.SRV: + return r.processSRV(*v, question, v.Header().Ttl) case *dns.CNAME: return r.processCNAME(ctx, logger, request, *v, resolvedCnames, question, v.Header().Ttl) } @@ -211,6 +215,35 @@ func (r *CustomDNSResolver) processIP(ip net.IP, question dns.Question, ttl uint return result, nil } +func (r *CustomDNSResolver) processTXT(value []string, question dns.Question, ttl uint32) (result []dns.RR, err error) { + if question.Qtype == dns.TypeTXT { + txt := new(dns.TXT) + txt.Hdr = dns.RR_Header{Class: dns.ClassINET, Ttl: ttl, Rrtype: dns.TypeTXT, Name: question.Name} + txt.Txt = value + result = append(result, txt) + } + + return result, nil +} + +func (r *CustomDNSResolver) processSRV( + targetSRV dns.SRV, + question dns.Question, + ttl uint32, +) (result []dns.RR, err error) { + if question.Qtype == dns.TypeSRV { + srv := new(dns.SRV) + srv.Hdr = dns.RR_Header{Class: dns.ClassINET, Ttl: ttl, Rrtype: dns.TypeSRV, Name: question.Name} + srv.Priority = targetSRV.Priority + srv.Weight = targetSRV.Weight + srv.Port = targetSRV.Port + srv.Target = targetSRV.Target + result = append(result, srv) + } + + return result, nil +} + func (r *CustomDNSResolver) processCNAME( ctx context.Context, logger *logrus.Entry, diff --git a/resolver/custom_dns_resolver_test.go b/resolver/custom_dns_resolver_test.go index 6a4abefbd..1ba324ef1 100644 --- a/resolver/custom_dns_resolver_test.go +++ b/resolver/custom_dns_resolver_test.go @@ -58,6 +58,8 @@ var _ = Describe("CustomDNSResolver", func() { "cname.ip6.": {&dns.CNAME{Target: "ip6.domain", Hdr: zoneHdr}}, "cname.example.": {&dns.CNAME{Target: "example.com", Hdr: zoneHdr}}, "cname.recursive.": {&dns.CNAME{Target: "cname.recursive", Hdr: zoneHdr}}, + "srv.": {&dns.SRV{Priority: 0, Weight: 5, Port: 12345, Target: "service", Hdr: zoneHdr}}, + "txt.": {&dns.TXT{Txt: []string{"space", "separated", "value"}, Hdr: zoneHdr}}, "mx.domain.": {&dns.MX{Mx: "mx.domain", Hdr: zoneHdr}}, }, }, @@ -375,6 +377,34 @@ var _ = Describe("CustomDNSResolver", func() { }) }) }) + When("Querying other record types", func() { + It("Returns an SRV response", func() { + Expect(sut.Resolve(ctx, newRequest("srv", SRV))). + Should( + SatisfyAll( + WithTransform(ToAnswer, SatisfyAll( + ContainElements( + BeDNSRecord("srv.", SRV, "0 5 12345 service")), + )), + HaveResponseType(ResponseTypeCUSTOMDNS), + HaveReason("CUSTOM DNS"), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + It("Returns a TXT response", func() { + Expect(sut.Resolve(ctx, newRequest("txt", TXT))). + Should( + SatisfyAll( + WithTransform(ToAnswer, SatisfyAll( + ContainElements( + BeDNSRecord("txt.", TXT, "space separated value")), + )), + HaveResponseType(ResponseTypeCUSTOMDNS), + HaveReason("CUSTOM DNS"), + HaveReturnCode(dns.RcodeSuccess), + )) + }) + }) When("An unsupported DNS query type is queried from the resolver but found in the config mapping ", func() { It("an error should be returned", func() { By("MX query", func() {