Skip to content

Commit

Permalink
Merge pull request #42 from Tantalor93/options
Browse files Browse the repository at this point in the history
adjust Client construction API
  • Loading branch information
Tantalor93 authored Oct 27, 2024
2 parents 5df5b91 + 9b6a00b commit 6e1532e
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 39 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ go get github.com/tantalor93/doh-go

## Examples
```
// create client with default http.Client
c := doh.NewClient(nil)
// create client with default settings
c := doh.NewClient("https://1.1.1.1/dns-query")
// prepare payload
msg := dns.Msg{}
msg.SetQuestion("google.com.", dns.TypeA)
// send DNS query to Cloudflare Server over DoH using POST method
r, err := c.SendViaPost(context.Background(), "https://1.1.1.1/dns-query", &msg)
r, err := c.SendViaPost(context.Background(), &msg)
if err != nil {
panic(err)
}
Expand Down
37 changes: 21 additions & 16 deletions doh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,55 +12,60 @@ import (

// Client encapsulates and provides logic for querying DNS servers over DoH.
type Client struct {
c *http.Client
addr string
client *http.Client
}

// NewClient creates new Client instance with standard net/http client. If nil, default http.Client is used.
func NewClient(c *http.Client) *Client {
if c == nil {
c = &http.Client{}
// NewClient creates new Client instance with standard net/http client.
func NewClient(addr string, opts ...Option) *Client {
client := &Client{
addr: addr,
client: &http.Client{},
}
return &Client{c}
for _, opt := range opts {
opt.apply(client)
}
return client
}

// SendViaPost sends DNS message to the given DNS server over DoH using POST, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1
func (dc *Client) SendViaPost(ctx context.Context, server string, msg *dns.Msg) (*dns.Msg, error) {
// SendViaPost sends DNS message using HTTP POST method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1
func (c *Client) SendViaPost(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
pack, err := msg.Pack()
if err != nil {
return nil, err
}

request, err := http.NewRequest("POST", server, bytes.NewReader(pack))
request, err := http.NewRequest("POST", c.addr, bytes.NewReader(pack))
if err != nil {
return nil, err
}
request = request.WithContext(ctx)
request.Header.Set("Accept", "application/dns-message")
request.Header.Set("content-type", "application/dns-message")

return dc.send(request)
return c.send(request)
}

// SendViaGet sends DNS message to the given DNS server over DoH using GET, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1
func (dc *Client) SendViaGet(ctx context.Context, server string, msg *dns.Msg) (*dns.Msg, error) {
// SendViaGet sends DNS message using HTTP GET method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1
func (c *Client) SendViaGet(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
pack, err := msg.Pack()
if err != nil {
return nil, err
}

url := fmt.Sprint(server, "?dns=", base64.RawURLEncoding.EncodeToString(pack))
url := fmt.Sprint(c.addr, "?dns=", base64.RawURLEncoding.EncodeToString(pack))
request, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
request = request.WithContext(ctx)
request.Header.Set("Accept", "application/dns-message")

return dc.send(request)
return c.send(request)
}

func (dc *Client) send(r *http.Request) (*dns.Msg, error) {
resp, err := dc.c.Do(r)
func (c *Client) send(r *http.Request) (*dns.Msg, error) {
resp, err := c.client.Do(r)
if err != nil {
return nil, err
}
Expand Down
32 changes: 12 additions & 20 deletions doh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,37 +58,33 @@ func Test_SendViaPost(t *testing.T) {
}))
defer ts.Close()

type args struct {
server string
msg *dns.Msg
}
tests := []struct {
name string
args args
msg *dns.Msg
wantRcode int
wantErr error
}{
{
name: "NOERROR DNS resolution",
args: args{server: ts.URL, msg: question(existingDomain)},
msg: question(existingDomain),
wantRcode: dns.RcodeSuccess,
},
{
name: "NXDOMAIN DNS resolution",
args: args{server: ts.URL, msg: question(notExistingDomain)},
msg: question(notExistingDomain),
wantRcode: dns.RcodeNameError,
},
{
name: "bad upstream HTTP response",
args: args{server: ts.URL, msg: question(badStatusDomain)},
msg: question(badStatusDomain),
wantErr: &doh.UnexpectedServerHTTPStatusError{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := doh.NewClient(nil)
client := doh.NewClient(ts.URL)

got, err := client.SendViaPost(context.Background(), tt.args.server, tt.args.msg)
got, err := client.SendViaPost(context.Background(), tt.msg)

if tt.wantErr != nil {
require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error")
Expand Down Expand Up @@ -142,37 +138,33 @@ func Test_SendViaGet(t *testing.T) {
}))
defer ts.Close()

type args struct {
server string
msg *dns.Msg
}
tests := []struct {
name string
args args
msg *dns.Msg
wantRcode int
wantErr error
}{
{
name: "NOERROR DNS resolution",
args: args{server: ts.URL, msg: question(existingDomain)},
msg: question(existingDomain),
wantRcode: dns.RcodeSuccess,
},
{
name: "NXDOMAIN DNS resolution",
args: args{server: ts.URL, msg: question(notExistingDomain)},
msg: question(notExistingDomain),
wantRcode: dns.RcodeNameError,
},
{
name: "bad upstream HTTP response",
args: args{server: ts.URL, msg: question(badStatusDomain)},
msg: question(badStatusDomain),
wantErr: &doh.UnexpectedServerHTTPStatusError{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := doh.NewClient(nil)
client := doh.NewClient(ts.URL)

got, err := client.SendViaGet(context.Background(), tt.args.server, tt.args.msg)
got, err := client.SendViaGet(context.Background(), tt.msg)

if tt.wantErr != nil {
require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error")
Expand Down
23 changes: 23 additions & 0 deletions doh/opts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package doh

import "net/http"

// Option represents configuration options for doh.Client.
type Option interface {
apply(c *Client)
}

type httpClientOption struct {
client *http.Client
}

func (o *httpClientOption) apply(c *Client) {
c.client = o.client
}

// WithHTTPClient is a configuration option that overrides default http.Client instance used by the doh.Client.
func WithHTTPClient(c *http.Client) Option {
return &httpClientOption{
client: c,
}
}

0 comments on commit 6e1532e

Please sign in to comment.