diff --git a/CHANGELOG.md b/CHANGELOG.md index faca07433..8ed1fb078 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The following emojis are used to highlight certain changes: ### Added - `routing/http/server`: added Prometheus instrumentation to http delegated routing endpoints. +- `routing/http/server`: added configurable routing timeout (`DefaultRoutingTimeout` being 30s) to prevent indefinite hangs during content/peer routing. Set custom duration via `WithRoutingTimeout`. ### Changed diff --git a/routing/http/server/server.go b/routing/http/server/server.go index 6177da125..f6d0f1993 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -41,6 +41,7 @@ const ( DefaultRecordsLimit = 20 DefaultStreamingRecordsLimit = 0 + DefaultRoutingTimeout = 30 * time.Second ) var logger = logging.Logger("routing/http/server") @@ -132,11 +133,18 @@ func WithPrometheusRegistry(reg prometheus.Registerer) Option { } } +func WithRoutingTimeout(timeout time.Duration) Option { + return func(s *server) { + s.routingTimeout = timeout + } +} + func Handler(svc ContentRouter, opts ...Option) http.Handler { server := &server{ svc: svc, recordsLimit: DefaultRecordsLimit, streamingRecordsLimit: DefaultStreamingRecordsLimit, + routingTimeout: DefaultRoutingTimeout, } for _, opt := range opts { @@ -174,6 +182,7 @@ type server struct { recordsLimit int streamingRecordsLimit int promRegistry prometheus.Registerer + routingTimeout time.Duration } func (s *server) detectResponseType(r *http.Request) (string, error) { @@ -246,7 +255,10 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { recordsLimit = s.recordsLimit } - provIter, err := s.svc.FindProviders(httpReq.Context(), cid, recordsLimit) + ctx, cancel := context.WithTimeout(httpReq.Context(), s.routingTimeout) + defer cancel() + + provIter, err := s.svc.FindProviders(ctx, cid, recordsLimit) if err != nil { if errors.Is(err, routing.ErrNotFound) { // handlerFunc takes care of setting the 404 and necessary headers @@ -335,7 +347,11 @@ func (s *server) findPeers(w http.ResponseWriter, r *http.Request) { recordsLimit = s.recordsLimit } - provIter, err := s.svc.FindPeers(r.Context(), pid, recordsLimit) + // Add timeout to the routing operation + ctx, cancel := context.WithTimeout(r.Context(), s.routingTimeout) + defer cancel() + + provIter, err := s.svc.FindPeers(ctx, pid, recordsLimit) if err != nil { if errors.Is(err, routing.ErrNotFound) { // handlerFunc takes care of setting the 404 and necessary headers @@ -466,7 +482,10 @@ func (s *server) GetIPNS(w http.ResponseWriter, r *http.Request) { return } - record, err := s.svc.GetIPNS(r.Context(), name) + ctx, cancel := context.WithTimeout(r.Context(), s.routingTimeout) + defer cancel() + + record, err := s.svc.GetIPNS(ctx, name) if err != nil { if errors.Is(err, routing.ErrNotFound) { writeErr(w, "GetIPNS", http.StatusNotFound, fmt.Errorf("delegate error: %w", err)) @@ -550,7 +569,10 @@ func (s *server) PutIPNS(w http.ResponseWriter, r *http.Request) { return } - err = s.svc.PutIPNS(r.Context(), name, record) + ctx, cancel := context.WithTimeout(r.Context(), s.routingTimeout) + defer cancel() + + err = s.svc.PutIPNS(ctx, name, record) if err != nil { writeErr(w, "PutIPNS", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err)) return