From b7ca47c207441abb9f0a4818ff615537e8a8f227 Mon Sep 17 00:00:00 2001 From: Artem Date: Tue, 31 Oct 2023 12:46:33 +0100 Subject: [PATCH] Fix: notify closed client --- cmd/api/handler/websocket/client.go | 10 ++++++++++ cmd/api/handler/websocket/client_test.go | 14 ++++++++++++++ cmd/api/init.go | 12 ++++++------ 3 files changed, 30 insertions(+), 6 deletions(-) create mode 100644 cmd/api/handler/websocket/client_test.go diff --git a/cmd/api/handler/websocket/client.go b/cmd/api/handler/websocket/client.go index cee114fe..7ed2a61b 100644 --- a/cmd/api/handler/websocket/client.go +++ b/cmd/api/handler/websocket/client.go @@ -7,6 +7,7 @@ import ( "context" "io" "net" + "sync/atomic" "time" "github.com/dipdup-io/workerpool" @@ -43,14 +44,19 @@ type Client struct { filters *Filters ch chan any g workerpool.Group + + closed *atomic.Bool } func newClient(id uint64, manager *Manager) *Client { + closed := new(atomic.Bool) + closed.Store(false) return &Client{ id: id, manager: manager, ch: make(chan any, 1024), g: workerpool.NewGroup(), + closed: closed, } } @@ -102,11 +108,15 @@ func (c *Client) DetachFilters(msg Unsubscribe) error { } func (c *Client) Notify(msg any) { + if c.closed.Load() { + return + } c.ch <- msg } func (c *Client) Close() error { c.g.Wait() + c.closed.Store(true) close(c.ch) return nil } diff --git a/cmd/api/handler/websocket/client_test.go b/cmd/api/handler/websocket/client_test.go new file mode 100644 index 00000000..400d06c7 --- /dev/null +++ b/cmd/api/handler/websocket/client_test.go @@ -0,0 +1,14 @@ +package websocket + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNotifyClosedClient(t *testing.T) { + client := newClient(10, nil) + err := client.Close() + require.NoError(t, err, "closing client") + client.Notify("test") +} diff --git a/cmd/api/init.go b/cmd/api/init.go index 520f8f2f..78e1f299 100644 --- a/cmd/api/init.go +++ b/cmd/api/init.go @@ -288,7 +288,7 @@ func initHandlers(ctx context.Context, e *echo.Echo, cfg Config, db postgres.Sto v1.GET("/swagger/*", echoSwagger.WrapHandler) - // initWebsocket(ctx, db, v1) + initWebsocket(ctx, db, v1) log.Info().Msg("API routes:") for _, route := range e.Routes() { @@ -300,8 +300,8 @@ var ( wsManager *websocket.Manager ) -// func initWebsocket(ctx context.Context, db postgres.Storage, group *echo.Group) { -// wsManager = websocket.NewManager(db, db.Blocks, db.Tx) -// wsManager.Start(ctx) -// group.GET("/ws", wsManager.Handle) -// } +func initWebsocket(ctx context.Context, db postgres.Storage, group *echo.Group) { + wsManager = websocket.NewManager(db, db.Blocks, db.Tx) + wsManager.Start(ctx) + group.GET("/ws", wsManager.Handle) +}