Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace gorilla ws #143

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
test:
strategy:
matrix:
go-version: ['1.21']
go-version: ["1.21"]
permissions:
contents: read
pull-requests: read
Expand All @@ -28,7 +28,7 @@ jobs:
- name: download surrealdb
run: curl --proto '=https' --tlsv1.2 -sSf https://install.surrealdb.com | sh -s -- --nightly
- name: start surrealdb
run: surreal start memory -A --auth --user root --pass root &
run: surreal start memory -A --user root --pass root &
- name: test
run: go test -v -cover ./...
env:
Expand Down
27 changes: 14 additions & 13 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"os"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/surrealdb/surrealdb.go/pkg/logger/slog"
Expand All @@ -18,7 +17,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/surrealdb/surrealdb.go"
"github.com/surrealdb/surrealdb.go/pkg/conn/gorilla"
"github.com/surrealdb/surrealdb.go/pkg/conn/nhooyr"
"github.com/surrealdb/surrealdb.go/pkg/constants"

"github.com/surrealdb/surrealdb.go/pkg/conn"
Expand Down Expand Up @@ -65,17 +64,18 @@ func TestSurrealDBSuite(t *testing.T) {
SurrealDBSuite := new(SurrealDBTestSuite)
SurrealDBSuite.connImplementations = make(map[string]conn.Connection)

// // Nhooyr
// Without options
buff := bytes.NewBufferString("")
logData := createLogger(t, buff)
SurrealDBSuite.connImplementations["gorilla"] = gorilla.Create().Logger(logData)
SurrealDBSuite.logBuffer = buff
nbuff := bytes.NewBufferString("")
nlogData := createLogger(t, nbuff)
SurrealDBSuite.connImplementations["nhooyr"] = nhooyr.Create().Logger(nlogData)
SurrealDBSuite.logBuffer = nbuff

// With options
buffOpt := bytes.NewBufferString("")
logDataOpt := createLogger(t, buff)
SurrealDBSuite.connImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(logDataOpt)
SurrealDBSuite.logBuffer = buffOpt
nbuffOpt := bytes.NewBufferString("")
nlogDataOpt := createLogger(t, nbuffOpt)
SurrealDBSuite.connImplementations["nhooyr_opt"] = nhooyr.Create().Logger(nlogDataOpt)
SurrealDBSuite.logBuffer = nbuffOpt

RunWsMap(t, SurrealDBSuite)
}
Expand Down Expand Up @@ -781,7 +781,7 @@ func (s *SurrealDBTestSuite) TestConcurrentOperations() {
}

