Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle RPC errors #168

Merged
merged 8 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,3 +464,15 @@ func (s *SurrealDBTestSuite) TestQueryRaw() {
fmt.Println(created)
fmt.Println(selected)
}

func (s *SurrealDBTestSuite) TestRPCError() {
s.Run("Test valid query", func() {
_, err := surrealdb.Query[[]testUser](s.db, "SELECT * FROM users", map[string]interface{}{})
s.Require().NoError(err)
})

s.Run("Test invalid query", func() {
_, err := surrealdb.Query[[]testUser](s.db, "SELEC * FROM users", map[string]interface{}{})
s.Require().Error(err)
})
}
30 changes: 30 additions & 0 deletions pkg/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ type BaseConnection struct {
responseChannels map[string]chan []byte
responseChannelsLock sync.RWMutex

errorChannels map[string]chan error
errorChannelsLock sync.RWMutex

notificationChannels map[string]chan Notification
notificationChannelsLock sync.RWMutex
}
Expand All @@ -60,6 +63,20 @@ func (bc *BaseConnection) createResponseChannel(id string) (chan []byte, error)
return ch, nil
}

func (bc *BaseConnection) createErrorChannel(id string) (chan error, error) {
bc.errorChannelsLock.Lock()
defer bc.errorChannelsLock.Unlock()

if _, ok := bc.errorChannels[id]; ok {
return nil, fmt.Errorf("%w: %v", constants.ErrIDInUse, id)
}

ch := make(chan error)
bc.errorChannels[id] = ch

return ch, nil
}

func (bc *BaseConnection) createNotificationChannel(liveQueryID string) (chan Notification, error) {
bc.notificationChannelsLock.Lock()
defer bc.notificationChannelsLock.Unlock()
Expand All @@ -80,13 +97,26 @@ func (bc *BaseConnection) removeResponseChannel(id string) {
delete(bc.responseChannels, id)
}

func (bc *BaseConnection) removeErrorChannel(id string) {
bc.errorChannelsLock.Lock()
defer bc.errorChannelsLock.Unlock()
delete(bc.errorChannels, id)
}

func (bc *BaseConnection) getResponseChannel(id string) (chan []byte, bool) {
bc.responseChannelsLock.RLock()
defer bc.responseChannelsLock.RUnlock()
ch, ok := bc.responseChannels[id]
return ch, ok
}

func (bc *BaseConnection) getErrorChannel(id string) (chan error, bool) {
bc.errorChannelsLock.RLock()
defer bc.errorChannelsLock.RUnlock()
ch, ok := bc.errorChannels[id]
return ch, ok
}

func (bc *BaseConnection) getLiveChannel(id string) (chan Notification, bool) {
bc.notificationChannelsLock.RLock()
defer bc.notificationChannelsLock.RUnlock()
Expand Down
22 changes: 22 additions & 0 deletions pkg/connection/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func NewWebSocketConnection(p NewConnectionParams) *WebSocketConnection {
unmarshaler: p.Unmarshaler,

responseChannels: make(map[string]chan []byte),
errorChannels: make(map[string]chan error),
notificationChannels: make(map[string]chan Notification),
},

Expand Down Expand Up @@ -159,7 +160,12 @@ func (ws *WebSocketConnection) Send(dest interface{}, method string, params ...i
if err != nil {
return err
}
errorChan, err := ws.createErrorChannel(id)
if err != nil {
return err
}
defer ws.removeResponseChannel(id)
defer ws.removeErrorChannel(id)

if err := ws.write(request); err != nil {
return err
Expand All @@ -177,6 +183,11 @@ func (ws *WebSocketConnection) Send(dest interface{}, method string, params ...i
return ws.unmarshaler.Unmarshal(resBytes, dest)
}
return nil
case resErr, open := <-errorChan:
if !open {
return errors.New("error channel closed")
}
return resErr
}
}

Expand Down Expand Up @@ -234,6 +245,17 @@ func (ws *WebSocketConnection) handleResponse(res []byte) {
if rpcRes.Error != nil {
err := fmt.Errorf("rpc request err %w", rpcRes.Error)
ws.logger.Error(err.Error())

errChan, ok := ws.getErrorChannel(fmt.Sprintf("%v", rpcRes.ID))
if !ok {
err := fmt.Errorf("unavailable ErrorChannel %+v", rpcRes.ID)
ws.logger.Error(err.Error())
return
}

defer close(errChan)
errChan <- rpcRes.Error

return
}

Expand Down
Loading