Skip to content

Commit

Permalink
Refactoring server.go (#1277)
Browse files Browse the repository at this point in the history
* made Stop context aware

* added error check

* context aware OnRequest

* linter fix

* fixed some flakiness in tests

* made DoGetRequest context aware

* this doesn't belong there and produces flakyness
  • Loading branch information
kwitsch authored Nov 28, 2023
1 parent fda2dbe commit 976d619
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 24 deletions.
2 changes: 1 addition & 1 deletion cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func startServer(_ *cobra.Command, _ []string) error {
select {
case <-signals:
log.Log().Infof("Terminating...")
util.LogOnError("can't stop server: ", srv.Stop())
util.LogOnError("can't stop server: ", srv.Stop(ctx))
done <- true

case err := <-errChan:
Expand Down
20 changes: 10 additions & 10 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,14 @@ var _ = Describe("Config", func() {

Describe("Creation of Config", func() {
When("Test config file will be parsed", func() {
It("should return a valid config struct", func() {
confFile := writeConfigYml(tmpDir)
var confFile *helpertest.TmpFile

BeforeEach(func() {
confFile = writeConfigYml(tmpDir)
Expect(confFile.Error).Should(Succeed())
})

It("should return a valid config struct", func() {
c, err = LoadConfig(confFile.Path, true)
Expect(err).Should(Succeed())

Expand All @@ -165,30 +169,26 @@ var _ = Describe("Config", func() {
})
})
When("Multiple config files are used", func() {
It("should return a valid config struct", func() {
BeforeEach(func() {
err = writeConfigDir(tmpDir)
Expect(err).Should(Succeed())
})

c, err = LoadConfig(tmpDir.Path, true)
It("should return a valid config struct", func() {
c, err := LoadConfig(tmpDir.Path, true)
Expect(err).Should(Succeed())

defaultTestFileConfig(c)
})

It("should ignore non YAML files", func() {
err = writeConfigDir(tmpDir)
Expect(err).Should(Succeed())

tmpDir.CreateStringFile("ignore-me.txt", "THIS SHOULD BE IGNORED!")

_, err := LoadConfig(tmpDir.Path, true)
Expect(err).Should(Succeed())
})

It("should ignore non regular files", func() {
err = writeConfigDir(tmpDir)
Expect(err).Should(Succeed())

tmpDir.CreateSubFolder("subfolder")
tmpDir.CreateSubFolder("subfolder.yml")

Expand Down
5 changes: 3 additions & 2 deletions helpertest/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package helpertest

import (
"bytes"
"context"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -68,10 +69,10 @@ func TestServer(data string) *httptest.Server {
}

// DoGetRequest performs a GET request
func DoGetRequest(url string,
func DoGetRequest(ctx context.Context, url string,
fn func(w http.ResponseWriter, r *http.Request),
) (*httptest.ResponseRecorder, *bytes.Buffer) {
r, _ := http.NewRequest(http.MethodGet, url, nil)
r, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)

rr := httptest.NewRecorder()
handler := http.HandlerFunc(fn)
Expand Down
1 change: 0 additions & 1 deletion resolver/parallel_best_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
WithDelay(timeout / 2)
DeferCleanup(slowTestUpstream.Close)
upstreams = []config.Upstream{{Host: "wrong"}, slowTestUpstream.Start()}
Expect(err).Should(Succeed())
})
It("Should use result from successful resolver", func() {
request := newRequest("example.com.", A)
Expand Down
18 changes: 11 additions & 7 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err

server.printConfiguration()

server.registerDNSHandlers()
server.registerDNSHandlers(ctx)
err = server.registerAPIEndpoints(httpRouter)

if err != nil {
Expand Down Expand Up @@ -415,10 +415,14 @@ func createQueryResolver(
return r, nil
}

func (s *Server) registerDNSHandlers() {
func (s *Server) registerDNSHandlers(ctx context.Context) {
wrappedOnRequest := func(w dns.ResponseWriter, request *dns.Msg) {
s.OnRequest(ctx, w, request)
}

for _, server := range s.dnsServers {
handler := server.Handler.(*dns.ServeMux)
handler.HandleFunc(".", s.OnRequest)
handler.HandleFunc(".", wrappedOnRequest)
handler.HandleFunc("healthcheck.blocky", s.OnHealthCheck)
}
}
Expand Down Expand Up @@ -534,11 +538,11 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
}

// Stop stops the server
func (s *Server) Stop() error {
func (s *Server) Stop(ctx context.Context) error {
logger().Info("Stopping server")

for _, server := range s.dnsServers {
if err := server.Shutdown(); err != nil {
if err := server.ShutdownContext(ctx); err != nil {
return fmt.Errorf("stop %s listener failed: %w", server.Net, err)
}
}
Expand Down Expand Up @@ -591,12 +595,12 @@ func newRequest(clientIP net.IP, protocol model.RequestProtocol,
}

// OnRequest will be executed if a new DNS request is received
func (s *Server) OnRequest(w dns.ResponseWriter, request *dns.Msg) {
func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) {
logger().Debug("new request")

r := createResolverRequest(w, request)

response, err := s.queryResolver.Resolve(context.Background(), r)
response, err := s.queryResolver.Resolve(ctx, r)

if err != nil {
logger().Error("error on processing request:", err)
Expand Down
6 changes: 3 additions & 3 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ var _ = BeforeSuite(func() {

// start server
go sut.Start(ctx, errChan)
DeferCleanup(sut.Stop)
DeferCleanup(func() { Expect(sut.Stop(ctx)).Should(Succeed()) })

Consistently(errChan, "1s").ShouldNot(Receive())
})
Expand Down Expand Up @@ -681,13 +681,13 @@ var _ = Describe("Running DNS server", func() {

time.Sleep(100 * time.Millisecond)

err = server.Stop()
err = server.Stop(ctx)

// stop server, should be ok
Expect(err).Should(Succeed())

// stop again, should raise error
err = server.Stop()
err = server.Stop(ctx)

Expect(err).Should(HaveOccurred())
})
Expand Down

0 comments on commit 976d619

Please sign in to comment.