From 579e6315f59769eaf7669609582519f12889d107 Mon Sep 17 00:00:00 2001
From: Steven Lee <steven.lee@kynrai.com>
Date: Fri, 14 Jun 2024 23:18:08 +0100
Subject: [PATCH 1/2] FEAT: server graceful shutdown (#50)

- Wait for active connections to finish before server shutdown
- Test server startup
---
 main.go               | 30 ++++++++++++++++++++++++++++--
 server/server.go      | 12 +++++++++++-
 server/server_test.go | 41 +++++++++++++++++++++++++++++++++++++++++
 3 files changed, 80 insertions(+), 3 deletions(-)
 create mode 100644 server/server_test.go

diff --git a/main.go b/main.go
index cb4e83d..85470cc 100644
--- a/main.go
+++ b/main.go
@@ -1,13 +1,39 @@
 package main
 
 import (
+	"context"
 	"log"
+	"net/http"
+	"os"
+	"os/signal"
+	"syscall"
+	"time"
 
 	"github.com/xray-web/web-check-api/config"
 	"github.com/xray-web/web-check-api/server"
 )
 
 func main() {
-	s := server.New(config.New())
-	log.Println(s.Run())
+	srv := server.New(config.New())
+
+	done := make(chan os.Signal, 1)
+	signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
+
+	go func() {
+		if err := srv.Run(); err != nil && err != http.ErrServerClosed {
+			log.Fatalf("listen: %v\n", err)
+		}
+	}()
+
+	<-done
+
+	ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
+	defer func() {
+		// extra handling here, databases etc
+		cancel()
+	}()
+
+	if err := srv.Shutdown(ctx); err != nil {
+		log.Fatalf("Server Shutdown Failed:%+v", err)
+	}
 }
diff --git a/server/server.go b/server/server.go
index 1b55c4b..47843ac 100644
--- a/server/server.go
+++ b/server/server.go
@@ -1,6 +1,7 @@
 package server
 
 import (
+	"context"
 	"fmt"
 	"log"
 	"net/http"
@@ -14,10 +15,12 @@ type Server struct {
 	conf   config.Config
 	mux    *http.ServeMux
 	checks *checks.Checks
+	srv    *http.Server
 }
 
 func New(conf config.Config) *Server {
 	return &Server{
+		srv:    &http.Server{},
 		conf:   conf,
 		mux:    http.NewServeMux(),
 		checks: checks.NewChecks(),
@@ -49,6 +52,8 @@ func (s *Server) routes() {
 	s.mux.Handle("GET /api/social-tags", handlers.HandleGetSocialTags(s.checks.SocialTags))
 	s.mux.Handle("GET /api/tls", handlers.HandleTLS(s.checks.Tls))
 	s.mux.Handle("GET /api/trace-route", handlers.HandleTraceRoute())
+
+	s.srv.Handler = s.CORS(s.mux)
 }
 
 func (s *Server) Run() error {
@@ -56,5 +61,10 @@ func (s *Server) Run() error {
 
 	addr := fmt.Sprintf("%s:%s", s.conf.Host, s.conf.Port)
 	log.Printf("Server started, listening on: %v\n", addr)
-	return http.ListenAndServe(addr, s.CORS(s.mux))
+	s.srv.Addr = addr
+	return s.srv.ListenAndServe()
+}
+
+func (s *Server) Shutdown(ctx context.Context) error {
+	return s.srv.Shutdown(ctx)
 }
diff --git a/server/server_test.go b/server/server_test.go
new file mode 100644
index 0000000..5fa34cb
--- /dev/null
+++ b/server/server_test.go
@@ -0,0 +1,41 @@
+package server
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/xray-web/web-check-api/config"
+	"golang.org/x/net/context"
+)
+
+func TestServer(t *testing.T) {
+	t.Parallel()
+
+	t.Run("start server", func(t *testing.T) {
+		t.Parallel()
+
+		srv := New(config.New())
+		srv.routes()
+		ts := httptest.NewServer(srv.CORS(srv.mux))
+		defer ts.Close()
+
+		// wait up tot 10 seconds for health check to return 200
+		ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(+10*time.Second))
+		defer cancel()
+		for {
+			req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+"/health", nil)
+			assert.NoError(t, err)
+			resp, err := http.DefaultClient.Do(req)
+			if err == nil && resp.StatusCode == http.StatusOK {
+				break
+			}
+		}
+		ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
+		defer cancel()
+		err := srv.Shutdown(ctx)
+		assert.NoError(t, err)
+	})
+}

From 4fc478712c54868fb9e982a326b42e8a776f87b7 Mon Sep 17 00:00:00 2001
From: Samantha Wu <87901016+syywu@users.noreply.github.com>
Date: Sat, 15 Jun 2024 00:19:50 +0100
Subject: [PATCH 2/2] RF: headers (#49)

---
 checks/checks.go         |  2 ++
 checks/headers.go        | 36 ++++++++++++++++++++++++++++++++++++
 checks/headers_test.go   | 28 ++++++++++++++++++++++++++++
 handlers/headers.go      | 18 ++++--------------
 handlers/headers_test.go | 31 +++++--------------------------
 server/server.go         |  2 +-
 6 files changed, 76 insertions(+), 41 deletions(-)
 create mode 100644 checks/headers.go
 create mode 100644 checks/headers_test.go

diff --git a/checks/checks.go b/checks/checks.go
index 6c92238..1e4c550 100644
--- a/checks/checks.go
+++ b/checks/checks.go
@@ -9,6 +9,7 @@ import (
 
 type Checks struct {
 	Carbon     *Carbon
+	Headers    *Headers
 	IpAddress  *Ip
 	LegacyRank *LegacyRank
 	Rank       *Rank
@@ -22,6 +23,7 @@ func NewChecks() *Checks {
 	}
 	return &Checks{
 		Carbon:     NewCarbon(client),
+		Headers:    NewHeaders(client),
 		IpAddress:  NewIp(NewNetIp()),
 		LegacyRank: NewLegacyRank(legacyrank.NewInMemoryStore()),
 		Rank:       NewRank(client),
diff --git a/checks/headers.go b/checks/headers.go
new file mode 100644
index 0000000..242ea06
--- /dev/null
+++ b/checks/headers.go
@@ -0,0 +1,36 @@
+package checks
+
+import (
+	"context"
+	"net/http"
+)
+
+type Headers struct {
+	client *http.Client
+}
+
+func NewHeaders(client *http.Client) *Headers {
+	return &Headers{client: client}
+}
+
+func (h *Headers) List(ctx context.Context, url string) (map[string]string, error) {
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	resp, err := h.client.Do(req)
+	if err != nil {
+		return nil, err
+	}
+	defer resp.Body.Close()
+
+	responseHeaders := make(map[string]string)
+	for k, v := range resp.Header {
+		for _, s := range v {
+			responseHeaders[k] = s
+		}
+	}
+
+	return responseHeaders, nil
+}
diff --git a/checks/headers_test.go b/checks/headers_test.go
new file mode 100644
index 0000000..57725aa
--- /dev/null
+++ b/checks/headers_test.go
@@ -0,0 +1,28 @@
+package checks
+
+import (
+	"context"
+	"net/http"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/xray-web/web-check-api/testutils"
+)
+
+func TestList(t *testing.T) {
+	t.Parallel()
+
+	c := testutils.MockClient(&http.Response{
+		Header: http.Header{
+			"Cache-Control":    {"private, max-age=0"},
+			"X-Xss-Protection": {"0"},
+		},
+	})
+	h := NewHeaders(c)
+
+	actual, err := h.List(context.Background(), "example.com")
+	assert.NoError(t, err)
+
+	assert.Equal(t, "private, max-age=0", actual["Cache-Control"])
+	assert.Equal(t, "0", actual["X-Xss-Protection"])
+}
diff --git a/handlers/headers.go b/handlers/headers.go
index 9d24adc..0c9a291 100644
--- a/handlers/headers.go
+++ b/handlers/headers.go
@@ -2,9 +2,11 @@ package handlers
 
 import (
 	"net/http"
+
+	"github.com/xray-web/web-check-api/checks"
 )
 
-func HandleGetHeaders() http.Handler {
+func HandleGetHeaders(h *checks.Headers) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		rawURL, err := extractURL(r)
 		if err != nil {
@@ -12,21 +14,9 @@ func HandleGetHeaders() http.Handler {
 			return
 		}
 
-		resp, err := http.Get(rawURL.String())
+		headers, err := h.List(r.Context(), rawURL.String())
 		if err != nil {
 			JSONError(w, err, http.StatusInternalServerError)
-			return
-		}
-		defer resp.Body.Close()
-
-		// Copying headers from the response
-		headers := make(map[string]interface{})
-		for key, values := range resp.Header {
-			if len(values) > 1 {
-				headers[key] = values
-			} else {
-				headers[key] = values[0]
-			}
 		}
 
 		JSON(w, headers, http.StatusOK)
diff --git a/handlers/headers_test.go b/handlers/headers_test.go
index dd139f8..264d4a5 100644
--- a/handlers/headers_test.go
+++ b/handlers/headers_test.go
@@ -1,12 +1,12 @@
 package handlers
 
 import (
-	"encoding/json"
 	"net/http"
 	"net/http/httptest"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/xray-web/web-check-api/checks"
 )
 
 func TestHandleGetHeaders(t *testing.T) {
@@ -15,21 +15,13 @@ func TestHandleGetHeaders(t *testing.T) {
 	t.Run("url parameter is missing", func(t *testing.T) {
 		req := httptest.NewRequest(http.MethodGet, "/headers", nil)
 		rec := httptest.NewRecorder()
-		HandleGetHeaders().ServeHTTP(rec, req)
+		HandleGetHeaders(nil).ServeHTTP(rec, req)
 
 		assert.Equal(t, http.StatusBadRequest, rec.Code)
 		assert.JSONEq(t, `{"error": "missing URL parameter"}`, rec.Body.String())
 	})
 
 	t.Run("invalid url format", func(t *testing.T) {
-		req := httptest.NewRequest(http.MethodGet, "/headers?url=invalid-url", nil)
-		rec := httptest.NewRecorder()
-		HandleGetHeaders().ServeHTTP(rec, req)
-
-		assert.Equal(t, http.StatusInternalServerError, rec.Code)
-	})
-
-	t.Run("valid url", func(t *testing.T) {
 		mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			w.Header().Set("Content-Type", "application/json")
 			w.Header().Set("X-Custom-Header", "value")
@@ -37,23 +29,10 @@ func TestHandleGetHeaders(t *testing.T) {
 		}))
 		defer mockServer.Close()
 
-		req := httptest.NewRequest(http.MethodGet, "/headers?url="+mockServer.URL, nil)
+		req := httptest.NewRequest(http.MethodGet, "/headers?url=invalid-url", nil)
 		rec := httptest.NewRecorder()
-		HandleGetHeaders().ServeHTTP(rec, req)
-
-		assert.Equal(t, http.StatusOK, rec.Code)
-
-		var responseBody map[string]interface{}
-		err := json.Unmarshal(rec.Body.Bytes(), &responseBody)
-		assert.NoError(t, err)
+		HandleGetHeaders(checks.NewHeaders(mockServer.Client())).ServeHTTP(rec, req)
 
-		expectedHeaders := map[string]interface{}{
-			"Content-Type":    "application/json",
-			"X-Custom-Header": "value",
-		}
-
-		for key, expectedValue := range expectedHeaders {
-			assert.Equal(t, expectedValue, responseBody[key])
-		}
+		assert.Equal(t, http.StatusInternalServerError, rec.Code)
 	})
 }
diff --git a/server/server.go b/server/server.go
index 47843ac..102bdf3 100644
--- a/server/server.go
+++ b/server/server.go
@@ -40,7 +40,7 @@ func (s *Server) routes() {
 	s.mux.Handle("GET /api/dnssec", handlers.HandleDnsSec())
 	s.mux.Handle("GET /api/firewall", handlers.HandleFirewall())
 	s.mux.Handle("GET /api/get-ip", handlers.HandleGetIP(s.checks.IpAddress))
-	s.mux.Handle("GET /api/headers", handlers.HandleGetHeaders())
+	s.mux.Handle("GET /api/headers", handlers.HandleGetHeaders(s.checks.Headers))
 	s.mux.Handle("GET /api/hsts", handlers.HandleHsts())
 	s.mux.Handle("GET /api/http-security", handlers.HandleHttpSecurity())
 	s.mux.Handle("GET /api/legacy-rank", handlers.HandleLegacyRank(s.checks.LegacyRank))