Skip to content

Commit

Permalink
Add an API to hook to Write() and Read() of the Conn
Browse files Browse the repository at this point in the history
New struct TrafficRecorder added for enabling recording
and mocking single Connection flow.
  • Loading branch information
sylwiaszunejko committed Oct 15, 2024
1 parent 8af5e34 commit 6af5165
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 9 deletions.
90 changes: 82 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ var TimeoutLimit int64 = 0
// queries, but users are usually advised to use a more reliable, higher
// level API.
type Conn struct {
conn net.Conn
r *bufio.Reader
w contextWriter
connWrapper *ConnWrapper
r *bufio.Reader
w contextWriter

timeout time.Duration
writeTimeout time.Duration
Expand Down Expand Up @@ -212,6 +212,49 @@ type Conn struct {
tabletsRoutingV1 int32
}

type ConnWrapper struct {
conn net.Conn
recorder *TrafficRecorder
mock bool // flag to indicate if we're in mock mode
}

func NewConnWrapper(conn net.Conn) *ConnWrapper {
return &ConnWrapper{
conn: conn,
recorder: NewTrafficRecorder(),
mock: false,
}
}

func (cw *ConnWrapper) SetMockMode(mock bool) {
cw.mock = mock
}

type TrafficRecorder struct {
mu sync.Mutex
recordedWrites [][]byte
recordedReads [][]byte
}

func NewTrafficRecorder() *TrafficRecorder {
return &TrafficRecorder{
recordedWrites: make([][]byte, 0),
recordedReads: make([][]byte, 0),
}
}

func (r *TrafficRecorder) RecordWrite(p []byte) {
r.mu.Lock()
defer r.mu.Unlock()
r.recordedWrites = append(r.recordedWrites, p)
}

func (r *TrafficRecorder) RecordRead(p []byte) {
r.mu.Lock()
defer r.mu.Unlock()
r.recordedReads = append(r.recordedReads, p)
}

// connect establishes a connection to a Cassandra node using session's connection config.
func (s *Session) connect(ctx context.Context, host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
return s.dial(ctx, host, s.connCfg, errorHandler)
Expand Down Expand Up @@ -280,7 +323,7 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *

ctx, cancel := context.WithCancel(ctx)
c := &Conn{
conn: dialedHost.Conn,
connWrapper: NewConnWrapper(dialedHost.Conn),
r: bufio.NewReader(dialedHost.Conn),
cfg: cfg,
calls: make(map[int]*callReq),
Expand Down Expand Up @@ -347,7 +390,7 @@ func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error {

// dont coalesce startup frames
if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce && !dialedHost.DisableCoalesce {
c.w = newWriteCoalescer(c.conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done())
c.w = newWriteCoalescer(c.connWrapper.conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done())
}

go c.serve(ctx)
Expand All @@ -357,20 +400,51 @@ func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error {
}

func (c *Conn) Write(p []byte) (n int, err error) {
if c.connWrapper != nil && c.connWrapper.mock {
c.connWrapper.recorder.mu.Lock()
defer c.connWrapper.recorder.mu.Unlock()
// In mock mode, simulate successful write based on recorded data
if len(c.connWrapper.recorder.recordedWrites) > 0 {
recordedWrite := c.connWrapper.recorder.recordedWrites[0]
c.connWrapper.recorder.recordedWrites = c.connWrapper.recorder.recordedWrites[1:]
return len(recordedWrite), nil
}
return 0, io.EOF // No more recorded data
}

if c.connWrapper != nil && c.connWrapper.recorder != nil {
c.connWrapper.recorder.RecordWrite(p)
}
return c.w.writeContext(context.Background(), p)
}

func (c *Conn) Read(p []byte) (n int, err error) {
if c.connWrapper != nil && c.connWrapper.mock {
c.connWrapper.recorder.mu.Lock()
defer c.connWrapper.recorder.mu.Unlock()
// In mock mode, simulate reading recorded data
if len(c.connWrapper.recorder.recordedReads) > 0 {
recordedRead := c.connWrapper.recorder.recordedReads[0]
c.connWrapper.recorder.recordedReads = c.connWrapper.recorder.recordedReads[1:]
copy(p, recordedRead)
return len(recordedRead), nil
}
return 0, io.EOF // No more recorded data
}

const maxAttempts = 5

for i := 0; i < maxAttempts; i++ {
var nn int
if c.timeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
c.connWrapper.conn.SetReadDeadline(time.Now().Add(c.timeout))
}

nn, err = io.ReadFull(c.r, p[n:])
n += nn
if c.connWrapper != nil && c.connWrapper.recorder != nil && nn > 0 {
c.connWrapper.recorder.RecordRead(p[:n])
}
if err == nil {
break
}
Expand Down Expand Up @@ -617,7 +691,7 @@ func (c *Conn) setTabletSupported(val bool) {
}

func (c *Conn) close() error {
return c.conn.Close()
return c.connWrapper.conn.Close()
}

func (c *Conn) Close() {
Expand Down Expand Up @@ -708,7 +782,7 @@ func (c *Conn) recv(ctx context.Context) error {
// read a full header, ignore timeouts, as this is being ran in a loop
// TODO: TCP level deadlines? or just query level deadlines?
if c.timeout > 0 {
c.conn.SetReadDeadline(time.Time{})
c.connWrapper.conn.SetReadDeadline(time.Time{})
}

headStartTime := time.Now()
Expand Down
75 changes: 75 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1394,3 +1394,78 @@ func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) {

return framer, nil
}

func NewMockNetConn() net.Conn {
return &mockNetConn{}
}

type mockNetConn struct{}

func (m *mockNetConn) Read(p []byte) (n int, err error) {
data := []byte("test receive")
n = copy(p, data)
return n, nil
}

func (m *mockNetConn) Write(p []byte) (n int, err error) {
return len(p), nil
}

func (m *mockNetConn) Close() error {
return nil
}

func (m *mockNetConn) LocalAddr() net.Addr {
return nil
}

func (m *mockNetConn) RemoteAddr() net.Addr {
return nil
}

func (m *mockNetConn) SetDeadline(t time.Time) error {
return nil
}

func (m *mockNetConn) SetReadDeadline(t time.Time) error {
return nil
}

func (m *mockNetConn) SetWriteDeadline(t time.Time) error {
return nil
}

func TestConnTrafficRecorder(t *testing.T) {
conn := &Conn{
connWrapper: NewConnWrapper(NewMockNetConn()),
w: &deadlineContextWriter{
w: NewMockNetConn(),
timeout: 0,
semaphore: make(chan struct{}, 1),
quit: make(chan struct{}),
},
timeout: 0,
r: bufio.NewReader(NewMockNetConn()),
}

sendData := []byte("test send")
_, err := conn.Write(sendData)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

if len(conn.connWrapper.recorder.recordedWrites) != 1 || !bytes.Equal(conn.connWrapper.recorder.recordedWrites[0], sendData) {
t.Fatalf("expected sent data to be recorded")
}

// Simulate receiving data
receiveData := []byte("test receive")
_, err = conn.Read(receiveData)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

if len(conn.connWrapper.recorder.recordedReads) != 1 || !bytes.Equal(conn.connWrapper.recorder.recordedReads[0], receiveData) {
t.Fatalf("expected received data to be recorded")
}
}
2 changes: 1 addition & 1 deletion control.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func (c *controlConn) setupConn(conn *Conn) error {
// we need up-to-date host info for the filterHost call below
iter := conn.querySystemLocal(context.TODO())
defaultPort := 9042
if tcpAddr, ok := conn.conn.RemoteAddr().(*net.TCPAddr); ok {
if tcpAddr, ok := conn.connWrapper.conn.RemoteAddr().(*net.TCPAddr); ok {
defaultPort = tcpAddr.Port
}
host, err := c.session.hostInfoFromIter(iter, conn.host.connectAddress, defaultPort)
Expand Down

0 comments on commit 6af5165

Please sign in to comment.