Skip to content

Commit

Permalink
feat(substituter): add substituter middleware and `MIDDLEWARE_SUBSTIT…
Browse files Browse the repository at this point in the history
…UTER_SUBSTITUTIONS`
  • Loading branch information
qdm12 committed Mar 28, 2024
1 parent 56a3c24 commit 156a695
Show file tree
Hide file tree
Showing 10 changed files with 717 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ For example, the environment variable `UPSTREAM_TYPE` corresponds to the CLI fla
| `METRICS_PROMETHEUS_SUBSYSTEM` | `dns` | Prometheus metrics prefix/subsystem |
| `MIDDLEWARE_LOCALDNS_ENABLED` | `on` | Enable or disable the local DNS middleware |
| `MIDDLEWARE_LOCALDNS_RESOLVERS` | Local DNS servers | Comma separated list of local DNS resolvers to use for local names DNS requests |
| `MIDDLEWARE_SUBSTITUTER_SUBSTITUTIONS` | | JSON encoded list of substitutions. For example `[{"name":"github.com","ips":["1.2.3.4"]}]`. You can also specify the `type`, `class` and `ttl`, where they default respectively to `A`/`AAAA`, `IN` and `300`. |
| `CHECK_DNS` | `on` | `on` or `off`. Check resolving github.com using `127.0.0.1:53` at start |
| `UPDATE_PERIOD` | `24h` | Period to update block lists and restart Unbound. Set to `0` to disable. |

Expand Down
12 changes: 11 additions & 1 deletion internal/config/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"time"

