diff --git a/db_test.go b/db_test.go index 874c567..def4dd1 100644 --- a/db_test.go +++ b/db_test.go @@ -26,6 +26,15 @@ import ( "github.com/surrealdb/surrealdb.go/pkg/marshal" ) +// Default consts and vars for testing +const ( + defaultURL = "ws://localhost:8000/rpc" +) + +var currentURL = os.Getenv("SURREALDB_URL") + +// + // TestDBSuite is a test s for the DB struct type SurrealDBTestSuite struct { suite.Suite @@ -112,13 +121,18 @@ func (t testUser) String() (str string, err error) { return } -// openConnection opens a new connection to the database -func (s *SurrealDBTestSuite) openConnection() *surrealdb.DB { +func (s *SurrealDBTestSuite) createTestDB() *surrealdb.DB { url := os.Getenv("SURREALDB_URL") if url == "" { url = "ws://localhost:8000/rpc" } impl := s.connImplementations[s.name] + db := s.openConnection(url, impl) + return db +} + +// openConnection opens a new connection to the database +func (s *SurrealDBTestSuite) openConnection(url string, impl conn.Connection) *surrealdb.DB { require.NotNil(s.T(), impl) db, err := surrealdb.New(url, impl) s.Require().NoError(err) @@ -127,7 +141,7 @@ func (s *SurrealDBTestSuite) openConnection() *surrealdb.DB { // SetupSuite is called before the s starts running func (s *SurrealDBTestSuite) SetupSuite() { - db := s.openConnection() + db := s.createTestDB() s.Require().NotNil(db) s.db = db _ = signin(s) @@ -766,6 +780,24 @@ func (s *SurrealDBTestSuite) TestConcurrentOperations() { }) } +func (s *SurrealDBTestSuite) TestConnectionBreak() { + ws := gorilla.Create() + var url string + if currentURL == "" { + url = defaultURL + } else { + url = currentURL + } + + db := s.openConnection(url, ws) + // Close the connection hard from ws + ws.Conn.Close() + + // Needs to be return error when the connection is closed or broken + _, err := db.Select("users") + s.Require().Error(err) +} + // assertContains performs an assertion on a list, asserting that at least one element matches a provided condition. // All the matching elements are returned from this function, which can be used as a filter. func assertContains[K any](s *SurrealDBTestSuite, input []K, matcher func(K) bool) []K { diff --git a/pkg/conn/gorilla/gorilla.go b/pkg/conn/gorilla/gorilla.go index dcfbb37..4dcbf2c 100644 --- a/pkg/conn/gorilla/gorilla.go +++ b/pkg/conn/gorilla/gorilla.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "reflect" "strconv" @@ -43,13 +44,14 @@ type WebSocket struct { notificationChannels map[string]chan model.Notification notificationChannelsLock sync.RWMutex - close chan int + closeChan chan int + closeError error } func Create() *WebSocket { return &WebSocket{ Conn: nil, - close: make(chan int), + closeChan: make(chan int), responseChannels: make(map[string]chan rpc.RPCResponse), notificationChannels: make(map[string]chan model.Notification), Timeout: DefaultTimeout * time.Second, @@ -73,7 +75,7 @@ func (ws *WebSocket) Connect(url string) (conn.Connection, error) { } } - ws.initialize() + go ws.initialize() return ws, nil } @@ -107,7 +109,7 @@ func (ws *WebSocket) SetCompression(compress bool) *WebSocket { func (ws *WebSocket) Close() error { ws.connLock.Lock() defer ws.connLock.Unlock() - close(ws.close) + close(ws.closeChan) err := ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, "")) if err != nil { return err @@ -179,6 +181,12 @@ func (ws *WebSocket) getLiveChannel(id string) (chan model.Notification, bool) { } func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, error) { + select { + case <-ws.closeChan: + return nil, ws.closeError + default: + } + id := rand.String(RequestIDLength) request := &rpc.RPCRequest{ ID: id, @@ -235,25 +243,38 @@ func (ws *WebSocket) write(v interface{}) error { } func (ws *WebSocket) initialize() { - go func() { - for { - select { - case <-ws.close: - return - default: - var res rpc.RPCResponse - err := ws.read(&res) - if err != nil { - if errors.Is(err, net.ErrClosed) { - break - } - ws.logger.Error(err.Error()) - continue + for { + select { + case <-ws.closeChan: + return + default: + var res rpc.RPCResponse + err := ws.read(&res) + if err != nil { + shouldExit := ws.handleError(err) + if shouldExit { + return } - go ws.handleResponse(res) + continue } + go ws.handleResponse(res) } - }() + } +} + +func (ws *WebSocket) handleError(err error) bool { + if errors.Is(err, net.ErrClosed) { + ws.closeError = net.ErrClosed + return true + } + if gorilla.IsUnexpectedCloseError(err) { + ws.closeError = io.ErrClosedPipe + <-ws.closeChan + return true + } + + ws.logger.Error(err.Error()) + return false } func (ws *WebSocket) handleResponse(res rpc.RPCResponse) {