Skip to content

Commit

Permalink
fix surrealdb#119, panic gorilla ws when connection lost.
Browse files Browse the repository at this point in the history
instead of panic it will return error when any send function call
  • Loading branch information
ElecTwix committed May 21, 2024
1 parent bc55e64 commit dda7c6b
Showing 1 changed file with 41 additions and 20 deletions.
61 changes: 41 additions & 20 deletions pkg/conn/gorilla/gorilla.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"reflect"
"strconv"
Expand Down Expand Up @@ -43,13 +44,14 @@ type WebSocket struct {
notificationChannels map[string]chan model.Notification
notificationChannelsLock sync.RWMutex

close chan int
closeChan chan int
closeError error
}

func Create() *WebSocket {
return &WebSocket{
Conn: nil,
close: make(chan int),
closeChan: make(chan int),
responseChannels: make(map[string]chan rpc.RPCResponse),
notificationChannels: make(map[string]chan model.Notification),
Timeout: DefaultTimeout * time.Second,
Expand All @@ -73,7 +75,7 @@ func (ws *WebSocket) Connect(url string) (conn.Connection, error) {
}
}

ws.initialize()
go ws.initialize()
return ws, nil
}

Expand Down Expand Up @@ -107,7 +109,7 @@ func (ws *WebSocket) SetCompression(compress bool) *WebSocket {
func (ws *WebSocket) Close() error {
ws.connLock.Lock()
defer ws.connLock.Unlock()
close(ws.close)
close(ws.closeChan)
err := ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, ""))
if err != nil {
return err
Expand Down Expand Up @@ -179,6 +181,12 @@ func (ws *WebSocket) getLiveChannel(id string) (chan model.Notification, bool) {
}

func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, error) {
select {
case <-ws.closeChan:
return nil, ws.closeError
default:
}

id := rand.String(RequestIDLength)
request := &rpc.RPCRequest{
ID: id,
Expand Down Expand Up @@ -235,25 +243,38 @@ func (ws *WebSocket) write(v interface{}) error {
}

func (ws *WebSocket) initialize() {
go func() {
for {
select {
case <-ws.close:
return
default:
var res rpc.RPCResponse
err := ws.read(&res)
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
ws.logger.Error(err.Error())
continue
for {
select {
case <-ws.closeChan:
return
default:
var res rpc.RPCResponse
err := ws.read(&res)
if err != nil {
shouldExit := ws.handleError(err)
if shouldExit {
return
}
go ws.handleResponse(res)
continue
}
go ws.handleResponse(res)
}
}()
}
}

func (ws *WebSocket) handleError(err error) bool {
if errors.Is(err, net.ErrClosed) {
ws.closeError = net.ErrClosed
return true
}
if gorilla.IsUnexpectedCloseError(err) {
ws.closeError = io.ErrClosedPipe
<-ws.closeChan
return true
}

ws.logger.Error(err.Error())
return false
}

func (ws *WebSocket) handleResponse(res rpc.RPCResponse) {
Expand Down

0 comments on commit dda7c6b

Please sign in to comment.