diff --git a/checks/checks.go b/checks/checks.go index 5bb5389..1f1cd93 100644 --- a/checks/checks.go +++ b/checks/checks.go @@ -3,10 +3,13 @@ package checks import ( "net/http" "time" + + "github.com/xray-web/web-check-api/checks/store/legacyrank" ) type Checks struct { Carbon *Carbon + LegacyRank *LegacyRank Rank *Rank SocialTags *SocialTags Tls *Tls @@ -18,6 +21,7 @@ func NewChecks() *Checks { } return &Checks{ Carbon: NewCarbon(client), + LegacyRank: NewLegacyRank(legacyrank.NewInMemoryStore()), Rank: NewRank(client), SocialTags: NewSocialTags(client), Tls: NewTls(client), diff --git a/checks/legacy_rank.go b/checks/legacy_rank.go new file mode 100644 index 0000000..fa0d215 --- /dev/null +++ b/checks/legacy_rank.go @@ -0,0 +1,27 @@ +package checks + +import "github.com/xray-web/web-check-api/checks/store/legacyrank" + +type DomainRank struct { + Domain string `json:"domain"` + Rank int `json:"rank"` +} + +type LegacyRank struct { + data legacyrank.Getter +} + +func NewLegacyRank(lrg legacyrank.Getter) *LegacyRank { + return &LegacyRank{data: lrg} +} + +func (lr *LegacyRank) LegacyRank(domain string) (*DomainRank, error) { + rank, err := lr.data.GetLegacyRank(domain) + if err != nil { + return nil, err + } + return &DomainRank{ + Domain: domain, + Rank: rank, + }, nil +} diff --git a/checks/legacy_rank_test.go b/checks/legacy_rank_test.go new file mode 100644 index 0000000..1a732df --- /dev/null +++ b/checks/legacy_rank_test.go @@ -0,0 +1,23 @@ +package checks + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xray-web/web-check-api/checks/store/legacyrank" +) + +func TestLegacyRank(t *testing.T) { + t.Parallel() + + t.Run("get rank", func(t *testing.T) { + t.Parallel() + lr := NewLegacyRank(legacyrank.GetterFunc(func(domain string) (int, error) { + return 1, nil + })) + dr, err := lr.LegacyRank("example.com") + assert.NoError(t, err) + assert.Equal(t, 1, dr.Rank) + assert.Equal(t, "example.com", dr.Domain) + }) +} diff --git a/checks/store/legacyrank/legacy_rank.go b/checks/store/legacyrank/legacy_rank.go new file mode 100644 index 0000000..ca8938b --- /dev/null +++ b/checks/store/legacyrank/legacy_rank.go @@ -0,0 +1,99 @@ +package legacyrank + +import ( + "archive/zip" + "bytes" + "context" + "encoding/csv" + "errors" + "io" + "log" + "net/http" + "strconv" + "sync" + "time" +) + +var ErrNotFound = errors.New("domain not found") + +type Getter interface { + GetLegacyRank(domain string) (int, error) +} + +type GetterFunc func(domain string) (int, error) + +func (f GetterFunc) GetLegacyRank(domain string) (int, error) { + return f(domain) +} + +type InMemoryStore struct{} + +var once sync.Once +var data map[string]int //map of domain to rank + +func NewInMemoryStore() *InMemoryStore { + return &InMemoryStore{} +} + +func (s *InMemoryStore) GetLegacyRank(url string) (int, error) { + once.Do(func() { + var err error + data, err = load() + if err != nil { + log.Println(err) + } + }) + + rank, ok := data[url] + if !ok { + return -1, ErrNotFound + } + return rank, nil +} + +func load() (map[string]int, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://s3-us-west-1.amazonaws.com/umbrella-static/top-1m.csv.zip", nil) + if err != nil { + return nil, err + } + client := &http.Client{ + Timeout: time.Second * 10, + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + zf, err := zip.NewReader(bytes.NewReader(b), int64(len(b))) + if err != nil { + return nil, err + } + f, err := zf.Open("top-1m.csv") + if err != nil { + return nil, err + } + defer f.Close() + r := csv.NewReader(f) + data := make(map[string]int) + for { + record, err := r.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + rank, err := strconv.Atoi(record[0]) + if err != nil { + return nil, err + } + data[record[1]] = rank + } + return data, nil +} diff --git a/checks/store/legacyrank/legacy_rank_test.go b/checks/store/legacyrank/legacy_rank_test.go new file mode 100644 index 0000000..eb35e26 --- /dev/null +++ b/checks/store/legacyrank/legacy_rank_test.go @@ -0,0 +1,26 @@ +package legacyrank_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xray-web/web-check-api/checks/store/legacyrank" +) + +func TestInMemoryStore(t *testing.T) { + t.Parallel() + + t.Run("get google rank", func(t *testing.T) { + t.Parallel() + ims := legacyrank.NewInMemoryStore() + dr, err := ims.GetLegacyRank("google.com") + assert.NoError(t, err, dr) + }) + + t.Run("get microsoft rank", func(t *testing.T) { + t.Parallel() + ims := legacyrank.NewInMemoryStore() + dr, err := ims.GetLegacyRank("microsoft.com") + assert.NoError(t, err, dr) + }) +} diff --git a/handlers/legacy_rank.go b/handlers/legacy_rank.go index 64644c4..9f8d6a6 100644 --- a/handlers/legacy_rank.go +++ b/handlers/legacy_rank.go @@ -1,152 +1,12 @@ package handlers import ( - "archive/zip" - "encoding/csv" - "fmt" - "io" "net/http" - "net/url" - "os" - "path/filepath" - "strings" -) -const ( - fileURL = "https://s3-us-west-1.amazonaws.com/umbrella-static/top-1m.csv.zip" - tempFilePath = "/tmp/top-1m.csv" + "github.com/xray-web/web-check-api/checks" ) -type RankResponse struct { - Domain string `json:"domain"` - Rank string `json:"rank"` - IsFound bool `json:"isFound"` -} - -func checkLegacyRank(urlStr string) (RankResponse, error) { - var domain string - var err error - - // Parse the URL to extract the domain - u, err := url.Parse(urlStr) - if err != nil { - return RankResponse{}, fmt.Errorf("invalid URL") - } - - // Extract the domain from the parsed URL - if u.Host != "" { - domain = u.Host - } else { - // If Host is empty, try to extract the domain from the Path - parts := strings.Split(u.Path, "/") - if len(parts) > 0 { - domain = parts[0] - } else { - return RankResponse{}, fmt.Errorf("unable to extract domain from URL") - } - } - - // Download and unzip the file if not in cache - if _, err := os.Stat(tempFilePath); os.IsNotExist(err) { - if err := downloadAndUnzip(fileURL); err != nil { - return RankResponse{}, err - } - } - - // Parse the CSV and find the rank - file, err := os.Open(tempFilePath) - if err != nil { - return RankResponse{}, fmt.Errorf("error opening CSV file: %s", err) - } - defer file.Close() - - reader := csv.NewReader(file) - for { - record, err := reader.Read() - if err == io.EOF { - break - } - if err != nil { - return RankResponse{}, fmt.Errorf("error reading CSV record: %s", err) - } - - if record[1] == domain { - return RankResponse{ - Domain: domain, - Rank: record[0], - IsFound: true, - }, nil - } - } - - return RankResponse{ - Domain: domain, - IsFound: false, - }, nil -} - -func downloadAndUnzip(url string) error { - resp, err := http.Get(url) - if err != nil { - return fmt.Errorf("error downloading file: %s", err) - } - defer resp.Body.Close() - - zipFile, err := os.Create(tempFilePath + ".zip") - if err != nil { - return fmt.Errorf("error creating zip file: %s", err) - } - defer zipFile.Close() - - _, err = io.Copy(zipFile, resp.Body) - if err != nil { - return fmt.Errorf("error writing zip file: %s", err) - } - - err = unzip(tempFilePath+".zip", "/tmp") - if err != nil { - return fmt.Errorf("error unzipping file: %s", err) - } - - return nil -} - -func unzip(src, dest string) error { - r, err := zip.OpenReader(src) - if err != nil { - return err - } - defer r.Close() - - for _, f := range r.File { - rc, err := f.Open() - if err != nil { - return err - } - defer rc.Close() - - path := filepath.Join(dest, f.Name) - if f.FileInfo().IsDir() { - os.MkdirAll(path, f.Mode()) - } else { - os.MkdirAll(filepath.Dir(path), os.ModePerm) - f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) - if err != nil { - return err - } - defer f.Close() - - _, err = io.Copy(f, rc) - if err != nil { - return err - } - } - } - - return nil -} - -func HandleLegacyRank() http.Handler { +func HandleLegacyRank(l *checks.LegacyRank) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rawURL, err := extractURL(r) if err != nil { @@ -154,7 +14,7 @@ func HandleLegacyRank() http.Handler { return } - result, err := checkLegacyRank(rawURL.String()) + result, err := l.LegacyRank(rawURL.Hostname()) if err != nil { JSONError(w, err, http.StatusInternalServerError) return diff --git a/handlers/legacy_rank_test.go b/handlers/legacy_rank_test.go index 2d10977..62d9d40 100644 --- a/handlers/legacy_rank_test.go +++ b/handlers/legacy_rank_test.go @@ -1,7 +1,6 @@ package handlers import ( - "encoding/json" "net/http" "net/http/httptest" "testing" @@ -11,19 +10,15 @@ import ( func TestHandleLegacyRank(t *testing.T) { t.Parallel() - req := httptest.NewRequest("GET", "/legacy-rank?url=www.google.com", nil) - rec := httptest.NewRecorder() - HandleLegacyRank().ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) + t.Run("missing URL parameter", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodGet, "/legacy-rank", nil) + rec := httptest.NewRecorder() - var response RankResponse - err := json.Unmarshal(rec.Body.Bytes(), &response) - assert.NoError(t, err) + HandleBlockLists().ServeHTTP(rec, req) - assert.NotNil(t, response) - - assert.Equal(t, "www.google.com", response.Domain) - - assert.True(t, response.IsFound || !response.IsFound) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.JSONEq(t, `{"error": "missing URL parameter"}`, rec.Body.String()) + }) } diff --git a/server/server.go b/server/server.go index 4d49ffd..96ebf33 100644 --- a/server/server.go +++ b/server/server.go @@ -40,7 +40,7 @@ func (s *Server) routes() { s.mux.Handle("GET /api/headers", handlers.HandleGetHeaders()) s.mux.Handle("GET /api/hsts", handlers.HandleHsts()) s.mux.Handle("GET /api/http-security", handlers.HandleHttpSecurity()) - s.mux.Handle("GET /api/legacy-rank", handlers.HandleLegacyRank()) + s.mux.Handle("GET /api/legacy-rank", handlers.HandleLegacyRank(s.checks.LegacyRank)) s.mux.Handle("GET /api/linked-pages", handlers.HandleGetLinks()) s.mux.Handle("GET /api/ports", handlers.HandleGetPorts()) s.mux.Handle("GET /api/quality", handlers.HandleGetQuality())