Skip to content

Commit

Permalink
Add a user-agent across all feeds. (#438)
Browse files Browse the repository at this point in the history
* Add a user-agent across all feeds.

Signed-off-by: Caleb Brown <calebbrown@google.com>

* Fix linter errors

Signed-off-by: Caleb Brown <calebbrown@google.com>

* Add user-agent to maven central

Signed-off-by: Caleb Brown <calebbrown@google.com>

* Fix maven lint and conform it to the style of other feeds.

Signed-off-by: Caleb Brown <calebbrown@google.com>

* Fix last lint error

Signed-off-by: Caleb Brown <calebbrown@google.com>

---------

Signed-off-by: Caleb Brown <calebbrown@google.com>
  • Loading branch information
calebbrown authored Feb 25, 2024
1 parent 975d2ec commit a5f3088
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 17 deletions.
4 changes: 3 additions & 1 deletion pkg/feeds/crates/crates.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/ossf/package-feeds/pkg/events"
"github.com/ossf/package-feeds/pkg/feeds"
"github.com/ossf/package-feeds/pkg/useragent"
"github.com/ossf/package-feeds/pkg/utils"
)

Expand All @@ -18,7 +19,8 @@ const (
)

var httpClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &useragent.RoundTripper{UserAgent: feeds.DefaultUserAgent},
Timeout: 10 * time.Second,
}

type crates struct {
Expand Down
6 changes: 5 additions & 1 deletion pkg/feeds/feed.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import (
"time"
)

const schemaVer = "1.1"
const (
schemaVer = "1.1"

DefaultUserAgent = "package-feeds (github.com/ossf/package-feeds)"
)

var ErrNoPackagesPolled = errors.New("no packages were successfully polled")

Expand Down
6 changes: 5 additions & 1 deletion pkg/feeds/goproxy/goproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/ossf/package-feeds/pkg/feeds"
"github.com/ossf/package-feeds/pkg/useragent"
"github.com/ossf/package-feeds/pkg/utils"
)

Expand All @@ -17,7 +18,10 @@ const (
indexPath = "/index"
)

var httpClient = &http.Client{Timeout: 10 * time.Second}
var httpClient = &http.Client{
Transport: &useragent.RoundTripper{UserAgent: feeds.DefaultUserAgent},
Timeout: 10 * time.Second,
}

type PackageJSON struct {
Path string `json:"Path"`
Expand Down
22 changes: 19 additions & 3 deletions pkg/feeds/maven/maven.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"time"

"github.com/ossf/package-feeds/pkg/feeds"
"github.com/ossf/package-feeds/pkg/useragent"
)

