Skip to content

Commit

Permalink
improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Sep 17, 2024
1 parent 03aa027 commit 39f3021
Showing 1 changed file with 157 additions and 58 deletions.
215 changes: 157 additions & 58 deletions x/mongo/driver/topology/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package topology
import (
"context"
"errors"
"io"
"net"
"regexp"
"sync"
Expand Down Expand Up @@ -1159,12 +1158,7 @@ func TestBackgroundRead(t *testing.T) {
})

p := newPool(
poolConfig{},
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return net.Dial("tcp", addr.String())
})
}),
poolConfig{Address: address.Address(addr.String())},
)
defer p.close(context.Background())
err := p.ready()
Expand All @@ -1182,7 +1176,7 @@ func TestBackgroundRead(t *testing.T) {
assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil")
close(errsCh) // this line causes a double close if BGReadCallback is ever called.
})
t.Run("timeout on reading the message header", func(t *testing.T) {
t.Run("timeout reading message header, successful background read", func(t *testing.T) {
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
Expand All @@ -1192,26 +1186,19 @@ func TestBackgroundRead(t *testing.T) {

const timeout = 10 * time.Millisecond

cleanup := make(chan struct{})
defer close(cleanup)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
defer func() {
<-cleanup
_ = nc.Close()
}()

// Wait until the operation times out, then write an full message.
time.Sleep(timeout * 2)
_, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0})
_, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0, 0, 0})
noerr(t, err)
})

p := newPool(
poolConfig{},
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return net.Dial("tcp", addr.String())
})
}),
poolConfig{Address: address.Address(addr.String())},
)
defer p.close(context.Background())
err := p.ready()
Expand All @@ -1228,26 +1215,63 @@ func TestBackgroundRead(t *testing.T) {
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
noerr(t, err)
wantErrs := []*regexp.Regexp{
regexp.MustCompile(
`^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
),
}
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
for i, err := range bgErrs {
if i < len(wantErrs) {
assert.True(t, wantErrs[i].MatchString(err.Error()), "error %q does not match pattern %q", err, wantErrs[i])
} else {
assert.Fail(t, "unexpected error", "got unexpected error: %v", err)
}
require.Len(t, bgErrs, 0, "expected no error from bgRead()")
})
t.Run("timeout reading message header, incomplete head during background read", func(t *testing.T) {
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

const timeout = 10 * time.Millisecond

addr := bootstrapConnections(t, 1, func(nc net.Conn) {
defer func() {
_ = nc.Close()
}()

// Wait until the operation times out, then write an incomplete head.
time.Sleep(timeout * 2)
_, err := nc.Write([]byte{10, 0, 0})
noerr(t, err)
})

p := newPool(
poolConfig{Address: address.Address(addr.String())},
)
defer p.close(context.Background())
err := p.ready()
noerr(t, err)

conn, err := p.checkOut(context.Background())
noerr(t, err)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
noerr(t, err)
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
require.Len(t, bgErrs, 1, "expected 1 error from bgRead()")
assert.EqualError(t, bgErrs[0], "error reading the message size: unexpected EOF")
})
t.Run("timeout on reading the full message", func(t *testing.T) {
t.Run("timeout reading message header, background read timeout", func(t *testing.T) {
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
Expand All @@ -1265,23 +1289,69 @@ func TestBackgroundRead(t *testing.T) {
_ = nc.Close()
}()

// Wait until the operation times out, then write an incomplete
// message.
time.Sleep(timeout * 2)
_, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0})
noerr(t, err)
})

p := newPool(
poolConfig{Address: address.Address(addr.String())},
)
defer p.close(context.Background())
err := p.ready()
noerr(t, err)

conn, err := p.checkOut(context.Background())
noerr(t, err)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
noerr(t, err)
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
require.Len(t, bgErrs, 1, "expected 1 error from bgRead()")
wantErr := regexp.MustCompile(
`^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, wantErr.MatchString(bgErrs[0].Error()), "error %q does not match pattern %q", bgErrs[0], wantErr)
})
t.Run("timeout reading full message, successful background read", func(t *testing.T) {
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

const timeout = 10 * time.Millisecond

addr := bootstrapConnections(t, 1, func(nc net.Conn) {
defer func() {
_ = nc.Close()
}()

var err error
_, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1})
noerr(t, err)
time.Sleep(timeout * 2)
// write a complete message
_, err = nc.Write([]byte{2, 3, 4})
noerr(t, err)
})

p := newPool(
poolConfig{},
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
conn, err := net.Dial("tcp", addr.String())
noerr(t, err)
return newLimitConn(conn, 10), nil
})
}),
poolConfig{Address: address.Address(addr.String())},
)
defer p.close(context.Background())
err := p.ready()
Expand All @@ -1298,36 +1368,65 @@ func TestBackgroundRead(t *testing.T) {
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
noerr(t, err)
wantErrs := []string{
"error discarding 3 byte message: EOF",
}
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
for i, err := range bgErrs {
if i < len(wantErrs) {
assert.EqualError(t, err, wantErrs[i], "mismatched err: %v", err)
} else {
assert.Fail(t, "unexpected error", "got unexpected error: %v", err)
}
}
require.Len(t, bgErrs, 0, "expected no error from bgRead()")
})
}
t.Run("timeout reading full message, background read EOF", func(t *testing.T) {
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

type limitConn struct {
net.Conn
r io.Reader
}
const timeout = 10 * time.Millisecond

func newLimitConn(conn net.Conn, n int64) limitConn {
return limitConn{conn, io.LimitReader(conn, n)}
}
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
defer func() {
_ = nc.Close()
}()

var err error
_, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1})
noerr(t, err)
time.Sleep(timeout * 2)
// write an incomplete message
_, err = nc.Write([]byte{2})
noerr(t, err)
})

p := newPool(
poolConfig{Address: address.Address(addr.String())},
)
defer p.close(context.Background())
err := p.ready()
noerr(t, err)

func (lc limitConn) Read(b []byte) (n int, err error) {
return lc.r.Read(b)
conn, err := p.checkOut(context.Background())
noerr(t, err)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
noerr(t, err)
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
require.Len(t, bgErrs, 1, "expected 1 error from bgRead()")
assert.EqualError(t, bgErrs[0], "error discarding 3 byte message: EOF")
})
}

func assertConnectionsClosed(t *testing.T, dialer *dialer, count int) {
Expand Down

0 comments on commit 39f3021

Please sign in to comment.