Skip to content

Commit

Permalink
feat: internal reconnect on each method in datastream client
Browse files Browse the repository at this point in the history
  • Loading branch information
V-Staykov committed Oct 11, 2024
1 parent df99d8a commit d779f99
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 102 deletions.
164 changes: 93 additions & 71 deletions zk/datastream/client/stream_client.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package client

import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"reflect"
"sync/atomic"
Expand Down Expand Up @@ -105,21 +103,35 @@ func (c *StreamClient) GetEntryNumberLimit() uint64 {
// GetL2BlockByNumber queries the data stream by sending the L2 block start bookmark for the certain block number
// and streams the changes for that block (including the transactions).
// Note that this function is intended for on demand querying and it disposes the connection after it ends.
func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, int, error) {
if err := c.EnsureConnected(); err != nil {
return nil, -1, err
}

func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (fullBLock *types.FullL2Block, errorCode int, err error) {
var (
l2Block *types.FullL2Block
err error
isL2Block bool
socketErr error = nil
connected bool = true
)

for {
if connected {
if fullBLock, errorCode, err, socketErr = c.getL2BlockByNumber(blockNum); err != nil {
return nil, errorCode, err
}
}
if socketErr == nil {
break
}
time.Sleep(1 * time.Second)
connected = c.handleSocketError(socketErr)
}

return fullBLock, types.CmdErrOK, nil
}

func (c *StreamClient) getL2BlockByNumber(blockNum uint64) (l2Block *types.FullL2Block, errorCode int, err, socketErr error) {
var isL2Block bool

bookmark := types.NewBookmarkProto(blockNum, datastream.BookmarkType_BOOKMARK_TYPE_L2_BLOCK)
bookmarkRaw, err := bookmark.Marshal()
if err != nil {
return nil, -1, err
return nil, -1, err, nil
}

re, err := c.initiateDownloadBookmark(bookmarkRaw)
Expand All @@ -128,7 +140,7 @@ func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block,
if re != nil {
errorCode = int(re.ErrorNum)
}
return nil, errorCode, err
return nil, errorCode, nil, err
}

for l2Block == nil {
Expand All @@ -138,13 +150,13 @@ func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block,
if re != nil {
errorCode = int(re.ErrorNum)
}
return l2Block, errorCode, nil
return l2Block, errorCode, nil, nil
default:
}

parsedEntry, _, err := ReadParsedProto(c)
if err != nil {
return nil, -1, err
return nil, -1, nil, err
}

l2Block, isL2Block = parsedEntry.(*types.FullL2Block)
Expand All @@ -154,51 +166,68 @@ func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block,
}

if l2Block.L2BlockNumber != blockNum {
return nil, -1, fmt.Errorf("expected block number %d but got %d", blockNum, l2Block.L2BlockNumber)
return nil, -1, fmt.Errorf("expected block number %d but got %d", blockNum, l2Block.L2BlockNumber), nil
}

return l2Block, types.CmdErrOK, nil
return l2Block, types.CmdErrOK, nil, nil
}

