Skip to content

Commit

Permalink
Improve RequestMany
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Sep 30, 2024
1 parent a594231 commit f80e715
Show file tree
Hide file tree
Showing 2 changed files with 378 additions and 181 deletions.
243 changes: 156 additions & 87 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package nats
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
Expand Down Expand Up @@ -3994,8 +3993,6 @@ func (nc *Conn) respHandler(m *Msg) {
rt := nc.respToken(m.Subject)
if rt != _EMPTY_ {
mch = nc.respMap[rt]
// Delete the key regardless, one response only.
delete(nc.respMap, rt)
} else if len(nc.respMap) == 1 {
// If the server has rewritten the subject, the response token (rt)
// will not match (could be the case with JetStream). If that is the
Expand Down Expand Up @@ -4071,22 +4068,54 @@ func (nc *Conn) Request(subj string, data []byte, timeout time.Duration) (*Msg,
return nc.request(subj, nil, data, timeout)
}

// RequestMany will send a request payload and return a channel to receive multiple responses.
// By default, the number of messages received is constrained by the client's timeout.
//
// Use the RequestManyOpt functions to further configure this method's behavior.
// - [RequestManyMaxWait] sets the maximum time to wait for responses (defaults to client's timeout).
// - [RequestManyStallTimer] sets the stall timer, which waits a certain amount of time for the
// first message before starting the stall timer (reset on each delivered message).
// - [RequestManyMaxMessages] sets the maximum number of messages to receive.
// - [RequestManySentinel] stops returning responses once a message with an empty payload is received.
func (nc *Conn) RequestMany(subject string, data []byte, opts ...RequestManyOpt) (*RequestManyResponse, error) {
return nc.requestMany(subject, nil, data, opts...)
}

// RequestManyMsg will send a Msg request and return a channel to receive multiple responses.
// By default, the number of messages received is constrained by the client's timeout.
//
// Use the RequestManyOpt functions to further configure this method's behavior.
// - [RequestManyMaxWait] sets the maximum time to wait for responses (defaults to client's timeout).
// - [RequestManyStallTimer] sets the stall timer, which waits a certain amount of time for the
// first message before starting the stall timer (reset on each delivered message).
// - [RequestManyMaxMessages] sets the maximum number of messages to receive.
// - [RequestManySentinel] stops returning responses once a message with an empty payload is received.
func (nc *Conn) RequestManyMsg(msg *Msg, opts ...RequestManyOpt) (*RequestManyResponse, error) {
if msg == nil {
return nil, ErrInvalidMsg
}
hdr, err := msg.headerBytes()
if err != nil {
return nil, err
}
return nc.requestMany(msg.Subject, hdr, msg.Data, opts...)
}

type requestManyOpts struct {
maxWait time.Duration
gapTimeout *gapTimeoutOpts
stallTimer *stallTimer
count int
sentinel bool
}

type gapTimeoutOpts struct {
type stallTimer struct {
firstMax time.Duration
gap time.Duration
// used to differenciate between default value and user-set value
custom bool
stall time.Duration
}

type RequestManyOpt func(*requestManyOpts) error

func WithRequestManyMaxWait(maxWait time.Duration) RequestManyOpt {
func RequestManyMaxWait(maxWait time.Duration) RequestManyOpt {
return func(opts *requestManyOpts) error {
if maxWait <= 0 {
return fmt.Errorf("%w: max wait has to be greater than 0", ErrInvalidArg)
Expand All @@ -4096,24 +4125,23 @@ func WithRequestManyMaxWait(maxWait time.Duration) RequestManyOpt {
}
}

func WithRequestManyGapTimer(initialMaxWait, gapTime time.Duration) RequestManyOpt {
func RequestManyStallTimer(waitForFirstMsg, stall time.Duration) RequestManyOpt {
return func(opts *requestManyOpts) error {
if initialMaxWait <= 0 {
if waitForFirstMsg <= 0 {
return fmt.Errorf("%w: initial wait time has to be greater than 0", ErrInvalidArg)
}
if gapTime <= 0 {
return fmt.Errorf("%w: gap time has to be greater than 0", ErrInvalidArg)
if stall <= 0 {
return fmt.Errorf("%w: stall time has to be greater than 0", ErrInvalidArg)
}
opts.gapTimeout = &gapTimeoutOpts{
firstMax: initialMaxWait,
gap: gapTime,
custom: true,
opts.stallTimer = &stallTimer{
firstMax: waitForFirstMsg,
stall: stall,
}
return nil
}
}

func WithRequestManyCount(count int) RequestManyOpt {
func RequestManyMaxMessages(count int) RequestManyOpt {
return func(opts *requestManyOpts) error {
if count <= 0 {
return fmt.Errorf("%w: expected request count has to be greater than 0", ErrInvalidArg)
Expand All @@ -4123,105 +4151,143 @@ func WithRequestManyCount(count int) RequestManyOpt {
}
}

func (nc *Conn) RequestMany(subject string, data []byte, opts ...RequestManyOpt) (<-chan *Msg, error) {
func RequestManySentinel() RequestManyOpt {
return func(opts *requestManyOpts) error {
opts.sentinel = true
return nil
}
}

type RequestManyResponse struct {
Msgs chan *Msg
Err error
stop chan struct{}
}

func (r RequestManyResponse) Stop() {
if r.stop == nil {
return
}
close(r.stop)
}

func (nc *Conn) requestMany(subject string, hdr, data []byte, opts ...RequestManyOpt) (*RequestManyResponse, error) {
reqOpts := &requestManyOpts{
gapTimeout: &gapTimeoutOpts{
firstMax: nc.Opts.Timeout,
gap: 300 * time.Millisecond,
},
maxWait: nc.Opts.Timeout,
}

for _, opt := range opts {
if err := opt(reqOpts); err != nil {
return nil, err
}
}
// if user set a custom maxWait and did not set gap timer, we don't want to
// use the defaults
if reqOpts.maxWait != 0 && !reqOpts.gapTimeout.custom {
reqOpts.gapTimeout = nil
}

inbox := nc.newRespInbox()
var mch chan *Msg
var respCh chan *Msg
if reqOpts.count > 0 {
mch = make(chan *Msg, reqOpts.count)
respCh = make(chan *Msg, reqOpts.count)
} else {
mch = make(chan *Msg, nc.Opts.SubChanLen)
respCh = make(chan *Msg, nc.Opts.SubChanLen)
}

resp := &RequestManyResponse{
Msgs: respCh,
stop: make(chan struct{}),
}
mch, token, err := nc.createNewRequestAndSend(subject, hdr, data)
if err != nil {
return nil, err
}

var maxWait *time.Timer
var gapTimer *time.Timer
var stallTimer *time.Timer
if reqOpts.maxWait > 0 {
maxWait = globalTimerPool.Get(reqOpts.maxWait)
}
if reqOpts.gapTimeout != nil {
gapTimer = globalTimerPool.Get(reqOpts.gapTimeout.firstMax)
if reqOpts.stallTimer != nil {
stallTimer = globalTimerPool.Get(reqOpts.stallTimer.firstMax)
}
returnTimers := func() {
cleanup := func() {
close(respCh)
resp.stop = nil
if maxWait != nil {
globalTimerPool.Put(maxWait)
}
if gapTimer != nil {
globalTimerPool.Put(gapTimer)
if stallTimer != nil {
globalTimerPool.Put(stallTimer)
}
nc.mu.Lock()
delete(nc.respMap, token)
nc.mu.Unlock()
}
ctx, cancel := context.WithCancel(context.Background())
sub, err := nc.Subscribe(inbox, func(msg *Msg) {
if gapTimer != nil {
gapTimer.Reset(reqOpts.gapTimeout.gap)

handleMsg := func(msg *Msg) bool {
if reqOpts.sentinel && len(msg.Data) == 0 {
return false
}
mch <- msg
})
if err != nil {
cancel()
returnTimers()
return nil, err
}
sub.SetClosedHandler(func(subject string) {
returnTimers()
close(mch)
cancel()
})
if reqOpts.count != 0 {
sub.AutoUnsubscribe(reqOpts.count)
respCh <- msg
if reqOpts.count > 0 {
reqOpts.count--
if reqOpts.count == 0 {
return false
}
}
return true
}

go func() {
if maxWait != nil && gapTimer != nil {
select {
case <-maxWait.C:
sub.Unsubscribe()
case <-gapTimer.C:
sub.Unsubscribe()
case <-ctx.Done():
sub.Unsubscribe()
}
} else if maxWait != nil {
fmt.Println("with max wait")
select {
case <-maxWait.C:
sub.Unsubscribe()
case <-ctx.Done():
sub.Unsubscribe()
}
} else if gapTimer != nil {
select {
case <-gapTimer.C:
sub.Unsubscribe()
case <-ctx.Done():
sub.Unsubscribe()
if stallTimer != nil {
for {
select {
case msg, ok := <-mch:
if !ok {
cleanup()
resp.Err = ErrConnectionClosed
return
}
if stallTimer != nil {
stallTimer.Reset(reqOpts.stallTimer.stall)
}
if !handleMsg(msg) {
cleanup()
return
}

case <-maxWait.C:
cleanup()
return
case <-stallTimer.C:
cleanup()
return
case <-resp.stop:
cleanup()
return
}
}
} else {
<-ctx.Done()
sub.Unsubscribe()
for {
select {
case msg, ok := <-mch:
if !ok {
cleanup()
resp.Err = ErrConnectionClosed
return
}
if !handleMsg(msg) {
cleanup()
return
}
case <-maxWait.C:
cleanup()
return
case <-resp.stop:
cleanup()
return
}
}
}
}()
err = nc.PublishRequest(subject, inbox, data)
if err != nil {
sub.Unsubscribe()
return nil, err
}
return mch, nil

return resp, nil
}

func (nc *Conn) useOldRequestStyle() bool {
Expand Down Expand Up @@ -4266,6 +4332,9 @@ func (nc *Conn) newRequest(subj string, hdr, data []byte, timeout time.Duration)

select {
case msg, ok = <-mch:
nc.mu.Lock()
delete(nc.respMap, token)
nc.mu.Unlock()
if !ok {
return nil, ErrConnectionClosed
}
Expand Down
Loading

0 comments on commit f80e715

Please sign in to comment.