Skip to content

Commit

Permalink
Merge branch 'main' into rf/linked-pages
Browse files Browse the repository at this point in the history
  • Loading branch information
kynrai authored Jun 15, 2024
2 parents 914c030 + 4fc4787 commit 1ccf8be
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 44 deletions.
2 changes: 2 additions & 0 deletions checks/checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

type Checks struct {
Carbon *Carbon
Headers *Headers
IpAddress *Ip
LegacyRank *LegacyRank
LinkedPages *LinkedPages
Expand All @@ -23,6 +24,7 @@ func NewChecks() *Checks {
}
return &Checks{
Carbon: NewCarbon(client),

Check warning on line 26 in checks/checks.go

View check run for this annotation

Codecov / codecov/patch

checks/checks.go#L26

Added line #L26 was not covered by tests
Headers: NewHeaders(client),
IpAddress: NewIp(NewNetIp()),
LegacyRank: NewLegacyRank(legacyrank.NewInMemoryStore()),
LinkedPages: NewLinkedPages(client),
Expand Down
36 changes: 36 additions & 0 deletions checks/headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package checks

import (
"context"
"net/http"
)

type Headers struct {
client *http.Client
}

func NewHeaders(client *http.Client) *Headers {
return &Headers{client: client}
}

func (h *Headers) List(ctx context.Context, url string) (map[string]string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}

resp, err := h.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

responseHeaders := make(map[string]string)
for k, v := range resp.Header {
for _, s := range v {
responseHeaders[k] = s
}
}

return responseHeaders, nil
}
28 changes: 28 additions & 0 deletions checks/headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package checks

import (
"context"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/xray-web/web-check-api/testutils"
)

func TestList(t *testing.T) {
t.Parallel()

c := testutils.MockClient(&http.Response{
Header: http.Header{
"Cache-Control": {"private, max-age=0"},
"X-Xss-Protection": {"0"},
},
})
h := NewHeaders(c)

actual, err := h.List(context.Background(), "example.com")
assert.NoError(t, err)

assert.Equal(t, "private, max-age=0", actual["Cache-Control"])
assert.Equal(t, "0", actual["X-Xss-Protection"])
}
18 changes: 4 additions & 14 deletions handlers/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,21 @@ package handlers

import (
"net/http"

"github.com/xray-web/web-check-api/checks"
)

func HandleGetHeaders() http.Handler {
func HandleGetHeaders(h *checks.Headers) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rawURL, err := extractURL(r)
if err != nil {
JSONError(w, ErrMissingURLParameter, http.StatusBadRequest)
return
}

resp, err := http.Get(rawURL.String())
headers, err := h.List(r.Context(), rawURL.String())
if err != nil {
JSONError(w, err, http.StatusInternalServerError)
return
}
defer resp.Body.Close()

// Copying headers from the response
headers := make(map[string]interface{})
for key, values := range resp.Header {
if len(values) > 1 {
headers[key] = values
} else {
headers[key] = values[0]
}
}

JSON(w, headers, http.StatusOK)
Expand Down
31 changes: 5 additions & 26 deletions handlers/headers_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package handlers

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/xray-web/web-check-api/checks"
)

func TestHandleGetHeaders(t *testing.T) {
Expand All @@ -15,45 +15,24 @@ func TestHandleGetHeaders(t *testing.T) {
t.Run("url parameter is missing", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/headers", nil)
rec := httptest.NewRecorder()
HandleGetHeaders().ServeHTTP(rec, req)
HandleGetHeaders(nil).ServeHTTP(rec, req)

assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.JSONEq(t, `{"error": "missing URL parameter"}`, rec.Body.String())
})

t.Run("invalid url format", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/headers?url=invalid-url", nil)
rec := httptest.NewRecorder()
HandleGetHeaders().ServeHTTP(rec, req)

assert.Equal(t, http.StatusInternalServerError, rec.Code)
})

t.Run("valid url", func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Custom-Header", "value")
w.WriteHeader(http.StatusOK)
}))
defer mockServer.Close()

