From d6f59f95b0247c7f36432421404d9fe5c248d3f9 Mon Sep 17 00:00:00 2001 From: PotatoCloud <60210021+PotatoCloud@users.noreply.github.com> Date: Tue, 18 Jun 2024 20:00:37 +0800 Subject: [PATCH 1/2] Add Handshake.Header field --- dialer.go | 3 +++ server.go | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/dialer.go b/dialer.go index e66678e..d4d80a5 100644 --- a/dialer.go +++ b/dialer.go @@ -31,6 +31,9 @@ type Handshake struct { // Extensions is the list of negotiated extensions. Extensions []httphead.Option + + // Header all request headers obtained during handshake + Header http.Header } // Errors used by the websocket client. diff --git a/server.go b/server.go index 863bb22..ef3016a 100644 --- a/server.go +++ b/server.go @@ -241,6 +241,10 @@ func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net. if h := u.Header; h != nil { header[0] = HandshakeHeaderHTTP(h) } + + // set handshake header + hs.Header = r.Header + if err == nil { httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo) err = rw.Writer.Flush() @@ -498,6 +502,10 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { nonce = make([]byte, nonceSize) ) + + // init handshake headers + hs.Header = make(http.Header) + for err == nil { line, e := readLine(br) if e != nil { @@ -514,6 +522,9 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { break } + // copy and add header + hs.Header.Add(btsToString(bytes.Clone(k)), btsToString(bytes.Clone(v))) + switch btsToString(k) { case headerHostCanonical: headerSeen |= headerSeenHost From e3196f96a146120805c88e2cc808017e66101705 Mon Sep 17 00:00:00 2001 From: PotatoCloud <60210021+PotatoCloud@users.noreply.github.com> Date: Tue, 18 Jun 2024 21:27:45 +0800 Subject: [PATCH 2/2] Add `HTTPUpgrader.CopyHeadersToHandshake`, `Upgrader.CopyHeadersToHandshake` field. --- server.go | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/server.go b/server.go index ef3016a..edb4ff1 100644 --- a/server.go +++ b/server.go @@ -109,6 +109,10 @@ func Upgrade(conn io.ReadWriter) (Handshake, error) { // HTTPUpgrader contains options for upgrading connection to websocket from // net/http Handler arguments. type HTTPUpgrader struct { + // CopyHeadersToHandshake setting specifies whether headers should be preserved during the handshake process. + // If enabled, the headers will be copied to Handshake.Header. + CopyHeadersToHandshake bool + // Timeout is the maximum amount of time an Upgrade() will spent while // writing handshake response. // @@ -242,8 +246,10 @@ func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net. header[0] = HandshakeHeaderHTTP(h) } - // set handshake header - hs.Header = r.Header + if u.CopyHeadersToHandshake { + // set handshake header + hs.Header = r.Header + } if err == nil { httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo) @@ -266,6 +272,10 @@ func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net. // Upgrader contains options for upgrading connection to websocket. type Upgrader struct { + // CopyHeadersToHandshake setting specifies whether headers should be preserved during the handshake process. + // If enabled, the headers will be copied to Handshake.Header. + CopyHeadersToHandshake bool + // ReadBufferSize and WriteBufferSize is an I/O buffer sizes. // They used to read and write http data while upgrading to WebSocket. // Allocated buffers are pooled with sync.Pool to avoid extra allocations. @@ -503,8 +513,10 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { nonce = make([]byte, nonceSize) ) - // init handshake headers - hs.Header = make(http.Header) + if u.CopyHeadersToHandshake { + // init handshake headers + hs.Header = make(http.Header) + } for err == nil { line, e := readLine(br) @@ -522,8 +534,10 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { break } - // copy and add header - hs.Header.Add(btsToString(bytes.Clone(k)), btsToString(bytes.Clone(v))) + if u.CopyHeadersToHandshake { + // copy and add header + hs.Header.Add(btsToString(bytes.Clone(k)), btsToString(bytes.Clone(v))) + } switch btsToString(k) { case headerHostCanonical: