Skip to content

Commit

Permalink
proper
Browse files Browse the repository at this point in the history
  • Loading branch information
matti committed Aug 7, 2024
1 parent a66be3b commit 0ddfa02
Showing 1 changed file with 51 additions and 9 deletions.
60 changes: 51 additions & 9 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io/ioutil"
"log"
"math/rand"
"net"
"os"
"os/signal"
Expand Down Expand Up @@ -54,20 +55,44 @@ func event(upstream string, name string) {
}

func harder(id string, question dns.Question, recursionDesired bool, currentUpstreams []string) *dns.Msg {
stop := false
responses := make(chan *dns.Msg, len(currentUpstreams))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

shuffledUpstreams := make([]string, len(currentUpstreams))
copy(shuffledUpstreams, currentUpstreams)

rand.Shuffle(len(shuffledUpstreams), func(i, j int) {
shuffledUpstreams[i], shuffledUpstreams[j] = shuffledUpstreams[j], shuffledUpstreams[i]
})

responses := make(chan *dns.Msg, len(shuffledUpstreams))

for i, upstream := range shuffledUpstreams {
go func(index int, upstream string, question dns.Question) {
if concurrencyDelay > 0 {
myConcurrencyDelay := time.Duration(concurrencyDelay) * time.Duration(index)
select {
case <-ctx.Done():
return
case <-time.After(myConcurrencyDelay):
}
}

for _, upstream := range currentUpstreams {
go func(upstream string, question dns.Question) {
try := 0
currentNet := netMode
for try < tries {
if stop {
select {
case <-ctx.Done():
return
default:
}

response, rtt, err := resolve(upstream, question, recursionDesired, currentNet)
if stop {

select {
case <-ctx.Done():
return
default:
}

if err == nil {
Expand Down Expand Up @@ -115,20 +140,32 @@ func harder(id string, question dns.Question, recursionDesired bool, currentUpst
logger(id, "ERROR", question, upstream, fmt.Sprintf("%v", err)+" "+rtt.String())
}

select {
case <-ctx.Done():
return
default:
}

try = try + 1
// retry truncated instantly
if response == nil {
time.Sleep(delay)
}

select {
case <-ctx.Done():
return
default:
}

if currentNet == "udp" {
currentNet = "tcp"
}
logger(id, "RETRY", question, upstream, currentNet, strconv.Itoa(try))
}

responses <- nil
}(upstream, question)
}(i, upstream, question)
}

received := 0
Expand All @@ -140,12 +177,11 @@ func harder(id string, question dns.Question, recursionDesired bool, currentUpst
break
}

if received == len(currentUpstreams) {
if received == len(shuffledUpstreams) {
break
}
}

stop = true
return final
}

Expand Down Expand Up @@ -254,8 +290,10 @@ func handleDnsRequest(w dns.ResponseWriter, request *dns.Msg) {
} else {
currentUpstreams = upstreams
}

logger(id, "QUERY", question, "recursion", strconv.FormatBool(request.RecursionDesired))
response := harder(id, question, request.RecursionDesired, currentUpstreams)

if response != nil {
final = response
} else {
Expand All @@ -282,6 +320,7 @@ var readTimeout time.Duration
var writeTimeout time.Duration

var delay time.Duration
var concurrencyDelay time.Duration
var tries int
var retry bool
var netMode string
Expand Down Expand Up @@ -325,6 +364,8 @@ func main() {
writeTimeoutMs := flag.Int("writeTimeout", 500, "writeTimeout")

delayMs := flag.Int("delay", 10, "delay in ms")
concurrencyDelayMs := flag.Int("concurrencyDelay", 0, "concurrency delay in ms, first upstream immediately and then add every delay")

flag.IntVar(&tries, "tries", 3, "tries")
flag.BoolVar(&retry, "retry", false, "retry")
flag.StringVar(&netMode, "netMode", "udp", "udp, tcp, tcp-tls")
Expand Down Expand Up @@ -355,6 +396,7 @@ func main() {
writeTimeout = time.Millisecond * time.Duration(*writeTimeoutMs)

delay = time.Millisecond * time.Duration(*delayMs)
concurrencyDelay = time.Millisecond * time.Duration(*concurrencyDelayMs)
statsDelay := time.Second * time.Duration(stats)

upstreams = flag.Args()
Expand Down

0 comments on commit 0ddfa02

Please sign in to comment.