"github.com/qdm12/dns/v2/pkg/middlewares/substituter"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader"
"github.com/qdm12/gosettings/validate"
Expand All @@ -32,6 +33,7 @@ type Settings struct {
MiddlewareLog MiddlewareLog
Metrics Metrics
LocalDNS LocalDNS
Substituter substituter.Settings
CheckDNS *bool
UpdatePeriod *time.Duration
}
Expand All @@ -47,6 +49,7 @@ func (s *Settings) SetDefaults() {
s.MiddlewareLog.setDefaults()
s.Metrics.setDefaults()
s.LocalDNS.setDefault()
s.Substituter.SetDefaults()
s.CheckDNS = gosettings.DefaultPointer(s.CheckDNS, true)
const defaultUpdaterPeriod = 24 * time.Hour
s.UpdatePeriod = gosettings.DefaultPointer(s.UpdatePeriod, defaultUpdaterPeriod)
Expand Down Expand Up @@ -77,6 +80,7 @@ func (s *Settings) Validate() (err error) {
"middleware log": s.MiddlewareLog.validate,
"metrics": s.Metrics.validate,
"local DNS": s.LocalDNS.validate,
"substituter": s.Substituter.Validate,
}
for name, validate := range nameToValidate {
err = validate()
Expand Down Expand Up @@ -119,6 +123,7 @@ func (s *Settings) ToLinesNode() (node *gotree.Node) {
node.AppendNode(s.MiddlewareLog.ToLinesNode())
node.AppendNode(s.Metrics.ToLinesNode())
node.AppendNode(s.LocalDNS.ToLinesNode())
node.AppendNode(s.Substituter.ToLinesNode())
node.Appendf("Check DNS: %s", gosettings.BoolToYesNo(s.CheckDNS))

if *s.UpdatePeriod == 0 {
Expand All @@ -130,7 +135,7 @@ func (s *Settings) ToLinesNode() (node *gotree.Node) {
return node
}

func (s *Settings) Read(reader *reader.Reader, warner Warner) (err error) {
func (s *Settings) Read(reader *reader.Reader, warner Warner) (err error) { //nolint:cyclop
warnings := checkOutdatedEnv(reader)
for _, warning := range warnings {
warner.Warn(warning)
Expand Down Expand Up @@ -173,6 +178,11 @@ func (s *Settings) Read(reader *reader.Reader, warner Warner) (err error) {
return fmt.Errorf("local DNS settings: %w", err)
}

err = s.Substituter.Read(reader)
if err != nil {
return fmt.Errorf("substituter settings: %w", err)
}

s.CheckDNS, err = reader.BoolPtr("CHECK_DNS")
if err != nil {
return err
Expand Down
9 changes: 9 additions & 0 deletions internal/setup/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
cachemiddleware "github.com/qdm12/dns/v2/pkg/middlewares/cache"
filtermiddleware "github.com/qdm12/dns/v2/pkg/middlewares/filter"
"github.com/qdm12/dns/v2/pkg/middlewares/localdns"
"github.com/qdm12/dns/v2/pkg/middlewares/substituter"
"github.com/qdm12/log"
)

Expand Down Expand Up @@ -85,6 +86,14 @@ func setupMiddlewares(userSettings config.Settings, cache Cache,
// to catch filtered responses found from the cache.
middlewares = append(middlewares, filterMiddleware)

substituterMiddleware, err := substituter.New(userSettings.Substituter)
if err != nil {
return nil, fmt.Errorf("creating substituter middleware: %w", err)
}
// Place after cache middleware, since we want to avoid caching for substitutions
// that may change suddenly.
middlewares = append(middlewares, substituterMiddleware)

metricsMiddleware, err := middlewareMetrics(userSettings.Metrics.Type,
commonPrometheus)
if err != nil {
Expand Down
85 changes: 85 additions & 0 deletions pkg/middlewares/substituter/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package substituter

import (
"fmt"

"github.com/miekg/dns"
)

type Middleware struct {
mapping map[questionKey][]dns.RR
}

func New(settings Settings) (middleware *Middleware, err error) {
settings.SetDefaults()
err = settings.Validate()
if err != nil {
return nil, fmt.Errorf("validating settings: %w", err)
}

mapping := make(map[questionKey][]dns.RR, len(settings.Substitutions))
for _, substitution := range settings.Substitutions {
substitution.setDefaults()
question := substitution.toQuestion()
key := makeKey(question)
mapping[key] = substitution.toRRs()
}

return &Middleware{
mapping: mapping,
}, nil
}

func (m *Middleware) String() string { return "substituter" }

// Wrap wraps the DNS handler with the middleware.
func (m *Middleware) Wrap(next dns.Handler) dns.Handler { //nolint:ireturn
if len(m.mapping) == 0 {
return next
}
return &handler{
mapping: m.mapping,
next: next,
}
}

func (m *Middleware) Stop() (err error) {
return nil
}

type handler struct {
mapping map[questionKey][]dns.RR
next dns.Handler
}

func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
for _, question := range r.Question {
key := makeKey(question)
rrs, found := h.mapping[key]
if !found {
continue
}
response := &dns.Msg{
Answer: rrs,
}
response.SetReply(r)
_ = w.WriteMsg(response)
return
}

h.next.ServeDNS(w, r)
}

func makeKey(question dns.Question) (key questionKey) {
return questionKey{
Name: question.Name,
Qtype: question.Qtype,
Qclass: question.Qclass,
}
}

type questionKey struct {
Name string
Qtype uint16
Qclass uint16
}
141 changes: 141 additions & 0 deletions pkg/middlewares/substituter/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package substituter

import (
net "net"
"net/netip"
"testing"

"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_New(t *testing.T) {
t.Parallel()

settings := Settings{
Substitutions: []Substitution{
{Name: "github.com", IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}},
},
}

middleware, err := New(settings)
require.NoError(t, err)

expectedMiddleware := &Middleware{
mapping: map[questionKey][]dns.RR{
{
Name: "github.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}: {&dns.A{
Hdr: dns.RR_Header{
Name: "github.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: net.IP{1, 2, 3, 4},
}},
},
}
assert.Equal(t, expectedMiddleware, middleware)

next := dns.HandlerFunc(func(rw dns.ResponseWriter, m *dns.Msg) {})
handler := middleware.Wrap(next)

request := &dns.Msg{Question: []dns.Question{
{Name: "github.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}}

ctrl := gomock.NewController(t)
writer := NewMockResponseWriter(ctrl)
substitutedResponse := &dns.Msg{
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "github.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: net.IP{1, 2, 3, 4},
},
},
}
substitutedResponse.SetReply(request)

writer.EXPECT().WriteMsg(substitutedResponse)

handler.ServeDNS(writer, request)

err = middleware.Stop()
require.NoError(t, err)
}

func Test_handler_ServeDNS(t *testing.T) {
t.Parallel()

request := &dns.Msg{
Question: []dns.Question{
{Name: "github.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
},
}
response := &dns.Msg{
Answer: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "github.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: net.IP{1, 2, 3, 4},
},
},
}
response.SetReply(request)

testCases := map[string]struct {
settings Settings
responseWriterBuilder func(ctrl *gomock.Controller) dns.ResponseWriter
}{
"no_substitution": {
responseWriterBuilder: func(ctrl *gomock.Controller) dns.ResponseWriter {
return NewMockResponseWriter(ctrl)
},
},
"substitution": {
settings: Settings{
Substitutions: []Substitution{
{Name: "github.com", IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}},
},
},
responseWriterBuilder: func(ctrl *gomock.Controller) dns.ResponseWriter {
writer := NewMockResponseWriter(ctrl)
writer.EXPECT().WriteMsg(response).Return(nil)
return writer
},
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)

middleware, err := New(testCase.settings)
require.NoError(t, err)

next := dns.HandlerFunc(func(rw dns.ResponseWriter, m *dns.Msg) {
assert.Equal(t, request, m)
})
handler := middleware.Wrap(next)

responseWriter := testCase.responseWriterBuilder(ctrl)

handler.ServeDNS(responseWriter, request)
})
}
}
Loading

0 comments on commit 156a695

Please sign in to comment.