const (
Expand All @@ -21,7 +23,14 @@ type Feed struct {
options feeds.FeedOptions
}

var ErrMaxRetriesReached = errors.New("maximum retries reached due to rate limiting")
var (
httpClient = &http.Client{
Transport: &useragent.RoundTripper{UserAgent: feeds.DefaultUserAgent},
Timeout: 10 * time.Second,
}

ErrMaxRetriesReached = errors.New("maximum retries reached due to rate limiting")
)

func New(feedOptions feeds.FeedOptions) (*Feed, error) {
if feedOptions.Packages != nil {
Expand All @@ -31,7 +40,7 @@ func New(feedOptions feeds.FeedOptions) (*Feed, error) {
}
}
return &Feed{
baseURL: "https://central.sonatype.com/" + indexPath,
baseURL: "https://central.sonatype.com",
options: feedOptions,
}, nil
}
Expand All @@ -55,6 +64,12 @@ type Response struct {

// fetchPackages fetches packages from Sonatype API for the given page.
func (feed Feed) fetchPackages(page int) ([]Package, error) {
indexURL, err := url.JoinPath(feed.baseURL, indexPath)
if err != nil {
return nil, err
}
indexURL += "?repository=maven-central"

maxRetries := 5
retryDelay := 5 * time.Second

Expand All @@ -71,9 +86,10 @@ func (feed Feed) fetchPackages(page int) ([]Package, error) {
if err != nil {
return nil, fmt.Errorf("error encoding JSON: %w", err)
}
body := bytes.NewReader(jsonPayload)

// Send POST request to Sonatype API.
resp, err := http.Post(feed.baseURL+"?repository=maven-central", "application/json", bytes.NewBuffer(jsonPayload))
resp, err := httpClient.Post(indexURL, "application/json", body)
if err != nil {
// Check if maximum retries have been reached
if attempt == maxRetries {
Expand Down
6 changes: 3 additions & 3 deletions pkg/feeds/maven/maven_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ func TestMavenLatest(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create Maven feed: %v", err)
}
feed.baseURL = srv.URL + "/api/internal/browse/components"
feed.baseURL = srv.URL

cutoff := time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC)
pkgs, gotCutoff, errs := feed.Latest(cutoff)

if len(errs) != 0 {
t.Fatalf("feed.Latest returned error: %v", err)
t.Fatalf("feed.Latest returned error: %v", errs)
}

// Returned cutoff should match the newest package creation time of packages retrieved.
Expand Down Expand Up @@ -61,7 +61,7 @@ func TestMavenNotFound(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create Maven feed: %v", err)
}
feed.baseURL = srv.URL + "/api/internal/browse/components"
feed.baseURL = srv.URL

cutoff := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)

Expand Down
8 changes: 6 additions & 2 deletions pkg/feeds/npm/npm.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

"github.com/ossf/package-feeds/pkg/events"
"github.com/ossf/package-feeds/pkg/feeds"
"github.com/ossf/package-feeds/pkg/useragent"
"github.com/ossf/package-feeds/pkg/utils"
)

Expand Down Expand Up @@ -383,8 +384,11 @@ func New(feedOptions feeds.FeedOptions, eventHandler *events.Handler) (*Feed, er
baseURL: "https://registry.npmjs.org/",
options: feedOptions,
client: &http.Client{
Transport: tr,
Timeout: 45 * time.Second,
Transport: &useragent.RoundTripper{
UserAgent: feeds.DefaultUserAgent,
Parent: tr,
},
Timeout: 45 * time.Second,
},
cache: cache,
}, nil
Expand Down
6 changes: 4 additions & 2 deletions pkg/feeds/nuget/nuget.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/ossf/package-feeds/pkg/feeds"
"github.com/ossf/package-feeds/pkg/useragent"
"github.com/ossf/package-feeds/pkg/utils"
)

Expand All @@ -19,8 +20,9 @@ const (
)

var (
httpClient = http.Client{
Timeout: 10 * time.Second,
httpClient = &http.Client{
Transport: &useragent.RoundTripper{UserAgent: feeds.DefaultUserAgent},
Timeout: 10 * time.Second,
}
errCatalogService = errors.New("error fetching catalog service")
)
Expand Down
4 changes: 3 additions & 1 deletion pkg/feeds/packagist/packagist.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ import (
"time"

"github.com/ossf/package-feeds/pkg/feeds"
"github.com/ossf/package-feeds/pkg/useragent"
"github.com/ossf/package-feeds/pkg/utils"
)

const FeedName = "packagist"

var httpClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &useragent.RoundTripper{UserAgent: feeds.DefaultUserAgent},
Timeout: 10 * time.Second,
}

type response struct {
Expand Down
4 changes: 3 additions & 1 deletion pkg/feeds/pypi/pypi.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/ossf/package-feeds/pkg/events"
"github.com/ossf/package-feeds/pkg/feeds"
"github.com/ossf/package-feeds/pkg/useragent"
"github.com/ossf/package-feeds/pkg/utils"
)

Expand All @@ -22,7 +23,8 @@ const (

var (
httpClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &useragent.RoundTripper{UserAgent: feeds.DefaultUserAgent},
Timeout: 10 * time.Second,
}
errInvalidLinkForPackage = errors.New("invalid link provided by pypi API")
)
Expand Down
3 changes: 2 additions & 1 deletion pkg/feeds/pypi/pypi_artifacts.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/kolo/xmlrpc"

"github.com/ossf/package-feeds/pkg/feeds"
"github.com/ossf/package-feeds/pkg/useragent"
)

const (
Expand All @@ -30,7 +31,7 @@ func NewArtifactFeed(feedOptions feeds.FeedOptions) (*ArtifactFeed, error) {
}

func (feed ArtifactFeed) Latest(cutoff time.Time) ([]*feeds.Package, time.Time, []error) {
client, err := xmlrpc.NewClient(feed.baseURL, nil)
client, err := xmlrpc.NewClient(feed.baseURL, &useragent.RoundTripper{UserAgent: feeds.DefaultUserAgent})
if err != nil {
return nil, cutoff, []error{err}
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/feeds/rubygems/rubygems.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/ossf/package-feeds/pkg/events"
"github.com/ossf/package-feeds/pkg/feeds"
"github.com/ossf/package-feeds/pkg/useragent"
"github.com/ossf/package-feeds/pkg/utils"
)

Expand All @@ -18,7 +19,8 @@ const (
)

var httpClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &useragent.RoundTripper{UserAgent: feeds.DefaultUserAgent},
Timeout: 10 * time.Second,
}

type Package struct {
Expand Down
21 changes: 21 additions & 0 deletions pkg/useragent/useragent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package useragent

import "net/http"

type RoundTripper struct {
UserAgent string
Parent http.RoundTripper
}

func (rt *RoundTripper) RoundTrip(ireq *http.Request) (*http.Response, error) {
req := ireq.Clone(ireq.Context())
req.Header.Set("User-Agent", rt.UserAgent)
return rt.parent().RoundTrip(req)
}

func (rt *RoundTripper) parent() http.RoundTripper {
if rt.Parent != nil {
return rt.Parent
}
return http.DefaultTransport
}
75 changes: 75 additions & 0 deletions pkg/useragent/useragent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package useragent_test

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/ossf/package-feeds/pkg/useragent"
)

func TestRoundTripper(t *testing.T) {
t.Parallel()
want := "test user agent string"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("user-agent")
if got != want {
t.Errorf("User Agent = %q, want %q", got, want)
}
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()

c := http.Client{
Transport: &useragent.RoundTripper{UserAgent: want},
}
resp, err := c.Get(ts.URL)
if err != nil {
t.Fatalf("Get() = %v; want no error", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Get() status = %v; want 200", resp.StatusCode)
}
}

type roundTripperFunc func(*http.Request) (*http.Response, error)

func (rt roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return rt(r)
}

func TestRoundTripper_Parent(t *testing.T) {
t.Parallel()
want := "test user agent string"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("user-agent")
if got != want {
t.Errorf("User Agent = %q, want %q", got, want)
}
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()

calledParent := false
c := http.Client{
Transport: &useragent.RoundTripper{
UserAgent: want,
Parent: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
calledParent = true
return http.DefaultTransport.RoundTrip(r)
}),
},
}
resp, err := c.Get(ts.URL)
if err != nil {
t.Fatalf("Get() = %v; want no error", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Get() status = %v; want 200", resp.StatusCode)
}
if !calledParent {
t.Errorf("Failed to call Parent RoundTripper")
}
}

0 comments on commit a5f3088

Please sign in to comment.