diff --git a/languages/go/go-server/sse.go b/languages/go/go-server/sse.go index 5cdf45cd..722ae8ed 100644 --- a/languages/go/go-server/sse.go +++ b/languages/go/go-server/sse.go @@ -52,6 +52,7 @@ func (controller *defaultSseController[T]) startStream() { controller.headersSent = true controller.pingTicker = time.NewTicker(time.Second * 10) go func() { + defer controller.pingTicker.Stop() for { select { case <-controller.pingTicker.C: diff --git a/playground/go/main.go b/playground/go/main.go index 90643988..4afefe82 100644 --- a/playground/go/main.go +++ b/playground/go/main.go @@ -28,6 +28,7 @@ type Message struct { func WatchMessages(params WatchMessagesParams, controller arri.SseController[Message], context arri.DefaultContext) arri.RpcError { // create ticker that fires each second t := time.NewTicker(time.Second) + defer t.Stop() msgCount := 0 for { select { @@ -40,8 +41,6 @@ func WatchMessages(params WatchMessagesParams, controller arri.SseController[Mes CreatedAt: time.Now(), }) case <-controller.Done(): - // cleanup when the connection is closed - t.Stop() return nil } } diff --git a/tests/servers/go/main.go b/tests/servers/go/main.go index 4eb6a542..8c04848a 100644 --- a/tests/servers/go/main.go +++ b/tests/servers/go/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "net/http" "strings" "time" @@ -295,6 +296,39 @@ func SendRecursiveUnion(params RecursiveUnion, _ AppContext) (RecursiveUnion, ar return params, nil } +type AutoReconnectParams struct { + MessageCount uint8 +} + +type AutoReconnectResponse struct { + Count uint8 + Message string +} + +func StreamAutoReconnect(params AutoReconnectParams, controller arri.SseController[AutoReconnectResponse], ctx AppContext) arri.RpcError { + t := time.NewTicker(time.Millisecond) + _, cancel := context.WithCancel(ctx.request.Context()) + defer t.Stop() + var msgCount uint8 = 0 + for { + select { + case <-t.C: + msgCount++ + controller.Push(AutoReconnectResponse{Count: msgCount, Message: "Hello World " + string(msgCount)}) + if msgCount == params.MessageCount { + cancel() + return nil + } + if msgCount > params.MessageCount { + panic("Request was not properly cancelled") + } + case <-controller.Done(): + return nil + } + + } +} + type ChatMessageParams struct { ChannelId string }