diff --git a/checks/checks.go b/checks/checks.go index 04d9a52..5bb5389 100644 --- a/checks/checks.go +++ b/checks/checks.go @@ -9,6 +9,7 @@ type Checks struct { Carbon *Carbon Rank *Rank SocialTags *SocialTags + Tls *Tls } func NewChecks() *Checks { @@ -19,5 +20,6 @@ func NewChecks() *Checks { Carbon: NewCarbon(client), Rank: NewRank(client), SocialTags: NewSocialTags(client), + Tls: NewTls(client), } } diff --git a/checks/tls.go b/checks/tls.go new file mode 100644 index 0000000..ec2e7bb --- /dev/null +++ b/checks/tls.go @@ -0,0 +1,78 @@ +package checks + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/url" + "strconv" + "strings" +) + +type Tls struct { + client *http.Client +} + +func NewTls(client *http.Client) *Tls { + return &Tls{client: client} +} + +func (t *Tls) initiateScan(ctx context.Context, domain string) (int, error) { + const scanUrl = "https://tls-observatory.services.mozilla.com/api/v1/scan" + + formData := url.Values{"target": {domain}} + req, err := http.NewRequestWithContext(ctx, http.MethodPost, scanUrl, strings.NewReader(formData.Encode())) + if err != nil { + return -1, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := t.client.Do(req) + if err != nil { + return -1, err + } + defer resp.Body.Close() + + var res struct { + ScanID int `json:"scan_id"` + } + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return -1, err + } + + if res.ScanID == 0 { + return res.ScanID, errors.New("failed to get scan_id from TLS Observatory") + } + + return res.ScanID, nil +} + +func (t *Tls) GetScanResults(ctx context.Context, domain string) (map[string]interface{}, error) { + scanID, err := t.initiateScan(ctx, domain) + if err != nil { + return nil, err + } + + const scanUrl = "https://tls-observatory.services.mozilla.com/api/v1/results" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, scanUrl, nil) + if err != nil { + return nil, err + } + q := req.URL.Query() + q.Add("id", strconv.Itoa(scanID)) + req.URL.RawQuery = q.Encode() + + resp, err := t.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + return result, nil +} diff --git a/checks/tls_test.go b/checks/tls_test.go new file mode 100644 index 0000000..0a11f7d --- /dev/null +++ b/checks/tls_test.go @@ -0,0 +1,27 @@ +package checks + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xray-web/web-check-api/testutils" +) + +func TestTLS(t *testing.T) { + t.Parallel() + + t.Run("Valid URL with successful scan", func(t *testing.T) { + t.Parallel() + + client := testutils.MockClient( + testutils.Response(http.StatusOK, []byte(`{"scan_id": 12345}`)), + testutils.Response(http.StatusOK, []byte(`{"grade": "A+"}`)), + ) + + tls, err := NewTls(client).GetScanResults(context.TODO(), "example.com") + assert.NoError(t, err) + assert.Equal(t, "A+", tls["grade"]) + }) +} diff --git a/handlers/tls.go b/handlers/tls.go index d59a80b..ed8244c 100644 --- a/handlers/tls.go +++ b/handlers/tls.go @@ -1,55 +1,12 @@ package handlers import ( - "encoding/json" - "errors" - "fmt" "net/http" - "net/url" - "time" -) - -const MOZILLA_TLS_OBSERVATORY_API = "https://tls-observatory.services.mozilla.com/api/v1" - -type ScanResponse struct { - ScanID int `json:"scan_id"` -} - -func initiateScan(domain string) (*ScanResponse, error) { - resp, err := http.PostForm(fmt.Sprintf("%s/scan", MOZILLA_TLS_OBSERVATORY_API), url.Values{"target": {domain}}) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - var scanResponse ScanResponse - if err := json.NewDecoder(resp.Body).Decode(&scanResponse); err != nil { - return nil, err - } - - return &scanResponse, nil -} - -func getScanResults(scanID int) (map[string]interface{}, error) { - client := &http.Client{ - Timeout: 10 * time.Second, - } - - resp, err := client.Get(fmt.Sprintf("%s/results?id=%d", MOZILLA_TLS_OBSERVATORY_API, scanID)) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - var result map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, err - } - return result, nil -} + "github.com/xray-web/web-check-api/checks" +) -func HandleTLS() http.Handler { +func HandleTLS(t *checks.Tls) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rawURL, err := extractURL(r) if err != nil { @@ -57,18 +14,7 @@ func HandleTLS() http.Handler { return } - scanResponse, err := initiateScan(rawURL.Hostname()) - if err != nil { - JSONError(w, err, http.StatusInternalServerError) - return - } - - if scanResponse.ScanID == 0 { - JSONError(w, errors.New("failed to get scan_id from TLS Observatory"), http.StatusInternalServerError) - return - } - - result, err := getScanResults(scanResponse.ScanID) + result, err := t.GetScanResults(r.Context(), rawURL.Hostname()) if err != nil { JSONError(w, err, http.StatusInternalServerError) return diff --git a/handlers/tls_test.go b/handlers/tls_test.go index f9bcc56..7336403 100644 --- a/handlers/tls_test.go +++ b/handlers/tls_test.go @@ -7,78 +7,44 @@ import ( "testing" "github.com/stretchr/testify/assert" - "gopkg.in/h2non/gock.v1" + "github.com/xray-web/web-check-api/checks" + "github.com/xray-web/web-check-api/testutils" ) func TestHandleTLS(t *testing.T) { t.Parallel() - tests := []struct { - name string - urlParam string - mockScanResp string - mockScanStatus int - mockResultResp string - mockResultStatus int - expectedStatus int - expectedBody map[string]interface{} - }{ - { - name: "Missing URL parameter", - urlParam: "", - expectedStatus: http.StatusBadRequest, - expectedBody: map[string]interface{}{"error": "missing URL parameter"}, - }, - { - name: "Invalid URL", - urlParam: "http://invalid-url", - mockScanResp: `{"scan_id": 0}`, - mockScanStatus: http.StatusOK, - expectedStatus: http.StatusInternalServerError, - expectedBody: map[string]interface{}{"error": "failed to get scan_id from TLS Observatory"}, - }, - { - name: "Valid URL with successful scan", - urlParam: "http://example.com", - mockScanResp: `{"scan_id": 12345}`, - mockScanStatus: http.StatusOK, - mockResultResp: `{"grade": "A+"}`, - mockResultStatus: http.StatusOK, - expectedStatus: http.StatusOK, - expectedBody: map[string]interface{}{"grade": "A+"}, - }, - } - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - // t.Parallel() - defer gock.Off() + t.Run("Missing URL parameter", func(t *testing.T) { + t.Parallel() - if tc.urlParam != "" { - gock.New(MOZILLA_TLS_OBSERVATORY_API). - Post("/scan"). - Reply(tc.mockScanStatus). - BodyString(tc.mockScanResp) + req := httptest.NewRequest("GET", "/tls?url=", nil) + rec := httptest.NewRecorder() - if tc.mockScanStatus == http.StatusOK && tc.mockResultResp != "" { - gock.New(MOZILLA_TLS_OBSERVATORY_API). - Get("/results"). - MatchParam("id", "12345"). - Reply(tc.mockResultStatus). - BodyString(tc.mockResultResp) - } - } + HandleTLS(checks.NewTls(nil)).ServeHTTP(rec, req) - req := httptest.NewRequest("GET", "/tls?url="+tc.urlParam, nil) - rec := httptest.NewRecorder() - HandleTLS().ServeHTTP(rec, req) + assert.Equal(t, http.StatusBadRequest, rec.Code) + var responseBody map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &responseBody) + assert.NoError(t, err) + assert.Equal(t, map[string]interface{}{"error": "missing URL parameter"}, responseBody) + }) - assert.Equal(t, tc.expectedStatus, rec.Code) + t.Run("Invalid URL", func(t *testing.T) { + t.Parallel() + + client := testutils.MockClient( + testutils.Response(http.StatusOK, []byte(`{"scan_id": 0}`)), + ) + req := httptest.NewRequest("GET", "/tls?url=http://invalid-url", nil) + rec := httptest.NewRecorder() + + HandleTLS(checks.NewTls(client)).ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + var responseBody map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &responseBody) + assert.NoError(t, err) + assert.Equal(t, map[string]interface{}{"error": "failed to get scan_id from TLS Observatory"}, responseBody) + }) - var responseBody map[string]interface{} - err := json.Unmarshal(rec.Body.Bytes(), &responseBody) - assert.NoError(t, err) - assert.Equal(t, tc.expectedBody, responseBody) - }) - } } diff --git a/server/server.go b/server/server.go index 612b4c9..4d49ffd 100644 --- a/server/server.go +++ b/server/server.go @@ -47,7 +47,7 @@ func (s *Server) routes() { s.mux.Handle("GET /api/rank", handlers.HandleGetRank(s.checks.Rank)) s.mux.Handle("GET /api/redirects", handlers.HandleGetRedirects()) s.mux.Handle("GET /api/social-tags", handlers.HandleGetSocialTags(s.checks.SocialTags)) - s.mux.Handle("GET /api/tls", handlers.HandleTLS()) + s.mux.Handle("GET /api/tls", handlers.HandleTLS(s.checks.Tls)) s.mux.Handle("GET /api/trace-route", handlers.HandleTraceRoute()) }