Skip to content

Commit

Permalink
websocket: don't call OnClose when ws conn is fast closed before the …
Browse files Browse the repository at this point in the history
…upgrade succeeds
  • Loading branch information
lesismal committed Apr 23, 2024
1 parent 76145b5 commit 3059984
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 25 deletions.
10 changes: 10 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ func (c *Conn) IsUnix() bool {
return c.typ == ConnTypeUnix
}

// Session returns user session.
func (c *Conn) Session() interface{} {
return c.session
}

// SetSession sets user session.
func (c *Conn) SetSession(session interface{}) {
c.session = session
}

// OnData registers Conn's data handler.
// Notice:
// 1. The data readed by the poller is not handled by this Conn's data
Expand Down
10 changes: 0 additions & 10 deletions conn_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,16 +437,6 @@ func (c *Conn) SetLinger(onoff int32, linger int32) error {
return nil
}

// Session returns user session.
func (c *Conn) Session() interface{} {
return c.session
}

// SetSession sets user session.
func (c *Conn) SetSession(session interface{}) {
c.session = session
}

func newConn(conn net.Conn) *Conn {
c := &Conn{}
addr := conn.LocalAddr().String()
Expand Down
10 changes: 0 additions & 10 deletions conn_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,16 +618,6 @@ func (c *Conn) SetLinger(onoff int32, linger int32) error {
)
}

// Session returns user session.
func (c *Conn) Session() interface{} {
return c.session
}

// SetSession sets user session.
func (c *Conn) SetSession(session interface{}) {
c.session = session
}

// sets writing event.
func (c *Conn) modWrite() {
if !c.closed && !c.isWAdded {
Expand Down
4 changes: 1 addition & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,4 @@ go 1.16

require github.com/lesismal/llib v1.1.13

retract (
v1.5.4 // Contains body length parsing bug.
)
retract v1.5.4 // Contains body length parsing bug.
17 changes: 15 additions & 2 deletions nbhttp/websocket/upgrader.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
var engine = u.Engine
var parser *nbhttp.Parser
var transferConn = u.BlockingModTrasferConnToPoller

if len(args) > 0 {
var b bool
b, ok = args[0].(bool)
Expand All @@ -279,6 +280,14 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
}
}

clearNBCWSSession := func() {
if nbc != nil {
if _, ok = nbc.Session().(*Conn); ok {
nbc.SetSession(nil)
}
}
}

var underLayerConn net.Conn
nbhttpConn, isReadingByParser := conn.(*nbhttp.Conn)
if isReadingByParser {
Expand All @@ -291,14 +300,15 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
switch vt := underLayerConn.(type) {
case *nbio.Conn:
// Scenario 1: *nbio.Conn, handled by nbhttp.Engine.
parser, ok = vt.Session().(*nbhttp.Parser)
nbc = vt
parser, ok = nbc.Session().(*nbhttp.Parser)
if !ok {
return nil, u.returnError(w, r, http.StatusInternalServerError, err)
}
wsc = NewServerConn(u, conn, subprotocol, compress, false)
wsc.Engine = parser.Engine
wsc.Execute = parser.Execute
vt.SetSession(wsc)
nbc.SetSession(wsc)
if nbhttpConn != nil {
nbhttpConn.Parser = nil
}
Expand Down Expand Up @@ -367,6 +377,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
vt.ResetConn(nbc, nonblock)
err = engine.AddTransferredConn(nbc)
if err != nil {
clearNBCWSSession()
return nil, u.returnError(w, r, http.StatusInternalServerError, err)
}
} else {
Expand Down Expand Up @@ -437,6 +448,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
})
err = engine.AddTransferredConn(nbc)
if err != nil {
clearNBCWSSession()
return nil, u.returnError(w, r, http.StatusInternalServerError, err)
}
} else {
Expand Down Expand Up @@ -468,6 +480,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade

err = u.commResponse(wsc.Conn, responseHeader, challengeKey, subprotocol, compress)
if err != nil {
clearNBCWSSession()
return nil, err
}

Expand Down

0 comments on commit 3059984

Please sign in to comment.