diff --git a/README.md b/README.md index 09e8f29d..8b9bc42c 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,19 @@ if err != nil { println(val) ``` +### Custom reconnect policy +By default, standard reconnect method will be used - `c.DefaultReconnect(3*time.Second, 3)` which will do 3 tries and wait 3 seconds before each. + +But you can use your own reconnection logic, this library support callbacks, in this case OnDisconnect callback can be used, you can set it like this: +```golang +client.SetOnDisconnect(func(addr, serverKey string) { + // ... do something +}) +``` + ### Features to implement -* ✅ ~~Support cell and slice as arguments for run get method~~ -* Reconnect +* ✅ Support cell and slice as arguments for run get method +* ✅ Reconnect on failure * More tests * Get account state method * Send external query method diff --git a/liteclient/client.go b/liteclient/client.go index 9c9e1c54..ac25e7de 100644 --- a/liteclient/client.go +++ b/liteclient/client.go @@ -7,8 +7,11 @@ import ( "errors" "io" "sync" + "time" ) +type OnDisconnectCallback func(addr, key string) + type LiteResponse struct { TypeID int32 Data []byte @@ -29,15 +32,21 @@ type Client struct { requester chan *LiteRequest activeConnections int32 + onDisconnect func(addr, key string) } var ErrNoActiveConnections = errors.New("no active connections") func NewClient() *Client { - return &Client{ + c := &Client{ activeReqs: map[string]*LiteRequest{}, requester: make(chan *LiteRequest), } + + // default reconnect policy + c.SetOnDisconnect(c.DefaultReconnect(3*time.Second, 3)) + + return c } func (c *Client) Do(ctx context.Context, typeID int32, payload []byte) (*LiteResponse, error) { @@ -91,3 +100,9 @@ func (c *Client) Do(ctx context.Context, typeID int32, payload []byte) (*LiteRes return nil, ctx.Err() } } + +func (c *Client) SetOnDisconnect(cb OnDisconnectCallback) { + c.mx.Lock() + c.onDisconnect = cb + c.mx.Unlock() +} diff --git a/liteclient/connection.go b/liteclient/connection.go index 16337323..e17ce6ec 100644 --- a/liteclient/connection.go +++ b/liteclient/connection.go @@ -1,7 +1,6 @@ package liteclient import ( - "bufio" "context" "crypto/cipher" "crypto/ed25519" @@ -13,13 +12,25 @@ import ( "errors" "fmt" "io" - "log" "math/big" "net" "sync/atomic" "time" ) +type connection struct { + addr string + serverKey string + + connResult chan error + + conn net.Conn + rCrypt cipher.Stream + wCrypt cipher.Stream + + client *Client +} + func (c *Client) Connect(ctx context.Context, addr, serverKey string) error { sKey, err := base64.StdEncoding.DecodeString(serverKey) if err != nil { @@ -32,201 +43,211 @@ func (c *Client) Connect(ctx context.Context, addr, serverKey string) error { return err } + cn := &connection{ + addr: addr, + serverKey: serverKey, + connResult: make(chan error, 1), + client: c, + } + // get timeout if exists till, ok := ctx.Deadline() if !ok { till = time.Now().Add(60 * time.Second) - - // if no timeout passed, we still set it to 60 sec, to not hang forever - timeoutCtx, cancel := context.WithTimeout(ctx, 60*time.Second) - defer cancel() - - ctx = timeoutCtx } - conn, err := net.DialTimeout("tcp", addr, till.Sub(time.Now())) + cn.conn, err = net.DialTimeout("tcp", addr, till.Sub(time.Now())) if err != nil { return err } // we need 160 random bytes from which we will construct encryption keys for packets - dd := make([]byte, 160) - if _, err := io.ReadFull(rand.Reader, dd); err != nil { + rnd := make([]byte, 160) + if _, err := io.ReadFull(rand.Reader, rnd); err != nil { return err } - r := bufio.NewReader(conn) - // build ciphers for incoming packets and for outgoing - rCrypt, err := newCipherCtr(dd[:32], dd[64:80]) + cn.rCrypt, err = newCipherCtr(rnd[:32], rnd[64:80]) if err != nil { return err } - wCrypt, err := newCipherCtr(dd[32:64], dd[80:96]) + cn.wCrypt, err = newCipherCtr(rnd[32:64], rnd[80:96]) if err != nil { return err } - hs, err := c.handshake(dd, privateKey, sKey) + err = cn.handshake(rnd, privateKey, sKey) if err != nil { return err } - // send handshake packet to establish connection - _, err = conn.Write(hs) - if err != nil { - return err - } + connResult := make(chan error, 1) + + go cn.listen(connResult) + + select { + case err = <-connResult: + if err != nil { + return err + } + + // start pings + go cn.startPings(5 * time.Second) + + // now we are ready to accept requests + go cn.serve(connResult) - connEvent := make(chan error, 1) + atomic.AddInt32(&c.activeConnections, 1) + return nil + case <-ctx.Done(): + return ctx.Err() + } +} +func (cn *connection) listen(connResult chan<- error) { var initialized bool - go func() { - // listen for incoming packets - for { - var sz uint32 - sz, err = c.readSize(r, rCrypt) - if err != nil { - break - } + var err error + // listen for incoming packets + for { + var sz uint32 + sz, err = cn.readSize() + if err != nil { + break + } - // should at least have nonce (its 32 bytes) and something else - if sz <= 32 { - err = errors.New("too small size of packet") - break - } + // should at least have nonce (its 32 bytes) and something else + if sz <= 32 { + err = errors.New("too small size of packet") + break + } - var data []byte - data, err = c.readData(r, rCrypt, sz) - if err != nil { - break - } + var data []byte + data, err = cn.readData(sz) + if err != nil { + break + } - checksum := data[len(data)-32:] - data = data[:len(data)-32] + checksum := data[len(data)-32:] + data = data[:len(data)-32] - err = c.validatePacket(data, checksum) - if err != nil { - break - } + err = validatePacket(data, checksum) + if err != nil { + break + } - // skip nonce - data = data[32:] + // skip nonce + data = data[32:] - // response for handshake is empty packet, it means that connection established - if len(data) == 0 { - if !initialized { - initialized = true - connEvent <- nil - } - continue + // response for handshake is empty packet, it means that connection established + if len(data) == 0 { + if !initialized { + initialized = true + connResult <- nil } + continue + } - var typeID int32 - var queryID string - var payload []byte + var typeID int32 + var queryID string + var payload []byte - typeID, queryID, payload, err = c.parseServerResp(data) - if err != nil { - break - } + typeID, queryID, payload, err = parseServerResp(data) + if err != nil { + break + } + + cn.client.mx.RLock() + ch := cn.client.activeReqs[queryID] + cn.client.mx.RUnlock() - c.mx.RLock() - ch := c.activeReqs[queryID] - c.mx.RUnlock() - - if ch != nil { - ch.RespChan <- &LiteResponse{ - TypeID: typeID, - Data: payload, - } - } else { - // handle system + if ch != nil { + ch.RespChan <- &LiteResponse{ + TypeID: typeID, + Data: payload, } } + } - if initialized { - // deactivate connection - atomic.AddInt32(&c.activeConnections, -1) - } + // force close in case of error + _ = cn.conn.Close() - connEvent <- err - _ = conn.Close() - }() + connResult <- err - select { - case err = <-connEvent: - if err != nil { - return err + if initialized { + // deactivate connection + atomic.AddInt32(&cn.client.activeConnections, -1) + + cn.client.mx.RLock() + dis := cn.client.onDisconnect + cn.client.mx.RUnlock() + + if dis != nil { + go dis(cn.addr, cn.serverKey) } + } +} - // start pings - go func() { - for { - time.Sleep(5 * time.Second) - - n, err := rand.Int(rand.Reader, new(big.Int).SetUint64(0xFFFFFFFFFFFFFF)) - if err != nil { - log.Println("rand err", err) - continue - } - - err = c.ping(conn, wCrypt, n.Uint64()) - if err != nil { - log.Println("ping err", err) - continue - } +func (cn *connection) serve(connResult <-chan error) { + for { + var req *LiteRequest + + select { + case req = <-cn.client.requester: + if req == nil { + // handle graceful shutdown + return } - }() + case <-connResult: + // on this stage it can be only connection issue + return + } - // now we are ready to accept requests - go func() { - for { - var req *LiteRequest - - select { - case req = <-c.requester: - if req == nil { - // handle graceful shutdown - return - } - case <-connEvent: - // on this stage it can be only connection issue - return - } - - err := c.queryLiteServer(conn, wCrypt, req.QueryID, req.TypeID, req.Data) - if err != nil { - // TODO: put request back to pool to pickup by next connection - - req.RespChan <- &LiteResponse{ - err: err, - } - - // err can happen only because of network error, anyway close it for some case - _ = conn.Close() - return - } + err := cn.queryLiteServer(req.QueryID, req.TypeID, req.Data) + if err != nil { + req.RespChan <- &LiteResponse{ + err: err, } - }() - atomic.AddInt32(&c.activeConnections, 1) - return nil - case <-ctx.Done(): - return ctx.Err() + // force close in case of error + _ = cn.conn.Close() + + return + } + } +} + +func (cn *connection) startPings(every time.Duration) { + for { + select { + case <-time.After(every): + } + + n, err := rand.Int(rand.Reader, new(big.Int).SetUint64(0xFFFFFFFFFFFFFF)) + if err != nil { + continue + } + + err = cn.ping(n.Uint64()) + if err != nil { + // force close in case of error + _ = cn.conn.Close() + + break + } } } -func (c *Client) readSize(reader io.Reader, cryptor cipher.Stream) (uint32, error) { +func (cn *connection) readSize() (uint32, error) { size := make([]byte, 4) - _, err := reader.Read(size) + _, err := cn.conn.Read(size) if err != nil { return 0, err } // decrypt packet - cryptor.XORKeyStream(size, size) + cn.rCrypt.XORKeyStream(size, size) sz := binary.LittleEndian.Uint32(size) @@ -237,20 +258,20 @@ func (c *Client) readSize(reader io.Reader, cryptor cipher.Stream) (uint32, erro return sz, nil } -func (c *Client) readData(reader io.Reader, cryptor cipher.Stream, sz uint32) ([]byte, error) { +func (cn *connection) readData(sz uint32) ([]byte, error) { var result []byte // read exact number of bytes requested, blocking operation left := int(sz) for left > 0 { data := make([]byte, left) - n, err := reader.Read(data) + n, err := cn.conn.Read(data) if err != nil { return nil, err } data = data[:n] - cryptor.XORKeyStream(data, data) + cn.rCrypt.XORKeyStream(data, data) result = append(result, data...) left -= n @@ -259,7 +280,7 @@ func (c *Client) readData(reader io.Reader, cryptor cipher.Stream, sz uint32) ([ return result, nil } -func (c *Client) send(w io.Writer, cryptor cipher.Stream, data []byte) error { +func (cn *connection) send(data []byte) error { buf := make([]byte, 4) // ADNL packet should have nonce @@ -279,11 +300,11 @@ func (c *Client) send(w io.Writer, cryptor cipher.Stream, data []byte) error { buf = append(buf, checksum...) // encrypt data - cryptor.XORKeyStream(buf, buf) + cn.wCrypt.XORKeyStream(buf, buf) // write all for len(buf) > 0 { - n, err := w.Write(buf) + n, err := cn.conn.Write(buf) if err != nil { return err } @@ -294,7 +315,7 @@ func (c *Client) send(w io.Writer, cryptor cipher.Stream, data []byte) error { return nil } -func (c *Client) handshake(data []byte, ourKey ed25519.PrivateKey, serverKey ed25519.PublicKey) ([]byte, error) { +func (cn *connection) handshake(data []byte, ourKey ed25519.PrivateKey, serverKey ed25519.PublicKey) error { hash := sha256.New() hash.Write(data) checksum := hash.Sum(nil) @@ -303,12 +324,12 @@ func (c *Client) handshake(data []byte, ourKey ed25519.PrivateKey, serverKey ed2 kid, err := keyID(serverKey) if err != nil { - return nil, err + return err } key, err := sharedKey(ourKey, serverKey) if err != nil { - return nil, err + return err } k := []byte{ @@ -325,7 +346,7 @@ func (c *Client) handshake(data []byte, ourKey ed25519.PrivateKey, serverKey ed2 ctr, err := newCipherCtr(k, iv) if err != nil { - return nil, err + return err } // encrypt data @@ -337,18 +358,24 @@ func (c *Client) handshake(data []byte, ourKey ed25519.PrivateKey, serverKey ed2 res = append(res, checksum...) res = append(res, data...) - return res, nil + // send handshake packet to establish connection + _, err = cn.conn.Write(res) + if err != nil { + return err + } + + return nil } -func (c *Client) ping(w io.Writer, cryptor cipher.Stream, qid uint64) error { +func (cn *connection) ping(qid uint64) error { data := make([]byte, 12) binary.LittleEndian.PutUint32(data, uint32(TCPPing)) binary.LittleEndian.PutUint64(data[4:], qid) - return c.send(w, cryptor, data) + return cn.send(data) } -func (c *Client) queryADNL(w io.Writer, cryptor cipher.Stream, qid, payload []byte) error { +func (cn *connection) queryADNL(qid, payload []byte) error { // bypass compiler negative check t := ADNLQuery @@ -370,10 +397,10 @@ func (c *Client) queryADNL(w io.Writer, cryptor cipher.Stream, qid, payload []by data = append(data, make([]byte, 4-left)...) } - return c.send(w, cryptor, data) + return cn.send(data) } -func (c *Client) queryLiteServer(w io.Writer, cryptor cipher.Stream, qid []byte, typeID int32, payload []byte) error { +func (cn *connection) queryLiteServer(qid []byte, typeID int32, payload []byte) error { data := make([]byte, 4) binary.LittleEndian.PutUint32(data, uint32(LiteServerQuery)) @@ -396,5 +423,31 @@ func (c *Client) queryLiteServer(w io.Writer, cryptor cipher.Stream, qid []byte, data = append(data, make([]byte, 4-left)...) } - return c.queryADNL(w, cryptor, qid, data) + return cn.queryADNL(qid, data) +} + +func (c *Client) DefaultReconnect(waitBeforeReconnect time.Duration, maxTries int) OnDisconnectCallback { + var tries int + + var cb OnDisconnectCallback + cb = func(addr, key string) { + ctx, cancel := context.WithTimeout(context.Background(), 7*time.Second) + defer cancel() + + err := c.Connect(ctx, addr, key) + if err != nil { + if tries < maxTries { + time.Sleep(waitBeforeReconnect) + tries++ + + cb(addr, key) + } + + return + } + + tries = 0 + } + + return cb } diff --git a/liteclient/crypto.go b/liteclient/crypto.go index 31098ca4..5caef52a 100644 --- a/liteclient/crypto.go +++ b/liteclient/crypto.go @@ -60,7 +60,7 @@ func newCipherCtr(key, iv []byte) (cipher.Stream, error) { return cipher.NewCTR(c, iv), nil } -func (c *Client) validatePacket(data []byte, recvChecksum []byte) error { +func validatePacket(data []byte, recvChecksum []byte) error { if len(data) < 32 { return errors.New("too small packet") } diff --git a/liteclient/parse.go b/liteclient/parse.go index cf740baa..1870eaa7 100644 --- a/liteclient/parse.go +++ b/liteclient/parse.go @@ -14,7 +14,7 @@ const ADNLQueryResponse int32 = 262964246 const LiteServerQuery int32 = 2039219935 -func (c *Client) parseServerResp(data []byte) (typ int32, queryID string, payload []byte, err error) { +func parseServerResp(data []byte) (typ int32, queryID string, payload []byte, err error) { if len(data) <= 4 { err = fmt.Errorf("too short adnl packet: %d", len(data)) return