diff --git a/hijack.go b/hijack.go index 477dd992..a02d4f33 100644 --- a/hijack.go +++ b/hijack.go @@ -356,6 +356,31 @@ func (ctx *HijackResponse) Headers() http.Header { // SetHeader of the payload via key-value pairs. func (ctx *HijackResponse) SetHeader(pairs ...string) *HijackResponse { + headerIndex := make(map[string]int, len(ctx.payload.ResponseHeaders)) + for i, header := range ctx.payload.ResponseHeaders { + headerIndex[header.Name] = i + } + + for i := 0; i < len(pairs); i += 2 { + name := pairs[i] + value := pairs[i+1] + + if idx, exists := headerIndex[name]; exists { + ctx.payload.ResponseHeaders[idx].Value = value + } else { + ctx.payload.ResponseHeaders = append(ctx.payload.ResponseHeaders, &proto.FetchHeaderEntry{ + Name: name, + Value: value, + }) + headerIndex[name] = len(ctx.payload.ResponseHeaders) - 1 + } + } + return ctx +} + +// AddHeader appends key-value pairs to the end of the response headers. +// Duplicate keys will be preserved. +func (ctx *HijackResponse) AddHeader(pairs ...string) *HijackResponse { for i := 0; i < len(pairs); i += 2 { ctx.payload.ResponseHeaders = append(ctx.payload.ResponseHeaders, &proto.FetchHeaderEntry{ Name: pairs[i], diff --git a/hijack_test.go b/hijack_test.go index a5b114ab..470fd8ce 100644 --- a/hijack_test.go +++ b/hijack_test.go @@ -75,6 +75,8 @@ func TestHijack(t *testing.T) { g.Has(ctx.Response.Headers().Get("Content-Type"), "text/html; charset=utf-8") // override response header + ctx.Response.AddHeader("Set-Cookie", "key=val1") + // This should override the previous one ctx.Response.SetHeader("Set-Cookie", "key=val") // override response body