From dda7c6ba975c57c354b2e3a362bc7cb13282b384 Mon Sep 17 00:00:00 2001 From: ElecTwix Date: Tue, 21 May 2024 20:44:46 +0300 Subject: [PATCH] fix #119, panic gorilla ws when connection lost. instead of panic it will return error when any send function call --- pkg/conn/gorilla/gorilla.go | 61 +++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/pkg/conn/gorilla/gorilla.go b/pkg/conn/gorilla/gorilla.go index dcfbb37..4dcbf2c 100644 --- a/pkg/conn/gorilla/gorilla.go +++ b/pkg/conn/gorilla/gorilla.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "reflect" "strconv" @@ -43,13 +44,14 @@ type WebSocket struct { notificationChannels map[string]chan model.Notification notificationChannelsLock sync.RWMutex - close chan int + closeChan chan int + closeError error } func Create() *WebSocket { return &WebSocket{ Conn: nil, - close: make(chan int), + closeChan: make(chan int), responseChannels: make(map[string]chan rpc.RPCResponse), notificationChannels: make(map[string]chan model.Notification), Timeout: DefaultTimeout * time.Second, @@ -73,7 +75,7 @@ func (ws *WebSocket) Connect(url string) (conn.Connection, error) { } } - ws.initialize() + go ws.initialize() return ws, nil } @@ -107,7 +109,7 @@ func (ws *WebSocket) SetCompression(compress bool) *WebSocket { func (ws *WebSocket) Close() error { ws.connLock.Lock() defer ws.connLock.Unlock() - close(ws.close) + close(ws.closeChan) err := ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, "")) if err != nil { return err @@ -179,6 +181,12 @@ func (ws *WebSocket) getLiveChannel(id string) (chan model.Notification, bool) { } func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, error) { + select { + case <-ws.closeChan: + return nil, ws.closeError + default: + } + id := rand.String(RequestIDLength) request := &rpc.RPCRequest{ ID: id, @@ -235,25 +243,38 @@ func (ws *WebSocket) write(v interface{}) error { } func (ws *WebSocket) initialize() { - go func() { - for { - select { - case <-ws.close: - return - default: - var res rpc.RPCResponse - err := ws.read(&res) - if err != nil { - if errors.Is(err, net.ErrClosed) { - break - } - ws.logger.Error(err.Error()) - continue + for { + select { + case <-ws.closeChan: + return + default: + var res rpc.RPCResponse + err := ws.read(&res) + if err != nil { + shouldExit := ws.handleError(err) + if shouldExit { + return } - go ws.handleResponse(res) + continue } + go ws.handleResponse(res) } - }() + } +} + +func (ws *WebSocket) handleError(err error) bool { + if errors.Is(err, net.ErrClosed) { + ws.closeError = net.ErrClosed + return true + } + if gorilla.IsUnexpectedCloseError(err) { + ws.closeError = io.ErrClosedPipe + <-ws.closeChan + return true + } + + ws.logger.Error(err.Error()) + return false } func (ws *WebSocket) handleResponse(res rpc.RPCResponse) {