From 6e8775f714f848780f0782b9a0538aad23c7d73e Mon Sep 17 00:00:00 2001 From: Valentin Staykov <79150443+V-Staykov@users.noreply.github.com> Date: Thu, 7 Nov 2024 17:35:40 +0200 Subject: [PATCH] fix(stage_baches): rpc resequence stop stage on unwind (#1297) * fix(stage_baches): rpc resequence stop stage on unwind * fix: tests * fix: datastream channel write blocking * fix: datastream blocking test * fix: add wait on the datastream connect loop * fix: merge problems * fix: blockhash comparison in stage batches processor * fix: download entries till reaching the amount in header * fix: add go sum package * feat: internal reconnect on each method in datastream client * fix: do not disconnect on stage batches end * feat: add ctx close in datastream reconnections * fix: send stop command after normal stop of reading * feat: retry a fixed number of times in stage batches * fix: return error on ctx done * fix: reverse daastream server version * feat: print ds found block * feat: added more logs in stage batches * fix: check for sync limit in stage batches * fix: sync limit in stage batches * refactor: make unwind test erros a bit more readable * refactor: make unwind tests erorrs more readable * refactor(ds_client): wrap connection read and write to set timeout * fix: add timeout to test clients * fix: stage batches limit * feat: up datastream server version * fix: up datastream server version * fix: go sum * fix: add error handling for set timeouts in datastream client * fix: handle zero checkTImeout value * fix: remove flag setting for datastream timeout * fix: ci config * fix: resequence test timeout * fix: remove timeout from pre-london ci config * refactor: error handling * fix: stop stage on unwind * fix: missing id in client * fix: tests * fix: tests * fix: finish processing blocks on last entry reached * feat: send stop command at start of new cycle to not get timedout by server * fix: remove accidental commit folder * fix: remove unneeded commit * fix: tests * fix: remove unnneeded return * fix: get correct parent block hash * fix: read correct blockhash * fix: unwind on ds block unwind * refactor: error handling in datastream and stage batches * fix: remove unneeded sleep * fix: add a small sleep interval in the entry loop * fix: stop streaming on querying new stuff from ds client * fix: buffer clear before new reads * fix: sleep more in resequence test * fix: cast call * fix: remove wrong flag on cast * fix: cast json flags in test * feat: added wait time for block to be available on sync node * fix: resequence block check test * Fix 'client already started' error on finding common ancestor * Add timeout --------- Co-authored-by: Scott Fairclough <70711990+hexoscott@users.noreply.github.com> Co-authored-by: Jerry --- .github/scripts/test_resequence.sh | 37 ++ .github/workflows/ci_zkevm.yml | 8 + .github/workflows/test-resequence.yml | 1 + cmd/utils/flags.go | 2 +- go.mod | 9 +- go.sum | 12 + zk/datastream/client/commands.go | 52 +- zk/datastream/client/stream_client.go | 538 ++++++++++++++------- zk/datastream/client/stream_client_test.go | 84 ++-- zk/datastream/client/utils.go | 39 +- zk/datastream/client/utils_test.go | 95 ++-- zk/datastream/server/data_stream_server.go | 6 +- zk/datastream/types/result.go | 8 + zk/stages/stage_batches.go | 172 ++++--- zk/stages/stage_batches_datastream.go | 69 +-- zk/stages/stage_batches_processor.go | 174 ++++--- zk/stages/stage_batches_test.go | 1 - zk/stages/test_utils.go | 30 +- zk/tests/unwinds/unwind.sh | 32 +- 19 files changed, 815 insertions(+), 554 deletions(-) diff --git a/.github/scripts/test_resequence.sh b/.github/scripts/test_resequence.sh index e80cf7188a4..b36bc878236 100755 --- a/.github/scripts/test_resequence.sh +++ b/.github/scripts/test_resequence.sh @@ -46,6 +46,7 @@ wait_for_l1_batch() { fi if [ "$batch_type" = "virtual" ]; then + current_batch=$(cast logs --rpc-url "$(kurtosis port print cdk-v1 el-1-geth-lighthouse rpc)" --address 0x1Fe038B54aeBf558638CA51C91bC8cCa06609e91 --from-block 0 --json | jq -r '.[] | select(.topics[0] == "0x3e54d0825ed78523037d00a81759237eb436ce774bd546993ee67a1b67b6e766") | .topics[1]' | tail -n 1 | sed 's/^0x//') current_batch=$((16#$current_batch)) elif [ "$batch_type" = "verified" ]; then @@ -70,6 +71,33 @@ wait_for_l1_batch() { done } + +wait_for_l2_block_number() { + local block_number=$1 + local node_url=$2 + local latest_block=0 + local tries=0 + + #while latest_block lower than block_number + #if more than 5 attempts - throw error + while [ "$latest_block" -lt "$block_number" ]; do + latest_block=$(cast block latest --rpc-url "$node_url" | grep "number" | awk '{print $2}') + if [[ $? -ne 0 ]]; then + echo "Error: Failed to get latest block number" >&2 + return 1 + fi + + if [ "$tries" -ge 5 ]; then + echo "Error: Failed to get block number $block_number" >&2 + return 1 + fi + tries=$((tries + 1)) + + echo "Current block number on $node_url: $latest_block, needed: $block_number. Waiting to try again." + sleep 60 + done +} + stop_cdk_erigon_sequencer() { echo "Stopping cdk-erigon" kurtosis service exec cdk-v1 cdk-erigon-sequencer-001 "pkill -SIGTRAP proc-runner.sh" || true @@ -139,9 +167,18 @@ echo "Calculating comparison block number" comparison_block=$((latest_block - 10)) echo "Block number to compare (10 blocks behind): $comparison_block" +echo "Waiting some time for the syncer to catch up" +sleep 30 + echo "Getting block hash from sequencer" sequencer_hash=$(cast block $comparison_block --rpc-url "$(kurtosis port print cdk-v1 cdk-erigon-sequencer-001 rpc)" | grep "hash" | awk '{print $2}') +# wait for block to be available on sync node +if ! wait_for_l2_block_number $comparison_block "$(kurtosis port print cdk-v1 cdk-erigon-node-001 rpc)"; then + echo "Failed to wait for batch verification" + exit 1 +fi + echo "Getting block hash from node" node_hash=$(cast block $comparison_block --rpc-url "$(kurtosis port print cdk-v1 cdk-erigon-node-001 rpc)" | grep "hash" | awk '{print $2}') diff --git a/.github/workflows/ci_zkevm.yml b/.github/workflows/ci_zkevm.yml index f1803011bfd..877f30f08a0 100644 --- a/.github/workflows/ci_zkevm.yml +++ b/.github/workflows/ci_zkevm.yml @@ -105,6 +105,12 @@ jobs: sed -i '/zkevm.sequencer-batch-seal-time:/d' templates/cdk-erigon/config.yml sed -i '/zkevm.sequencer-non-empty-batch-seal-time:/d' templates/cdk-erigon/config.yml sed -i '/sentry.drop-useless-peers:/d' templates/cdk-erigon/config.yml + sed -i '/zkevm.l2-datastreamer-timeout:/d' templates/cdk-erigon/config.yml + - name: Configure Kurtosis CDK + working-directory: ./kurtosis-cdk + run: | + /usr/local/bin/yq -i '.args.data_availability_mode = "${{ matrix.da-mode }}"' params.yml + /usr/local/bin/yq -i '.args.cdk_erigon_node_image = "cdk-erigon:local"' params.yml - name: Deploy Kurtosis CDK package working-directory: ./kurtosis-cdk @@ -224,6 +230,8 @@ jobs: sed -i '/sentry.drop-useless-peers:/d' templates/cdk-erigon/config.yml sed -i '/zkevm\.pool-manager-url/d' ./templates/cdk-erigon/config.yml sed -i '$a\zkevm.disable-virtual-counters: true' ./templates/cdk-erigon/config.yml + sed -i '/zkevm.l2-datastreamer-timeout:/d' templates/cdk-erigon/config.yml + - name: Configure Kurtosis CDK working-directory: ./kurtosis-cdk diff --git a/.github/workflows/test-resequence.yml b/.github/workflows/test-resequence.yml index 3b273d7cec4..63029d21c56 100644 --- a/.github/workflows/test-resequence.yml +++ b/.github/workflows/test-resequence.yml @@ -55,6 +55,7 @@ jobs: sed -i '/zkevm.sequencer-non-empty-batch-seal-time:/d' templates/cdk-erigon/config.yml sed -i '/sentry.drop-useless-peers:/d' templates/cdk-erigon/config.yml sed -i '/zkevm.pool-manager-url/d' templates/cdk-erigon/config.yml + sed -i '/zkevm.l2-datastreamer-timeout:/d' templates/cdk-erigon/config.yml - name: Configure Kurtosis CDK working-directory: ./kurtosis-cdk run: | diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 96fe7d2042f..1c8ece0b9c6 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -408,7 +408,7 @@ var ( L2DataStreamerTimeout = cli.StringFlag{ Name: "zkevm.l2-datastreamer-timeout", Usage: "The time to wait for data to arrive from the stream before reporting an error (0s doesn't check)", - Value: "0s", + Value: "3s", } L1SyncStartBlock = cli.Uint64Flag{ Name: "zkevm.l1-sync-start-block", diff --git a/go.mod b/go.mod index 4be1aea6a83..8e2906a27bb 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ replace github.com/ledgerwatch/erigon-lib => ./erigon-lib require ( gfx.cafe/util/go/generic v0.0.0-20230721185457-c559e86c829c - github.com/0xPolygonHermez/zkevm-data-streamer v0.2.5 + github.com/0xPolygonHermez/zkevm-data-streamer v0.2.7 github.com/99designs/gqlgen v0.17.40 github.com/Giulio2002/bls v0.0.0-20240315151443-652e18a3d188 github.com/Masterminds/sprig/v3 v3.2.3 @@ -62,7 +62,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/hashicorp/golang-lru/arc/v2 v2.0.6 github.com/hashicorp/golang-lru/v2 v2.0.7 - github.com/holiman/uint256 v1.2.4 + github.com/holiman/uint256 v1.3.1 github.com/huandu/xstrings v1.4.0 github.com/huin/goupnp v1.2.0 github.com/iden3/go-iden3-crypto v0.0.15 @@ -110,11 +110,11 @@ require ( golang.org/x/exp v0.0.0-20231226003508-02704c960a9b golang.org/x/net v0.24.0 golang.org/x/sync v0.7.0 - golang.org/x/sys v0.19.0 + golang.org/x/sys v0.20.0 golang.org/x/time v0.5.0 google.golang.org/grpc v1.63.2 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0 - google.golang.org/protobuf v1.33.0 + google.golang.org/protobuf v1.34.2 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v2 v2.4.0 @@ -174,6 +174,7 @@ require ( github.com/francoispqt/gojay v1.2.13 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/garslo/gogen v0.0.0-20170307003452-d6ebae628c7c // indirect + github.com/go-delve/delve v1.21.2 // indirect github.com/go-llsqlite/adapter v0.0.0-20230927005056-7f5ce7f0c916 // indirect github.com/go-llsqlite/crawshaw v0.4.0 // indirect github.com/go-logr/logr v1.2.4 // indirect diff --git a/go.sum b/go.sum index cd90cd9206b..2dc4ef179ba 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,10 @@ gfx.cafe/util/go/generic v0.0.0-20230721185457-c559e86c829c/go.mod h1:WvSX4JsCRB git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/0xPolygonHermez/zkevm-data-streamer v0.2.5 h1:p0epAhai44c34G+nzX0CZ67q3vkJtOXlO07lbhAEe9g= github.com/0xPolygonHermez/zkevm-data-streamer v0.2.5/go.mod h1:RC6ouyNsUtJrv5aGPcM6Dm5xhXN209tRSzcsJsaOtZI= +github.com/0xPolygonHermez/zkevm-data-streamer v0.2.6 h1:BSO1uu6dmLQ5kKb3uyDvsUxbnIoyumKvlwr0OtpTYMo= +github.com/0xPolygonHermez/zkevm-data-streamer v0.2.6/go.mod h1:RC6ouyNsUtJrv5aGPcM6Dm5xhXN209tRSzcsJsaOtZI= +github.com/0xPolygonHermez/zkevm-data-streamer v0.2.7 h1:73sYxRQ9cOmtYBEyHePgEwrVULR+YruSQxVXCt/SmzU= +github.com/0xPolygonHermez/zkevm-data-streamer v0.2.7/go.mod h1:7nM7Ihk+fTG1TQPwdZoGOYd3wprqqyIyjtS514uHzWE= github.com/99designs/gqlgen v0.17.40 h1:/l8JcEVQ93wqIfmH9VS1jsAkwm6eAF1NwQn3N+SDqBY= github.com/99designs/gqlgen v0.17.40/go.mod h1:b62q1USk82GYIVjC60h02YguAZLqYZtvWml8KkhJps4= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= @@ -327,6 +331,8 @@ github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s= github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= +github.com/go-delve/delve v1.21.2 h1:eaS+ziJo+660mi3D2q/VP8RxW5GcF4Y1zyKSi82alsU= +github.com/go-delve/delve v1.21.2/go.mod h1:FgTAiRUe43RS5EexL06RPyMtP8AMZVL/t9Qqgy3qUe4= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -490,6 +496,8 @@ github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSo github.com/holiman/uint256 v1.2.0/go.mod h1:y4ga/t+u+Xwd7CpDgZESaRcWy0I7XMlTMA25ApIH5Jw= github.com/holiman/uint256 v1.2.4 h1:jUc4Nk8fm9jZabQuqr2JzednajVmBpC+oiTiXZJEApU= github.com/holiman/uint256 v1.2.4/go.mod h1:EOMSn4q6Nyt9P6efbI3bueV4e1b3dGlUCXeiRV4ng7E= +github.com/holiman/uint256 v1.3.1 h1:JfTzmih28bittyHM8z360dCjIA9dbPIBlcTI6lmctQs= +github.com/holiman/uint256 v1.3.1/go.mod h1:EOMSn4q6Nyt9P6efbI3bueV4e1b3dGlUCXeiRV4ng7E= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo= github.com/huandu/xstrings v1.2.0/go.mod h1:DvyZB1rfVYsBIigL8HwpZgxHwXozlTgGqn63UyNX5k4= @@ -1333,6 +1341,8 @@ golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1547,6 +1557,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI= diff --git a/zk/datastream/client/commands.go b/zk/datastream/client/commands.go index 8676a2807eb..d871fb798ab 100644 --- a/zk/datastream/client/commands.go +++ b/zk/datastream/client/commands.go @@ -1,26 +1,19 @@ package client -import "fmt" - const ( // Commands - CmdUnknown Command = 0 - CmdStart Command = 1 - CmdStop Command = 2 - CmdHeader Command = 3 - CmdStartBookmark Command = 4 // CmdStartBookmark for the start from bookmark TCP client command - CmdEntry Command = 5 // CmdEntry for the get entry TCP client command - CmdBookmark Command = 6 // CmdBookmark for the get bookmark TCP client command + CmdUnknown Command = iota + CmdStart + CmdStop + CmdHeader + CmdStartBookmark // CmdStartBookmark for the start from bookmark TCP client command + CmdEntry // CmdEntry for the get entry TCP client command + CmdBookmark // CmdBookmark for the get bookmark TCP client command ) // sendHeaderCmd sends the header command to the server. func (c *StreamClient) sendHeaderCmd() error { - err := c.sendCommand(CmdHeader) - if err != nil { - return fmt.Errorf("%s %v", c.id, err) - } - - return nil + return c.sendCommand(CmdHeader) } // sendBookmarkCmd sends either CmdStartBookmark or CmdBookmark for the provided bookmark value. @@ -38,24 +31,23 @@ func (c *StreamClient) sendBookmarkCmd(bookmark []byte, streaming bool) error { } // Send bookmark length - if err := writeFullUint32ToConn(c.conn, uint32(len(bookmark))); err != nil { + if err := c.writeToConn(uint32(len(bookmark))); err != nil { return err } // Send the bookmark to retrieve - return writeBytesToConn(c.conn, bookmark) + return c.writeToConn(bookmark) } // sendStartCmd sends a start command to the server, indicating // that the client wishes to start streaming from the given entry number. func (c *StreamClient) sendStartCmd(from uint64) error { - err := c.sendCommand(CmdStart) - if err != nil { + if err := c.sendCommand(CmdStart); err != nil { return err } // Send starting/from entry number - return writeFullUint64ToConn(c.conn, from) + return c.writeToConn(from) } // sendEntryCmd sends the get data stream entry by number command to a TCP connection @@ -66,29 +58,21 @@ func (c *StreamClient) sendEntryCmd(entryNum uint64) error { } // Send entry number - return writeFullUint64ToConn(c.conn, entryNum) + return c.writeToConn(entryNum) } // sendHeaderCmd sends the header command to the server. func (c *StreamClient) sendStopCmd() error { - err := c.sendCommand(CmdStop) - if err != nil { - return fmt.Errorf("%s %v", c.id, err) - } - - return nil + return c.sendCommand(CmdStop) } func (c *StreamClient) sendCommand(cmd Command) error { + // Send command - if err := writeFullUint64ToConn(c.conn, uint64(cmd)); err != nil { - return fmt.Errorf("%s %v", c.id, err) + if err := c.writeToConn(uint64(cmd)); err != nil { + return err } // Send stream type - if err := writeFullUint64ToConn(c.conn, uint64(c.streamType)); err != nil { - return fmt.Errorf("%s %v", c.id, err) - } - - return nil + return c.writeToConn(uint64(c.streamType)) } diff --git a/zk/datastream/client/stream_client.go b/zk/datastream/client/stream_client.go index ba85dd8c8ff..e8e96ed9b29 100644 --- a/zk/datastream/client/stream_client.go +++ b/zk/datastream/client/stream_client.go @@ -41,9 +41,10 @@ type StreamClient struct { version int streamType StreamType conn net.Conn - id string // Client id checkTimeout time.Duration // time to wait for data before reporting an error + header *types.HeaderEntry + // atomic lastWrittenTime atomic.Int64 streaming atomic.Bool @@ -78,7 +79,6 @@ func NewClient(ctx context.Context, server string, version int, checkTimeout tim server: server, version: version, streamType: StSequencer, - id: "", entryChan: make(chan interface{}, 100000), currentFork: uint64(latestDownloadedForkId), } @@ -94,50 +94,80 @@ func (c *StreamClient) GetEntryChan() *chan interface{} { return &c.entryChan } +func (c *StreamClient) GetEntryNumberLimit() uint64 { + return c.header.TotalEntries +} + +var ( + ErrFailedAttempts = errors.New("failed to get the L2 block within 5 attempts") +) + // GetL2BlockByNumber queries the data stream by sending the L2 block start bookmark for the certain block number // and streams the changes for that block (including the transactions). // Note that this function is intended for on demand querying and it disposes the connection after it ends. -func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, int, error) { - if _, err := c.EnsureConnected(); err != nil { - return nil, -1, err - } - defer c.Stop() - +func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (fullBLock *types.FullL2Block, err error) { var ( - l2Block *types.FullL2Block - err error - isL2Block bool + connected bool = c.conn != nil ) + count := 0 + for { + select { + case <-c.ctx.Done(): + return nil, fmt.Errorf("context done - stopping") + + default: + } + if count > 5 { + return nil, ErrFailedAttempts + } + if connected { + if err := c.stopStreamingIfStarted(); err != nil { + return nil, fmt.Errorf("stopStreamingIfStarted: %w", err) + } + + if fullBLock, err = c.getL2BlockByNumber(blockNum); err == nil { + break + } + + if errors.Is(err, types.ErrAlreadyStarted) { + // if the client is already started, we can stop the client and try again + c.Stop() + } else if !errors.Is(err, ErrSocket) { + return nil, fmt.Errorf("getL2BlockByNumber: %w", err) + } + + } + time.Sleep(1 * time.Second) + connected = c.handleSocketError(err) + count++ + } + + return fullBLock, nil +} + +func (c *StreamClient) getL2BlockByNumber(blockNum uint64) (l2Block *types.FullL2Block, err error) { + var isL2Block bool bookmark := types.NewBookmarkProto(blockNum, datastream.BookmarkType_BOOKMARK_TYPE_L2_BLOCK) bookmarkRaw, err := bookmark.Marshal() if err != nil { - return nil, -1, err + return nil, fmt.Errorf("bookmark.Marshal: %w", err) } - re, err := c.initiateDownloadBookmark(bookmarkRaw) - if err != nil { - errorCode := -1 - if re != nil { - errorCode = int(re.ErrorNum) - } - return nil, errorCode, err + if _, err := c.initiateDownloadBookmark(bookmarkRaw); err != nil { + return nil, fmt.Errorf("initiateDownloadBookmark: %w", err) } for l2Block == nil { select { case <-c.ctx.Done(): - errorCode := -1 - if re != nil { - errorCode = int(re.ErrorNum) - } - return l2Block, errorCode, nil + return l2Block, fmt.Errorf("context done - stopping") default: } - parsedEntry, err := ReadParsedProto(c) + parsedEntry, _, err := ReadParsedProto(c) if err != nil { - return nil, -1, err + return nil, fmt.Errorf("ReadParsedProto: %w", err) } l2Block, isL2Block = parsedEntry.(*types.FullL2Block) @@ -147,41 +177,88 @@ func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, } if l2Block.L2BlockNumber != blockNum { - return nil, -1, fmt.Errorf("expected block number %d but got %d", blockNum, l2Block.L2BlockNumber) + return nil, fmt.Errorf("expected block number %d but got %d", blockNum, l2Block.L2BlockNumber) } - return l2Block, types.CmdErrOK, nil + return l2Block, nil } // GetLatestL2Block queries the data stream by reading the header entry and based on total entries field, // it retrieves the latest File entry that is of EntryTypeL2Block type. // Note that this function is intended for on demand querying and it disposes the connection after it ends. func (c *StreamClient) GetLatestL2Block() (l2Block *types.FullL2Block, err error) { - if _, err := c.EnsureConnected(); err != nil { - return nil, err + var ( + connected bool = c.conn != nil + ) + count := 0 + for { + select { + case <-c.ctx.Done(): + return nil, errors.New("context done - stopping") + default: + } + if count > 5 { + return nil, ErrFailedAttempts + } + if connected { + if err := c.stopStreamingIfStarted(); err != nil { + return nil, fmt.Errorf("stopStreamingIfStarted: %w", err) + } + + if l2Block, err = c.getLatestL2Block(); err == nil { + break + } + if !errors.Is(err, ErrSocket) { + return nil, fmt.Errorf("getLatestL2Block: %w", err) + } + } + + time.Sleep(1 * time.Second) + connected = c.handleSocketError(err) + count++ + } + return l2Block, nil +} + +// don't check for errors here, we just need to empty the socket for next reads +func (c *StreamClient) stopStreamingIfStarted() error { + if c.streaming.Load() { + c.sendStopCmd() + c.streaming.Store(false) + } + + // empty the socket buffer + for { + c.conn.SetReadDeadline(time.Now().Add(100)) + if _, err := c.readBuffer(100); err != nil { + break + } } - defer c.Stop() + return nil +} + +func (c *StreamClient) getLatestL2Block() (l2Block *types.FullL2Block, err error) { h, err := c.GetHeader() if err != nil { - return nil, err + return nil, fmt.Errorf("GetHeader: %w", err) } latestEntryNum := h.TotalEntries - 1 for l2Block == nil && latestEntryNum > 0 { if err := c.sendEntryCmdWrapper(latestEntryNum); err != nil { - return nil, err + return nil, fmt.Errorf("sendEntryCmdWrapper: %w", err) } entry, err := c.NextFileEntry() if err != nil { - return nil, err + return nil, fmt.Errorf("NextFileEntry: %w", err) } if entry.EntryType == types.EntryTypeL2Block { if l2Block, err = types.UnmarshalL2Block(entry.Data); err != nil { - return nil, err + return nil, fmt.Errorf("UnmarshalL2Block: %w", err) } } @@ -189,7 +266,7 @@ func (c *StreamClient) GetLatestL2Block() (l2Block *types.FullL2Block, err error } if latestEntryNum == 0 { - return nil, errors.New("failed to retrieve the latest block from the data stream") + return nil, errors.New("no block found") } return l2Block, nil @@ -198,9 +275,7 @@ func (c *StreamClient) GetLatestL2Block() (l2Block *types.FullL2Block, err error func (c *StreamClient) GetLastWrittenTimeAtomic() *atomic.Int64 { return &c.lastWrittenTime } -func (c *StreamClient) GetStreamingAtomic() *atomic.Bool { - return &c.streaming -} + func (c *StreamClient) GetProgressAtomic() *atomic.Uint64 { return &c.progress } @@ -211,11 +286,9 @@ func (c *StreamClient) Start() error { var err error c.conn, err = net.Dial("tcp", c.server) if err != nil { - return fmt.Errorf("error connecting to server %s: %v", c.server, err) + return fmt.Errorf("connecting to server %s: %w", c.server, err) } - c.id = c.conn.LocalAddr().String() - return nil } @@ -224,61 +297,59 @@ func (c *StreamClient) Stop() { return } if err := c.sendStopCmd(); err != nil { - log.Warn(fmt.Sprintf("Failed to send the stop command to the data stream server: %s", err)) + log.Warn(fmt.Sprintf("send stop command: %v", err)) } - c.conn.Close() - c.conn = nil - - c.clearEntryCHannel() + // c.conn.Close() + // c.conn = nil } // Command header: Get status // Returns the current status of the header. // If started, terminate the connection. func (c *StreamClient) GetHeader() (*types.HeaderEntry, error) { + if err := c.stopStreamingIfStarted(); err != nil { + return nil, fmt.Errorf("stopStreamingIfStarted: %w", err) + } + if err := c.sendHeaderCmd(); err != nil { - return nil, fmt.Errorf("%s send header error: %v", c.id, err) + return nil, fmt.Errorf("sendHeaderCmd: %w", err) } // Read packet - packet, err := readBuffer(c.conn, 1) + packet, err := c.readBuffer(1) if err != nil { - return nil, fmt.Errorf("%s read buffer: %v", c.id, err) + return nil, fmt.Errorf("readBuffer: %w", err) } // Check packet type if packet[0] != PtResult { - return nil, fmt.Errorf("%s error expecting result packet type %d and received %d", c.id, PtResult, packet[0]) + return nil, fmt.Errorf("expecting result packet type %d and received %d", PtResult, packet[0]) } // Read server result entry for the command - r, err := c.readResultEntry(packet) - if err != nil { - return nil, fmt.Errorf("%s read result entry error: %v", c.id, err) - } - if err := r.GetError(); err != nil { - return nil, fmt.Errorf("%s got Result error code %d: %v", c.id, r.ErrorNum, err) + if _, err := c.readResultEntry(packet); err != nil { + return nil, fmt.Errorf("readResultEntry: %w", err) } // Read header entry h, err := c.readHeaderEntry() if err != nil { - return nil, fmt.Errorf("%s read header entry error: %v", c.id, err) + return nil, fmt.Errorf("readHeaderEntry: %w", err) } + c.header = h + return h, nil } // sendEntryCmdWrapper sends CmdEntry command and reads packet type and decodes result entry. func (c *StreamClient) sendEntryCmdWrapper(entryNum uint64) error { if err := c.sendEntryCmd(entryNum); err != nil { - return err + return fmt.Errorf("sendEntryCmd: %w", err) } - if re, err := c.readPacketAndDecodeResultEntry(); err != nil { - return fmt.Errorf("failed to retrieve the result entry: %w", err) - } else if err := re.GetError(); err != nil { - return err + if _, err := c.readPacketAndDecodeResultEntry(); err != nil { + return fmt.Errorf("readPacketAndDecodeResultEntry: %w", err) } return nil @@ -288,16 +359,16 @@ func (c *StreamClient) ExecutePerFile(bookmark *types.BookmarkProto, function fu // Get header from server header, err := c.GetHeader() if err != nil { - return fmt.Errorf("%s get header error: %v", c.id, err) + return fmt.Errorf("GetHeader: %w", err) } protoBookmark, err := bookmark.Marshal() if err != nil { - return fmt.Errorf("failed to marshal bookmark: %v", err) + return fmt.Errorf("bookmark.Marshal: %w", err) } if _, err := c.initiateDownloadBookmark(protoBookmark); err != nil { - return err + return fmt.Errorf("initiateDownloadBookmark: %w", err) } count := uint64(0) logTicker := time.NewTicker(10 * time.Second) @@ -313,10 +384,10 @@ func (c *StreamClient) ExecutePerFile(bookmark *types.BookmarkProto, function fu } file, err := c.NextFileEntry() if err != nil { - return fmt.Errorf("reading file entry: %v", err) + return fmt.Errorf("NextFileEntry: %w", err) } if err := function(file); err != nil { - return fmt.Errorf("executing function: %v", err) + return fmt.Errorf("execute function: %w", err) } count++ @@ -326,45 +397,83 @@ func (c *StreamClient) ExecutePerFile(bookmark *types.BookmarkProto, function fu } func (c *StreamClient) clearEntryCHannel() { - select { - case <-c.entryChan: - close(c.entryChan) + defer func() { for range c.entryChan { } - default: - } + }() + defer func() { + if r := recover(); r != nil { + log.Warn("[datastream_client] Channel is already closed") + } + }() + + close(c.entryChan) } // close old entry chan and read all elements before opening a new one -func (c *StreamClient) renewEntryChannel() { +func (c *StreamClient) RenewEntryChannel() { c.clearEntryCHannel() c.entryChan = make(chan interface{}, entryChannelSize) } -func (c *StreamClient) EnsureConnected() (bool, error) { - if c.conn == nil { - if err := c.tryReConnect(); err != nil { - return false, fmt.Errorf("failed to reconnect the datastream client: %w", err) +func (c *StreamClient) ReadAllEntriesToChannel() (err error) { + var ( + connected bool = c.conn != nil + ) + count := 0 + for { + select { + case <-c.ctx.Done(): + return fmt.Errorf("context done - stopping") + default: + } + if connected { + if err := c.stopStreamingIfStarted(); err != nil { + return fmt.Errorf("stopStreamingIfStarted: %w", err) + } + + if err = c.readAllEntriesToChannel(); err == nil { + break + } + if !errors.Is(err, ErrSocket) { + return fmt.Errorf("readAllEntriesToChannel: %w", err) + } } - c.renewEntryChannel() + time.Sleep(1 * time.Second) + connected = c.handleSocketError(err) + count++ } - return true, nil + return nil +} + +func (c *StreamClient) handleSocketError(socketErr error) bool { + if socketErr != nil { + log.Warn(fmt.Sprintf("%v", socketErr)) + } + if err := c.tryReConnect(); err != nil { + log.Warn(fmt.Sprintf("try reconnect: %v", err)) + return false + } + + c.RenewEntryChannel() + + return true } // reads entries to the end of the stream // at end will wait for new entries to arrive -func (c *StreamClient) ReadAllEntriesToChannel() error { +func (c *StreamClient) readAllEntriesToChannel() (err error) { c.streaming.Store(true) - defer c.streaming.Store(false) + c.stopReadingToChannel.Store(false) var bookmark *types.BookmarkProto progress := c.progress.Load() if progress == 0 { bookmark = types.NewBookmarkProto(0, datastream.BookmarkType_BOOKMARK_TYPE_BATCH) } else { - bookmark = types.NewBookmarkProto(progress, datastream.BookmarkType_BOOKMARK_TYPE_L2_BLOCK) + bookmark = types.NewBookmarkProto(progress+1, datastream.BookmarkType_BOOKMARK_TYPE_L2_BLOCK) } protoBookmark, err := bookmark.Marshal() @@ -374,64 +483,50 @@ func (c *StreamClient) ReadAllEntriesToChannel() error { // send start command if _, err := c.initiateDownloadBookmark(protoBookmark); err != nil { - return err + return fmt.Errorf("initiateDownloadBookmark: %w", err) } if err := c.readAllFullL2BlocksToChannel(); err != nil { - err2 := fmt.Errorf("%s read full L2 blocks error: %v", c.id, err) - - if c.conn != nil { - if err2 := c.conn.Close(); err2 != nil { - log.Error("failed to close connection after error", "original-error", err, "new-error", err2) - } - c.conn = nil - } - - return err2 + return fmt.Errorf("readAllFullL2BlocksToChannel: %w", err) } - return nil + return } // runs the prerequisites for entries download func (c *StreamClient) initiateDownloadBookmark(bookmark []byte) (*types.ResultEntry, error) { // send CmdStartBookmark command if err := c.sendBookmarkCmd(bookmark, true); err != nil { - return nil, err + return nil, fmt.Errorf("sendBookmarkCmd: %w", err) } re, err := c.afterStartCommand() if err != nil { - return re, fmt.Errorf("after start command error: %v", err) + return re, fmt.Errorf("afterStartCommand: %w", err) } return re, nil } func (c *StreamClient) afterStartCommand() (*types.ResultEntry, error) { - re, err := c.readPacketAndDecodeResultEntry() - if err != nil { - return nil, err - } - - if err := re.GetError(); err != nil { - return re, fmt.Errorf("got Result error code %d: %v", re.ErrorNum, err) - } - - return re, nil + return c.readPacketAndDecodeResultEntry() } // reads all entries from the server and sends them to a channel // sends the parsed FullL2Blocks with transactions to a channel -func (c *StreamClient) readAllFullL2BlocksToChannel() error { - var err error - +func (c *StreamClient) readAllFullL2BlocksToChannel() (err error) { + readNewProto := true + entryNum := uint64(0) + parsedProto := interface{}(nil) LOOP: for { select { default: case <-c.ctx.Done(): - log.Warn("[Datastream client] Context done - stopping") + return fmt.Errorf("context done - stopping") + } + + if c.stopReadingToChannel.Load() { break LOOP } @@ -439,52 +534,70 @@ LOOP: c.conn.SetReadDeadline(time.Now().Add(c.checkTimeout)) } - parsedProto, localErr := ReadParsedProto(c) - if localErr != nil { - err = localErr - break + if readNewProto { + if parsedProto, entryNum, err = ReadParsedProto(c); err != nil { + return err + } + readNewProto = false } c.lastWrittenTime.Store(time.Now().UnixNano()) switch parsedProto := parsedProto.(type) { case *types.BookmarkProto: + readNewProto = true continue case *types.BatchStart: c.currentFork = parsedProto.ForkId - c.entryChan <- parsedProto case *types.GerUpdate: - c.entryChan <- parsedProto case *types.BatchEnd: - c.entryChan <- parsedProto case *types.FullL2Block: parsedProto.ForkId = c.currentFork - log.Trace("writing block to channel", "blockNumber", parsedProto.L2BlockNumber, "batchNumber", parsedProto.BatchNumber) - c.entryChan <- parsedProto + log.Trace("[Datastream client] writing block to channel", "blockNumber", parsedProto.L2BlockNumber, "batchNumber", parsedProto.BatchNumber) + default: + return fmt.Errorf("unexpected entry type: %v", parsedProto) + } + select { + case c.entryChan <- parsedProto: + readNewProto = true default: - err = fmt.Errorf("unexpected entry type: %v", parsedProto) + time.Sleep(10 * time.Microsecond) + } + + if c.header.TotalEntries == entryNum+1 { + log.Trace("[Datastream client] reached the current end of the stream", "header_totalEntries", c.header.TotalEntries, "entryNum", entryNum) + + retries := 0 + INTERNAL_LOOP: + for { + select { + case c.entryChan <- nil: + break INTERNAL_LOOP + default: + if retries > 5 { + return errors.New("[Datastream client] failed to write final entry to channel after 5 retries") + } + retries++ + log.Warn("[Datastream client] Channel is full, waiting to write nil and end stream client read") + time.Sleep(1 * time.Second) + } + } break LOOP } } - return err + return nil } -func (c *StreamClient) tryReConnect() error { - var err error - for i := 0; i < 50; i++ { - if c.conn != nil { - if err := c.conn.Close(); err != nil { - log.Warn(fmt.Sprintf("[%d. iteration] failed to close the DS connection: %s", i+1, err)) - return err - } - c.conn = nil - } - if err = c.Start(); err != nil { - log.Warn(fmt.Sprintf("[%d. iteration] failed to start the DS connection: %s", i+1, err)) - time.Sleep(5 * time.Second) - continue +func (c *StreamClient) tryReConnect() (err error) { + if c.conn != nil { + if err := c.conn.Close(); err != nil { + log.Warn(fmt.Sprintf("close DS connection: %v", err)) + return err } - return nil + c.conn = nil + } + if err = c.Start(); err != nil { + log.Warn(fmt.Sprintf("start DS connection: %v", err)) } return err @@ -496,21 +609,24 @@ func (c *StreamClient) StopReadingToChannel() { type FileEntryIterator interface { NextFileEntry() (*types.FileEntry, error) + GetEntryNumberLimit() uint64 } func ReadParsedProto(iterator FileEntryIterator) ( parsedEntry interface{}, + entryNum uint64, err error, ) { file, err := iterator.NextFileEntry() if err != nil { - err = fmt.Errorf("read file entry error: %w", err) + err = fmt.Errorf("NextFileEntry: %w", err) return } if file == nil { - return nil, nil + return } + entryNum = file.EntryNum switch file.EntryType { case types.BookmarkEntryType: @@ -524,6 +640,7 @@ func ReadParsedProto(iterator FileEntryIterator) ( case types.EntryTypeL2Block: var l2Block *types.FullL2Block if l2Block, err = types.UnmarshalL2Block(file.Data); err != nil { + err = fmt.Errorf("UnmarshalL2Block: %w", err) return } @@ -534,17 +651,20 @@ func ReadParsedProto(iterator FileEntryIterator) ( LOOP: for { if innerFile, err = iterator.NextFileEntry(); err != nil { + err = fmt.Errorf("NextFileEntry: %w", err) return } - + entryNum = innerFile.EntryNum if innerFile.IsL2Tx() { if l2Tx, err = types.UnmarshalTx(innerFile.Data); err != nil { + err = fmt.Errorf("UnmarshalTx: %w", err) return } txs = append(txs, *l2Tx) } else if innerFile.IsL2BlockEnd() { var l2BlockEnd *types.L2BlockEndProto if l2BlockEnd, err = types.UnmarshalL2BlockEnd(innerFile.Data); err != nil { + err = fmt.Errorf("UnmarshalL2BlockEnd: %w", err) return } if l2BlockEnd.GetBlockNumber() != l2Block.L2BlockNumber { @@ -555,6 +675,11 @@ func ReadParsedProto(iterator FileEntryIterator) ( } else if innerFile.IsBookmark() { var bookmark *types.BookmarkProto if bookmark, err = types.UnmarshalBookmark(innerFile.Data); err != nil || bookmark == nil { + if err != nil { + err = fmt.Errorf("UnmarshalBookmark: %w", err) + } else { + err = fmt.Errorf("unexpected nil bookmark") + } return } if bookmark.BookmarkType() == datastream.BookmarkType_BOOKMARK_TYPE_L2_BLOCK { @@ -565,6 +690,7 @@ func ReadParsedProto(iterator FileEntryIterator) ( } } else if innerFile.IsBatchEnd() { if _, err = types.UnmarshalBatchEnd(file.Data); err != nil { + err = fmt.Errorf("UnmarshalBatchEnd: %w", err) return } break LOOP @@ -572,6 +698,9 @@ func ReadParsedProto(iterator FileEntryIterator) ( err = fmt.Errorf("unexpected entry type inside a block: %d", innerFile.EntryType) return } + if entryNum == iterator.GetEntryNumberLimit() { + break LOOP + } } l2Block.L2Txs = txs @@ -579,12 +708,17 @@ func ReadParsedProto(iterator FileEntryIterator) ( return case types.EntryTypeL2BlockEnd: log.Debug(fmt.Sprintf("retrieved EntryTypeL2BlockEnd: %+v", file)) + parsedEntry, err = types.UnmarshalL2BlockEnd(file.Data) + if err != nil { + err = fmt.Errorf("UnmarshalL2BlockEnd: %w", err) + } return case types.EntryTypeL2Tx: err = errors.New("unexpected L2 tx entry, found outside of block") default: err = fmt.Errorf("unexpected entry type: %d", file.EntryType) } + return } @@ -592,21 +726,17 @@ func ReadParsedProto(iterator FileEntryIterator) ( // returns the parsed FileEntry func (c *StreamClient) NextFileEntry() (file *types.FileEntry, err error) { // Read packet type - packet, err := readBuffer(c.conn, 1) + packet, err := c.readBuffer(1) if err != nil { - return file, fmt.Errorf("failed to read packet type: %v", err) + return file, fmt.Errorf("readBuffer: %w", err) } packetType := packet[0] // Check packet type if packetType == PtResult { // Read server result entry for the command - r, err := c.readResultEntry(packet) - if err != nil { - return file, err - } - if err := r.GetError(); err != nil { - return file, fmt.Errorf("got Result error code %d: %v", r.ErrorNum, err) + if _, err := c.readResultEntry(packet); err != nil { + return file, fmt.Errorf("readResultEntry: %w", err) } return file, nil } else if packetType != PtData && packetType != PtDataRsp { @@ -614,9 +744,9 @@ func (c *StreamClient) NextFileEntry() (file *types.FileEntry, err error) { } // Read the rest of fixed size fields - buffer, err := readBuffer(c.conn, types.FileEntryMinSize-1) + buffer, err := c.readBuffer(types.FileEntryMinSize - 1) if err != nil { - return file, fmt.Errorf("error reading file bytes: %v", err) + return file, fmt.Errorf("reading file bytes: readBuffer: %w", err) } if packetType != PtData { @@ -627,19 +757,19 @@ func (c *StreamClient) NextFileEntry() (file *types.FileEntry, err error) { // Read variable field (data) length := binary.BigEndian.Uint32(buffer[1:5]) if length < types.FileEntryMinSize { - return file, errors.New("error reading data entry: wrong data length") + return file, errors.New("reading data entry: wrong data length") } // Read rest of the file data - bufferAux, err := readBuffer(c.conn, length-types.FileEntryMinSize) + bufferAux, err := c.readBuffer(length - types.FileEntryMinSize) if err != nil { - return file, fmt.Errorf("error reading file data bytes: %v", err) + return file, fmt.Errorf("reading file data bytes: readBuffer: %w", err) } buffer = append(buffer, bufferAux...) // Decode binary data to data entry struct if file, err = types.DecodeFileEntry(buffer); err != nil { - return file, fmt.Errorf("decode file entry error: %v", err) + return file, fmt.Errorf("DecodeFileEntry: %w", err) } if file.EntryType == types.EntryTypeNotFound { @@ -654,24 +784,24 @@ func (c *StreamClient) NextFileEntry() (file *types.FileEntry, err error) { func (c *StreamClient) readHeaderEntry() (h *types.HeaderEntry, err error) { // Read header stream bytes - binaryHeader, err := readBuffer(c.conn, types.HeaderSizePreEtrog) + binaryHeader, err := c.readBuffer(types.HeaderSizePreEtrog) if err != nil { - return h, fmt.Errorf("failed to read header bytes %v", err) + return h, fmt.Errorf("read header bytes: readBuffer: %w", err) } headLength := binary.BigEndian.Uint32(binaryHeader[1:5]) if headLength == types.HeaderSize { // Read the rest of fixed size fields - buffer, err := readBuffer(c.conn, types.HeaderSize-types.HeaderSizePreEtrog) + buffer, err := c.readBuffer(types.HeaderSize - types.HeaderSizePreEtrog) if err != nil { - return h, fmt.Errorf("failed to read header bytes %v", err) + return h, fmt.Errorf("read header bytes: readBuffer: %w", err) } binaryHeader = append(binaryHeader, buffer...) } // Decode bytes stream to header entry struct if h, err = types.DecodeHeaderEntry(binaryHeader); err != nil { - return h, fmt.Errorf("error decoding binary header: %v", err) + return h, fmt.Errorf("DecodeHeaderEntry: %w", err) } return @@ -685,28 +815,45 @@ func (c *StreamClient) readResultEntry(packet []byte) (re *types.ResultEntry, er } // Read the rest of fixed size fields - buffer, err := readBuffer(c.conn, types.ResultEntryMinSize-1) + buffer, err := c.readBuffer(types.ResultEntryMinSize - 1) if err != nil { - return re, fmt.Errorf("failed to read main result bytes %v", err) + return re, fmt.Errorf("read main result bytes: readBuffer: %w", err) } buffer = append(packet, buffer...) // Read variable field (errStr) length := binary.BigEndian.Uint32(buffer[1:5]) if length < types.ResultEntryMinSize { - return re, fmt.Errorf("%s Error reading result entry", c.id) + return re, errors.New("failed reading result entry") } // read the rest of the result - bufferAux, err := readBuffer(c.conn, length-types.ResultEntryMinSize) + bufferAux, err := c.readBuffer(length - types.ResultEntryMinSize) if err != nil { - return re, fmt.Errorf("failed to read result errStr bytes %v", err) + return re, fmt.Errorf("read result errStr bytes: readBuffer: %w", err) } buffer = append(buffer, bufferAux...) // Decode binary entry result if re, err = types.DecodeResultEntry(buffer); err != nil { - return re, fmt.Errorf("decode result entry error: %v", err) + return re, fmt.Errorf("DecodeResultEntry: %w", err) + } + + if !re.IsOk() { + switch re.ErrorNum { + case types.CmdErrAlreadyStarted: + return re, fmt.Errorf("%w: %s", types.ErrAlreadyStarted, re.ErrorStr) + case types.CmdErrAlreadyStopped: + return re, fmt.Errorf("%w: %s", types.ErrAlreadyStopped, re.ErrorStr) + case types.CmdErrBadFromEntry: + return re, fmt.Errorf("%w: %s", types.ErrBadFromEntry, re.ErrorStr) + case types.CmdErrBadFromBookmark: + return re, fmt.Errorf("%w: %s", types.ErrBadFromBookmark, re.ErrorStr) + case types.CmdErrInvalidCommand: + return re, fmt.Errorf("%w: %s", types.ErrInvalidCommand, re.ErrorStr) + default: + return re, fmt.Errorf("unknown error code: %s", re.ErrorStr) + } } return re, nil @@ -715,16 +862,71 @@ func (c *StreamClient) readResultEntry(packet []byte) (re *types.ResultEntry, er // readPacketAndDecodeResultEntry reads the packet from the connection and tries to decode the ResultEntry from it. func (c *StreamClient) readPacketAndDecodeResultEntry() (*types.ResultEntry, error) { // Read packet - packet, err := readBuffer(c.conn, 1) + packet, err := c.readBuffer(1) if err != nil { - return nil, fmt.Errorf("read buffer error: %w", err) + return nil, fmt.Errorf("read buffer: %w", err) } // Read server result entry for the command r, err := c.readResultEntry(packet) if err != nil { - return nil, fmt.Errorf("read result entry error: %w", err) + return nil, fmt.Errorf("readResultEntry: %w", err) } return r, nil } + +func (c *StreamClient) readBuffer(amount uint32) ([]byte, error) { + if err := c.resetReadTimeout(); err != nil { + return nil, fmt.Errorf("resetReadTimeout: %w", err) + } + return readBuffer(c.conn, amount) +} + +func (c *StreamClient) writeToConn(data interface{}) error { + if err := c.resetWriteTimeout(); err != nil { + return fmt.Errorf("resetWriteTimeout: %w", err) + } + switch parsed := data.(type) { + case []byte: + if err := writeBytesToConn(c.conn, parsed); err != nil { + return fmt.Errorf("writeBytesToConn: %w", err) + } + case uint32: + if err := writeFullUint32ToConn(c.conn, parsed); err != nil { + return fmt.Errorf("writeFullUint32ToConn: %w", err) + } + case uint64: + if err := writeFullUint64ToConn(c.conn, parsed); err != nil { + return fmt.Errorf("writeFullUint64ToConn: %w", err) + } + default: + return errors.New("unexpected write type") + } + + return nil +} + +func (c *StreamClient) resetWriteTimeout() error { + if c.checkTimeout == 0 { + return nil + } + + if err := c.conn.SetWriteDeadline(time.Now().Add(c.checkTimeout)); err != nil { + return fmt.Errorf("%w: conn.SetWriteDeadline: %v", ErrSocket, err) + } + + return nil +} + +func (c *StreamClient) resetReadTimeout() error { + if c.checkTimeout == 0 { + return nil + } + + if err := c.conn.SetReadDeadline(time.Now().Add(c.checkTimeout)); err != nil { + return fmt.Errorf("%w: conn.SetReadDeadline: %v", ErrSocket, err) + } + + return nil +} diff --git a/zk/datastream/client/stream_client_test.go b/zk/datastream/client/stream_client_test.go index 05e0cd80149..f8078889e6b 100644 --- a/zk/datastream/client/stream_client_test.go +++ b/zk/datastream/client/stream_client_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/binary" - "errors" "fmt" "net" "sync" @@ -27,7 +26,7 @@ func TestStreamClientReadHeaderEntry(t *testing.T) { name string input []byte expectedResult *types.HeaderEntry - expectedError error + expectedError string } testCases := []testCase{ { @@ -40,18 +39,18 @@ func TestStreamClientReadHeaderEntry(t *testing.T) { TotalLength: 24, TotalEntries: 64, }, - expectedError: nil, + expectedError: "", }, { name: "Invalid byte array length", input: []byte{20, 21, 22, 23, 24, 20}, expectedResult: nil, - expectedError: errors.New("failed to read header bytes reading from server: unexpected EOF"), + expectedError: "read header bytes: readBuffer: socket error: io.ReadFull: unexpected EOF", }, } for _, testCase := range testCases { - c := NewClient(context.Background(), "", 0, 0, 0) + c := NewClient(context.Background(), "", 0, 2*time.Second, 0) server, conn := net.Pipe() defer server.Close() defer c.Stop() @@ -64,7 +63,11 @@ func TestStreamClientReadHeaderEntry(t *testing.T) { }() header, err := c.readHeaderEntry() - require.Equal(t, testCase.expectedError, err) + if testCase.expectedError != "" { + require.EqualError(t, err, testCase.expectedError) + } else { + require.NoError(t, err) + } assert.DeepEqual(t, testCase.expectedResult, header) }) } @@ -75,7 +78,7 @@ func TestStreamClientReadResultEntry(t *testing.T) { name string input []byte expectedResult *types.ResultEntry - expectedError error + expectedError string } testCases := []testCase{ { @@ -87,7 +90,7 @@ func TestStreamClientReadResultEntry(t *testing.T) { ErrorNum: 0, ErrorStr: []byte{}, }, - expectedError: nil, + expectedError: "", }, { name: "Happy path - error str length", @@ -98,24 +101,24 @@ func TestStreamClientReadResultEntry(t *testing.T) { ErrorNum: 0, ErrorStr: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, }, - expectedError: nil, + expectedError: "", }, { name: "Invalid byte array length", input: []byte{20, 21, 22, 23, 24, 20}, expectedResult: nil, - expectedError: errors.New("failed to read main result bytes reading from server: unexpected EOF"), + expectedError: "read main result bytes: readBuffer: socket error: io.ReadFull: unexpected EOF", }, { name: "Invalid error length", input: []byte{0, 0, 0, 12, 0, 0, 0, 0, 20, 21}, expectedResult: nil, - expectedError: errors.New("failed to read result errStr bytes reading from server: unexpected EOF"), + expectedError: "read result errStr bytes: readBuffer: socket error: io.ReadFull: unexpected EOF", }, } for _, testCase := range testCases { - c := NewClient(context.Background(), "", 0, 0, 0) + c := NewClient(context.Background(), "", 0, 2*time.Second, 0) server, conn := net.Pipe() defer server.Close() defer c.Stop() @@ -128,7 +131,11 @@ func TestStreamClientReadResultEntry(t *testing.T) { }() result, err := c.readResultEntry([]byte{1}) - require.Equal(t, testCase.expectedError, err) + if testCase.expectedError != "" { + require.EqualError(t, err, testCase.expectedError) + } else { + require.NoError(t, err) + } assert.DeepEqual(t, testCase.expectedResult, result) }) } @@ -139,7 +146,7 @@ func TestStreamClientReadFileEntry(t *testing.T) { name string input []byte expectedResult *types.FileEntry - expectedError error + expectedError string } testCases := []testCase{ { @@ -152,7 +159,7 @@ func TestStreamClientReadFileEntry(t *testing.T) { EntryNum: 45, Data: []byte{0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 64}, }, - expectedError: nil, + expectedError: "", }, { name: "Happy path - no data", input: []byte{2, 0, 0, 0, 17, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 45}, @@ -163,28 +170,28 @@ func TestStreamClientReadFileEntry(t *testing.T) { EntryNum: 45, Data: []byte{}, }, - expectedError: nil, + expectedError: "", }, { name: "Invalid packet type", input: []byte{5, 0, 0, 0, 17, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 45}, expectedResult: nil, - expectedError: errors.New("expected data packet type 2 or 254 and received 5"), + expectedError: "expected data packet type 2 or 254 and received 5", }, { name: "Invalid byte array length", input: []byte{2, 21, 22, 23, 24, 20}, expectedResult: nil, - expectedError: errors.New("error reading file bytes: reading from server: unexpected EOF"), + expectedError: "reading file bytes: readBuffer: socket error: io.ReadFull: unexpected EOF", }, { name: "Invalid data length", input: []byte{2, 0, 0, 0, 31, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 45, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 64}, expectedResult: nil, - expectedError: errors.New("error reading file data bytes: reading from server: unexpected EOF"), + expectedError: "reading file data bytes: readBuffer: socket error: io.ReadFull: unexpected EOF", }, } for _, testCase := range testCases { - c := NewClient(context.Background(), "", 0, 0, 0) + c := NewClient(context.Background(), "", 0, 2*time.Second, 0) server, conn := net.Pipe() defer c.Stop() defer server.Close() @@ -197,16 +204,25 @@ func TestStreamClientReadFileEntry(t *testing.T) { }() result, err := c.NextFileEntry() - require.Equal(t, testCase.expectedError, err) + if testCase.expectedError != "" { + require.EqualError(t, err, testCase.expectedError) + } else { + require.NoError(t, err) + } assert.DeepEqual(t, testCase.expectedResult, result) }) } } func TestStreamClientReadParsedProto(t *testing.T) { - c := NewClient(context.Background(), "", 0, 0, 0) + c := NewClient(context.Background(), "", 0, 2*time.Second, 0) serverConn, clientConn := net.Pipe() c.conn = clientConn + c.checkTimeout = 1 * time.Second + + c.header = &types.HeaderEntry{ + TotalEntries: 3, + } defer func() { serverConn.Close() clientConn.Close() @@ -253,7 +269,7 @@ func TestStreamClientReadParsedProto(t *testing.T) { close(errCh) }() - parsedEntry, err := ReadParsedProto(c) + parsedEntry, entryNum, err := ReadParsedProto(c) require.NoError(t, err) serverErr := <-errCh require.NoError(t, serverErr) @@ -261,6 +277,7 @@ func TestStreamClientReadParsedProto(t *testing.T) { expectedL2Block := types.ConvertToFullL2Block(l2Block) expectedL2Block.L2Txs = append(expectedL2Block.L2Txs, *expectedL2Tx) require.Equal(t, expectedL2Block, parsedEntry) + require.Equal(t, uint64(3), entryNum) } func TestStreamClientGetLatestL2Block(t *testing.T) { @@ -270,9 +287,9 @@ func TestStreamClientGetLatestL2Block(t *testing.T) { clientConn.Close() }() - c := NewClient(context.Background(), "", 0, 0, 0) + c := NewClient(context.Background(), "", 0, 2*time.Second, 0) c.conn = clientConn - + c.checkTimeout = 1 * time.Second expectedL2Block, _ := createL2BlockAndTransactions(t, 5, 0) l2BlockProto := &types.L2BlockProto{L2Block: expectedL2Block} l2BlockRaw, err := l2BlockProto.Marshal() @@ -383,9 +400,12 @@ func TestStreamClientGetL2BlockByNumber(t *testing.T) { clientConn.Close() }() - c := NewClient(context.Background(), "", 0, 0, 0) + c := NewClient(context.Background(), "", 0, 2*time.Second, 0) + c.header = &types.HeaderEntry{ + TotalEntries: 4, + } c.conn = clientConn - + c.checkTimeout = 1 * time.Second bookmark := types.NewBookmarkProto(blockNum, datastream.BookmarkType_BOOKMARK_TYPE_L2_BLOCK) bookmarkRaw, err := bookmark.Marshal() require.NoError(t, err) @@ -472,13 +492,15 @@ func TestStreamClientGetL2BlockByNumber(t *testing.T) { go createServerResponses(t, serverConn, bookmarkRaw, l2BlockRaw, l2TxsRaw, l2BlockEndRaw, errCh) - l2Block, errCode, err := c.GetL2BlockByNumber(blockNum) + l2Block, err := c.GetL2BlockByNumber(blockNum) require.NoError(t, err) - require.Equal(t, types.CmdErrOK, errCode) - serverErr := <-errCh + var serverErr error + select { + case serverErr = <-errCh: + default: + } require.NoError(t, serverErr) - l2TxsProto := make([]types.L2TransactionProto, len(l2Txs)) for i, tx := range l2Txs { l2TxProto := types.ConvertToL2TransactionProto(tx) diff --git a/zk/datastream/client/utils.go b/zk/datastream/client/utils.go index e96e623e202..e27323865fe 100644 --- a/zk/datastream/client/utils.go +++ b/zk/datastream/client/utils.go @@ -8,32 +8,33 @@ import ( "net" ) -// writeFullUint64ToConn writes a uint64 to a connection +var ( + ErrSocket = errors.New("socket error") + ErrNilConnection = errors.New("nil connection") +) + func writeFullUint64ToConn(conn net.Conn, value uint64) error { buffer := make([]byte, 8) binary.BigEndian.PutUint64(buffer, value) if conn == nil { - return errors.New("error nil connection") + return fmt.Errorf("%w: %w", ErrSocket, ErrNilConnection) } - _, err := conn.Write(buffer) - if err != nil { - return fmt.Errorf("%s Error sending to server: %v", conn.RemoteAddr().String(), err) + if _, err := conn.Write(buffer); err != nil { + return fmt.Errorf("%w: conn.Write: %v", ErrSocket, err) } return nil } -// writeFullUint64ToConn writes a uint64 to a connection func writeBytesToConn(conn net.Conn, value []byte) error { if conn == nil { - return errors.New("error nil connection") + return fmt.Errorf("%w: %w", ErrSocket, ErrNilConnection) } - _, err := conn.Write(value) - if err != nil { - return fmt.Errorf("%s Error sending to server: %v", conn.RemoteAddr().String(), err) + if _, err := conn.Write(value); err != nil { + return fmt.Errorf("%w: conn.Write: %w", ErrSocket, err) } return nil @@ -45,12 +46,11 @@ func writeFullUint32ToConn(conn net.Conn, value uint32) error { binary.BigEndian.PutUint32(buffer, value) if conn == nil { - return errors.New("error nil connection") + return fmt.Errorf("%w: %w", ErrSocket, ErrNilConnection) } - _, err := conn.Write(buffer) - if err != nil { - return fmt.Errorf("%s Error sending to server: %v", conn.RemoteAddr().String(), err) + if _, err := conn.Write(buffer); err != nil { + return fmt.Errorf("%w: conn.Write: %w", ErrSocket, err) } return nil @@ -61,7 +61,7 @@ func readBuffer(conn net.Conn, n uint32) ([]byte, error) { buffer := make([]byte, n) rbc, err := io.ReadFull(conn, buffer) if err != nil { - return []byte{}, parseIoReadError(err) + return []byte{}, fmt.Errorf("%w: io.ReadFull: %w", ErrSocket, err) } if uint32(rbc) != n { @@ -70,12 +70,3 @@ func readBuffer(conn net.Conn, n uint32) ([]byte, error) { return buffer, nil } - -// parseIoReadError parses an error returned from io.ReadFull and returns a more concrete one -func parseIoReadError(err error) error { - if err == io.EOF { - return errors.New("server close connection") - } else { - return fmt.Errorf("reading from server: %v", err) - } -} diff --git a/zk/datastream/client/utils_test.go b/zk/datastream/client/utils_test.go index 3047bc25557..89ec1613137 100644 --- a/zk/datastream/client/utils_test.go +++ b/zk/datastream/client/utils_test.go @@ -1,8 +1,6 @@ package client import ( - "errors" - "fmt" "io" "net" "testing" @@ -27,10 +25,10 @@ func Test_WriteFullUint64ToConn(t *testing.T) { expectedError: nil, }, { - name: "happy path", + name: "nil connection", input: 10, shouldOpenConn: false, - expectedError: errors.New("error nil connection"), + expectedError: ErrNilConnection, }, } @@ -48,8 +46,9 @@ func Test_WriteFullUint64ToConn(t *testing.T) { err = writeFullUint64ToConn(client, testCase.input) } else { err = writeFullUint64ToConn(nil, testCase.input) + require.ErrorIs(t, err, ErrSocket) } - require.Equal(t, testCase.expectedError, err) + require.ErrorIs(t, err, testCase.expectedError) }) } } @@ -70,10 +69,10 @@ func Test_WriteFullUint32ToConn(t *testing.T) { expectedError: nil, }, { - name: "happy path", + name: "nil connection", input: 10, shouldOpenConn: false, - expectedError: errors.New("error nil connection"), + expectedError: ErrNilConnection, }, } @@ -91,8 +90,53 @@ func Test_WriteFullUint32ToConn(t *testing.T) { err = writeFullUint32ToConn(client, testCase.input) } else { err = writeFullUint32ToConn(nil, testCase.input) + require.ErrorIs(t, err, ErrSocket) + } + require.ErrorIs(t, err, testCase.expectedError) + }) + } +} + +func Test_WriteBytesToConn(t *testing.T) { + type testCase struct { + name string + input []byte + shouldOpenConn bool + expectedError error + } + + testCases := []testCase{ + { + name: "happy path", + input: []byte{1, 2, 3, 4, 5}, + shouldOpenConn: true, + expectedError: nil, + }, + { + name: "nil connection", + input: []byte{1, 2, 3, 4, 5}, + shouldOpenConn: false, + expectedError: ErrNilConnection, + }, + } + + for _, testCase := range testCases { + server, client := net.Pipe() + defer server.Close() + t.Run(testCase.name, func(t *testing.T) { + go func() { + buffer := make([]byte, len(testCase.input)) + io.ReadFull(server, buffer) + }() + + var err error + if testCase.shouldOpenConn { + err = writeBytesToConn(client, testCase.input) + } else { + err = writeBytesToConn(nil, testCase.input) + require.ErrorIs(t, err, ErrSocket) } - require.Equal(t, testCase.expectedError, err) + require.ErrorIs(t, err, testCase.expectedError) }) } } @@ -122,7 +166,7 @@ func Test_ReadBuffer(t *testing.T) { name: "test error", input: 6, expectedResult: []byte{}, - expectedError: fmt.Errorf("reading from server: %v", io.ErrUnexpectedEOF), + expectedError: io.ErrUnexpectedEOF, }, } @@ -136,36 +180,11 @@ func Test_ReadBuffer(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { result, err := readBuffer(client, testCase.input) - require.Equal(t, testCase.expectedError, err) + require.ErrorIs(t, err, testCase.expectedError) + if testCase.expectedError != nil { + require.ErrorIs(t, err, ErrSocket) + } assert.DeepEqual(t, testCase.expectedResult, result) }) } } - -func Test_ParseIoReadError(t *testing.T) { - type testCase struct { - name string - input error - expectedError error - } - - testCases := []testCase{ - { - name: "io error", - input: io.EOF, - expectedError: errors.New("server close connection"), - }, - { - name: "test error", - input: errors.New("test error"), - expectedError: errors.New("reading from server: test error"), - }, - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - result := parseIoReadError(testCase.input) - require.Equal(t, testCase.expectedError, result) - }) - } -} diff --git a/zk/datastream/server/data_stream_server.go b/zk/datastream/server/data_stream_server.go index 9c559d2da3c..93eb3c6c27c 100644 --- a/zk/datastream/server/data_stream_server.go +++ b/zk/datastream/server/data_stream_server.go @@ -624,6 +624,10 @@ func newDataStreamServerIterator(stream *datastreamer.StreamServer, start uint64 } } +func (it *dataStreamServerIterator) GetEntryNumberLimit() uint64 { + return it.header + 1 +} + func (it *dataStreamServerIterator) NextFileEntry() (entry *types.FileEntry, err error) { if it.curEntryNum > it.header { return nil, nil @@ -669,7 +673,7 @@ func ReadBatches(iterator client.FileEntryIterator, start uint64, end uint64) ([ LOOP_ENTRIES: for { - parsedProto, err := client.ReadParsedProto(iterator) + parsedProto, _, err := client.ReadParsedProto(iterator) if err != nil { return nil, err } diff --git a/zk/datastream/types/result.go b/zk/datastream/types/result.go index 1e6652cbb9d..6414acdf057 100644 --- a/zk/datastream/types/result.go +++ b/zk/datastream/types/result.go @@ -20,6 +20,14 @@ const ( CmdErrInvalidCommand = 9 // CmdErrInvalidCommand for invalid/unknown command error ) +var ( + ErrAlreadyStarted = errors.New("client already started") + ErrAlreadyStopped = errors.New("client already stopped") + ErrBadFromEntry = errors.New("invalid starting entry number") + ErrBadFromBookmark = errors.New("invalid starting bookmark") + ErrInvalidCommand = errors.New("invalid/unknown command") +) + type ResultEntry struct { PacketType uint8 // 0xff:Result Length uint32 diff --git a/zk/stages/stage_batches.go b/zk/stages/stage_batches.go index 2c0644e44c2..ed2b8291fa4 100644 --- a/zk/stages/stage_batches.go +++ b/zk/stages/stage_batches.go @@ -26,7 +26,6 @@ import ( "github.com/ledgerwatch/erigon/core/rawdb" "github.com/ledgerwatch/erigon/core/state" "github.com/ledgerwatch/erigon/eth/ethconfig" - "github.com/ledgerwatch/erigon/zk/datastream/client" "github.com/ledgerwatch/log/v3" ) @@ -59,13 +58,13 @@ type HermezDb interface { } type DatastreamClient interface { + RenewEntryChannel() ReadAllEntriesToChannel() error + StopReadingToChannel() GetEntryChan() *chan interface{} - GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, int, error) + GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, error) GetLatestL2Block() (*types.FullL2Block, error) - GetStreamingAtomic() *atomic.Bool GetProgressAtomic() *atomic.Uint64 - EnsureConnected() (bool, error) Start() error Stop() } @@ -73,7 +72,6 @@ type DatastreamClient interface { type DatastreamReadRunner interface { StartRead() StopRead() - RestartReadFromBlock(fromBlock uint64) } type dsClientCreatorHandler func(context.Context, *ethconfig.Zk, uint64) (DatastreamClient, error) @@ -153,6 +151,15 @@ func SpawnStageBatches( //// BISECT //// if cfg.zkCfg.DebugLimit > 0 && stageProgressBlockNo > cfg.zkCfg.DebugLimit { + log.Info(fmt.Sprintf("[%s] Debug limit reached", logPrefix), "stageProgressBlockNo", stageProgressBlockNo, "debugLimit", cfg.zkCfg.DebugLimit) + time.Sleep(2 * time.Second) + return nil + } + + // this limit is blocknumber not included, so up to limit-1 + if cfg.zkCfg.SyncLimit > 0 && stageProgressBlockNo+1 >= cfg.zkCfg.SyncLimit { + log.Info(fmt.Sprintf("[%s] Sync limit reached", logPrefix), "stageProgressBlockNo", stageProgressBlockNo, "syncLimit", cfg.zkCfg.SyncLimit) + time.Sleep(2 * time.Second) return nil } @@ -169,26 +176,55 @@ func SpawnStageBatches( return err } - dsQueryClient, err := newStreamClient(ctx, cfg, latestForkId) + dsQueryClient, stopDsClient, err := newStreamClient(ctx, cfg, latestForkId) if err != nil { log.Warn(fmt.Sprintf("[%s] %s", logPrefix, err)) return err } - defer dsQueryClient.Stop() + defer stopDsClient() - highestDSL2Block, err := dsQueryClient.GetLatestL2Block() - if err != nil { - return fmt.Errorf("failed to retrieve the latest datastream l2 block: %w", err) - } + var highestDSL2Block *types.FullL2Block + newBlockCheckStartTIme := time.Now() + for { + select { + case <-ctx.Done(): + return nil + default: + } + highestDSL2Block, err = dsQueryClient.GetLatestL2Block() + if err != nil { + // if we return error, stage will replay and block all other stages + log.Warn(fmt.Sprintf("[%s] Failed to get latest l2 block from datastream: %v", logPrefix, err)) + return nil + } - if highestDSL2Block.L2BlockNumber < stageProgressBlockNo { - stageProgressBlockNo = highestDSL2Block.L2BlockNumber + // a lower block should also break the loop because that means the datastream was unwound + // thus we should unwind as well and continue from there + if highestDSL2Block.L2BlockNumber != stageProgressBlockNo { + log.Info(fmt.Sprintf("[%s] Highest block in datastream", logPrefix), "datastreamBlock", highestDSL2Block.L2BlockNumber, "stageProgressBlockNo", stageProgressBlockNo) + break + } + if time.Since(newBlockCheckStartTIme) > 10*time.Second { + log.Info(fmt.Sprintf("[%s] Waiting for at least one new block in datastream", logPrefix), "datastreamBlock", highestDSL2Block.L2BlockNumber, "last processed block", stageProgressBlockNo) + newBlockCheckStartTIme = time.Now() + } + time.Sleep(1 * time.Second) } log.Debug(fmt.Sprintf("[%s] Highest block in db and datastream", logPrefix), "datastreamBlock", highestDSL2Block.L2BlockNumber, "dbBlock", stageProgressBlockNo) + unwindFn := func(unwindBlock uint64) (uint64, error) { + return rollback(logPrefix, eriDb, hermezDb, dsQueryClient, unwindBlock, tx, u) + } + if highestDSL2Block.L2BlockNumber < stageProgressBlockNo { + log.Info(fmt.Sprintf("[%s] Datastream behind, unwinding", logPrefix)) + if _, err := unwindFn(highestDSL2Block.L2BlockNumber); err != nil { + return err + } + return nil + } - dsClientProgress := cfg.dsClient.GetProgressAtomic() - dsClientProgress.Store(stageProgressBlockNo) + dsClientProgress := dsQueryClient.GetProgressAtomic() + dsClientProgress.Swap(stageProgressBlockNo) // start a routine to print blocks written progress progressChan, stopProgressPrinter := zk.ProgressPrinterWithoutTotal(fmt.Sprintf("[%s] Downloaded blocks from datastream progress", logPrefix)) @@ -212,25 +248,25 @@ func SpawnStageBatches( log.Info(fmt.Sprintf("[%s] Reading blocks from the datastream.", logPrefix)) - unwindFn := func(unwindBlock uint64) error { - return rollback(logPrefix, eriDb, hermezDb, dsQueryClient, unwindBlock, tx, u) + lastProcessedBlockHash, err := eriDb.ReadCanonicalHash(stageProgressBlockNo) + if err != nil { + return fmt.Errorf("failed to read canonical hash for block %d: %w", stageProgressBlockNo, err) } - batchProcessor, err := NewBatchesProcessor(ctx, logPrefix, tx, hermezDb, eriDb, cfg.zkCfg.SyncLimit, cfg.zkCfg.DebugLimit, cfg.zkCfg.DebugStepAfter, cfg.zkCfg.DebugStep, stageProgressBlockNo, stageProgressBatchNo, dsQueryClient, progressChan, cfg.chainConfig, cfg.miningConfig, unwindFn) + batchProcessor, err := NewBatchesProcessor(ctx, logPrefix, tx, hermezDb, eriDb, cfg.zkCfg.SyncLimit, cfg.zkCfg.DebugLimit, cfg.zkCfg.DebugStepAfter, cfg.zkCfg.DebugStep, stageProgressBlockNo, stageProgressBatchNo, lastProcessedBlockHash, dsQueryClient, progressChan, cfg.chainConfig, cfg.miningConfig, unwindFn) if err != nil { return err } // start routine to download blocks and push them in a channel - dsClientRunner := NewDatastreamClientRunner(cfg.dsClient, logPrefix) + dsClientRunner := NewDatastreamClientRunner(dsQueryClient, logPrefix) dsClientRunner.StartRead() defer dsClientRunner.StopRead() - entryChan := cfg.dsClient.GetEntryChan() + entryChan := dsQueryClient.GetEntryChan() - prevAmountBlocksWritten, restartDatastreamBlock := uint64(0), uint64(0) + prevAmountBlocksWritten := uint64(0) endLoop := false - unwound := false for { // get batch start and use to update forkid @@ -240,40 +276,19 @@ func SpawnStageBatches( // if both download routine stopped and channel empty - stop loop select { case entry := <-*entryChan: - if restartDatastreamBlock, endLoop, unwound, err = batchProcessor.ProcessEntry(entry); err != nil { + if endLoop, err = batchProcessor.ProcessEntry(entry); err != nil { + // if we triggered an unwind somewhere we need to return from the stage + if err == ErrorTriggeredUnwind { + return nil + } return err } dsClientProgress.Store(batchProcessor.LastBlockHeight()) - - if restartDatastreamBlock > 0 { - if err = dsClientRunner.RestartReadFromBlock(restartDatastreamBlock); err != nil { - return err - } - } - - // if we triggered an unwind somewhere we need to return from the stage - if unwound { - return nil - } case <-ctx.Done(): log.Warn(fmt.Sprintf("[%s] Context done", logPrefix)) endLoop = true default: - time.Sleep(1 * time.Second) - } - - // if ds end reached check again for new blocks in the stream - // if there are too many new blocks get them as well before ending stage - if batchProcessor.LastBlockHeight() >= highestDSL2Block.L2BlockNumber { - newLatestDSL2Block, err := dsQueryClient.GetLatestL2Block() - if err != nil { - return fmt.Errorf("failed to retrieve the latest datastream l2 block: %w", err) - } - if newLatestDSL2Block.L2BlockNumber > highestDSL2Block.L2BlockNumber+NEW_BLOCKS_ON_DS_LIMIT { - highestDSL2Block = newLatestDSL2Block - } else { - endLoop = true - } + time.Sleep(10 * time.Millisecond) } if endLoop { @@ -606,25 +621,33 @@ func PruneBatchesStage(s *stagedsync.PruneState, tx kv.RwTx, cfg BatchesCfg, ctx // 1. queries the latest common ancestor for datastream and db, // 2. resolves the unwind block (as the latest block in the previous batch, comparing to the found ancestor block) // 3. triggers the unwinding -func rollback(logPrefix string, eriDb *erigon_db.ErigonDb, hermezDb *hermez_db.HermezDb, - dsQueryClient DatastreamClient, latestDSBlockNum uint64, tx kv.RwTx, u stagedsync.Unwinder) error { +func rollback( + logPrefix string, + eriDb *erigon_db.ErigonDb, + hermezDb *hermez_db.HermezDb, + dsQueryClient DatastreamClient, + latestDSBlockNum uint64, + tx kv.RwTx, + u stagedsync.Unwinder, +) (uint64, error) { ancestorBlockNum, ancestorBlockHash, err := findCommonAncestor(eriDb, hermezDb, dsQueryClient, latestDSBlockNum) if err != nil { - return err + return 0, err } log.Debug(fmt.Sprintf("[%s] The common ancestor for datastream and db is block %d (%s)", logPrefix, ancestorBlockNum, ancestorBlockHash)) unwindBlockNum, unwindBlockHash, batchNum, err := getUnwindPoint(eriDb, hermezDb, ancestorBlockNum, ancestorBlockHash) if err != nil { - return err + return 0, err } if err = stages.SaveStageProgress(tx, stages.HighestSeenBatchNumber, batchNum-1); err != nil { - return err + return 0, err } log.Warn(fmt.Sprintf("[%s] Unwinding to block %d (%s)", logPrefix, unwindBlockNum, unwindBlockHash)) + u.UnwindTo(unwindBlockNum, stagedsync.BadBlock(unwindBlockHash, fmt.Errorf("unwind to block %d", unwindBlockNum))) - return nil + return unwindBlockNum, nil } // findCommonAncestor searches the latest common ancestor block number and hash between the data stream and the local db. @@ -651,21 +674,21 @@ func findCommonAncestor( } midBlockNum := (startBlockNum + endBlockNum) / 2 - midBlockDataStream, errCode, err := dsClient.GetL2BlockByNumber(midBlockNum) + midBlockDataStream, err := dsClient.GetL2BlockByNumber(midBlockNum) if err != nil && // the required block might not be in the data stream, so ignore that error - errCode != types.CmdErrBadFromBookmark { - return 0, emptyHash, err + !errors.Is(err, types.ErrBadFromBookmark) { + return 0, emptyHash, fmt.Errorf("GetL2BlockByNumber: failed to get l2 block %d from datastream: %w", midBlockNum, err) } midBlockDbHash, err := db.ReadCanonicalHash(midBlockNum) if err != nil { - return 0, emptyHash, err + return 0, emptyHash, fmt.Errorf("ReadCanonicalHash: failed to get canonical hash for block %d: %w", midBlockNum, err) } dbBatchNum, err := hermezDb.GetBatchNoByL2Block(midBlockNum) if err != nil { - return 0, emptyHash, err + return 0, emptyHash, fmt.Errorf("GetBatchNoByL2Block: failed to get batch number for block %d: %w", midBlockNum, err) } if midBlockDataStream != nil && @@ -701,37 +724,34 @@ func getUnwindPoint(eriDb erigon_db.ReadOnlyErigonDb, hermezDb state.ReadOnlyHer unwindBlockNum, _, err := hermezDb.GetHighestBlockInBatch(batchNum - 1) if err != nil { - return 0, emptyHash, 0, err + return 0, emptyHash, 0, fmt.Errorf("GetHighestBlockInBatch: batch %d: %w", batchNum-1, err) } unwindBlockHash, err := eriDb.ReadCanonicalHash(unwindBlockNum) if err != nil { - return 0, emptyHash, 0, err + return 0, emptyHash, 0, fmt.Errorf("ReadCanonicalHash: block %d: %w", unwindBlockNum, err) } return unwindBlockNum, unwindBlockHash, batchNum, nil } // newStreamClient instantiates new datastreamer client and starts it. -func newStreamClient(ctx context.Context, cfg BatchesCfg, latestForkId uint64) (DatastreamClient, error) { - var ( - dsClient DatastreamClient - err error - ) - +func newStreamClient(ctx context.Context, cfg BatchesCfg, latestForkId uint64) (dsClient DatastreamClient, stopFn func(), err error) { if cfg.dsQueryClientCreator != nil { dsClient, err = cfg.dsQueryClientCreator(ctx, cfg.zkCfg, latestForkId) if err != nil { - return nil, fmt.Errorf("failed to create a datastream client. Reason: %w", err) + return nil, nil, fmt.Errorf("dsQueryClientCreator: %w", err) + } + if err := dsClient.Start(); err != nil { + return nil, nil, fmt.Errorf("dsClient.Start: %w", err) + } + stopFn = func() { + dsClient.Stop() } } else { - zkCfg := cfg.zkCfg - dsClient = client.NewClient(ctx, zkCfg.L2DataStreamerUrl, zkCfg.DatastreamVersion, zkCfg.L2DataStreamerTimeout, uint16(latestForkId)) - } - - if err := dsClient.Start(); err != nil { - return nil, fmt.Errorf("failed to start a datastream client. Reason: %w", err) + dsClient = cfg.dsClient + stopFn = func() {} } - return dsClient, nil + return dsClient, stopFn, nil } diff --git a/zk/stages/stage_batches_datastream.go b/zk/stages/stage_batches_datastream.go index fefd7c17188..a1f9926e067 100644 --- a/zk/stages/stage_batches_datastream.go +++ b/zk/stages/stage_batches_datastream.go @@ -4,7 +4,6 @@ import ( "fmt" "math/rand" "sync/atomic" - "time" "github.com/ledgerwatch/log/v3" ) @@ -24,10 +23,13 @@ func NewDatastreamClientRunner(dsClient DatastreamClient, logPrefix string) *Dat } func (r *DatastreamClientRunner) StartRead() error { + r.dsClient.RenewEntryChannel() if r.isReading.Load() { return fmt.Errorf("tried starting datastream client runner thread while another is running") } + r.stopRunner.Store(false) + go func() { routineId := rand.Intn(1000000) @@ -37,27 +39,8 @@ func (r *DatastreamClientRunner) StartRead() error { r.isReading.Store(true) defer r.isReading.Store(false) - for { - if r.stopRunner.Load() { - log.Info(fmt.Sprintf("[%s] Downloading L2Blocks routine stopped intentionally", r.logPrefix)) - break - } - - // start routine to download blocks and push them in a channel - if !r.dsClient.GetStreamingAtomic().Load() { - log.Info(fmt.Sprintf("[%s] Starting stream", r.logPrefix)) - // this will download all blocks from datastream and push them in a channel - // if no error, break, else continue trying to get them - // Create bookmark - - if err := r.connectDatastream(); err != nil { - log.Error(fmt.Sprintf("[%s] Error connecting to datastream", r.logPrefix), "error", err) - } - - if err := r.dsClient.ReadAllEntriesToChannel(); err != nil { - log.Error(fmt.Sprintf("[%s] Error downloading blocks from datastream", r.logPrefix), "error", err) - } - } + if err := r.dsClient.ReadAllEntriesToChannel(); err != nil { + log.Warn(fmt.Sprintf("[%s] Error downloading blocks from datastream", r.logPrefix), "error", err) } }() @@ -65,44 +48,6 @@ func (r *DatastreamClientRunner) StartRead() error { } func (r *DatastreamClientRunner) StopRead() { - r.stopRunner.Store(true) -} - -func (r *DatastreamClientRunner) RestartReadFromBlock(fromBlock uint64) error { - r.StopRead() - - //wait for the old routine to be finished before continuing - counter := 0 - for { - if !r.isReading.Load() { - break - } - counter++ - if counter > 100 { - return fmt.Errorf("failed to stop reader routine correctly") - } - time.Sleep(100 * time.Millisecond) - } - - // set new block - r.dsClient.GetProgressAtomic().Store(fromBlock) - - log.Info(fmt.Sprintf("[%s] Restarting datastream from block %d", r.logPrefix, fromBlock)) - - return r.StartRead() -} - -func (r *DatastreamClientRunner) connectDatastream() (err error) { - var connected bool - for i := 0; i < 5; i++ { - if connected, err = r.dsClient.EnsureConnected(); err != nil { - log.Error(fmt.Sprintf("[%s] Error connecting to datastream", r.logPrefix), "error", err) - continue - } - if connected { - return nil - } - } - - return fmt.Errorf("failed to connect to datastream") + r.stopRunner.Swap(true) + r.dsClient.StopReadingToChannel() } diff --git a/zk/stages/stage_batches_processor.go b/zk/stages/stage_batches_processor.go index 404812f9320..0cb20853111 100644 --- a/zk/stages/stage_batches_processor.go +++ b/zk/stages/stage_batches_processor.go @@ -6,7 +6,6 @@ import ( "fmt" "math/big" "sync/atomic" - "time" "github.com/ledgerwatch/erigon-lib/chain" "github.com/ledgerwatch/erigon-lib/common" @@ -21,6 +20,11 @@ import ( "github.com/ledgerwatch/log/v3" ) +var ( + ErrorTriggeredUnwind = errors.New("triggered unwind") + ErrorSkippedBlock = errors.New("skipped block") +) + type ProcessorErigonDb interface { WriteHeader(batchNo *big.Int, blockHash common.Hash, stateRoot, txHash, parentHash common.Hash, coinbase common.Address, ts, gasLimit uint64, chainConfig *chain.Config) (*ethTypes.Header, error) WriteBody(batchNo *big.Int, headerHash common.Hash, txs []ethTypes.Transaction) error @@ -34,6 +38,7 @@ type ProcessorHermezDb interface { WriteEffectiveGasPricePercentage(txHash common.Hash, effectiveGasPricePercentage uint8) error WriteStateRoot(l2BlockNumber uint64, rpcRoot common.Hash) error + GetStateRoot(l2BlockNumber uint64) (common.Hash, error) CheckGlobalExitRootWritten(ger common.Hash) (bool, error) WriteBlockGlobalExitRoot(l2BlockNo uint64, ger common.Hash) error @@ -51,7 +56,7 @@ type ProcessorHermezDb interface { } type DsQueryClient interface { - GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, int, error) + GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, error) GetProgressAtomic() *atomic.Uint64 } @@ -66,20 +71,21 @@ type BatchesProcessor struct { debugStepAfter, debugStep, stageProgressBlockNo, - lastForkId, highestHashableL2BlockNo, - lastBlockHeight, + lastForkId uint64 + highestL1InfoTreeIndex uint32 + dsQueryClient DsQueryClient + progressChan chan uint64 + unwindFn func(uint64) (uint64, error) + highestSeenBatchNo, + lastBlockHeight, blocksWritten, highestVerifiedBatch uint64 - highestL1InfoTreeIndex uint32 lastBlockRoot, lastBlockHash common.Hash - dsQueryClient DsQueryClient - progressChan chan uint64 - unwindFn func(uint64) error - chainConfig *chain.Config - miningConfig *params.MiningConfig + chainConfig *chain.Config + miningConfig *params.MiningConfig } func NewBatchesProcessor( @@ -89,11 +95,12 @@ func NewBatchesProcessor( hermezDb ProcessorHermezDb, eriDb ProcessorErigonDb, syncBlockLimit, debugBlockLimit, debugStepAfter, debugStep, stageProgressBlockNo, stageProgressBatchNo uint64, + lastProcessedBlockHash common.Hash, dsQueryClient DsQueryClient, progressChan chan uint64, chainConfig *chain.Config, miningConfig *params.MiningConfig, - unwindFn func(uint64) error, + unwindFn func(uint64) (uint64, error), ) (*BatchesProcessor, error) { highestVerifiedBatch, err := stages.GetStageProgress(tx, stages.L1VerificationsBatchNo) if err != nil { @@ -121,7 +128,7 @@ func NewBatchesProcessor( highestVerifiedBatch: highestVerifiedBatch, dsQueryClient: dsQueryClient, progressChan: progressChan, - lastBlockHash: emptyHash, + lastBlockHash: lastProcessedBlockHash, lastBlockRoot: emptyHash, lastForkId: lastForkId, unwindFn: unwindFn, @@ -130,18 +137,20 @@ func NewBatchesProcessor( }, nil } -func (p *BatchesProcessor) ProcessEntry(entry interface{}) (rollbackBlock uint64, endLoop bool, unwound bool, err error) { +func (p *BatchesProcessor) ProcessEntry(entry interface{}) (endLoop bool, err error) { switch entry := entry.(type) { case *types.BatchStart: - return 0, false, false, p.processBatchStartEntry(entry) + return false, p.processBatchStartEntry(entry) case *types.BatchEnd: - return 0, false, false, p.processBatchEndEntry(entry) + return false, p.processBatchEndEntry(entry) case *types.FullL2Block: return p.processFullBlock(entry) case *types.GerUpdate: - return 0, false, false, p.processGerUpdate(entry) + return false, p.processGerUpdate(entry) + case nil: // we use nil to indicate the end of stream read + return true, nil default: - return 0, false, false, fmt.Errorf("unknown entry type: %T", entry) + return false, fmt.Errorf("unknown entry type: %T", entry) } } @@ -153,7 +162,7 @@ func (p *BatchesProcessor) processGerUpdate(gerUpdate *types.GerUpdate) error { // NB: we won't get these post Etrog (fork id 7) if err := p.hermezDb.WriteBatchGlobalExitRoot(gerUpdate.BatchNumber, gerUpdate); err != nil { - return fmt.Errorf("write batch global exit root error: %v", err) + return fmt.Errorf("write batch global exit root error: %w", err) } return nil @@ -187,113 +196,102 @@ func (p *BatchesProcessor) processBatchStartEntry(batchStart *types.BatchStart) return nil } -func (p *BatchesProcessor) processFullBlock(blockEntry *types.FullL2Block) (restartStreamFromBlock uint64, endLoop bool, unwound bool, err error) { +func (p *BatchesProcessor) unwind(blockNum uint64) (uint64, error) { + unwindBlock, err := p.unwindFn(blockNum) + if err != nil { + return 0, err + } + + return unwindBlock, nil +} + +func (p *BatchesProcessor) processFullBlock(blockEntry *types.FullL2Block) (endLoop bool, err error) { log.Debug(fmt.Sprintf("[%s] Retrieved %d (%s) block from stream", p.logPrefix, blockEntry.L2BlockNumber, blockEntry.L2Blockhash.String())) if p.syncBlockLimit > 0 && blockEntry.L2BlockNumber >= p.syncBlockLimit { // stop the node going into a crazy loop - time.Sleep(2 * time.Second) - return 0, true, false, nil + log.Info(fmt.Sprintf("[%s] Sync block limit reached, stopping stage", p.logPrefix), "blockLimit", p.syncBlockLimit, "block", blockEntry.L2BlockNumber) + return true, nil } - // handle batch boundary changes - we do this here instead of reading the batch start channel because - // channels can be read in random orders which then creates problems in detecting fork changes during - // execution if blockEntry.BatchNumber > p.highestSeenBatchNo && p.lastForkId < blockEntry.ForkId { if blockEntry.ForkId >= uint64(chain.ImpossibleForkId) { - message := fmt.Sprintf("unsupported fork id %v received from the data stream", blockEntry.ForkId) + message := fmt.Sprintf("unsupported fork id %d received from the data stream", blockEntry.ForkId) panic(message) } if err = stages.SaveStageProgress(p.tx, stages.ForkId, blockEntry.ForkId); err != nil { - return 0, false, false, fmt.Errorf("save stage progress error: %v", err) + return false, fmt.Errorf("save stage progress error: %w", err) } p.lastForkId = blockEntry.ForkId if err = p.hermezDb.WriteForkId(blockEntry.BatchNumber, blockEntry.ForkId); err != nil { - return 0, false, false, fmt.Errorf("write fork id error: %v", err) + return false, fmt.Errorf("write fork id error: %w", err) } // NOTE (RPC): avoided use of 'writeForkIdBlockOnce' by reading instead batch by forkId, and then lowest block number in batch } // ignore genesis or a repeat of the last block if blockEntry.L2BlockNumber == 0 { - return 0, false, false, nil + return false, nil } // skip but warn on already processed blocks if blockEntry.L2BlockNumber <= p.stageProgressBlockNo { - if blockEntry.L2BlockNumber < p.stageProgressBlockNo { + dbBatchNum, err := p.hermezDb.GetBatchNoByL2Block(blockEntry.L2BlockNumber) + if err != nil { + return false, err + } + + if blockEntry.L2BlockNumber == p.stageProgressBlockNo && dbBatchNum == blockEntry.BatchNumber { // only warn if the block is very old, we expect the very latest block to be requested // when the stage is fired up for the first time log.Warn(fmt.Sprintf("[%s] Skipping block %d, already processed", p.logPrefix, blockEntry.L2BlockNumber)) + return false, nil } - dbBatchNum, err := p.hermezDb.GetBatchNoByL2Block(blockEntry.L2BlockNumber) - if err != nil { - return 0, false, false, err - } - - if blockEntry.BatchNumber > dbBatchNum { - // if the batch number is higher than the one we know about, it means that we need to trigger an unwinding of blocks - log.Warn(fmt.Sprintf("[%s] Batch number mismatch detected. Triggering unwind...", p.logPrefix), - "block", blockEntry.L2BlockNumber, "ds batch", blockEntry.BatchNumber, "db batch", dbBatchNum) - if err := p.unwindFn(blockEntry.L2BlockNumber); err != nil { - return blockEntry.L2BlockNumber, false, false, err - } + // if the block is older or the batch number is different, we need to unwind because the block has definately changed + log.Warn(fmt.Sprintf("[%s] Block already processed. Triggering unwind...", p.logPrefix), + "block", blockEntry.L2BlockNumber, "ds batch", blockEntry.BatchNumber, "db batch", dbBatchNum) + if _, err := p.unwind(blockEntry.L2BlockNumber); err != nil { + return false, err } - return 0, false, false, nil + return false, ErrorTriggeredUnwind } var dbParentBlockHash common.Hash if blockEntry.L2BlockNumber > 1 { dbParentBlockHash, err = p.eriDb.ReadCanonicalHash(p.lastBlockHeight) if err != nil { - return 0, false, false, fmt.Errorf("failed to retrieve parent block hash for datastream block %d: %w", + return false, fmt.Errorf("failed to retrieve parent block hash for datastream block %d: %w", blockEntry.L2BlockNumber, err) } } - dsParentBlockHash := p.lastBlockHash - dsBlockNumber := p.lastBlockHeight - if dsParentBlockHash == emptyHash { - parentBlockDS, _, err := p.dsQueryClient.GetL2BlockByNumber(blockEntry.L2BlockNumber - 1) - if err != nil { - return 0, false, false, err - } - - if parentBlockDS != nil { - dsParentBlockHash = parentBlockDS.L2Blockhash - if parentBlockDS.L2BlockNumber > 0 { - dsBlockNumber = parentBlockDS.L2BlockNumber - } - } - } - - if blockEntry.L2BlockNumber > 1 && dbParentBlockHash != dsParentBlockHash { + if p.lastBlockHeight > 0 && dbParentBlockHash != p.lastBlockHash { // unwind/rollback blocks until the latest common ancestor block log.Warn(fmt.Sprintf("[%s] Parent block hashes mismatch on block %d. Triggering unwind...", p.logPrefix, blockEntry.L2BlockNumber), "db parent block hash", dbParentBlockHash, - "ds parent block number", dsBlockNumber, - "ds parent block hash", dsParentBlockHash, + "ds parent block number", p.lastBlockHeight, + "ds parent block hash", p.lastBlockHash, "ds parent block number", blockEntry.L2BlockNumber-1, ) //parent blockhash is wrong, so unwind to it, then restat stream from it to get the correct one - if err = p.unwindFn(blockEntry.L2BlockNumber - 1); err != nil { - return 0, false, false, err + if _, err := p.unwind(blockEntry.L2BlockNumber - 1); err != nil { + return false, err } - return blockEntry.L2BlockNumber - 1, false, true, nil + return false, ErrorTriggeredUnwind } - // unwind if we already have this block - could be a re-sequence event + // unwind if we already have this block if blockEntry.L2BlockNumber < p.lastBlockHeight+1 { - log.Warn(fmt.Sprintf("[%s] Skipping block %d, already processed, triggering unwind...", p.logPrefix, blockEntry.L2BlockNumber)) - if err = p.unwindFn(blockEntry.L2BlockNumber); err != nil { - return 0, false, false, err + log.Warn(fmt.Sprintf("[%s] Block %d, already processed unwinding...", p.logPrefix, blockEntry.L2BlockNumber)) + if _, err := p.unwind(blockEntry.L2BlockNumber); err != nil { + return false, err } - return blockEntry.L2BlockNumber, false, true, nil + + return false, ErrorTriggeredUnwind } // check for sequential block numbers if blockEntry.L2BlockNumber > p.lastBlockHeight+1 { - log.Warn(fmt.Sprintf("[%s] Stream skipped ahead, restarting datastream to block %d", p.logPrefix, blockEntry.L2BlockNumber)) - return p.lastBlockHeight + 1, false, false, nil + return false, ErrorSkippedBlock } // batch boundary - record the highest hashable block number (last block in last full batch) @@ -329,13 +327,13 @@ func (p *BatchesProcessor) processFullBlock(blockEntry *types.FullL2Block) (rest // first block in the loop so read the parent hash previousHash, err := p.eriDb.ReadCanonicalHash(blockEntry.L2BlockNumber - 1) if err != nil { - return 0, false, false, fmt.Errorf("failed to get genesis header: %v", err) + return false, fmt.Errorf("failed to get genesis header: %w", err) } blockEntry.ParentHash = previousHash } if err := p.writeL2Block(blockEntry); err != nil { - return 0, false, false, fmt.Errorf("writeL2Block error: %v", err) + return false, fmt.Errorf("writeL2Block error: %w", err) } p.dsQueryClient.GetProgressAtomic().Store(blockEntry.L2BlockNumber) @@ -355,7 +353,7 @@ func (p *BatchesProcessor) processFullBlock(blockEntry *types.FullL2Block) (rest if p.debugBlockLimit == 0 { endLoop = false } - return 0, endLoop, false, nil + return endLoop, nil } // writeL2Block writes L2Block to ErigonDb and HermezDb @@ -366,20 +364,20 @@ func (p *BatchesProcessor) writeL2Block(l2Block *types.FullL2Block) error { for _, transaction := range l2Block.L2Txs { ltx, _, err := txtype.DecodeTx(transaction.Encoded, transaction.EffectiveGasPricePercentage, l2Block.ForkId) if err != nil { - return fmt.Errorf("decode tx error: %v", err) + return fmt.Errorf("decode tx error: %w", err) } txs = append(txs, ltx) if err := p.hermezDb.WriteEffectiveGasPricePercentage(ltx.Hash(), transaction.EffectiveGasPricePercentage); err != nil { - return fmt.Errorf("write effective gas price percentage error: %v", err) + return fmt.Errorf("write effective gas price percentage error: %w", err) } if err := p.hermezDb.WriteStateRoot(l2Block.L2BlockNumber, transaction.IntermediateStateRoot); err != nil { - return fmt.Errorf("write rpc root error: %v", err) + return fmt.Errorf("write rpc root error: %w", err) } if err := p.hermezDb.WriteIntermediateTxStateRoot(l2Block.L2BlockNumber, ltx.Hash(), transaction.IntermediateStateRoot); err != nil { - return fmt.Errorf("write rpc root error: %v", err) + return fmt.Errorf("write rpc root error: %w", err) } } txCollection := ethTypes.Transactions(txs) @@ -393,7 +391,7 @@ func (p *BatchesProcessor) writeL2Block(l2Block *types.FullL2Block) error { } if _, err := p.eriDb.WriteHeader(bn, l2Block.L2Blockhash, l2Block.StateRoot, txHash, l2Block.ParentHash, l2Block.Coinbase, uint64(l2Block.Timestamp), gasLimit, p.chainConfig); err != nil { - return fmt.Errorf("write header error: %v", err) + return fmt.Errorf("write header error: %w", err) } didStoreGer := false @@ -402,16 +400,16 @@ func (p *BatchesProcessor) writeL2Block(l2Block *types.FullL2Block) error { if l2Block.GlobalExitRoot != emptyHash { gerWritten, err := p.hermezDb.CheckGlobalExitRootWritten(l2Block.GlobalExitRoot) if err != nil { - return fmt.Errorf("get global exit root error: %v", err) + return fmt.Errorf("get global exit root error: %w", err) } if !gerWritten { if err := p.hermezDb.WriteBlockGlobalExitRoot(l2Block.L2BlockNumber, l2Block.GlobalExitRoot); err != nil { - return fmt.Errorf("write block global exit root error: %v", err) + return fmt.Errorf("write block global exit root error: %w", err) } if err := p.hermezDb.WriteGlobalExitRoot(l2Block.GlobalExitRoot); err != nil { - return fmt.Errorf("write global exit root error: %v", err) + return fmt.Errorf("write global exit root error: %w", err) } didStoreGer = true } @@ -419,7 +417,7 @@ func (p *BatchesProcessor) writeL2Block(l2Block *types.FullL2Block) error { if l2Block.L1BlockHash != emptyHash { if err := p.hermezDb.WriteBlockL1BlockHash(l2Block.L2BlockNumber, l2Block.L1BlockHash); err != nil { - return fmt.Errorf("write block global exit root error: %v", err) + return fmt.Errorf("write block global exit root error: %w", err) } } @@ -457,19 +455,19 @@ func (p *BatchesProcessor) writeL2Block(l2Block *types.FullL2Block) error { } if err := p.eriDb.WriteBody(bn, l2Block.L2Blockhash, txs); err != nil { - return fmt.Errorf("write body error: %v", err) + return fmt.Errorf("write body error: %w", err) } if err := p.hermezDb.WriteForkId(l2Block.BatchNumber, l2Block.ForkId); err != nil { - return fmt.Errorf("write block batch error: %v", err) + return fmt.Errorf("write block batch error: %w", err) } if err := p.hermezDb.WriteForkIdBlockOnce(l2Block.ForkId, l2Block.L2BlockNumber); err != nil { - return fmt.Errorf("write fork id block error: %v", err) + return fmt.Errorf("write fork id block error: %w", err) } if err := p.hermezDb.WriteBlockBatch(l2Block.L2BlockNumber, l2Block.BatchNumber); err != nil { - return fmt.Errorf("write block batch error: %v", err) + return fmt.Errorf("write block batch error: %w", err) } return nil diff --git a/zk/stages/stage_batches_test.go b/zk/stages/stage_batches_test.go index 9f2578ae5dc..6299f75cc39 100644 --- a/zk/stages/stage_batches_test.go +++ b/zk/stages/stage_batches_test.go @@ -79,7 +79,6 @@ func TestUnwindBatches(t *testing.T) { err = SpawnStageBatches(s, u, ctx, tx, cfg) require.NoError(t, err) tx.Commit() - tx2 := memdb.BeginRw(t, db1) // unwind to zero and check if there is any data in the tables diff --git a/zk/stages/test_utils.go b/zk/stages/test_utils.go index af9ec190587..f24557522b4 100644 --- a/zk/stages/test_utils.go +++ b/zk/stages/test_utils.go @@ -11,6 +11,7 @@ type TestDatastreamClient struct { gerUpdates []types.GerUpdate lastWrittenTimeAtomic atomic.Int64 streamingAtomic atomic.Bool + stopReadingToChannel atomic.Bool progress atomic.Uint64 entriesChan chan interface{} errChan chan error @@ -28,10 +29,6 @@ func NewTestDatastreamClient(fullL2Blocks []types.FullL2Block, gerUpdates []type return client } -func (c *TestDatastreamClient) EnsureConnected() (bool, error) { - return true, nil -} - func (c *TestDatastreamClient) ReadAllEntriesToChannel() error { c.streamingAtomic.Store(true) defer c.streamingAtomic.Swap(false) @@ -43,9 +40,24 @@ func (c *TestDatastreamClient) ReadAllEntriesToChannel() error { c.entriesChan <- &c.gerUpdates[i] } + c.entriesChan <- nil // needed to stop processing + + for { + if c.stopReadingToChannel.Load() { + break + } + } + return nil } +func (c *TestDatastreamClient) RenewEntryChannel() { +} + +func (c *TestDatastreamClient) StopReadingToChannel() { + c.stopReadingToChannel.Store(true) +} + func (c *TestDatastreamClient) GetEntryChan() *chan interface{} { return &c.entriesChan } @@ -54,14 +66,14 @@ func (c *TestDatastreamClient) GetErrChan() chan error { return c.errChan } -func (c *TestDatastreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, int, error) { +func (c *TestDatastreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, error) { for _, l2Block := range c.fullL2Blocks { if l2Block.L2BlockNumber == blockNum { - return &l2Block, types.CmdErrOK, nil + return &l2Block, nil } } - return nil, -1, nil + return nil, nil } func (c *TestDatastreamClient) GetLatestL2Block() (*types.FullL2Block, error) { @@ -75,10 +87,6 @@ func (c *TestDatastreamClient) GetLastWrittenTimeAtomic() *atomic.Int64 { return &c.lastWrittenTimeAtomic } -func (c *TestDatastreamClient) GetStreamingAtomic() *atomic.Bool { - return &c.streamingAtomic -} - func (c *TestDatastreamClient) GetProgressAtomic() *atomic.Uint64 { return &c.progress } diff --git a/zk/tests/unwinds/unwind.sh b/zk/tests/unwinds/unwind.sh index d83eeebbdca..84b5f436180 100755 --- a/zk/tests/unwinds/unwind.sh +++ b/zk/tests/unwinds/unwind.sh @@ -60,19 +60,9 @@ go run ./cmd/integration state_stages_zkevm \ # now get a dump of the datadir at this point go run ./cmd/hack --action=dumpAll --chaindata="$dataPath/rpc-datadir/chaindata" --output="$dataPath/phase1-dump2" -# now sync again -timeout $secondTimeout ./build/bin/cdk-erigon \ - --datadir="$dataPath/rpc-datadir" \ - --config=./dynamic-integration8.yaml \ - --zkevm.sync-limit=${stopBlock} - -# dump the data again into the post folder -go run ./cmd/hack --action=dumpAll --chaindata="$dataPath/rpc-datadir/chaindata" --output="$dataPath/phase2-dump2" mkdir -p "$dataPath/phase1-diffs/pre" mkdir -p "$dataPath/phase1-diffs/post" -mkdir -p "$dataPath/phase2-diffs/pre" -mkdir -p "$dataPath/phase2-diffs/post" # iterate over the files in the pre-dump folder for file in $(ls $dataPath/phase1-dump1); do @@ -84,14 +74,26 @@ for file in $(ls $dataPath/phase1-dump1); do echo "No difference found in $filename" else if [ "$filename" = "Code.txt" ] || [ "$filename" = "HashedCodeHash.txt" ] || [ "$filename" = "hermez_l1Sequences.txt" ] || [ "$filename" = "hermez_l1Verifications.txt" ] || [ "$filename" = "HermezSmt.txt" ] || [ "$filename" = "PlainCodeHash.txt" ] || [ "$filename" = "SyncStage.txt" ] || [ "$filename" = "BadHeaderNumber.txt" ]; then - echo "Expected differences in $filename" + echo "Phase 1 Expected differences in $filename" else - echo "Unexpected differences in $filename" + echo "Phase 1 Unexpected differences in $filename" exit 1 fi fi done +# now sync again +timeout $secondTimeout ./build/bin/cdk-erigon \ + --datadir="$dataPath/rpc-datadir" \ + --config=./dynamic-integration8.yaml \ + --zkevm.sync-limit=${stopBlock} + +# dump the data again into the post folder +go run ./cmd/hack --action=dumpAll --chaindata="$dataPath/rpc-datadir/chaindata" --output="$dataPath/phase2-dump2" + +mkdir -p "$dataPath/phase2-diffs/pre" +mkdir -p "$dataPath/phase2-diffs/post" + # iterate over the files in the pre-dump folder for file in $(ls $dataPath/phase2-dump1); do # get the filename @@ -99,12 +101,12 @@ for file in $(ls $dataPath/phase2-dump1); do # diff the files and if there is a difference found copy the pre and post files into the diffs folder if cmp -s $dataPath/phase2-dump1/$filename $dataPath/phase2-dump2/$filename; then - echo "No difference found in $filename" + echo "Phase 2 No difference found in $filename" else if [ "$filename" = "BadHeaderNumber.txt" ]; then - echo "Expected differences in $filename" + echo "Phase 2 Expected differences in $filename" else - echo "Unexpected differences in $filename" + echo "Phase 2 Unexpected differences in $filename" exit 2 fi fi