From 579e6315f59769eaf7669609582519f12889d107 Mon Sep 17 00:00:00 2001 From: Steven Lee <steven.lee@kynrai.com> Date: Fri, 14 Jun 2024 23:18:08 +0100 Subject: [PATCH 1/2] FEAT: server graceful shutdown (#50) - Wait for active connections to finish before server shutdown - Test server startup --- main.go | 30 ++++++++++++++++++++++++++++-- server/server.go | 12 +++++++++++- server/server_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 server/server_test.go diff --git a/main.go b/main.go index cb4e83d..85470cc 100644 --- a/main.go +++ b/main.go @@ -1,13 +1,39 @@ package main import ( + "context" "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" "github.com/xray-web/web-check-api/config" "github.com/xray-web/web-check-api/server" ) func main() { - s := server.New(config.New()) - log.Println(s.Run()) + srv := server.New(config.New()) + + done := make(chan os.Signal, 1) + signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + go func() { + if err := srv.Run(); err != nil && err != http.ErrServerClosed { + log.Fatalf("listen: %v\n", err) + } + }() + + <-done + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer func() { + // extra handling here, databases etc + cancel() + }() + + if err := srv.Shutdown(ctx); err != nil { + log.Fatalf("Server Shutdown Failed:%+v", err) + } } diff --git a/server/server.go b/server/server.go index 1b55c4b..47843ac 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "log" "net/http" @@ -14,10 +15,12 @@ type Server struct { conf config.Config mux *http.ServeMux checks *checks.Checks + srv *http.Server } func New(conf config.Config) *Server { return &Server{ + srv: &http.Server{}, conf: conf, mux: http.NewServeMux(), checks: checks.NewChecks(), @@ -49,6 +52,8 @@ func (s *Server) routes() { s.mux.Handle("GET /api/social-tags", handlers.HandleGetSocialTags(s.checks.SocialTags)) s.mux.Handle("GET /api/tls", handlers.HandleTLS(s.checks.Tls)) s.mux.Handle("GET /api/trace-route", handlers.HandleTraceRoute()) + + s.srv.Handler = s.CORS(s.mux) } func (s *Server) Run() error { @@ -56,5 +61,10 @@ func (s *Server) Run() error { addr := fmt.Sprintf("%s:%s", s.conf.Host, s.conf.Port) log.Printf("Server started, listening on: %v\n", addr) - return http.ListenAndServe(addr, s.CORS(s.mux)) + s.srv.Addr = addr + return s.srv.ListenAndServe() +} + +func (s *Server) Shutdown(ctx context.Context) error { + return s.srv.Shutdown(ctx) } diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..5fa34cb --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,41 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/xray-web/web-check-api/config" + "golang.org/x/net/context" +) + +func TestServer(t *testing.T) { + t.Parallel() + + t.Run("start server", func(t *testing.T) { + t.Parallel() + + srv := New(config.New()) + srv.routes() + ts := httptest.NewServer(srv.CORS(srv.mux)) + defer ts.Close() + + // wait up tot 10 seconds for health check to return 200 + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(+10*time.Second)) + defer cancel() + for { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+"/health", nil) + assert.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + if err == nil && resp.StatusCode == http.StatusOK { + break + } + } + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := srv.Shutdown(ctx) + assert.NoError(t, err) + }) +} From 4fc478712c54868fb9e982a326b42e8a776f87b7 Mon Sep 17 00:00:00 2001 From: Samantha Wu <87901016+syywu@users.noreply.github.com> Date: Sat, 15 Jun 2024 00:19:50 +0100 Subject: [PATCH 2/2] RF: headers (#49) --- checks/checks.go | 2 ++ checks/headers.go | 36 ++++++++++++++++++++++++++++++++++++ checks/headers_test.go | 28 ++++++++++++++++++++++++++++ handlers/headers.go | 18 ++++-------------- handlers/headers_test.go | 31 +++++-------------------------- server/server.go | 2 +- 6 files changed, 76 insertions(+), 41 deletions(-) create mode 100644 checks/headers.go create mode 100644 checks/headers_test.go diff --git a/checks/checks.go b/checks/checks.go index 6c92238..1e4c550 100644 --- a/checks/checks.go +++ b/checks/checks.go @@ -9,6 +9,7 @@ import ( type Checks struct { Carbon *Carbon + Headers *Headers IpAddress *Ip LegacyRank *LegacyRank Rank *Rank @@ -22,6 +23,7 @@ func NewChecks() *Checks { } return &Checks{ Carbon: NewCarbon(client), + Headers: NewHeaders(client), IpAddress: NewIp(NewNetIp()), LegacyRank: NewLegacyRank(legacyrank.NewInMemoryStore()), Rank: NewRank(client), diff --git a/checks/headers.go b/checks/headers.go new file mode 100644 index 0000000..242ea06 --- /dev/null +++ b/checks/headers.go @@ -0,0 +1,36 @@ +package checks + +import ( + "context" + "net/http" +) + +type Headers struct { + client *http.Client +} + +func NewHeaders(client *http.Client) *Headers { + return &Headers{client: client} +} + +func (h *Headers) List(ctx context.Context, url string) (map[string]string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + resp, err := h.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + responseHeaders := make(map[string]string) + for k, v := range resp.Header { + for _, s := range v { + responseHeaders[k] = s + } + } + + return responseHeaders, nil +} diff --git a/checks/headers_test.go b/checks/headers_test.go new file mode 100644 index 0000000..57725aa --- /dev/null +++ b/checks/headers_test.go @@ -0,0 +1,28 @@ +package checks + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xray-web/web-check-api/testutils" +) + +func TestList(t *testing.T) { + t.Parallel() + + c := testutils.MockClient(&http.Response{ + Header: http.Header{ + "Cache-Control": {"private, max-age=0"}, + "X-Xss-Protection": {"0"}, + }, + }) + h := NewHeaders(c) + + actual, err := h.List(context.Background(), "example.com") + assert.NoError(t, err) + + assert.Equal(t, "private, max-age=0", actual["Cache-Control"]) + assert.Equal(t, "0", actual["X-Xss-Protection"]) +} diff --git a/handlers/headers.go b/handlers/headers.go index 9d24adc..0c9a291 100644 --- a/handlers/headers.go +++ b/handlers/headers.go @@ -2,9 +2,11 @@ package handlers import ( "net/http" + + "github.com/xray-web/web-check-api/checks" ) -func HandleGetHeaders() http.Handler { +func HandleGetHeaders(h *checks.Headers) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rawURL, err := extractURL(r) if err != nil { @@ -12,21 +14,9 @@ func HandleGetHeaders() http.Handler { return } - resp, err := http.Get(rawURL.String()) + headers, err := h.List(r.Context(), rawURL.String()) if err != nil { JSONError(w, err, http.StatusInternalServerError) - return - } - defer resp.Body.Close() - - // Copying headers from the response - headers := make(map[string]interface{}) - for key, values := range resp.Header { - if len(values) > 1 { - headers[key] = values - } else { - headers[key] = values[0] - } } JSON(w, headers, http.StatusOK) diff --git a/handlers/headers_test.go b/handlers/headers_test.go index dd139f8..264d4a5 100644 --- a/handlers/headers_test.go +++ b/handlers/headers_test.go @@ -1,12 +1,12 @@ package handlers import ( - "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + "github.com/xray-web/web-check-api/checks" ) func TestHandleGetHeaders(t *testing.T) { @@ -15,21 +15,13 @@ func TestHandleGetHeaders(t *testing.T) { t.Run("url parameter is missing", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/headers", nil) rec := httptest.NewRecorder() - HandleGetHeaders().ServeHTTP(rec, req) + HandleGetHeaders(nil).ServeHTTP(rec, req) assert.Equal(t, http.StatusBadRequest, rec.Code) assert.JSONEq(t, `{"error": "missing URL parameter"}`, rec.Body.String()) }) t.Run("invalid url format", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/headers?url=invalid-url", nil) - rec := httptest.NewRecorder() - HandleGetHeaders().ServeHTTP(rec, req) - - assert.Equal(t, http.StatusInternalServerError, rec.Code) - }) - - t.Run("valid url", func(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Custom-Header", "value") @@ -37,23 +29,10 @@ func TestHandleGetHeaders(t *testing.T) { })) defer mockServer.Close() - req := httptest.NewRequest(http.MethodGet, "/headers?url="+mockServer.URL, nil) + req := httptest.NewRequest(http.MethodGet, "/headers?url=invalid-url", nil) rec := httptest.NewRecorder() - HandleGetHeaders().ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - - var responseBody map[string]interface{} - err := json.Unmarshal(rec.Body.Bytes(), &responseBody) - assert.NoError(t, err) + HandleGetHeaders(checks.NewHeaders(mockServer.Client())).ServeHTTP(rec, req) - expectedHeaders := map[string]interface{}{ - "Content-Type": "application/json", - "X-Custom-Header": "value", - } - - for key, expectedValue := range expectedHeaders { - assert.Equal(t, expectedValue, responseBody[key]) - } + assert.Equal(t, http.StatusInternalServerError, rec.Code) }) } diff --git a/server/server.go b/server/server.go index 47843ac..102bdf3 100644 --- a/server/server.go +++ b/server/server.go @@ -40,7 +40,7 @@ func (s *Server) routes() { s.mux.Handle("GET /api/dnssec", handlers.HandleDnsSec()) s.mux.Handle("GET /api/firewall", handlers.HandleFirewall()) s.mux.Handle("GET /api/get-ip", handlers.HandleGetIP(s.checks.IpAddress)) - s.mux.Handle("GET /api/headers", handlers.HandleGetHeaders()) + s.mux.Handle("GET /api/headers", handlers.HandleGetHeaders(s.checks.Headers)) s.mux.Handle("GET /api/hsts", handlers.HandleHsts()) s.mux.Handle("GET /api/http-security", handlers.HandleHttpSecurity()) s.mux.Handle("GET /api/legacy-rank", handlers.HandleLegacyRank(s.checks.LegacyRank))