Skip to content

Commit

Permalink
RF: TLS checker (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
vleong99 authored Jun 11, 2024
1 parent b00c48a commit eaf9a31
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 123 deletions.
2 changes: 2 additions & 0 deletions checks/checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type Checks struct {
Carbon *Carbon
Rank *Rank
SocialTags *SocialTags
Tls *Tls
}

func NewChecks() *Checks {
Expand All @@ -19,5 +20,6 @@ func NewChecks() *Checks {
Carbon: NewCarbon(client),
Rank: NewRank(client),
SocialTags: NewSocialTags(client),
Tls: NewTls(client),
}
}
78 changes: 78 additions & 0 deletions checks/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package checks

import (
"context"
"encoding/json"
"errors"
"net/http"
"net/url"
"strconv"
"strings"
)

type Tls struct {
client *http.Client
}

func NewTls(client *http.Client) *Tls {
return &Tls{client: client}
}

func (t *Tls) initiateScan(ctx context.Context, domain string) (int, error) {
const scanUrl = "https://tls-observatory.services.mozilla.com/api/v1/scan"

formData := url.Values{"target": {domain}}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, scanUrl, strings.NewReader(formData.Encode()))
if err != nil {
return -1, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := t.client.Do(req)
if err != nil {
return -1, err
}
defer resp.Body.Close()

var res struct {
ScanID int `json:"scan_id"`
}
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
return -1, err
}

if res.ScanID == 0 {
return res.ScanID, errors.New("failed to get scan_id from TLS Observatory")
}

return res.ScanID, nil
}

func (t *Tls) GetScanResults(ctx context.Context, domain string) (map[string]interface{}, error) {
scanID, err := t.initiateScan(ctx, domain)
if err != nil {
return nil, err
}

const scanUrl = "https://tls-observatory.services.mozilla.com/api/v1/results"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, scanUrl, nil)
if err != nil {
return nil, err
}
q := req.URL.Query()
q.Add("id", strconv.Itoa(scanID))
req.URL.RawQuery = q.Encode()

resp, err := t.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}

return result, nil
}
27 changes: 27 additions & 0 deletions checks/tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package checks

import (
"context"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/xray-web/web-check-api/testutils"
)

func TestTLS(t *testing.T) {
t.Parallel()

t.Run("Valid URL with successful scan", func(t *testing.T) {
t.Parallel()

client := testutils.MockClient(
testutils.Response(http.StatusOK, []byte(`{"scan_id": 12345}`)),
testutils.Response(http.StatusOK, []byte(`{"grade": "A+"}`)),
)

tls, err := NewTls(client).GetScanResults(context.TODO(), "example.com")
assert.NoError(t, err)
assert.Equal(t, "A+", tls["grade"])
})
}
62 changes: 4 additions & 58 deletions handlers/tls.go
Original file line number Diff line number Diff line change
@@ -1,74 +1,20 @@
package handlers

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"time"
)

const MOZILLA_TLS_OBSERVATORY_API = "https://tls-observatory.services.mozilla.com/api/v1"

type ScanResponse struct {
ScanID int `json:"scan_id"`
}

func initiateScan(domain string) (*ScanResponse, error) {
resp, err := http.PostForm(fmt.Sprintf("%s/scan", MOZILLA_TLS_OBSERVATORY_API), url.Values{"target": {domain}})
if err != nil {
return nil, err
}
defer resp.Body.Close()

var scanResponse ScanResponse
if err := json.NewDecoder(resp.Body).Decode(&scanResponse); err != nil {
return nil, err
}

return &scanResponse, nil
}

func getScanResults(scanID int) (map[string]interface{}, error) {
client := &http.Client{
Timeout: 10 * time.Second,
}

resp, err := client.Get(fmt.Sprintf("%s/results?id=%d", MOZILLA_TLS_OBSERVATORY_API, scanID))
if err != nil {
return nil, err
}
defer resp.Body.Close()

var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}

return result, nil
}
"github.com/xray-web/web-check-api/checks"
)

