diff --git a/examples/go.mod b/examples/go.mod index c7c389626..16b6faa3e 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -43,6 +43,8 @@ require ( github.com/flynn/noise v1.1.0 // indirect github.com/francoispqt/gojay v1.2.13 // indirect github.com/gabriel-vasile/mimetype v1.4.6 // indirect + github.com/gammazero/chanqueue v1.0.0 // indirect + github.com/gammazero/deque v1.0.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect diff --git a/examples/go.sum b/examples/go.sum index 8d7be4a96..401e43184 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -77,6 +77,10 @@ github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7z github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/gabriel-vasile/mimetype v1.4.6 h1:3+PzJTKLkvgjeTbts6msPJt4DixhT4YtFNf1gtGe3zc= github.com/gabriel-vasile/mimetype v1.4.6/go.mod h1:JX1qVKqZd40hUPpAfiNTe0Sne7hdfKSbOqqmkq8GCXc= +github.com/gammazero/chanqueue v1.0.0 h1:FER/sMailGFA3DDvFooEkipAMU+3c9Bg3bheloPSz6o= +github.com/gammazero/chanqueue v1.0.0/go.mod h1:fMwpwEiuUgpab0sH4VHiVcEoji1pSi+EIzeG4TPeKPc= +github.com/gammazero/deque v1.0.0 h1:LTmimT8H7bXkkCy6gZX7zNLtkbz4NdS2z8LZuor3j34= +github.com/gammazero/deque v1.0.0/go.mod h1:iflpYvtGfM3U8S8j+sZEKIak3SAKYpA5/SQewgfXDKo= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= diff --git a/go.mod b/go.mod index db188a203..292fbb95a 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,8 @@ require ( github.com/cskr/pubsub v1.0.2 github.com/dustin/go-humanize v1.0.1 github.com/gabriel-vasile/mimetype v1.4.6 + github.com/gammazero/chanqueue v1.0.0 + github.com/gammazero/deque v1.0.0 github.com/gogo/protobuf v1.3.2 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 diff --git a/go.sum b/go.sum index f3c5da72b..2e41a5014 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,10 @@ github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7z github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/gabriel-vasile/mimetype v1.4.6 h1:3+PzJTKLkvgjeTbts6msPJt4DixhT4YtFNf1gtGe3zc= github.com/gabriel-vasile/mimetype v1.4.6/go.mod h1:JX1qVKqZd40hUPpAfiNTe0Sne7hdfKSbOqqmkq8GCXc= +github.com/gammazero/chanqueue v1.0.0 h1:FER/sMailGFA3DDvFooEkipAMU+3c9Bg3bheloPSz6o= +github.com/gammazero/chanqueue v1.0.0/go.mod h1:fMwpwEiuUgpab0sH4VHiVcEoji1pSi+EIzeG4TPeKPc= +github.com/gammazero/deque v1.0.0 h1:LTmimT8H7bXkkCy6gZX7zNLtkbz4NdS2z8LZuor3j34= +github.com/gammazero/deque v1.0.0/go.mod h1:iflpYvtGfM3U8S8j+sZEKIak3SAKYpA5/SQewgfXDKo= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= diff --git a/routing/providerquerymanager/providerquerymanager.go b/routing/providerquerymanager/providerquerymanager.go index d9005020e..55880ecd9 100644 --- a/routing/providerquerymanager/providerquerymanager.go +++ b/routing/providerquerymanager/providerquerymanager.go @@ -5,6 +5,8 @@ import ( "sync" "time" + "github.com/gammazero/chanqueue" + "github.com/gammazero/deque" "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" peer "github.com/libp2p/go-libp2p/core/peer" @@ -17,7 +19,7 @@ import ( var log = logging.Logger("routing/provqrymgr") const ( - defaultMaxInProcessRequests = 6 + defaultMaxInProcessRequests = 16 defaultMaxProviders = 0 defaultTimeout = 10 * time.Second ) @@ -82,15 +84,13 @@ type cancelRequestMessage struct { // - ensure two findprovider calls for the same block don't run concurrently // - manage timeouts type ProviderQueryManager struct { - ctx context.Context - dialer ProviderQueryDialer - router ProviderQueryRouter - providerQueryMessages chan providerQueryMessage - providerRequestsProcessing chan *findProviderRequest - incomingFindProviderRequests chan *findProviderRequest + ctx context.Context + dialer ProviderQueryDialer + router ProviderQueryRouter + providerQueryMessages chan providerQueryMessage + providerRequestsProcessing *chanqueue.ChanQueue[*findProviderRequest] findProviderTimeout time.Duration - timeoutMutex sync.RWMutex maxProviders int maxInProcessRequests int @@ -108,7 +108,9 @@ func WithMaxTimeout(timeout time.Duration) Option { } } -// WithMaxInProcessRequests is the maximum number of requests that can be processed in parallel +// WithMaxInProcessRequests is the maximum number of requests that can be +// processed in parallel. If this is 0, then the number is unlimited. Default +// is defaultMaxInProcessRequests (16). func WithMaxInProcessRequests(count int) Option { return func(mgr *ProviderQueryManager) error { mgr.maxInProcessRequests = count @@ -117,7 +119,7 @@ func WithMaxInProcessRequests(count int) Option { } // WithMaxProviders is the maximum number of providers that will be looked up -// per query. We only return providers that we can connect to. Defaults to 0, +// per query. We only return providers that we can connect to. Defaults to 0, // which means unbounded. func WithMaxProviders(count int) Option { return func(mgr *ProviderQueryManager) error { @@ -130,16 +132,13 @@ func WithMaxProviders(count int) Option { // network provider. func New(ctx context.Context, dialer ProviderQueryDialer, router ProviderQueryRouter, opts ...Option) (*ProviderQueryManager, error) { pqm := &ProviderQueryManager{ - ctx: ctx, - dialer: dialer, - router: router, - providerQueryMessages: make(chan providerQueryMessage, 16), - providerRequestsProcessing: make(chan *findProviderRequest), - incomingFindProviderRequests: make(chan *findProviderRequest), - inProgressRequestStatuses: make(map[cid.Cid]*inProgressRequestStatus), - findProviderTimeout: defaultTimeout, - maxInProcessRequests: defaultMaxInProcessRequests, - maxProviders: defaultMaxProviders, + ctx: ctx, + dialer: dialer, + router: router, + providerQueryMessages: make(chan providerQueryMessage), + findProviderTimeout: defaultTimeout, + maxInProcessRequests: defaultMaxInProcessRequests, + maxProviders: defaultMaxProviders, } for _, o := range opts { @@ -161,13 +160,6 @@ type inProgressRequest struct { incoming chan peer.AddrInfo } -// setFindProviderTimeout changes the timeout for finding providers -func (pqm *ProviderQueryManager) setFindProviderTimeout(findProviderTimeout time.Duration) { - pqm.timeoutMutex.Lock() - pqm.findProviderTimeout = findProviderTimeout - pqm.timeoutMutex.Unlock() -} - // FindProvidersAsync finds providers for the given block. The max parameter // controls how many will be returned at most. For a provider to be returned, // we must have successfully connected to it. Setting max to 0 will use the @@ -216,32 +208,36 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, } func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, max int, receivedInProgressRequest inProgressRequest, onCloseFn func()) <-chan peer.AddrInfo { - // maintains an unbuffered queue for incoming providers for given request for a given session - // essentially, as a provider comes in, for a given CID, we want to immediately broadcast to all - // sessions that queried that CID, without worrying about whether the client code is actually - // reading from the returned channel -- so that the broadcast never blocks - // based on: https://medium.com/capital-one-tech/building-an-unbounded-channel-in-go-789e175cd2cd + // maintains an unbuffered queue for incoming providers for given request + // for a given session. Essentially, as a provider comes in, for a given + // CID, immediately broadcast to all sessions that queried that CID, + // without worrying about whether the client code is actually reading from + // the returned channel -- so that the broadcast never blocks. returnedProviders := make(chan peer.AddrInfo) - receivedProviders := append([]peer.AddrInfo(nil), receivedInProgressRequest.providersSoFar[0:]...) + var receivedProviders deque.Deque[peer.AddrInfo] + receivedProviders.Grow(len(receivedInProgressRequest.providersSoFar)) + for _, addrInfo := range receivedInProgressRequest.providersSoFar { + receivedProviders.PushBack(addrInfo) + } incomingProviders := receivedInProgressRequest.incoming // count how many providers we received from our workers etc. // these providers should be peers we managed to connect to. - total := len(receivedProviders) + total := receivedProviders.Len() go func() { defer close(returnedProviders) defer onCloseFn() outgoingProviders := func() chan<- peer.AddrInfo { - if len(receivedProviders) == 0 { + if receivedProviders.Len() == 0 { return nil } return returnedProviders } nextProvider := func() peer.AddrInfo { - if len(receivedProviders) == 0 { + if receivedProviders.Len() == 0 { return peer.AddrInfo{} } - return receivedProviders[0] + return receivedProviders.Front() } stopWhenMaxReached := func() { @@ -258,7 +254,7 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k // need. stopWhenMaxReached() - for len(receivedProviders) > 0 || incomingProviders != nil { + for receivedProviders.Len() > 0 || incomingProviders != nil { select { case <-pqm.ctx.Done(): return @@ -271,7 +267,7 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k if !ok { incomingProviders = nil } else { - receivedProviders = append(receivedProviders, provider) + receivedProviders.PushBack(provider) total++ stopWhenMaxReached() // we do not return, we will loop on @@ -281,7 +277,7 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k // via returnedProviders } case outgoingProviders() <- nextProvider(): - receivedProviders = receivedProviders[1:] + receivedProviders.PopFront() } } }() @@ -310,27 +306,42 @@ func (pqm *ProviderQueryManager) cancelProviderRequest(ctx context.Context, k ci } } +// findProviderWorker cycles through incoming provider queries one at a time. func (pqm *ProviderQueryManager) findProviderWorker() { - // findProviderWorker just cycles through incoming provider queries one - // at a time. We have six of these workers running at once - // to let requests go in parallel but keep them rate limited - for { - select { - case fpr, ok := <-pqm.providerRequestsProcessing: - if !ok { + var findSem chan struct{} + // If limiting the number of concurrent requests, create a counting + // semaphore to enforce this limit. + if pqm.maxInProcessRequests > 0 { + findSem = make(chan struct{}, pqm.maxInProcessRequests) + } + + // Read find provider requests until channel is closed. The channl is + // closed as soon as pqm.ctx is canceled, so there is no need to select on + // that context here. + for fpr := range pqm.providerRequestsProcessing.Out() { + if findSem != nil { + select { + case findSem <- struct{}{}: + case <-pqm.ctx.Done(): return } - k := fpr.k + } + + go func(ctx context.Context, k cid.Cid) { + if findSem != nil { + defer func() { + <-findSem + }() + } + log.Debugf("Beginning Find Provider Request for cid: %s", k.String()) - pqm.timeoutMutex.RLock() - findProviderCtx, cancel := context.WithTimeout(fpr.ctx, pqm.findProviderTimeout) - pqm.timeoutMutex.RUnlock() + findProviderCtx, cancel := context.WithTimeout(ctx, pqm.findProviderTimeout) span := trace.SpanFromContext(findProviderCtx) span.AddEvent("StartFindProvidersAsync") - // We set count == 0. We will cancel the query - // manually once we have enough. This assumes the - // ContentDiscovery implementation does that, which a - // requirement per the libp2p/core/routing interface. + // We set count == 0. We will cancel the query manually once we + // have enough. This assumes the ContentDiscovery + // implementation does that, which a requirement per the + // libp2p/core/routing interface. providers := pqm.router.FindProvidersAsync(findProviderCtx, k, 0) wg := &sync.WaitGroup{} for p := range providers { @@ -347,7 +358,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() { span.AddEvent("ConnectedToProvider", trace.WithAttributes(attribute.Stringer("peer", p.ID))) select { case pqm.providerQueryMessages <- &receivedProviderMessage{ - ctx: fpr.ctx, + ctx: ctx, k: k, p: p, }: @@ -360,48 +371,12 @@ func (pqm *ProviderQueryManager) findProviderWorker() { cancel() select { case pqm.providerQueryMessages <- &finishedProviderQueryMessage{ - ctx: fpr.ctx, + ctx: ctx, k: k, }: case <-pqm.ctx.Done(): } - case <-pqm.ctx.Done(): - return - } - } -} - -func (pqm *ProviderQueryManager) providerRequestBufferWorker() { - // the provider request buffer worker just maintains an unbounded - // buffer for incoming provider queries and dispatches to the find - // provider workers as they become available - // based on: https://medium.com/capital-one-tech/building-an-unbounded-channel-in-go-789e175cd2cd - var providerQueryRequestBuffer []*findProviderRequest - nextProviderQuery := func() *findProviderRequest { - if len(providerQueryRequestBuffer) == 0 { - return nil - } - return providerQueryRequestBuffer[0] - } - outgoingRequests := func() chan<- *findProviderRequest { - if len(providerQueryRequestBuffer) == 0 { - return nil - } - return pqm.providerRequestsProcessing - } - - for { - select { - case incomingRequest, ok := <-pqm.incomingFindProviderRequests: - if !ok { - return - } - providerQueryRequestBuffer = append(providerQueryRequestBuffer, incomingRequest) - case outgoingRequests() <- nextProviderQuery(): - providerQueryRequestBuffer = providerQueryRequestBuffer[1:] - case <-pqm.ctx.Done(): - return - } + }(fpr.ctx, fpr.k) } } @@ -417,10 +392,10 @@ func (pqm *ProviderQueryManager) cleanupInProcessRequests() { func (pqm *ProviderQueryManager) run() { defer pqm.cleanupInProcessRequests() - go pqm.providerRequestBufferWorker() - for i := 0; i < pqm.maxInProcessRequests; i++ { - go pqm.findProviderWorker() - } + pqm.providerRequestsProcessing = chanqueue.New[*findProviderRequest]() + defer pqm.providerRequestsProcessing.Shutdown() + + go pqm.findProviderWorker() for { select { @@ -469,6 +444,9 @@ func (fpqm *finishedProviderQueryMessage) handle(pqm *ProviderQueryManager) { close(listener) } delete(pqm.inProgressRequestStatuses, fpqm.k) + if len(pqm.inProgressRequestStatuses) == 0 { + pqm.inProgressRequestStatuses = nil + } requestStatus.cancelFn() } @@ -480,7 +458,6 @@ func (npqm *newProvideQueryMessage) debugMessage() { func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) { requestStatus, ok := pqm.inProgressRequestStatuses[npqm.k] if !ok { - ctx, cancelFn := context.WithCancel(pqm.ctx) span := trace.SpanFromContext(npqm.ctx) span.AddEvent("NewQuery", trace.WithAttributes(attribute.Stringer("cid", npqm.k))) @@ -492,10 +469,13 @@ func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) { cancelFn: cancelFn, } + if pqm.inProgressRequestStatuses == nil { + pqm.inProgressRequestStatuses = make(map[cid.Cid]*inProgressRequestStatus) + } pqm.inProgressRequestStatuses[npqm.k] = requestStatus select { - case pqm.incomingFindProviderRequests <- &findProviderRequest{ + case pqm.providerRequestsProcessing.In() <- &findProviderRequest{ k: npqm.k, ctx: ctx, }: @@ -536,6 +516,9 @@ func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) { close(crm.incomingProviders) if len(requestStatus.listeners) == 0 { delete(pqm.inProgressRequestStatuses, crm.k) + if len(pqm.inProgressRequestStatuses) == 0 { + pqm.inProgressRequestStatuses = nil + } requestStatus.cancelFn() } } diff --git a/routing/providerquerymanager/providerquerymanager_test.go b/routing/providerquerymanager/providerquerymanager_test.go index b55c1debc..7369231de 100644 --- a/routing/providerquerymanager/providerquerymanager_test.go +++ b/routing/providerquerymanager/providerquerymanager_test.go @@ -263,6 +263,8 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) { } func TestRateLimitingRequests(t *testing.T) { + const maxInProcessRequests = 6 + peers := random.Peers(10) fpd := &fakeProviderDialer{} fpn := &fakeProviderDiscovery{ @@ -272,31 +274,73 @@ func TestRateLimitingRequests(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn)) + providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxInProcessRequests(maxInProcessRequests))) providerQueryManager.Startup() - keys := random.Cids(providerQueryManager.maxInProcessRequests + 1) + keys := random.Cids(maxInProcessRequests + 1) sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() var requestChannels []<-chan peer.AddrInfo - for i := 0; i < providerQueryManager.maxInProcessRequests+1; i++ { + for i := 0; i < maxInProcessRequests+1; i++ { requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i], 0)) } time.Sleep(20 * time.Millisecond) fpn.queriesMadeMutex.Lock() - if fpn.liveQueries != providerQueryManager.maxInProcessRequests { + if fpn.liveQueries != maxInProcessRequests { t.Logf("Queries made: %d\n", fpn.liveQueries) t.Fatal("Did not limit parallel requests to rate limit") } fpn.queriesMadeMutex.Unlock() - for i := 0; i < providerQueryManager.maxInProcessRequests+1; i++ { + for i := 0; i < maxInProcessRequests+1; i++ { + for range requestChannels[i] { + } + } + + fpn.queriesMadeMutex.Lock() + defer fpn.queriesMadeMutex.Unlock() + if fpn.queriesMade != maxInProcessRequests+1 { + t.Logf("Queries made: %d\n", fpn.queriesMade) + t.Fatal("Did not make all separate requests") + } +} + +func TestUnlimitedRequests(t *testing.T) { + const inProcessRequests = 11 + + peers := random.Peers(10) + fpd := &fakeProviderDialer{} + fpn := &fakeProviderDiscovery{ + peersFound: peers, + delay: 5 * time.Millisecond, + } + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxInProcessRequests(0))) + providerQueryManager.Startup() + + keys := random.Cids(inProcessRequests) + sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + var requestChannels []<-chan peer.AddrInfo + for i := 0; i < inProcessRequests; i++ { + requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i], 0)) + } + time.Sleep(20 * time.Millisecond) + fpn.queriesMadeMutex.Lock() + if fpn.liveQueries != inProcessRequests { + t.Logf("Queries made: %d\n", fpn.liveQueries) + t.Fatal("Parallel requests appear to be rate limited") + } + fpn.queriesMadeMutex.Unlock() + for i := 0; i < inProcessRequests; i++ { for range requestChannels[i] { } } fpn.queriesMadeMutex.Lock() defer fpn.queriesMadeMutex.Unlock() - if fpn.queriesMade != providerQueryManager.maxInProcessRequests+1 { + if fpn.queriesMade != inProcessRequests { t.Logf("Queries made: %d\n", fpn.queriesMade) t.Fatal("Did not make all separate requests") } @@ -310,9 +354,8 @@ func TestFindProviderTimeout(t *testing.T) { delay: 10 * time.Millisecond, } ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn)) + providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxTimeout(2*time.Millisecond))) providerQueryManager.Startup() - providerQueryManager.setFindProviderTimeout(2 * time.Millisecond) keys := random.Cids(1) sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) @@ -335,9 +378,8 @@ func TestFindProviderPreCanceled(t *testing.T) { delay: 1 * time.Millisecond, } ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn)) + providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxTimeout(100*time.Millisecond))) providerQueryManager.Startup() - providerQueryManager.setFindProviderTimeout(100 * time.Millisecond) keys := random.Cids(1) sessionCtx, cancel := context.WithCancel(ctx) @@ -361,9 +403,8 @@ func TestCancelFindProvidersAfterCompletion(t *testing.T) { delay: 1 * time.Millisecond, } ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn)) + providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxTimeout(100*time.Millisecond))) providerQueryManager.Startup() - providerQueryManager.setFindProviderTimeout(100 * time.Millisecond) keys := random.Cids(1) sessionCtx, cancel := context.WithCancel(ctx) @@ -395,9 +436,8 @@ func TestLimitedProviders(t *testing.T) { delay: 1 * time.Millisecond, } ctx := context.Background() - providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxProviders(max))) + providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxProviders(max), WithMaxTimeout(100*time.Millisecond))) providerQueryManager.Startup() - providerQueryManager.setFindProviderTimeout(100 * time.Millisecond) keys := random.Cids(1) providersChan := providerQueryManager.FindProvidersAsync(ctx, keys[0], 0)