Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #119, panic gorilla ws when connection lost. #138

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ import (
"github.com/surrealdb/surrealdb.go/pkg/marshal"
)

// Default consts and vars for testing
const (
defaultURL = "ws://localhost:8000/rpc"
)

var currentURL = os.Getenv("SURREALDB_URL")

//

// TestDBSuite is a test s for the DB struct
type SurrealDBTestSuite struct {
suite.Suite
Expand Down Expand Up @@ -112,13 +121,18 @@ func (t testUser) String() (str string, err error) {
return
}

// openConnection opens a new connection to the database
func (s *SurrealDBTestSuite) openConnection() *surrealdb.DB {
func (s *SurrealDBTestSuite) createTestDB() *surrealdb.DB {
url := os.Getenv("SURREALDB_URL")
if url == "" {
url = "ws://localhost:8000/rpc"
}
impl := s.connImplementations[s.name]
db := s.openConnection(url, impl)
return db
}

// openConnection opens a new connection to the database
func (s *SurrealDBTestSuite) openConnection(url string, impl conn.Connection) *surrealdb.DB {
require.NotNil(s.T(), impl)
db, err := surrealdb.New(url, impl)
s.Require().NoError(err)
Expand All @@ -127,7 +141,7 @@ func (s *SurrealDBTestSuite) openConnection() *surrealdb.DB {

// SetupSuite is called before the s starts running
func (s *SurrealDBTestSuite) SetupSuite() {
db := s.openConnection()
db := s.createTestDB()
s.Require().NotNil(db)
s.db = db
_ = signin(s)
Expand Down Expand Up @@ -766,6 +780,24 @@ func (s *SurrealDBTestSuite) TestConcurrentOperations() {
})
}

func (s *SurrealDBTestSuite) TestConnectionBreak() {
ws := gorilla.Create()
var url string
if currentURL == "" {
url = defaultURL
} else {
url = currentURL
}

db := s.openConnection(url, ws)
// Close the connection hard from ws
ws.Conn.Close()

// Needs to be return error when the connection is closed or broken
_, err := db.Select("users")
s.Require().Error(err)
}

// assertContains performs an assertion on a list, asserting that at least one element matches a provided condition.
// All the matching elements are returned from this function, which can be used as a filter.
func assertContains[K any](s *SurrealDBTestSuite, input []K, matcher func(K) bool) []K {
Expand Down
61 changes: 41 additions & 20 deletions pkg/conn/gorilla/gorilla.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"reflect"
"strconv"
Expand Down Expand Up @@ -43,13 +44,14 @@ type WebSocket struct {
notificationChannels map[string]chan model.Notification
notificationChannelsLock sync.RWMutex

close chan int
closeChan chan int
closeError error
}

func Create() *WebSocket {
return &WebSocket{
Conn: nil,
close: make(chan int),
closeChan: make(chan int),
responseChannels: make(map[string]chan rpc.RPCResponse),
notificationChannels: make(map[string]chan model.Notification),
Timeout: DefaultTimeout * time.Second,
Expand All @@ -73,7 +75,7 @@ func (ws *WebSocket) Connect(url string) (conn.Connection, error) {
}
}

ws.initialize()
go ws.initialize()
return ws, nil
}

Expand Down Expand Up @@ -107,7 +109,7 @@ func (ws *WebSocket) SetCompression(compress bool) *WebSocket {
func (ws *WebSocket) Close() error {
ws.connLock.Lock()
defer ws.connLock.Unlock()
close(ws.close)
close(ws.closeChan)
err := ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, ""))
if err != nil {
return err
Expand Down Expand Up @@ -179,6 +181,12 @@ func (ws *WebSocket) getLiveChannel(id string) (chan model.Notification, bool) {
}

func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, error) {
select {
case <-ws.closeChan:
return nil, ws.closeError
default:
}

id := rand.String(RequestIDLength)
request := &rpc.RPCRequest{
ID: id,
Expand Down Expand Up @@ -235,25 +243,38 @@ func (ws *WebSocket) write(v interface{}) error {
}

func (ws *WebSocket) initialize() {
go func() {
for {
select {
case <-ws.close:
return
default:
var res rpc.RPCResponse
err := ws.read(&res)
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
ws.logger.Error(err.Error())
continue
for {
select {
case <-ws.closeChan:
return
default:
var res rpc.RPCResponse
err := ws.read(&res)
if err != nil {
shouldExit := ws.handleError(err)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function should exit on any error returned from ws.Conn.ReadMessage. You can fix by inlining WebSocket.read here and exiting when the call to ws.Conn.ReadMessage returns an error. Because handleError does not accurately detect all websocket errors, this thing can loop endlessly and eventually panic (the gorilla detects useless application loops).

if shouldExit {
return
}
go ws.handleResponse(res)
continue
}
go ws.handleResponse(res)
}
}()
}
}

func (ws *WebSocket) handleError(err error) bool {
if errors.Is(err, net.ErrClosed) {
ws.closeError = net.ErrClosed
return true
}
if gorilla.IsUnexpectedCloseError(err) {
ws.closeError = io.ErrClosedPipe
<-ws.closeChan
return true
}

ws.logger.Error(err.Error())
return false
}

func (ws *WebSocket) handleResponse(res rpc.RPCResponse) {
Expand Down
Loading