From d779f998139ccf0fd46e6bb2d34e28f39c3d129d Mon Sep 17 00:00:00 2001 From: Valentin Staykov Date: Fri, 11 Oct 2024 08:18:03 +0000 Subject: [PATCH] feat: internal reconnect on each method in datastream client --- zk/datastream/client/stream_client.go | 164 ++++++++++++--------- zk/datastream/client/stream_client_test.go | 19 ++- zk/stages/stage_batches.go | 1 - zk/stages/stage_batches_datastream.go | 24 +-- zk/stages/test_utils.go | 4 - 5 files changed, 110 insertions(+), 102 deletions(-) diff --git a/zk/datastream/client/stream_client.go b/zk/datastream/client/stream_client.go index ac89fab1dd3..22853f0174e 100644 --- a/zk/datastream/client/stream_client.go +++ b/zk/datastream/client/stream_client.go @@ -1,12 +1,10 @@ package client import ( - "bytes" "context" "encoding/binary" "errors" "fmt" - "io" "net" "reflect" "sync/atomic" @@ -105,21 +103,35 @@ func (c *StreamClient) GetEntryNumberLimit() uint64 { // GetL2BlockByNumber queries the data stream by sending the L2 block start bookmark for the certain block number // and streams the changes for that block (including the transactions). // Note that this function is intended for on demand querying and it disposes the connection after it ends. -func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, int, error) { - if err := c.EnsureConnected(); err != nil { - return nil, -1, err - } - +func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (fullBLock *types.FullL2Block, errorCode int, err error) { var ( - l2Block *types.FullL2Block - err error - isL2Block bool + socketErr error = nil + connected bool = true ) + for { + if connected { + if fullBLock, errorCode, err, socketErr = c.getL2BlockByNumber(blockNum); err != nil { + return nil, errorCode, err + } + } + if socketErr == nil { + break + } + time.Sleep(1 * time.Second) + connected = c.handleSocketError(socketErr) + } + + return fullBLock, types.CmdErrOK, nil +} + +func (c *StreamClient) getL2BlockByNumber(blockNum uint64) (l2Block *types.FullL2Block, errorCode int, err, socketErr error) { + var isL2Block bool + bookmark := types.NewBookmarkProto(blockNum, datastream.BookmarkType_BOOKMARK_TYPE_L2_BLOCK) bookmarkRaw, err := bookmark.Marshal() if err != nil { - return nil, -1, err + return nil, -1, err, nil } re, err := c.initiateDownloadBookmark(bookmarkRaw) @@ -128,7 +140,7 @@ func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, if re != nil { errorCode = int(re.ErrorNum) } - return nil, errorCode, err + return nil, errorCode, nil, err } for l2Block == nil { @@ -138,13 +150,13 @@ func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, if re != nil { errorCode = int(re.ErrorNum) } - return l2Block, errorCode, nil + return l2Block, errorCode, nil, nil default: } parsedEntry, _, err := ReadParsedProto(c) if err != nil { - return nil, -1, err + return nil, -1, nil, err } l2Block, isL2Block = parsedEntry.(*types.FullL2Block) @@ -154,40 +166,57 @@ func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, } if l2Block.L2BlockNumber != blockNum { - return nil, -1, fmt.Errorf("expected block number %d but got %d", blockNum, l2Block.L2BlockNumber) + return nil, -1, fmt.Errorf("expected block number %d but got %d", blockNum, l2Block.L2BlockNumber), nil } - return l2Block, types.CmdErrOK, nil + return l2Block, types.CmdErrOK, nil, nil } // GetLatestL2Block queries the data stream by reading the header entry and based on total entries field, // it retrieves the latest File entry that is of EntryTypeL2Block type. // Note that this function is intended for on demand querying and it disposes the connection after it ends. func (c *StreamClient) GetLatestL2Block() (l2Block *types.FullL2Block, err error) { - if err := c.EnsureConnected(); err != nil { - return nil, fmt.Errorf("failed to ensure connect: %w", err) + var ( + socketErr error = nil + connected bool = true + ) + + for { + if connected { + if l2Block, err, socketErr = c.getLatestL2Block(); err != nil { + return nil, err + } + } + if socketErr == nil { + break + } + time.Sleep(1 * time.Second) + connected = c.handleSocketError(socketErr) } + return l2Block, nil +} +func (c *StreamClient) getLatestL2Block() (l2Block *types.FullL2Block, err, socketErr error) { h, err := c.GetHeader() if err != nil { - return nil, fmt.Errorf("failed to get header: %w", err) + return nil, nil, fmt.Errorf("failed to get header: %w", err) } latestEntryNum := h.TotalEntries - 1 for l2Block == nil && latestEntryNum > 0 { if err := c.sendEntryCmdWrapper(latestEntryNum); err != nil { - return nil, err + return nil, nil, err } entry, err := c.NextFileEntry() if err != nil { - return nil, err + return nil, nil, err } if entry.EntryType == types.EntryTypeL2Block { if l2Block, err = types.UnmarshalL2Block(entry.Data); err != nil { - return nil, err + return nil, err, nil } } @@ -195,10 +224,10 @@ func (c *StreamClient) GetLatestL2Block() (l2Block *types.FullL2Block, err error } if latestEntryNum == 0 { - return nil, errors.New("no block found") + return nil, errors.New("no block found"), nil } - return l2Block, nil + return l2Block, nil, nil } func (c *StreamClient) GetLastWrittenTimeAtomic() *atomic.Int64 { @@ -351,37 +380,44 @@ func (c *StreamClient) RenewEntryChannel() { c.entryChan = make(chan interface{}, entryChannelSize) } -func (c *StreamClient) connClosed() bool { - if c.conn == nil { - return true - } +func (c *StreamClient) ReadAllEntriesToChannel() (err error) { + var ( + socketErr error = nil + connected bool = true + ) - c.conn.SetReadDeadline(time.Now()) - one := new(bytes.Buffer) - if _, err := io.CopyN(one, c.conn, 1); err == io.ErrClosedPipe { - c.conn = nil - return true + for { + if connected { + if err, socketErr = c.readAllEntriesToChannel(); err != nil { + return err + } + } + if socketErr == nil { + break + } + + time.Sleep(1 * time.Second) + connected = c.handleSocketError(socketErr) } - c.conn.SetReadDeadline(time.Now().Add(1 * c.checkTimeout)) - return false + return nil } -func (c *StreamClient) EnsureConnected() error { - if c.connClosed() { - if err := c.tryReConnect(); err != nil { - return fmt.Errorf("failed to reconnect the datastream client: %w", err) - } - - c.RenewEntryChannel() +func (c *StreamClient) handleSocketError(socketErr error) bool { + log.Warn(fmt.Sprintf("Socket error: %s", socketErr)) + if err := c.tryReConnect(); err != nil { + log.Warn(fmt.Sprintf("Failed to reconnect the datastream client: %s", err)) + return false } - return nil + c.RenewEntryChannel() + + return true } // reads entries to the end of the stream // at end will wait for new entries to arrive -func (c *StreamClient) ReadAllEntriesToChannel() error { +func (c *StreamClient) readAllEntriesToChannel() (err, socketErr error) { c.streaming.Store(true) c.stopReadingToChannel.Store(false) defer c.streaming.Store(false) @@ -396,28 +432,19 @@ func (c *StreamClient) ReadAllEntriesToChannel() error { protoBookmark, err := bookmark.Marshal() if err != nil { - return err + return err, nil } // send start command - if _, err := c.initiateDownloadBookmark(protoBookmark); err != nil { - return err + if _, socketErr := c.initiateDownloadBookmark(protoBookmark); err != nil { + return nil, socketErr } if err := c.readAllFullL2BlocksToChannel(); err != nil { - err2 := fmt.Errorf("%s read full L2 blocks error: %v", c.id, err) - - if c.conn != nil { - if err2 := c.conn.Close(); err2 != nil { - log.Error("failed to close connection after error", "original-error", err, "new-error", err2) - } - c.conn = nil - } - - return err2 + return nil, fmt.Errorf("%s read full L2 blocks error: %v", c.id, err) } - return nil + return nil, nil } // runs the prerequisites for entries download @@ -511,20 +538,15 @@ LOOP: func (c *StreamClient) tryReConnect() error { var err error - for i := 0; i < 50; i++ { - if c.conn != nil { - if err := c.conn.Close(); err != nil { - log.Warn(fmt.Sprintf("[%d. iteration] failed to close the DS connection: %s", i+1, err)) - return err - } - c.conn = nil + if c.conn != nil { + if err := c.conn.Close(); err != nil { + log.Warn(fmt.Sprintf("failed to close the DS connection: %s", err)) + return err } - if err = c.Start(); err != nil { - log.Warn(fmt.Sprintf("[%d. iteration] failed to start the DS connection: %s", i+1, err)) - time.Sleep(5 * time.Second) - continue - } - return nil + c.conn = nil + } + if err = c.Start(); err != nil { + log.Warn(fmt.Sprintf("failed to start the DS connection: %s", err)) } return err diff --git a/zk/datastream/client/stream_client_test.go b/zk/datastream/client/stream_client_test.go index 30a96339652..fb0eefb14b0 100644 --- a/zk/datastream/client/stream_client_test.go +++ b/zk/datastream/client/stream_client_test.go @@ -207,6 +207,11 @@ func TestStreamClientReadParsedProto(t *testing.T) { c := NewClient(context.Background(), "", 0, 0, 0) serverConn, clientConn := net.Pipe() c.conn = clientConn + c.checkTimeout = 1 * time.Second + + c.header = &types.HeaderEntry{ + TotalEntries: 3, + } defer func() { serverConn.Close() clientConn.Close() @@ -273,7 +278,7 @@ func TestStreamClientGetLatestL2Block(t *testing.T) { c := NewClient(context.Background(), "", 0, 0, 0) c.conn = clientConn - + c.checkTimeout = 1 * time.Second expectedL2Block, _ := createL2BlockAndTransactions(t, 5, 0) l2BlockProto := &types.L2BlockProto{L2Block: expectedL2Block} l2BlockRaw, err := l2BlockProto.Marshal() @@ -385,8 +390,11 @@ func TestStreamClientGetL2BlockByNumber(t *testing.T) { }() c := NewClient(context.Background(), "", 0, 0, 0) + c.header = &types.HeaderEntry{ + TotalEntries: 4, + } c.conn = clientConn - + c.checkTimeout = 1 * time.Second bookmark := types.NewBookmarkProto(blockNum, datastream.BookmarkType_BOOKMARK_TYPE_L2_BLOCK) bookmarkRaw, err := bookmark.Marshal() require.NoError(t, err) @@ -477,9 +485,12 @@ func TestStreamClientGetL2BlockByNumber(t *testing.T) { require.NoError(t, err) require.Equal(t, types.CmdErrOK, errCode) - serverErr := <-errCh + var serverErr error + select { + case serverErr = <-errCh: + default: + } require.NoError(t, serverErr) - l2TxsProto := make([]types.L2TransactionProto, len(l2Txs)) for i, tx := range l2Txs { l2TxProto := types.ConvertToL2TransactionProto(tx) diff --git a/zk/stages/stage_batches.go b/zk/stages/stage_batches.go index e68da573842..0041e038d60 100644 --- a/zk/stages/stage_batches.go +++ b/zk/stages/stage_batches.go @@ -68,7 +68,6 @@ type DatastreamClient interface { GetLatestL2Block() (*types.FullL2Block, error) GetStreamingAtomic() *atomic.Bool GetProgressAtomic() *atomic.Uint64 - EnsureConnected() error Start() error Stop() } diff --git a/zk/stages/stage_batches_datastream.go b/zk/stages/stage_batches_datastream.go index 7a17544ff5f..dfa2b9c3202 100644 --- a/zk/stages/stage_batches_datastream.go +++ b/zk/stages/stage_batches_datastream.go @@ -40,28 +40,8 @@ func (r *DatastreamClientRunner) StartRead() error { r.isReading.Store(true) defer r.isReading.Store(false) - for { - if r.stopRunner.Load() { - break - } - - // this will download all blocks from datastream and push them in a channel - // if no error, break, else continue trying to get them - // Create bookmark - if err := r.dsClient.EnsureConnected(); err != nil { - log.Error(fmt.Sprintf("[%s] Error connecting to datastream", r.logPrefix), "error", err) - time.Sleep(10 * time.Millisecond) - continue - } - - if err := r.dsClient.ReadAllEntriesToChannel(); err != nil { - - log.Error(fmt.Sprintf("[%s] Error downloading blocks from datastream", r.logPrefix), "error", err) - - time.Sleep(10 * time.Millisecond) - continue - } - break + if err := r.dsClient.ReadAllEntriesToChannel(); err != nil { + log.Warn(fmt.Sprintf("[%s] Error downloading blocks from datastream", r.logPrefix), "error", err) } }() diff --git a/zk/stages/test_utils.go b/zk/stages/test_utils.go index be0c633fb0f..2b6997561dd 100644 --- a/zk/stages/test_utils.go +++ b/zk/stages/test_utils.go @@ -29,10 +29,6 @@ func NewTestDatastreamClient(fullL2Blocks []types.FullL2Block, gerUpdates []type return client } -func (c *TestDatastreamClient) EnsureConnected() (bool, error) { - return true, nil -} - func (c *TestDatastreamClient) ReadAllEntriesToChannel() error { c.streamingAtomic.Store(true) defer c.streamingAtomic.Swap(false)