Skip to content

Commit

Permalink
improve test logic
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Sep 13, 2024
1 parent 2d988ed commit 9bea0eb
Showing 1 changed file with 56 additions and 67 deletions.
123 changes: 56 additions & 67 deletions x/mongo/driver/topology/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1130,37 +1130,33 @@ func TestPool(t *testing.T) {
func TestBackgroundRead(t *testing.T) {
t.Parallel()

newBGReadCallback := func(errCh chan error) func(string, time.Time, time.Time, []error, bool) {
newBGReadCallback := func(errsCh chan []error) func(string, time.Time, time.Time, []error, bool) {
return func(_ string, _, _ time.Time, errs []error, _ bool) {
defer close(errCh)

for _, err := range errs {
errCh <- err
}
errsCh <- errs
close(errsCh)
}
}

t.Run("incomplete read of message header", func(t *testing.T) {
errCh := make(chan error)
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

wg := &sync.WaitGroup{}
wg.Add(1)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
t.Helper()
const timeout = 10 * time.Millisecond

cleanup := make(chan struct{})
defer close(cleanup)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
defer func() {
<-cleanup
_ = nc.Close()
wg.Done()
}()

_, err := nc.Write([]byte{10, 0, 0})
noerr(t, err)
time.Sleep(1500 * time.Millisecond)
})

p := newPool(
Expand All @@ -1171,48 +1167,44 @@ func TestBackgroundRead(t *testing.T) {
})
}),
)
defer p.close(context.Background())
err := p.ready()
noerr(t, err)

conn, err := p.checkOut(context.Background())
noerr(t, err)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil")
wg.Wait()
p.close(context.Background())
close(errCh)
close(errsCh) // this line causes a double close if BGReadCallback is ever called.
})
t.Run("timeout on reading the message header", func(t *testing.T) {
errCh := make(chan error)
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

wg := &sync.WaitGroup{}
wg.Add(1)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
t.Helper()
const timeout = 10 * time.Millisecond

cleanup := make(chan struct{})
defer close(cleanup)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
defer func() {
<-cleanup
_ = nc.Close()
wg.Done()
}()

time.Sleep(1500 * time.Millisecond)
time.Sleep(timeout * 2)
_, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0})
noerr(t, err)
time.Sleep(1500 * time.Millisecond)
})
go func(t *testing.T) {
}(t)

p := newPool(
poolConfig{},
Expand All @@ -1222,66 +1214,64 @@ func TestBackgroundRead(t *testing.T) {
})
}),
)
defer p.close(context.Background())
err := p.ready()
noerr(t, err)

conn, err := p.checkOut(context.Background())
noerr(t, err)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
noerr(t, err)
wg.Wait()
p.close(context.Background())
errs := []*regexp.Regexp{
wantErrs := []*regexp.Regexp{
regexp.MustCompile(
`^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
),
}
for i := 0; true; i++ {
err, ok := <-errCh
if !ok {
if i != len(errs) {
assert.Fail(t, "expected more errors")
}
break
} else if i < len(errs) {
assert.True(t, errs[i].MatchString(err.Error()), "mismatched err: %v", err)
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
for i, err := range bgErrs {
if i < len(wantErrs) {
assert.True(t, wantErrs[i].MatchString(err.Error()), "error %q does not match pattern %q", err, wantErrs[i])
} else {
assert.Fail(t, "unexpected error", "got unexpected error: %v", err)
}
}
})
t.Run("timeout on reading the full message", func(t *testing.T) {
errCh := make(chan error)
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

wg := &sync.WaitGroup{}
wg.Add(1)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
t.Helper()
const timeout = 10 * time.Millisecond

cleanup := make(chan struct{})
defer close(cleanup)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
defer func() {
<-cleanup
_ = nc.Close()
wg.Done()
}()

var err error
_, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1})
noerr(t, err)
time.Sleep(1500 * time.Millisecond)
time.Sleep(timeout * 2)
_, err = nc.Write([]byte{2, 3, 4})
noerr(t, err)
time.Sleep(1500 * time.Millisecond)
})

p := newPool(
Expand All @@ -1294,34 +1284,33 @@ func TestBackgroundRead(t *testing.T) {
})
}),
)
defer p.close(context.Background())
err := p.ready()
noerr(t, err)

conn, err := p.checkOut(context.Background())
noerr(t, err)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), 1*time.Second)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
noerr(t, err)
wg.Wait()
p.close(context.Background())
errs := []string{
wantErrs := []string{
"error discarding 3 byte message: EOF",
}
for i := 0; true; i++ {
err, ok := <-errCh
if !ok {
if i != len(errs) {
assert.Fail(t, "expected more errors")
}
break
} else if i < len(errs) {
assert.EqualError(t, err, errs[i], "mismatched err: %v", err)
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
for i, err := range bgErrs {
if i < len(wantErrs) {
assert.EqualError(t, err, wantErrs[i], "mismatched err: %v", err)
} else {
assert.Fail(t, "unexpected error", "got unexpected error: %v", err)
}
Expand Down

0 comments on commit 9bea0eb

Please sign in to comment.