func HandleTLS() http.Handler {
func HandleTLS(t *checks.Tls) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rawURL, err := extractURL(r)
if err != nil {
JSONError(w, ErrMissingURLParameter, http.StatusBadRequest)
return
}

scanResponse, err := initiateScan(rawURL.Hostname())
if err != nil {
JSONError(w, err, http.StatusInternalServerError)
return
}

if scanResponse.ScanID == 0 {
JSONError(w, errors.New("failed to get scan_id from TLS Observatory"), http.StatusInternalServerError)
return
}

result, err := getScanResults(scanResponse.ScanID)
result, err := t.GetScanResults(r.Context(), rawURL.Hostname())
if err != nil {
JSONError(w, err, http.StatusInternalServerError)
return
Expand Down
94 changes: 30 additions & 64 deletions handlers/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,78 +7,44 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"gopkg.in/h2non/gock.v1"
"github.com/xray-web/web-check-api/checks"
"github.com/xray-web/web-check-api/testutils"
)

func TestHandleTLS(t *testing.T) {
t.Parallel()
tests := []struct {
name string
urlParam string
mockScanResp string
mockScanStatus int
mockResultResp string
mockResultStatus int
expectedStatus int
expectedBody map[string]interface{}
}{
{
name: "Missing URL parameter",
urlParam: "",
expectedStatus: http.StatusBadRequest,
expectedBody: map[string]interface{}{"error": "missing URL parameter"},
},
{
name: "Invalid URL",
urlParam: "http://invalid-url",
mockScanResp: `{"scan_id": 0}`,
mockScanStatus: http.StatusOK,
expectedStatus: http.StatusInternalServerError,
expectedBody: map[string]interface{}{"error": "failed to get scan_id from TLS Observatory"},
},
{
name: "Valid URL with successful scan",
urlParam: "http://example.com",
mockScanResp: `{"scan_id": 12345}`,
mockScanStatus: http.StatusOK,
mockResultResp: `{"grade": "A+"}`,
mockResultStatus: http.StatusOK,
expectedStatus: http.StatusOK,
expectedBody: map[string]interface{}{"grade": "A+"},
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// t.Parallel()
defer gock.Off()
t.Run("Missing URL parameter", func(t *testing.T) {
t.Parallel()

if tc.urlParam != "" {
gock.New(MOZILLA_TLS_OBSERVATORY_API).
Post("/scan").
Reply(tc.mockScanStatus).
BodyString(tc.mockScanResp)
req := httptest.NewRequest("GET", "/tls?url=", nil)
rec := httptest.NewRecorder()

if tc.mockScanStatus == http.StatusOK && tc.mockResultResp != "" {
gock.New(MOZILLA_TLS_OBSERVATORY_API).
Get("/results").
MatchParam("id", "12345").
Reply(tc.mockResultStatus).
BodyString(tc.mockResultResp)
}
}
HandleTLS(checks.NewTls(nil)).ServeHTTP(rec, req)

req := httptest.NewRequest("GET", "/tls?url="+tc.urlParam, nil)
rec := httptest.NewRecorder()
HandleTLS().ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
var responseBody map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &responseBody)
assert.NoError(t, err)
assert.Equal(t, map[string]interface{}{"error": "missing URL parameter"}, responseBody)
})

assert.Equal(t, tc.expectedStatus, rec.Code)
t.Run("Invalid URL", func(t *testing.T) {
t.Parallel()

client := testutils.MockClient(
testutils.Response(http.StatusOK, []byte(`{"scan_id": 0}`)),
)
req := httptest.NewRequest("GET", "/tls?url=http://invalid-url", nil)
rec := httptest.NewRecorder()

HandleTLS(checks.NewTls(client)).ServeHTTP(rec, req)

assert.Equal(t, http.StatusInternalServerError, rec.Code)
var responseBody map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &responseBody)
assert.NoError(t, err)
assert.Equal(t, map[string]interface{}{"error": "failed to get scan_id from TLS Observatory"}, responseBody)
})

var responseBody map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &responseBody)
assert.NoError(t, err)
assert.Equal(t, tc.expectedBody, responseBody)
})
}
}
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (s *Server) routes() {
s.mux.Handle("GET /api/rank", handlers.HandleGetRank(s.checks.Rank))
s.mux.Handle("GET /api/redirects", handlers.HandleGetRedirects())
s.mux.Handle("GET /api/social-tags", handlers.HandleGetSocialTags(s.checks.SocialTags))
s.mux.Handle("GET /api/tls", handlers.HandleTLS())
s.mux.Handle("GET /api/tls", handlers.HandleTLS(s.checks.Tls))
s.mux.Handle("GET /api/trace-route", handlers.HandleTraceRoute())
}

Expand Down

0 comments on commit eaf9a31

Please sign in to comment.