diff --git a/db.go b/db.go index 9dc6924..f73f2da 100644 --- a/db.go +++ b/db.go @@ -37,8 +37,8 @@ func New(url string, connection conn.Connection) (*DB, error) { // -------------------------------------------------- // Close closes the underlying WebSocket connection. -func (db *DB) Close() { - _ = db.conn.Close() +func (db *DB) Close() error { + return db.conn.Close() } // -------------------------------------------------- diff --git a/db_test.go b/db_test.go index 554186c..874c567 100644 --- a/db_test.go +++ b/db_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" rawslog "log/slog" "os" "sync" @@ -31,6 +32,7 @@ type SurrealDBTestSuite struct { db *surrealdb.DB name string connImplementations map[string]conn.Connection + logBuffer *bytes.Buffer } // a simple user struct for testing @@ -55,20 +57,23 @@ func TestSurrealDBSuite(t *testing.T) { SurrealDBSuite.connImplementations = make(map[string]conn.Connection) // Without options - logData := createLogger(t) + buff := bytes.NewBufferString("") + logData := createLogger(t, buff) SurrealDBSuite.connImplementations["gorilla"] = gorilla.Create().Logger(logData) + SurrealDBSuite.logBuffer = buff // With options - logData = createLogger(t) - SurrealDBSuite.connImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(logData) + buffOpt := bytes.NewBufferString("") + logDataOpt := createLogger(t, buff) + SurrealDBSuite.connImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(logDataOpt) + SurrealDBSuite.logBuffer = buffOpt RunWsMap(t, SurrealDBSuite) } -func createLogger(t *testing.T) logger.Logger { +func createLogger(t *testing.T, writer io.Writer) logger.Logger { t.Helper() - buff := bytes.NewBuffer([]byte{}) - handler := rawslog.NewJSONHandler(buff, &rawslog.HandlerOptions{Level: rawslog.LevelDebug}) + handler := rawslog.NewJSONHandler(writer, &rawslog.HandlerOptions{Level: rawslog.LevelDebug}) return slog.New(handler) } @@ -86,11 +91,16 @@ func RunWsMap(t *testing.T, s *SurrealDBTestSuite) { func (s *SurrealDBTestSuite) TearDownTest() { _, err := s.db.Delete("users") s.Require().NoError(err) + + if s.logBuffer.Len() > 0 { + s.T().Logf("Log output:\n%s", s.logBuffer.String()) + } } // TearDownSuite is called after the s has finished running func (s *SurrealDBTestSuite) TearDownSuite() { - s.db.Close() + err := s.db.Close() + s.Require().NoError(err) } func (t testUser) String() (str string, err error) { diff --git a/pkg/conn/gorilla/gorilla.go b/pkg/conn/gorilla/gorilla.go index c2a8b4c..dcfbb37 100644 --- a/pkg/conn/gorilla/gorilla.go +++ b/pkg/conn/gorilla/gorilla.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "reflect" "strconv" "sync" @@ -104,11 +105,15 @@ func (ws *WebSocket) SetCompression(compress bool) *WebSocket { } func (ws *WebSocket) Close() error { - defer func() { - close(ws.close) - }() + ws.connLock.Lock() + defer ws.connLock.Unlock() + close(ws.close) + err := ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, "")) + if err != nil { + return err + } - return ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, "")) + return ws.Conn.Close() } func (ws *WebSocket) LiveNotifications(liveQueryID string) (chan model.Notification, error) { @@ -239,6 +244,9 @@ func (ws *WebSocket) initialize() { var res rpc.RPCResponse err := ws.read(&res) if err != nil { + if errors.Is(err, net.ErrClosed) { + break + } ws.logger.Error(err.Error()) continue }