diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7d6d7aa..04ea0fc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,7 @@ jobs: test: strategy: matrix: - go-version: ['1.21'] + go-version: ["1.21"] permissions: contents: read pull-requests: read @@ -28,7 +28,7 @@ jobs: - name: download surrealdb run: curl --proto '=https' --tlsv1.2 -sSf https://install.surrealdb.com | sh -s -- --nightly - name: start surrealdb - run: surreal start memory -A --auth --user root --pass root & + run: surreal start memory -A --user root --pass root & - name: test run: go test -v -cover ./... env: diff --git a/db_test.go b/db_test.go index def4dd1..7c31d50 100644 --- a/db_test.go +++ b/db_test.go @@ -9,7 +9,6 @@ import ( "os" "sync" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/surrealdb/surrealdb.go/pkg/logger/slog" @@ -18,7 +17,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/surrealdb/surrealdb.go" - "github.com/surrealdb/surrealdb.go/pkg/conn/gorilla" + "github.com/surrealdb/surrealdb.go/pkg/conn/nhooyr" "github.com/surrealdb/surrealdb.go/pkg/constants" "github.com/surrealdb/surrealdb.go/pkg/conn" @@ -65,17 +64,18 @@ func TestSurrealDBSuite(t *testing.T) { SurrealDBSuite := new(SurrealDBTestSuite) SurrealDBSuite.connImplementations = make(map[string]conn.Connection) + // // Nhooyr // Without options - buff := bytes.NewBufferString("") - logData := createLogger(t, buff) - SurrealDBSuite.connImplementations["gorilla"] = gorilla.Create().Logger(logData) - SurrealDBSuite.logBuffer = buff + nbuff := bytes.NewBufferString("") + nlogData := createLogger(t, nbuff) + SurrealDBSuite.connImplementations["nhooyr"] = nhooyr.Create().Logger(nlogData) + SurrealDBSuite.logBuffer = nbuff // With options - buffOpt := bytes.NewBufferString("") - logDataOpt := createLogger(t, buff) - SurrealDBSuite.connImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(logDataOpt) - SurrealDBSuite.logBuffer = buffOpt + nbuffOpt := bytes.NewBufferString("") + nlogDataOpt := createLogger(t, nbuffOpt) + SurrealDBSuite.connImplementations["nhooyr_opt"] = nhooyr.Create().Logger(nlogDataOpt) + SurrealDBSuite.logBuffer = nbuffOpt RunWsMap(t, SurrealDBSuite) } @@ -781,7 +781,7 @@ func (s *SurrealDBTestSuite) TestConcurrentOperations() { } func (s *SurrealDBTestSuite) TestConnectionBreak() { - ws := gorilla.Create() + ws := nhooyr.Create() var url string if currentURL == "" { url = defaultURL @@ -791,10 +791,11 @@ func (s *SurrealDBTestSuite) TestConnectionBreak() { db := s.openConnection(url, ws) // Close the connection hard from ws - ws.Conn.Close() + err := ws.Conn.CloseNow() + s.Require().NoError(err) // Needs to be return error when the connection is closed or broken - _, err := db.Select("users") + _, err = db.Select("users") s.Require().Error(err) } diff --git a/go.mod b/go.mod index 43ef31e..e49abcd 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,10 @@ module github.com/surrealdb/surrealdb.go go 1.20 require ( + github.com/coder/websocket v1.8.12 github.com/gorilla/websocket v1.5.0 github.com/stretchr/testify v1.8.4 + nhooyr.io/websocket v1.8.17 ) require ( diff --git a/go.sum b/go.sum index ab9f4d8..e2df07c 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -23,3 +25,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= +nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= +nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y= +nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/pkg/conn/gorilla/gorilla.go b/pkg/conn/nhooyr/nhooyr.go similarity index 84% rename from pkg/conn/gorilla/gorilla.go rename to pkg/conn/nhooyr/nhooyr.go index 4dcbf2c..f4521e2 100644 --- a/pkg/conn/gorilla/gorilla.go +++ b/pkg/conn/nhooyr/nhooyr.go @@ -1,19 +1,20 @@ -package gorilla +package nhooyr import ( + "context" "encoding/json" "errors" "fmt" - "io" "net" + "net/http" "reflect" "strconv" "sync" "time" + coderws "github.com/coder/websocket" "github.com/surrealdb/surrealdb.go/pkg/model" - gorilla "github.com/gorilla/websocket" "github.com/surrealdb/surrealdb.go/internal/rpc" "github.com/surrealdb/surrealdb.go/pkg/conn" "github.com/surrealdb/surrealdb.go/pkg/logger" @@ -32,7 +33,7 @@ const ( type Option func(ws *WebSocket) error type WebSocket struct { - Conn *gorilla.Conn + Conn *coderws.Conn connLock sync.Mutex Timeout time.Duration Option []Option @@ -44,14 +45,13 @@ type WebSocket struct { notificationChannels map[string]chan model.Notification notificationChannelsLock sync.RWMutex - closeChan chan int - closeError error + close chan int } func Create() *WebSocket { return &WebSocket{ Conn: nil, - closeChan: make(chan int), + close: make(chan int), responseChannels: make(map[string]chan rpc.RPCResponse), notificationChannels: make(map[string]chan model.Notification), Timeout: DefaultTimeout * time.Second, @@ -59,14 +59,19 @@ func Create() *WebSocket { } func (ws *WebSocket) Connect(url string) (conn.Connection, error) { - dialer := gorilla.DefaultDialer - dialer.EnableCompression = true + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() - connection, _, err := dialer.Dial(url, nil) + connection, resp, err := coderws.Dial(ctx, url, nil) if err != nil { return nil, err } + if resp.StatusCode != http.StatusSwitchingProtocols { + resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + ws.Conn = connection for _, option := range ws.Option { @@ -75,7 +80,7 @@ func (ws *WebSocket) Connect(url string) (conn.Connection, error) { } } - go ws.initialize() + ws.initialize() return ws, nil } @@ -98,24 +103,12 @@ func (ws *WebSocket) RawLogger(logData logger.Logger) *WebSocket { return ws } -func (ws *WebSocket) SetCompression(compress bool) *WebSocket { - ws.Option = append(ws.Option, func(ws *WebSocket) error { - ws.Conn.EnableWriteCompression(compress) - return nil - }) - return ws -} - func (ws *WebSocket) Close() error { ws.connLock.Lock() defer ws.connLock.Unlock() - close(ws.closeChan) - err := ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, "")) - if err != nil { - return err - } + close(ws.close) - return ws.Conn.Close() + return ws.Conn.Close(coderws.StatusNormalClosure, "") } func (ws *WebSocket) LiveNotifications(liveQueryID string) (chan model.Notification, error) { @@ -181,12 +174,6 @@ func (ws *WebSocket) getLiveChannel(id string) (chan model.Notification, bool) { } func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, error) { - select { - case <-ws.closeChan: - return nil, ws.closeError - default: - } - id := rand.String(RequestIDLength) request := &rpc.RPCRequest{ ID: id, @@ -224,7 +211,7 @@ func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, err } func (ws *WebSocket) read(v interface{}) error { - _, data, err := ws.Conn.ReadMessage() + _, data, err := ws.Conn.Read(context.Background()) if err != nil { return err } @@ -239,42 +226,29 @@ func (ws *WebSocket) write(v interface{}) error { ws.connLock.Lock() defer ws.connLock.Unlock() - return ws.Conn.WriteMessage(gorilla.TextMessage, data) + return ws.Conn.Write(context.Background(), coderws.MessageText, data) } func (ws *WebSocket) initialize() { - for { - select { - case <-ws.closeChan: - return - default: - var res rpc.RPCResponse - err := ws.read(&res) - if err != nil { - shouldExit := ws.handleError(err) - if shouldExit { - return + go func() { + for { + select { + case <-ws.close: + return + default: + var res rpc.RPCResponse + err := ws.read(&res) + if err != nil { + if errors.Is(err, net.ErrClosed) { + break + } + ws.logger.Error(err.Error()) + continue } - continue + go ws.handleResponse(res) } - go ws.handleResponse(res) } - } -} - -func (ws *WebSocket) handleError(err error) bool { - if errors.Is(err, net.ErrClosed) { - ws.closeError = net.ErrClosed - return true - } - if gorilla.IsUnexpectedCloseError(err) { - ws.closeError = io.ErrClosedPipe - <-ws.closeChan - return true - } - - ws.logger.Error(err.Error()) - return false + }() } func (ws *WebSocket) handleResponse(res rpc.RPCResponse) {