// GetLatestL2Block queries the data stream by reading the header entry and based on total entries field,
// it retrieves the latest File entry that is of EntryTypeL2Block type.
// Note that this function is intended for on demand querying and it disposes the connection after it ends.
func (c *StreamClient) GetLatestL2Block() (l2Block *types.FullL2Block, err error) {
if err := c.EnsureConnected(); err != nil {
return nil, fmt.Errorf("failed to ensure connect: %w", err)
var (
socketErr error = nil
connected bool = true
)

for {
if connected {
if l2Block, err, socketErr = c.getLatestL2Block(); err != nil {
return nil, err
}
}
if socketErr == nil {
break
}
time.Sleep(1 * time.Second)
connected = c.handleSocketError(socketErr)
}
return l2Block, nil
}

func (c *StreamClient) getLatestL2Block() (l2Block *types.FullL2Block, err, socketErr error) {
h, err := c.GetHeader()
if err != nil {
return nil, fmt.Errorf("failed to get header: %w", err)
return nil, nil, fmt.Errorf("failed to get header: %w", err)
}

latestEntryNum := h.TotalEntries - 1

for l2Block == nil && latestEntryNum > 0 {
if err := c.sendEntryCmdWrapper(latestEntryNum); err != nil {
return nil, err
return nil, nil, err
}

entry, err := c.NextFileEntry()
if err != nil {
return nil, err
return nil, nil, err
}

if entry.EntryType == types.EntryTypeL2Block {
if l2Block, err = types.UnmarshalL2Block(entry.Data); err != nil {
return nil, err
return nil, err, nil
}
}

latestEntryNum--
}

if latestEntryNum == 0 {
return nil, errors.New("no block found")
return nil, errors.New("no block found"), nil
}

return l2Block, nil
return l2Block, nil, nil
}

func (c *StreamClient) GetLastWrittenTimeAtomic() *atomic.Int64 {
Expand Down Expand Up @@ -351,37 +380,44 @@ func (c *StreamClient) RenewEntryChannel() {
c.entryChan = make(chan interface{}, entryChannelSize)
}

func (c *StreamClient) connClosed() bool {
if c.conn == nil {
return true
}
func (c *StreamClient) ReadAllEntriesToChannel() (err error) {
var (
socketErr error = nil
connected bool = true
)

c.conn.SetReadDeadline(time.Now())
one := new(bytes.Buffer)
if _, err := io.CopyN(one, c.conn, 1); err == io.ErrClosedPipe {
c.conn = nil
return true
for {
if connected {
if err, socketErr = c.readAllEntriesToChannel(); err != nil {
return err
}
}
if socketErr == nil {
break
}

time.Sleep(1 * time.Second)
connected = c.handleSocketError(socketErr)
}

c.conn.SetReadDeadline(time.Now().Add(1 * c.checkTimeout))
return false
return nil
}

func (c *StreamClient) EnsureConnected() error {
if c.connClosed() {
if err := c.tryReConnect(); err != nil {
return fmt.Errorf("failed to reconnect the datastream client: %w", err)
}

c.RenewEntryChannel()
func (c *StreamClient) handleSocketError(socketErr error) bool {
log.Warn(fmt.Sprintf("Socket error: %s", socketErr))
if err := c.tryReConnect(); err != nil {
log.Warn(fmt.Sprintf("Failed to reconnect the datastream client: %s", err))
return false
}

return nil
c.RenewEntryChannel()

return true
}

// reads entries to the end of the stream
// at end will wait for new entries to arrive
func (c *StreamClient) ReadAllEntriesToChannel() error {
func (c *StreamClient) readAllEntriesToChannel() (err, socketErr error) {
c.streaming.Store(true)
c.stopReadingToChannel.Store(false)
defer c.streaming.Store(false)
Expand All @@ -396,28 +432,19 @@ func (c *StreamClient) ReadAllEntriesToChannel() error {

protoBookmark, err := bookmark.Marshal()
if err != nil {
return err
return err, nil
}

// send start command
if _, err := c.initiateDownloadBookmark(protoBookmark); err != nil {
return err
if _, socketErr := c.initiateDownloadBookmark(protoBookmark); err != nil {
return nil, socketErr
}

if err := c.readAllFullL2BlocksToChannel(); err != nil {
err2 := fmt.Errorf("%s read full L2 blocks error: %v", c.id, err)

if c.conn != nil {
if err2 := c.conn.Close(); err2 != nil {
log.Error("failed to close connection after error", "original-error", err, "new-error", err2)
}
c.conn = nil
}

return err2
return nil, fmt.Errorf("%s read full L2 blocks error: %v", c.id, err)
}

return nil
return nil, nil
}

// runs the prerequisites for entries download
Expand Down Expand Up @@ -511,20 +538,15 @@ LOOP:

func (c *StreamClient) tryReConnect() error {
var err error
for i := 0; i < 50; i++ {
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Warn(fmt.Sprintf("[%d. iteration] failed to close the DS connection: %s", i+1, err))
return err
}
c.conn = nil
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Warn(fmt.Sprintf("failed to close the DS connection: %s", err))
return err
}
if err = c.Start(); err != nil {
log.Warn(fmt.Sprintf("[%d. iteration] failed to start the DS connection: %s", i+1, err))
time.Sleep(5 * time.Second)
continue
}
return nil
c.conn = nil
}
if err = c.Start(); err != nil {
log.Warn(fmt.Sprintf("failed to start the DS connection: %s", err))
}

return err
Expand Down
19 changes: 15 additions & 4 deletions zk/datastream/client/stream_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ func TestStreamClientReadParsedProto(t *testing.T) {
c := NewClient(context.Background(), "", 0, 0, 0)
serverConn, clientConn := net.Pipe()
c.conn = clientConn
c.checkTimeout = 1 * time.Second

c.header = &types.HeaderEntry{
TotalEntries: 3,
}
defer func() {
serverConn.Close()
clientConn.Close()
Expand Down Expand Up @@ -273,7 +278,7 @@ func TestStreamClientGetLatestL2Block(t *testing.T) {

c := NewClient(context.Background(), "", 0, 0, 0)
c.conn = clientConn

c.checkTimeout = 1 * time.Second
expectedL2Block, _ := createL2BlockAndTransactions(t, 5, 0)
l2BlockProto := &types.L2BlockProto{L2Block: expectedL2Block}
l2BlockRaw, err := l2BlockProto.Marshal()
Expand Down Expand Up @@ -385,8 +390,11 @@ func TestStreamClientGetL2BlockByNumber(t *testing.T) {
}()

c := NewClient(context.Background(), "", 0, 0, 0)
c.header = &types.HeaderEntry{
TotalEntries: 4,
}
c.conn = clientConn

c.checkTimeout = 1 * time.Second
bookmark := types.NewBookmarkProto(blockNum, datastream.BookmarkType_BOOKMARK_TYPE_L2_BLOCK)
bookmarkRaw, err := bookmark.Marshal()
require.NoError(t, err)
Expand Down Expand Up @@ -477,9 +485,12 @@ func TestStreamClientGetL2BlockByNumber(t *testing.T) {
require.NoError(t, err)
require.Equal(t, types.CmdErrOK, errCode)

serverErr := <-errCh
var serverErr error
select {
case serverErr = <-errCh:
default:
}
require.NoError(t, serverErr)

l2TxsProto := make([]types.L2TransactionProto, len(l2Txs))
for i, tx := range l2Txs {
l2TxProto := types.ConvertToL2TransactionProto(tx)
Expand Down
1 change: 0 additions & 1 deletion zk/stages/stage_batches.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ type DatastreamClient interface {
GetLatestL2Block() (*types.FullL2Block, error)
GetStreamingAtomic() *atomic.Bool
GetProgressAtomic() *atomic.Uint64
EnsureConnected() error
Start() error
Stop()
}
Expand Down
24 changes: 2 additions & 22 deletions zk/stages/stage_batches_datastream.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,8 @@ func (r *DatastreamClientRunner) StartRead() error {
r.isReading.Store(true)
defer r.isReading.Store(false)

for {
if r.stopRunner.Load() {
break
}

// this will download all blocks from datastream and push them in a channel
// if no error, break, else continue trying to get them
// Create bookmark
if err := r.dsClient.EnsureConnected(); err != nil {
log.Error(fmt.Sprintf("[%s] Error connecting to datastream", r.logPrefix), "error", err)
time.Sleep(10 * time.Millisecond)
continue
}

if err := r.dsClient.ReadAllEntriesToChannel(); err != nil {

log.Error(fmt.Sprintf("[%s] Error downloading blocks from datastream", r.logPrefix), "error", err)

time.Sleep(10 * time.Millisecond)
continue
}
break
if err := r.dsClient.ReadAllEntriesToChannel(); err != nil {
log.Warn(fmt.Sprintf("[%s] Error downloading blocks from datastream", r.logPrefix), "error", err)
}
}()

Expand Down
4 changes: 0 additions & 4 deletions zk/stages/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ func NewTestDatastreamClient(fullL2Blocks []types.FullL2Block, gerUpdates []type
return client
}

func (c *TestDatastreamClient) EnsureConnected() (bool, error) {
return true, nil
}

func (c *TestDatastreamClient) ReadAllEntriesToChannel() error {
c.streamingAtomic.Store(true)
defer c.streamingAtomic.Swap(false)
Expand Down

0 comments on commit d779f99

Please sign in to comment.