diff --git a/nats.go b/nats.go index 234d006ba..6fefda7d8 100644 --- a/nats.go +++ b/nats.go @@ -17,7 +17,6 @@ package nats import ( "bufio" "bytes" - "context" "crypto/tls" "crypto/x509" "encoding/base64" @@ -3994,8 +3993,6 @@ func (nc *Conn) respHandler(m *Msg) { rt := nc.respToken(m.Subject) if rt != _EMPTY_ { mch = nc.respMap[rt] - // Delete the key regardless, one response only. - delete(nc.respMap, rt) } else if len(nc.respMap) == 1 { // If the server has rewritten the subject, the response token (rt) // will not match (could be the case with JetStream). If that is the @@ -4071,22 +4068,54 @@ func (nc *Conn) Request(subj string, data []byte, timeout time.Duration) (*Msg, return nc.request(subj, nil, data, timeout) } +// RequestMany will send a request payload and return a channel to receive multiple responses. +// By default, the number of messages received is constrained by the client's timeout. +// +// Use the RequestManyOpt functions to further configure this method's behavior. +// - [RequestManyMaxWait] sets the maximum time to wait for responses (defaults to client's timeout). +// - [RequestManyStallTimer] sets the stall timer, which waits a certain amount of time for the +// first message before starting the stall timer (reset on each delivered message). +// - [RequestManyMaxMessages] sets the maximum number of messages to receive. +// - [RequestManySentinel] stops returning responses once a message with an empty payload is received. +func (nc *Conn) RequestMany(subject string, data []byte, opts ...RequestManyOpt) (*RequestManyResponse, error) { + return nc.requestMany(subject, nil, data, opts...) +} + +// RequestManyMsg will send a Msg request and return a channel to receive multiple responses. +// By default, the number of messages received is constrained by the client's timeout. +// +// Use the RequestManyOpt functions to further configure this method's behavior. +// - [RequestManyMaxWait] sets the maximum time to wait for responses (defaults to client's timeout). +// - [RequestManyStallTimer] sets the stall timer, which waits a certain amount of time for the +// first message before starting the stall timer (reset on each delivered message). +// - [RequestManyMaxMessages] sets the maximum number of messages to receive. +// - [RequestManySentinel] stops returning responses once a message with an empty payload is received. +func (nc *Conn) RequestManyMsg(msg *Msg, opts ...RequestManyOpt) (*RequestManyResponse, error) { + if msg == nil { + return nil, ErrInvalidMsg + } + hdr, err := msg.headerBytes() + if err != nil { + return nil, err + } + return nc.requestMany(msg.Subject, hdr, msg.Data, opts...) +} + type requestManyOpts struct { maxWait time.Duration - gapTimeout *gapTimeoutOpts + stallTimer *stallTimer count int + sentinel bool } -type gapTimeoutOpts struct { +type stallTimer struct { firstMax time.Duration - gap time.Duration - // used to differenciate between default value and user-set value - custom bool + stall time.Duration } type RequestManyOpt func(*requestManyOpts) error -func WithRequestManyMaxWait(maxWait time.Duration) RequestManyOpt { +func RequestManyMaxWait(maxWait time.Duration) RequestManyOpt { return func(opts *requestManyOpts) error { if maxWait <= 0 { return fmt.Errorf("%w: max wait has to be greater than 0", ErrInvalidArg) @@ -4096,24 +4125,23 @@ func WithRequestManyMaxWait(maxWait time.Duration) RequestManyOpt { } } -func WithRequestManyGapTimer(initialMaxWait, gapTime time.Duration) RequestManyOpt { +func RequestManyStallTimer(waitForFirstMsg, stall time.Duration) RequestManyOpt { return func(opts *requestManyOpts) error { - if initialMaxWait <= 0 { + if waitForFirstMsg <= 0 { return fmt.Errorf("%w: initial wait time has to be greater than 0", ErrInvalidArg) } - if gapTime <= 0 { - return fmt.Errorf("%w: gap time has to be greater than 0", ErrInvalidArg) + if stall <= 0 { + return fmt.Errorf("%w: stall time has to be greater than 0", ErrInvalidArg) } - opts.gapTimeout = &gapTimeoutOpts{ - firstMax: initialMaxWait, - gap: gapTime, - custom: true, + opts.stallTimer = &stallTimer{ + firstMax: waitForFirstMsg, + stall: stall, } return nil } } -func WithRequestManyCount(count int) RequestManyOpt { +func RequestManyMaxMessages(count int) RequestManyOpt { return func(opts *requestManyOpts) error { if count <= 0 { return fmt.Errorf("%w: expected request count has to be greater than 0", ErrInvalidArg) @@ -4123,12 +4151,29 @@ func WithRequestManyCount(count int) RequestManyOpt { } } -func (nc *Conn) RequestMany(subject string, data []byte, opts ...RequestManyOpt) (<-chan *Msg, error) { +func RequestManySentinel() RequestManyOpt { + return func(opts *requestManyOpts) error { + opts.sentinel = true + return nil + } +} + +type RequestManyResponse struct { + Msgs chan *Msg + Err error + stop chan struct{} +} + +func (r RequestManyResponse) Stop() { + if r.stop == nil { + return + } + close(r.stop) +} + +func (nc *Conn) requestMany(subject string, hdr, data []byte, opts ...RequestManyOpt) (*RequestManyResponse, error) { reqOpts := &requestManyOpts{ - gapTimeout: &gapTimeoutOpts{ - firstMax: nc.Opts.Timeout, - gap: 300 * time.Millisecond, - }, + maxWait: nc.Opts.Timeout, } for _, opt := range opts { @@ -4136,92 +4181,113 @@ func (nc *Conn) RequestMany(subject string, data []byte, opts ...RequestManyOpt) return nil, err } } - // if user set a custom maxWait and did not set gap timer, we don't want to - // use the defaults - if reqOpts.maxWait != 0 && !reqOpts.gapTimeout.custom { - reqOpts.gapTimeout = nil - } - inbox := nc.newRespInbox() - var mch chan *Msg + var respCh chan *Msg if reqOpts.count > 0 { - mch = make(chan *Msg, reqOpts.count) + respCh = make(chan *Msg, reqOpts.count) } else { - mch = make(chan *Msg, nc.Opts.SubChanLen) + respCh = make(chan *Msg, nc.Opts.SubChanLen) + } + + resp := &RequestManyResponse{ + Msgs: respCh, + stop: make(chan struct{}), + } + mch, token, err := nc.createNewRequestAndSend(subject, hdr, data) + if err != nil { + return nil, err } var maxWait *time.Timer - var gapTimer *time.Timer + var stallTimer *time.Timer if reqOpts.maxWait > 0 { maxWait = globalTimerPool.Get(reqOpts.maxWait) } - if reqOpts.gapTimeout != nil { - gapTimer = globalTimerPool.Get(reqOpts.gapTimeout.firstMax) + if reqOpts.stallTimer != nil { + stallTimer = globalTimerPool.Get(reqOpts.stallTimer.firstMax) } - returnTimers := func() { + cleanup := func() { + close(respCh) + resp.stop = nil if maxWait != nil { globalTimerPool.Put(maxWait) } - if gapTimer != nil { - globalTimerPool.Put(gapTimer) + if stallTimer != nil { + globalTimerPool.Put(stallTimer) } + nc.mu.Lock() + delete(nc.respMap, token) + nc.mu.Unlock() } - ctx, cancel := context.WithCancel(context.Background()) - sub, err := nc.Subscribe(inbox, func(msg *Msg) { - if gapTimer != nil { - gapTimer.Reset(reqOpts.gapTimeout.gap) + + handleMsg := func(msg *Msg) bool { + if reqOpts.sentinel && len(msg.Data) == 0 { + return false } - mch <- msg - }) - if err != nil { - cancel() - returnTimers() - return nil, err - } - sub.SetClosedHandler(func(subject string) { - returnTimers() - close(mch) - cancel() - }) - if reqOpts.count != 0 { - sub.AutoUnsubscribe(reqOpts.count) + respCh <- msg + if reqOpts.count > 0 { + reqOpts.count-- + if reqOpts.count == 0 { + return false + } + } + return true } + go func() { - if maxWait != nil && gapTimer != nil { - select { - case <-maxWait.C: - sub.Unsubscribe() - case <-gapTimer.C: - sub.Unsubscribe() - case <-ctx.Done(): - sub.Unsubscribe() - } - } else if maxWait != nil { - fmt.Println("with max wait") - select { - case <-maxWait.C: - sub.Unsubscribe() - case <-ctx.Done(): - sub.Unsubscribe() - } - } else if gapTimer != nil { - select { - case <-gapTimer.C: - sub.Unsubscribe() - case <-ctx.Done(): - sub.Unsubscribe() + if stallTimer != nil { + for { + select { + case msg, ok := <-mch: + if !ok { + cleanup() + resp.Err = ErrConnectionClosed + return + } + if stallTimer != nil { + stallTimer.Reset(reqOpts.stallTimer.stall) + } + if !handleMsg(msg) { + cleanup() + return + } + + case <-maxWait.C: + cleanup() + return + case <-stallTimer.C: + cleanup() + return + case <-resp.stop: + cleanup() + return + } } } else { - <-ctx.Done() - sub.Unsubscribe() + for { + select { + case msg, ok := <-mch: + if !ok { + cleanup() + resp.Err = ErrConnectionClosed + return + } + if !handleMsg(msg) { + cleanup() + return + } + case <-maxWait.C: + cleanup() + return + case <-resp.stop: + cleanup() + return + } + } } }() - err = nc.PublishRequest(subject, inbox, data) - if err != nil { - sub.Unsubscribe() - return nil, err - } - return mch, nil + + return resp, nil } func (nc *Conn) useOldRequestStyle() bool { @@ -4266,6 +4332,9 @@ func (nc *Conn) newRequest(subj string, hdr, data []byte, timeout time.Duration) select { case msg, ok = <-mch: + nc.mu.Lock() + delete(nc.respMap, token) + nc.mu.Unlock() if !ok { return nil, ErrConnectionClosed } diff --git a/test/basic_test.go b/test/basic_test.go index 51e19b7da..7b065013d 100644 --- a/test/basic_test.go +++ b/test/basic_test.go @@ -1,4 +1,4 @@ -// Copyright 2012-2023 The NATS Authors +// Copyright 2012-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -16,6 +16,7 @@ package test import ( "bytes" "context" + "errors" "fmt" "math" "regexp" @@ -799,105 +800,232 @@ func TestRequestCloseTimeout(t *testing.T) { } func TestRequestMany(t *testing.T) { - s := RunDefaultServer() - defer s.Shutdown() + f := []string{"RequestMany", "RequestManyMsg"} - nc, err := nats.Connect(s.ClientURL(), nats.Timeout(400*time.Millisecond)) - if err != nil { - t.Fatalf("Failed to connect: %v", err) - } - defer nc.Close() + for _, name := range f { + t.Run(name, func(t *testing.T) { - response := []byte("I will help you") - for i := 0; i < 5; i++ { - sub, err := nc.Subscribe("foo", func(m *nats.Msg) { - nc.Publish(m.Reply, response) - }) - if err != nil { - t.Fatalf("Received an error on subscribe: %s", err) - } - defer sub.Unsubscribe() - } + s := RunDefaultServer() + defer s.Shutdown() - tests := []struct { - name string - subject string - opts []nats.RequestManyOpt - minTime time.Duration - expectedMsgs int - }{ - { - name: "default", - subject: "foo", - opts: nil, - minTime: 300 * time.Millisecond, - expectedMsgs: 5, - }, - { - name: "with max wait", - subject: "foo", - opts: []nats.RequestManyOpt{ - nats.WithRequestManyMaxWait(500 * time.Millisecond), - }, - minTime: 500 * time.Millisecond, - expectedMsgs: 5, - }, - { - name: "with count reached", - subject: "foo", - opts: []nats.RequestManyOpt{ - nats.WithRequestManyCount(3), - }, - minTime: 0, - expectedMsgs: 3, - }, - { - name: "with max wait and limit", - subject: "foo", - opts: []nats.RequestManyOpt{ - nats.WithRequestManyMaxWait(500 * time.Millisecond), - nats.WithRequestManyCount(3), - }, - minTime: 0, - expectedMsgs: 3, - }, - { - name: "with count timeout", - subject: "foo", - opts: []nats.RequestManyOpt{ - nats.WithRequestManyCount(6), - }, - minTime: 300 * time.Millisecond, - expectedMsgs: 5, - }, - { - name: "with no responses", - subject: "bar", - minTime: 300, - // no responders - expectedMsgs: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - now := time.Now() - msgs, err := nc.RequestMany(test.subject, nil, test.opts...) + nc, err := nats.Connect(s.ClientURL(), nats.Timeout(400*time.Millisecond)) if err != nil { - t.Fatalf("Received an error on Request test: %s", err) + t.Fatalf("Failed to connect: %v", err) } - - var i int - for msg := range msgs { - fmt.Println(string(msg.Data)) - fmt.Println(msg.Header) - i++ - } - if i != test.expectedMsgs { - t.Fatalf("Expected %d messages, got %d", test.expectedMsgs, i) + defer nc.Close() + + tests := []struct { + name string + subject string + opts []nats.RequestManyOpt + minTime time.Duration + expectedMsgs int + withError error + }{ + { + name: "default, only max wait", + subject: "foo", + opts: nil, + minTime: 400 * time.Millisecond, + expectedMsgs: 6, + }, + { + name: "with stall, short circuit", + subject: "foo", + opts: []nats.RequestManyOpt{ + nats.RequestManyStallTimer(100*time.Millisecond, 50*time.Millisecond), + }, + minTime: 50 * time.Millisecond, + expectedMsgs: 5, + }, + { + name: "with custom wait", + subject: "foo", + opts: []nats.RequestManyOpt{ + nats.RequestManyMaxWait(500 * time.Millisecond), + }, + minTime: 500 * time.Millisecond, + expectedMsgs: 6, + }, + { + name: "with count reached", + subject: "foo", + opts: []nats.RequestManyOpt{ + nats.RequestManyMaxMessages(3), + }, + minTime: 0, + expectedMsgs: 3, + }, + { + name: "with max wait and limit", + subject: "foo", + opts: []nats.RequestManyOpt{ + nats.RequestManyMaxWait(500 * time.Millisecond), + nats.RequestManyMaxMessages(3), + }, + minTime: 0, + expectedMsgs: 3, + }, + { + name: "with count timeout", + subject: "foo", + opts: []nats.RequestManyOpt{ + nats.RequestManyMaxMessages(10), + }, + minTime: 400 * time.Millisecond, + expectedMsgs: 6, + }, + { + name: "sentinel", + subject: "foo", + opts: []nats.RequestManyOpt{ + nats.RequestManySentinel(), + }, + minTime: 100 * time.Millisecond, + expectedMsgs: 5, + }, + { + name: "all options provided, stall timer short circuit", + subject: "foo", + opts: []nats.RequestManyOpt{ + nats.RequestManyMaxWait(500 * time.Millisecond), + nats.RequestManyStallTimer(100*time.Millisecond, 50*time.Millisecond), + nats.RequestManyMaxMessages(10), + nats.RequestManySentinel(), + }, + minTime: 50 * time.Millisecond, + expectedMsgs: 5, + }, + { + name: "all options provided, msg count short circuit", + subject: "foo", + opts: []nats.RequestManyOpt{ + nats.RequestManyMaxWait(500 * time.Millisecond), + nats.RequestManyStallTimer(100*time.Millisecond, 50*time.Millisecond), + nats.RequestManyMaxMessages(3), + nats.RequestManySentinel(), + }, + minTime: 0, + expectedMsgs: 3, + }, + { + name: "all options provided, max wait short circuit", + subject: "foo", + opts: []nats.RequestManyOpt{ + nats.RequestManyMaxWait(50 * time.Millisecond), + nats.RequestManyStallTimer(100*time.Millisecond, 100*time.Millisecond), + nats.RequestManyMaxMessages(10), + nats.RequestManySentinel(), + }, + minTime: 0, + expectedMsgs: 5, + }, + { + name: "all options provided, sentinel short circuit", + subject: "foo", + opts: []nats.RequestManyOpt{ + nats.RequestManyMaxWait(500 * time.Millisecond), + nats.RequestManyStallTimer(100*time.Millisecond, 150*time.Millisecond), + nats.RequestManyMaxMessages(10), + nats.RequestManySentinel(), + }, + minTime: 100 * time.Millisecond, + expectedMsgs: 5, + }, + { + name: "with no responses", + subject: "bar", + minTime: 400 * time.Millisecond, + // no responders + expectedMsgs: 1, + }, + { + name: "invalid options - max wait", + opts: []nats.RequestManyOpt{ + nats.RequestManyMaxWait(-1), + }, + subject: "foo", + withError: nats.ErrInvalidArg, + }, + { + name: "invalid options - stall timer initial", + opts: []nats.RequestManyOpt{ + nats.RequestManyStallTimer(-1, 100*time.Millisecond), + }, + subject: "foo", + withError: nats.ErrInvalidArg, + }, + { + name: "invalid options - stall timer interval", + opts: []nats.RequestManyOpt{ + nats.RequestManyStallTimer(100*time.Millisecond, -1), + }, + subject: "foo", + withError: nats.ErrInvalidArg, + }, + { + name: "invalid options - max messages", + opts: []nats.RequestManyOpt{ + nats.RequestManyMaxMessages(-1), + }, + subject: "foo", + withError: nats.ErrInvalidArg, + }, } - if time.Since(now) < test.minTime || time.Since(now) > test.minTime+100*time.Millisecond { - t.Fatalf("Expected to receive all messages between %v and %v, got %v", test.minTime, test.minTime+100*time.Millisecond, time.Since(now)) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + response := []byte("I will help you") + for i := 0; i < 5; i++ { + sub, err := nc.Subscribe("foo", func(m *nats.Msg) { + nc.Publish(m.Reply, response) + }) + if err != nil { + t.Fatalf("Received an error on subscribe: %s", err) + } + defer sub.Unsubscribe() + } + // after short delay, send sentinel (should come in last) + sub, err := nc.Subscribe("foo", func(m *nats.Msg) { + time.Sleep(100 * time.Millisecond) + nc.Publish(m.Reply, []byte("")) + }) + if err != nil { + t.Fatalf("Received an error on subscribe: %s", err) + } + defer sub.Unsubscribe() + + now := time.Now() + var msgs *nats.RequestManyResponse + if name == "RequestMany" { + msgs, err = nc.RequestMany(test.subject, nil, test.opts...) + } else { + msgs, err = nc.RequestManyMsg(&nats.Msg{Subject: test.subject}, test.opts...) + } + if test.withError != nil { + if !errors.Is(err, test.withError) { + t.Fatalf("Expected error %v, got %v", test.withError, err) + } + return + } + if err != nil { + t.Fatalf("Received an error on Request test: %s", err) + } + + var i int + for _ = range msgs.Msgs { + i++ + } + if msgs.Err != nil { + t.Fatalf("Received an error on Request test: %s", msgs.Err) + } + if i != test.expectedMsgs { + t.Fatalf("Expected %d messages, got %d", test.expectedMsgs, i) + } + if time.Since(now) < test.minTime || time.Since(now) > test.minTime+100*time.Millisecond { + t.Fatalf("Expected to receive all messages between %v and %v, got %v", test.minTime, test.minTime+100*time.Millisecond, time.Since(now)) + } + }) } }) }