func (s *SurrealDBTestSuite) TestConnectionBreak() {
ws := gorilla.Create()
ws := nhooyr.Create()
var url string
if currentURL == "" {
url = defaultURL
Expand All @@ -791,10 +791,11 @@ func (s *SurrealDBTestSuite) TestConnectionBreak() {

db := s.openConnection(url, ws)
// Close the connection hard from ws
ws.Conn.Close()
err := ws.Conn.CloseNow()
s.Require().NoError(err)

// Needs to be return error when the connection is closed or broken
_, err := db.Select("users")
_, err = db.Select("users")
s.Require().Error(err)
}

Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ module github.com/surrealdb/surrealdb.go
go 1.20

require (
github.com/coder/websocket v1.8.12
github.com/gorilla/websocket v1.5.0
github.com/stretchr/testify v1.8.4
nhooyr.io/websocket v1.8.17
)

require (
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand All @@ -23,3 +25,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q=
nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y=
nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
98 changes: 36 additions & 62 deletions pkg/conn/gorilla/gorilla.go → pkg/conn/nhooyr/nhooyr.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
package gorilla
package nhooyr

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"reflect"
"strconv"
"sync"
"time"

coderws "github.com/coder/websocket"
"github.com/surrealdb/surrealdb.go/pkg/model"

gorilla "github.com/gorilla/websocket"
"github.com/surrealdb/surrealdb.go/internal/rpc"
"github.com/surrealdb/surrealdb.go/pkg/conn"
"github.com/surrealdb/surrealdb.go/pkg/logger"
Expand All @@ -32,7 +33,7 @@ const (
type Option func(ws *WebSocket) error

type WebSocket struct {
Conn *gorilla.Conn
Conn *coderws.Conn
connLock sync.Mutex
Timeout time.Duration
Option []Option
Expand All @@ -44,29 +45,33 @@ type WebSocket struct {
notificationChannels map[string]chan model.Notification
notificationChannelsLock sync.RWMutex

closeChan chan int
closeError error
close chan int
}

func Create() *WebSocket {
return &WebSocket{
Conn: nil,
closeChan: make(chan int),
close: make(chan int),
responseChannels: make(map[string]chan rpc.RPCResponse),
notificationChannels: make(map[string]chan model.Notification),
Timeout: DefaultTimeout * time.Second,
}
}

func (ws *WebSocket) Connect(url string) (conn.Connection, error) {
dialer := gorilla.DefaultDialer
dialer.EnableCompression = true
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

connection, _, err := dialer.Dial(url, nil)
connection, resp, err := coderws.Dial(ctx, url, nil)
if err != nil {
return nil, err
}

if resp.StatusCode != http.StatusSwitchingProtocols {
resp.Body.Close()
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

ws.Conn = connection

for _, option := range ws.Option {
Expand All @@ -75,7 +80,7 @@ func (ws *WebSocket) Connect(url string) (conn.Connection, error) {
}
}

go ws.initialize()
ws.initialize()
return ws, nil
}

Expand All @@ -98,24 +103,12 @@ func (ws *WebSocket) RawLogger(logData logger.Logger) *WebSocket {
return ws
}

func (ws *WebSocket) SetCompression(compress bool) *WebSocket {
ws.Option = append(ws.Option, func(ws *WebSocket) error {
ws.Conn.EnableWriteCompression(compress)
return nil
})
return ws
}

func (ws *WebSocket) Close() error {
ws.connLock.Lock()
defer ws.connLock.Unlock()
close(ws.closeChan)
err := ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, ""))
if err != nil {
return err
}
close(ws.close)

return ws.Conn.Close()
return ws.Conn.Close(coderws.StatusNormalClosure, "")
}

func (ws *WebSocket) LiveNotifications(liveQueryID string) (chan model.Notification, error) {
Expand Down Expand Up @@ -181,12 +174,6 @@ func (ws *WebSocket) getLiveChannel(id string) (chan model.Notification, bool) {
}

func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, error) {
select {
case <-ws.closeChan:
return nil, ws.closeError
default:
}

id := rand.String(RequestIDLength)
request := &rpc.RPCRequest{
ID: id,
Expand Down Expand Up @@ -224,7 +211,7 @@ func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, err
}

func (ws *WebSocket) read(v interface{}) error {
_, data, err := ws.Conn.ReadMessage()
_, data, err := ws.Conn.Read(context.Background())
if err != nil {
return err
}
Expand All @@ -239,42 +226,29 @@ func (ws *WebSocket) write(v interface{}) error {

ws.connLock.Lock()
defer ws.connLock.Unlock()
return ws.Conn.WriteMessage(gorilla.TextMessage, data)
return ws.Conn.Write(context.Background(), coderws.MessageText, data)
}

func (ws *WebSocket) initialize() {
for {
select {
case <-ws.closeChan:
return
default:
var res rpc.RPCResponse
err := ws.read(&res)
if err != nil {
shouldExit := ws.handleError(err)
if shouldExit {
return
go func() {
for {
select {
case <-ws.close:
return
default:
var res rpc.RPCResponse
err := ws.read(&res)
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
ws.logger.Error(err.Error())
continue
}
continue
go ws.handleResponse(res)
}
go ws.handleResponse(res)
}
}
}

func (ws *WebSocket) handleError(err error) bool {
if errors.Is(err, net.ErrClosed) {
ws.closeError = net.ErrClosed
return true
}
if gorilla.IsUnexpectedCloseError(err) {
ws.closeError = io.ErrClosedPipe
<-ws.closeChan
return true
}

ws.logger.Error(err.Error())
return false
}()
}

func (ws *WebSocket) handleResponse(res rpc.RPCResponse) {
Expand Down
Loading