req := httptest.NewRequest(http.MethodGet, "/headers?url="+mockServer.URL, nil)
req := httptest.NewRequest(http.MethodGet, "/headers?url=invalid-url", nil)
rec := httptest.NewRecorder()
HandleGetHeaders().ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)

var responseBody map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &responseBody)
assert.NoError(t, err)
HandleGetHeaders(checks.NewHeaders(mockServer.Client())).ServeHTTP(rec, req)

expectedHeaders := map[string]interface{}{
"Content-Type": "application/json",
"X-Custom-Header": "value",
}

for key, expectedValue := range expectedHeaders {
assert.Equal(t, expectedValue, responseBody[key])
}
assert.Equal(t, http.StatusInternalServerError, rec.Code)
})
}
30 changes: 28 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,39 @@
package main

import (
"context"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"

"github.com/xray-web/web-check-api/config"
"github.com/xray-web/web-check-api/server"
)

func main() {
s := server.New(config.New())
log.Println(s.Run())
srv := server.New(config.New())

done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)

go func() {
if err := srv.Run(); err != nil && err != http.ErrServerClosed {
log.Fatalf("listen: %v\n", err)
}
}()

<-done

ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer func() {
// extra handling here, databases etc
cancel()
}()

if err := srv.Shutdown(ctx); err != nil {
log.Fatalf("Server Shutdown Failed:%+v", err)
}
}
14 changes: 12 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"fmt"
"log"
"net/http"
Expand All @@ -14,10 +15,12 @@ type Server struct {
conf config.Config
mux *http.ServeMux
checks *checks.Checks
srv *http.Server
}

func New(conf config.Config) *Server {
return &Server{
srv: &http.Server{},
conf: conf,
mux: http.NewServeMux(),
checks: checks.NewChecks(),
Expand All @@ -37,7 +40,7 @@ func (s *Server) routes() {
s.mux.Handle("GET /api/dnssec", handlers.HandleDnsSec())
s.mux.Handle("GET /api/firewall", handlers.HandleFirewall())
s.mux.Handle("GET /api/get-ip", handlers.HandleGetIP(s.checks.IpAddress))
s.mux.Handle("GET /api/headers", handlers.HandleGetHeaders())
s.mux.Handle("GET /api/headers", handlers.HandleGetHeaders(s.checks.Headers))
s.mux.Handle("GET /api/hsts", handlers.HandleHsts())
s.mux.Handle("GET /api/http-security", handlers.HandleHttpSecurity())
s.mux.Handle("GET /api/legacy-rank", handlers.HandleLegacyRank(s.checks.LegacyRank))
Expand All @@ -49,12 +52,19 @@ func (s *Server) routes() {
s.mux.Handle("GET /api/social-tags", handlers.HandleGetSocialTags(s.checks.SocialTags))
s.mux.Handle("GET /api/tls", handlers.HandleTLS(s.checks.Tls))
s.mux.Handle("GET /api/trace-route", handlers.HandleTraceRoute())

s.srv.Handler = s.CORS(s.mux)
}

func (s *Server) Run() error {
s.routes()

addr := fmt.Sprintf("%s:%s", s.conf.Host, s.conf.Port)
log.Printf("Server started, listening on: %v\n", addr)
return http.ListenAndServe(addr, s.CORS(s.mux))
s.srv.Addr = addr
return s.srv.ListenAndServe()
}

func (s *Server) Shutdown(ctx context.Context) error {
return s.srv.Shutdown(ctx)
}
41 changes: 41 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package server

import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/xray-web/web-check-api/config"
"golang.org/x/net/context"
)

func TestServer(t *testing.T) {
t.Parallel()

t.Run("start server", func(t *testing.T) {
t.Parallel()

srv := New(config.New())
srv.routes()
ts := httptest.NewServer(srv.CORS(srv.mux))
defer ts.Close()

// wait up tot 10 seconds for health check to return 200
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(+10*time.Second))
defer cancel()
for {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+"/health", nil)
assert.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
if err == nil && resp.StatusCode == http.StatusOK {
break
}
}
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := srv.Shutdown(ctx)
assert.NoError(t, err)
})
}

0 comments on commit 1ccf8be

Please sign in to comment.