Skip to content

Commit

Permalink
[ADDED] RequestMany func to get multiple responses
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Sep 26, 2024
1 parent f0c0194 commit a594231
Show file tree
Hide file tree
Showing 2 changed files with 259 additions and 0 deletions.
154 changes: 154 additions & 0 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package nats
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
Expand Down Expand Up @@ -4070,6 +4071,159 @@ func (nc *Conn) Request(subj string, data []byte, timeout time.Duration) (*Msg,
return nc.request(subj, nil, data, timeout)
}

type requestManyOpts struct {
maxWait time.Duration
gapTimeout *gapTimeoutOpts
count int
}

type gapTimeoutOpts struct {
firstMax time.Duration
gap time.Duration
// used to differenciate between default value and user-set value
custom bool
}

type RequestManyOpt func(*requestManyOpts) error

func WithRequestManyMaxWait(maxWait time.Duration) RequestManyOpt {
return func(opts *requestManyOpts) error {
if maxWait <= 0 {
return fmt.Errorf("%w: max wait has to be greater than 0", ErrInvalidArg)
}
opts.maxWait = maxWait
return nil
}
}

func WithRequestManyGapTimer(initialMaxWait, gapTime time.Duration) RequestManyOpt {
return func(opts *requestManyOpts) error {
if initialMaxWait <= 0 {
return fmt.Errorf("%w: initial wait time has to be greater than 0", ErrInvalidArg)
}
if gapTime <= 0 {
return fmt.Errorf("%w: gap time has to be greater than 0", ErrInvalidArg)
}
opts.gapTimeout = &gapTimeoutOpts{
firstMax: initialMaxWait,
gap: gapTime,
custom: true,
}
return nil
}
}

func WithRequestManyCount(count int) RequestManyOpt {
return func(opts *requestManyOpts) error {
if count <= 0 {
return fmt.Errorf("%w: expected request count has to be greater than 0", ErrInvalidArg)
}
opts.count = count
return nil
}
}

func (nc *Conn) RequestMany(subject string, data []byte, opts ...RequestManyOpt) (<-chan *Msg, error) {
reqOpts := &requestManyOpts{
gapTimeout: &gapTimeoutOpts{
firstMax: nc.Opts.Timeout,
gap: 300 * time.Millisecond,
},
}

for _, opt := range opts {
if err := opt(reqOpts); err != nil {
return nil, err
}
}
// if user set a custom maxWait and did not set gap timer, we don't want to
// use the defaults
if reqOpts.maxWait != 0 && !reqOpts.gapTimeout.custom {
reqOpts.gapTimeout = nil
}

inbox := nc.newRespInbox()
var mch chan *Msg
if reqOpts.count > 0 {
mch = make(chan *Msg, reqOpts.count)
} else {
mch = make(chan *Msg, nc.Opts.SubChanLen)
}

var maxWait *time.Timer
var gapTimer *time.Timer
if reqOpts.maxWait > 0 {
maxWait = globalTimerPool.Get(reqOpts.maxWait)
}
if reqOpts.gapTimeout != nil {
gapTimer = globalTimerPool.Get(reqOpts.gapTimeout.firstMax)
}
returnTimers := func() {
if maxWait != nil {
globalTimerPool.Put(maxWait)
}
if gapTimer != nil {
globalTimerPool.Put(gapTimer)
}
}
ctx, cancel := context.WithCancel(context.Background())
sub, err := nc.Subscribe(inbox, func(msg *Msg) {
if gapTimer != nil {
gapTimer.Reset(reqOpts.gapTimeout.gap)
}
mch <- msg
})
if err != nil {
cancel()
returnTimers()
return nil, err
}
sub.SetClosedHandler(func(subject string) {
returnTimers()
close(mch)
cancel()
})
if reqOpts.count != 0 {
sub.AutoUnsubscribe(reqOpts.count)
}
go func() {
if maxWait != nil && gapTimer != nil {
select {
case <-maxWait.C:
sub.Unsubscribe()
case <-gapTimer.C:
sub.Unsubscribe()
case <-ctx.Done():
sub.Unsubscribe()
}
} else if maxWait != nil {
fmt.Println("with max wait")
select {
case <-maxWait.C:
sub.Unsubscribe()
case <-ctx.Done():
sub.Unsubscribe()
}
} else if gapTimer != nil {
select {
case <-gapTimer.C:
sub.Unsubscribe()
case <-ctx.Done():
sub.Unsubscribe()
}
} else {
<-ctx.Done()
sub.Unsubscribe()
}
}()
err = nc.PublishRequest(subject, inbox, data)
if err != nil {
sub.Unsubscribe()
return nil, err
}
return mch, nil
}

func (nc *Conn) useOldRequestStyle() bool {
nc.mu.RLock()
r := nc.Opts.UseOldRequestStyle
Expand Down
105 changes: 105 additions & 0 deletions test/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,111 @@ func TestRequestCloseTimeout(t *testing.T) {
}
}

func TestRequestMany(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()

nc, err := nats.Connect(s.ClientURL(), nats.Timeout(400*time.Millisecond))
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer nc.Close()

response := []byte("I will help you")
for i := 0; i < 5; i++ {
sub, err := nc.Subscribe("foo", func(m *nats.Msg) {
nc.Publish(m.Reply, response)
})
if err != nil {
t.Fatalf("Received an error on subscribe: %s", err)
}
defer sub.Unsubscribe()
}

tests := []struct {
name string
subject string
opts []nats.RequestManyOpt
minTime time.Duration
expectedMsgs int
}{
{
name: "default",
subject: "foo",
opts: nil,
minTime: 300 * time.Millisecond,
expectedMsgs: 5,
},
{
name: "with max wait",
subject: "foo",
opts: []nats.RequestManyOpt{
nats.WithRequestManyMaxWait(500 * time.Millisecond),
},
minTime: 500 * time.Millisecond,
expectedMsgs: 5,
},
{
name: "with count reached",
subject: "foo",
opts: []nats.RequestManyOpt{
nats.WithRequestManyCount(3),
},
minTime: 0,
expectedMsgs: 3,
},
{
name: "with max wait and limit",
subject: "foo",
opts: []nats.RequestManyOpt{
nats.WithRequestManyMaxWait(500 * time.Millisecond),
nats.WithRequestManyCount(3),
},
minTime: 0,
expectedMsgs: 3,
},
{
name: "with count timeout",
subject: "foo",
opts: []nats.RequestManyOpt{
nats.WithRequestManyCount(6),
},
minTime: 300 * time.Millisecond,
expectedMsgs: 5,
},
{
name: "with no responses",
subject: "bar",
minTime: 300,
// no responders
expectedMsgs: 1,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
now := time.Now()
msgs, err := nc.RequestMany(test.subject, nil, test.opts...)
if err != nil {
t.Fatalf("Received an error on Request test: %s", err)
}

var i int
for msg := range msgs {
fmt.Println(string(msg.Data))
fmt.Println(msg.Header)
i++
}
if i != test.expectedMsgs {
t.Fatalf("Expected %d messages, got %d", test.expectedMsgs, i)
}
if time.Since(now) < test.minTime || time.Since(now) > test.minTime+100*time.Millisecond {
t.Fatalf("Expected to receive all messages between %v and %v, got %v", test.minTime, test.minTime+100*time.Millisecond, time.Since(now))
}
})
}
}

func TestFlushInCB(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()
Expand Down

0 comments on commit a594231

Please sign in to comment.