diff --git a/wrap.go b/wrap.go index a4790d4..4336767 100644 --- a/wrap.go +++ b/wrap.go @@ -86,9 +86,6 @@ func WithErrorHandler(f ErrorHandler) WrapOption { // Wrap takes a request handler function and returns a http.HandlerFunc for use with net/http. // The given handler is expected to take arguments like context.Context, *http.Request and return a convreq.HttpResponse or an error. func Wrap(f interface{}, opts ...WrapOption) http.HandlerFunc { - t := reflect.TypeOf(f) - v := reflect.ValueOf(f) - wo := wrapOptions{ extractors: map[reflect.Type]extractor{}, handlers: map[reflect.Type]reflect.Value{}, @@ -103,6 +100,28 @@ func Wrap(f interface{}, opts ...WrapOption) http.HandlerFunc { o(&wo) } + if fun, ok := f.(func(context.Context, *http.Request) HttpResponse); ok { + // Fast path without reflection for this common signature. + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if len(wo.contextWrappers) > 0 { + var cancel func() + for _, cw := range wo.contextWrappers { + ctx, cancel = cw(ctx) + if cancel != nil { + defer cancel() + } + } + r = r.WithContext(ctx) + } + hr := fun(ctx, r) + internal.DoRespond(w, r, hr) + } + } + + t := reflect.TypeOf(f) + v := reflect.ValueOf(f) + if t.Kind() != reflect.Func { panic(fmt.Errorf("convreq: %s: not a function", v.String())) }