diff --git a/zk/datastream/client/stream_client.go b/zk/datastream/client/stream_client.go index 4d5b61e12a9..fa52fdd07a1 100644 --- a/zk/datastream/client/stream_client.go +++ b/zk/datastream/client/stream_client.go @@ -106,17 +106,24 @@ func (c *StreamClient) GetEntryNumberLimit() uint64 { func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (fullBLock *types.FullL2Block, errorCode int, err error) { var ( socketErr error = nil - connected bool = true + connected bool = c.conn != nil ) for { + + select { + case <-c.ctx.Done(): + log.Warn("[Datastream client] Context done - stopping") + return nil, errorCode, nil + default: + } if connected { if fullBLock, errorCode, err, socketErr = c.getL2BlockByNumber(blockNum); err != nil { return nil, errorCode, err } - } - if socketErr == nil { - break + if socketErr == nil { + break + } } time.Sleep(1 * time.Second) connected = c.handleSocketError(socketErr) @@ -178,18 +185,25 @@ func (c *StreamClient) getL2BlockByNumber(blockNum uint64) (l2Block *types.FullL func (c *StreamClient) GetLatestL2Block() (l2Block *types.FullL2Block, err error) { var ( socketErr error = nil - connected bool = true + connected bool = c.conn != nil ) for { + select { + case <-c.ctx.Done(): + log.Warn("[Datastream client] Context done - stopping") + return nil, nil + default: + } if connected { if l2Block, err, socketErr = c.getLatestL2Block(); err != nil { return nil, err } + if socketErr == nil { + break + } } - if socketErr == nil { - break - } + time.Sleep(1 * time.Second) connected = c.handleSocketError(socketErr) } @@ -383,17 +397,23 @@ func (c *StreamClient) RenewEntryChannel() { func (c *StreamClient) ReadAllEntriesToChannel() (err error) { var ( socketErr error = nil - connected bool = true + connected bool = c.conn != nil ) for { + select { + case <-c.ctx.Done(): + log.Warn("[Datastream client] Context done - stopping") + return nil + default: + } if connected { if err, socketErr = c.readAllEntriesToChannel(); err != nil { return err } - } - if socketErr == nil { - break + if socketErr == nil { + break + } } time.Sleep(1 * time.Second) @@ -404,7 +424,9 @@ func (c *StreamClient) ReadAllEntriesToChannel() (err error) { } func (c *StreamClient) handleSocketError(socketErr error) bool { - log.Warn(fmt.Sprintf("Socket error: %s", socketErr)) + if socketErr != nil { + log.Warn(fmt.Sprintf("Socket error: %s", socketErr)) + } if err := c.tryReConnect(); err != nil { log.Warn(fmt.Sprintf("Failed to reconnect the datastream client: %s", err)) return false