Skip to content

Commit

Permalink
Merge pull request #792 from twmb/769
Browse files Browse the repository at this point in the history
kgo: allow record ctx cancelation to propagate a bit more
  • Loading branch information
twmb authored Jul 29, 2024
2 parents 718591a + 305d8dc commit afcb32b
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 29 deletions.
17 changes: 14 additions & 3 deletions pkg/kgo/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func isRetryableBrokerErr(err error) bool {
}
// We could have a retryable producer ID failure, which then bubbled up
// as errProducerIDLoadFail so as to be retried later.
if errors.Is(err, errProducerIDLoadFail) {
if pe := (*errProducerIDLoadFail)(nil); errors.As(err, &pe) {
return true
}
// We could have chosen a broker, and then a concurrent metadata update
Expand Down Expand Up @@ -139,8 +139,6 @@ var (
// restart a new connection ourselves.
errSaslReauthLoop = errors.New("the broker is repeatedly giving us sasl lifetimes that are too short to write a request")

errProducerIDLoadFail = errors.New("unable to initialize a producer ID due to request failures")

// A temporary error returned when Kafka replies with a different
// correlation ID than we were expecting for the request the client
// issued.
Expand Down Expand Up @@ -224,6 +222,19 @@ type ErrFirstReadEOF struct {
err error
}

type errProducerIDLoadFail struct {
err error
}

func (e *errProducerIDLoadFail) Error() string {
if e.err == nil {
return "unable to initialize a producer ID due to request failures"
}
return fmt.Sprintf("unable to initialize a producer ID due to request failures: %v", e.err)
}

func (e *errProducerIDLoadFail) Unwrap() error { return e.err }

const (
firstReadSASL uint8 = iota
firstReadTLS
Expand Down
26 changes: 26 additions & 0 deletions pkg/kgo/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@ var (
npartitionsAt int64
)

type slowConn struct {
net.Conn
}

func (s *slowConn) Write(p []byte) (int, error) {
time.Sleep(100 * time.Millisecond)
return s.Conn.Write(p)
}

func (s *slowConn) Read(p []byte) (int, error) {
time.Sleep(100 * time.Millisecond)
return s.Conn.Read(p)
}

type slowDialer struct {
d net.Dialer
}

func (s *slowDialer) DialContext(ctx context.Context, network, host string) (net.Conn, error) {
c, err := s.d.DialContext(ctx, network, host)
if err != nil {
return nil, err
}
return &slowConn{c}, nil
}

func init() {
var err error
if n, _ := strconv.Atoi(os.Getenv("KGO_TEST_RF")); n > 0 {
Expand Down
124 changes: 124 additions & 0 deletions pkg/kgo/produce_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"
"sync/atomic"
"testing"
"time"

"github.com/twmb/franz-go/pkg/kbin"
"github.com/twmb/franz-go/pkg/kmsg"
Expand Down Expand Up @@ -71,6 +72,129 @@ func TestClient_Produce(t *testing.T) {
}
}

// The produce below actually SUCCEEDS if the code for 769 is not working
// correctly. 769 is about a hanging produce not obeying a record cancelation,
// but we can simulate the same thing.
func TestIssue769(t *testing.T) {
t.Parallel()

topic, cleanup := tmpTopic(t)
defer cleanup()

cl, _ := newTestClient(
DefaultProduceTopic(topic),
UnknownTopicRetries(-1),
Dialer(new(slowDialer).DialContext),
)
defer cl.Close()

ctx, cancel := context.WithCancel(context.Background())
cancel()
canceled := &Record{Value: []byte("foo"), Context: ctx}
okay := &Record{Value: []byte("foo")}

// First check: ensure that an already-canceled record bails right
// away. This actually bails in the unknown-topic bit of logic,
// although there is no way to surface that to the end user.
{
done := make(chan struct{})
var rerr error
cl.Produce(context.Background(), canceled, func(_ *Record, err2 error) {
defer close(done)
rerr = err2
})
timer := time.NewTimer(3 * time.Second)
select {
case <-done:
case <-timer.C:
t.Fatal("expected record to fail within 3s")
}
if !errors.Is(rerr, context.Canceled) {
t.Errorf("got %v != exp context.Canceled", rerr)
}
}

// We have to produce one record successfully to ensure the topic is
// known, then we modify the guts of the client to forget the loaded
// producer ID.
{
done := make(chan struct{})
var rerr error
cl.Produce(context.Background(), okay, func(_ *Record, err2 error) {
defer close(done)
rerr = err2
})
<-done
if rerr != nil {
t.Fatal("unexpected error on the first produce")
}
cl.producer.id.Store(&producerID{
id: -1,
epoch: -1,
err: errReloadProducerID,
})
}

// With a loaded topic but forgotten producer ID, we now ensure that a
// canceled record fails in the producer ID portion.
{
done := make(chan struct{})
var rerr error
cl.Produce(context.Background(), canceled, func(_ *Record, err2 error) {
defer close(done)
rerr = err2
})
timer := time.NewTimer(3 * time.Second)
select {
case <-done:
case <-timer.C:
t.Fatal("expected record to fail within 3s")
}
if pe := (*errProducerIDLoadFail)(nil); !errors.As(rerr, &pe) || !(errors.Is(pe.err, context.Canceled) || strings.Contains(pe.err.Error(), "canceled")) {
t.Errorf("got %v != exp errProducerIDLoadFail{context.Canceled}", rerr)
}
}

// We now produce successfully again to ensure the next attempt fails
// after the producer ID stage.
{
done := make(chan struct{})
var rerr error
cl.Produce(context.Background(), okay, func(_ *Record, err2 error) {
defer close(done)
rerr = err2
})
cl.Flush(context.Background())
<-done
if rerr != nil {
t.Fatal("unexpected error on the first produce")
}
}

// This fails before the produce request is issued, which is the furthest we
// can take the test. We do not use record context's in issued produce requests.
{
done := make(chan struct{})
var rerr error
cl.Produce(context.Background(), canceled, func(_ *Record, err2 error) {
defer close(done)
rerr = err2
})
timer := time.NewTimer(3 * time.Second)
select {
case <-done:
case <-timer.C:
t.Fatal("expected record to fail within 3s")
}
if pe := (*errProducerIDLoadFail)(nil); errors.As(rerr, &pe) {
t.Error("unexpectedly got errProducerIDLoadFail")
}
if !errors.Is(rerr, context.Canceled) {
t.Errorf("got %v != context.Canceled", rerr)
}
}
}

// This file contains golden tests against kmsg AppendTo's to ensure our custom
// encoding is correct.

Expand Down
26 changes: 17 additions & 9 deletions pkg/kgo/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,11 @@ func (cl *Client) TryProduce(
// retries. If any of these conditions are hit and it is currently safe to fail
// records, all buffered records for the relevant partition are failed. Only
// the first record's context in a batch is considered when determining whether
// the batch should be canceled.
// the batch should be canceled. A record is not safe to fail if the client
// is idempotently producing and a request has been sent; in this case, the
// client cannot know if the broker actually processed the request (if so, then
// removing the records from the client will create errors the next time you
// produce).
//
// If the client is transactional and a transaction has not been begun, the
// promise is immediately called with an error corresponding to not being in a
Expand Down Expand Up @@ -679,7 +683,7 @@ func (cl *Client) ProducerID(ctx context.Context) (int64, int16, error) {

go func() {
defer close(done)
id, epoch, err = cl.producerID()
id, epoch, err = cl.producerID(ctx2fn(ctx))
}()

select {
Expand All @@ -701,7 +705,7 @@ var errReloadProducerID = errors.New("producer id needs reloading")
// initProducerID initializes the client's producer ID for idempotent
// producing only (no transactions, which are more special). After the first
// load, this clears all buffered unknown topics.
func (cl *Client) producerID() (int64, int16, error) {
func (cl *Client) producerID(ctxFn func() context.Context) (int64, int16, error) {
p := &cl.producer

id := p.id.Load().(*producerID)
Expand Down Expand Up @@ -730,7 +734,7 @@ func (cl *Client) producerID() (int64, int16, error) {
}
p.id.Store(id)
} else {
newID, keep := cl.doInitProducerID(id.id, id.epoch)
newID, keep := cl.doInitProducerID(ctxFn, id.id, id.epoch)
if keep {
id = newID
// Whenever we have a new producer ID, we need
Expand All @@ -748,7 +752,7 @@ func (cl *Client) producerID() (int64, int16, error) {
id = &producerID{
id: id.id,
epoch: id.epoch,
err: errProducerIDLoadFail,
err: &errProducerIDLoadFail{newID.err},
}
}
}
Expand Down Expand Up @@ -825,7 +829,7 @@ func (cl *Client) failProducerID(id int64, epoch int16, err error) {

// doInitProducerID inits the idempotent ID and potentially the transactional
// producer epoch, returning whether to keep the result.
func (cl *Client) doInitProducerID(lastID int64, lastEpoch int16) (*producerID, bool) {
func (cl *Client) doInitProducerID(ctxFn func() context.Context, lastID int64, lastEpoch int16) (*producerID, bool) {
cl.cfg.logger.Log(LogLevelInfo, "initializing producer id")
req := kmsg.NewPtrInitProducerIDRequest()
req.TransactionalID = cl.cfg.txnID
Expand All @@ -835,7 +839,8 @@ func (cl *Client) doInitProducerID(lastID int64, lastEpoch int16) (*producerID,
req.TransactionTimeoutMillis = int32(cl.cfg.txnTimeout.Milliseconds())
}

resp, err := req.RequestWith(cl.ctx, cl)
ctx := ctxFn()
resp, err := req.RequestWith(ctx, cl)
if err != nil {
if errors.Is(err, errUnknownRequestKey) || errors.Is(err, errBrokerTooOld) {
cl.cfg.logger.Log(LogLevelInfo, "unable to initialize a producer id because the broker is too old or the client is pinned to an old version, continuing without a producer id")
Expand Down Expand Up @@ -940,13 +945,14 @@ func (cl *Client) addUnknownTopicRecord(pr promisedRec) {
}
unknown.buffered = append(unknown.buffered, pr)
if len(unknown.buffered) == 1 {
go cl.waitUnknownTopic(pr.ctx, pr.Topic, unknown)
go cl.waitUnknownTopic(pr.ctx, pr.Record.Context, pr.Topic, unknown)
}
}

// waitUnknownTopic waits for a notification
func (cl *Client) waitUnknownTopic(
rctx context.Context,
pctx context.Context, // context passed to Produce
rctx context.Context, // context on the record itself
topic string,
unknown *unknownTopicProduces,
) {
Expand Down Expand Up @@ -974,6 +980,8 @@ func (cl *Client) waitUnknownTopic(

for err == nil {
select {
case <-pctx.Done():
err = pctx.Err()
case <-rctx.Done():
err = rctx.Err()
case <-cl.ctx.Done():
Expand Down
Loading

0 comments on commit afcb32b

Please sign in to comment.