Skip to content

Commit

Permalink
feat: add multi message package
Browse files Browse the repository at this point in the history
  • Loading branch information
cosinlink committed Mar 24, 2024
1 parent a972393 commit af840db
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 9 deletions.
174 changes: 165 additions & 9 deletions x/oracle/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package keeper

import (
"context"
"encoding/hex"
"fmt"
"runtime/debug"

sdkerrors "cosmossdk.io/errors"
sdkmath "cosmossdk.io/math"
"encoding/hex"
"fmt"
proto "github.com/cosmos/gogoproto/proto"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"math/big"
"runtime/debug"
"strings"

govtypes "github.com/cosmos/cosmos-sdk/x/gov/types"

Expand Down Expand Up @@ -166,20 +169,164 @@ func (k Keeper) distributeReward(ctx sdk.Context, relayer sdk.AccAddress, signed
return nil
}

func (k Keeper) handlePackage(
type CrossChainMessage struct {
ChannelId uint8
MsgBytes []byte
RelayFee *big.Int
AckRelayFee *big.Int
Sender common.Address
}

var (
Uint8, _ = abi.NewType("uint8", "", nil)
Bytes, _ = abi.NewType("bytes", "", nil)
Uint256, _ = abi.NewType("uint256", "", nil)
Address, _ = abi.NewType("address", "", nil)

messageTypeArgs = abi.Arguments{
{Name: "ChannelId", Type: Uint8},
{Name: "MsgBytes", Type: Bytes},
{Name: "RelayFee", Type: Uint256},
{Name: "AckRelayFee", Type: Uint256},
{Name: "Sender", Type: Address},
}
)

func (k Keeper) handleMultiMessagePackage(
ctx sdk.Context,
pack *types.Package,
packageHeader *sdk.PackageHeader,
srcChainId uint32,
destChainId uint32,
timestamp uint64,
) (sdkmath.Int, *types.EventPackageClaim, error) {
logger := k.Logger(ctx)

crossChainApp := k.CrossChainKeeper.GetCrossChainApp(pack.ChannelId)
if crossChainApp == nil {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrChannelNotRegistered, "channel %d not registered", pack.ChannelId)
def := fmt.Sprintf(`[{ "name" : "method", "type": "function", "outputs": %s}]`, `[{"type": "bytes[]"}]`)
msgsAbi, err := abi.JSON(strings.NewReader(def))
if err != nil {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "invalid ABI definition %s: %v", def, err)
}

out, err := msgsAbi.Unpack("method", pack.Payload)
if err != nil {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "package unpack error, payload=%v: %v", pack.Payload, err)
}

type msgsType [][]byte
unpacked := abi.ConvertType(out[0], msgsType{})
messages, ok := unpacked.(msgsType)
if !ok {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "messages ConvertType failed")
}

crash := false
result := sdk.ExecuteResult{}
for i, message := range messages {
unpacked, err := messageTypeArgs.Unpack(message)
if err != nil || len(unpacked) != 5 {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode message error %d, message=%v, error: %v", i, message, err)
}

channelIdType := abi.ConvertType(unpacked[0], uint8(0))
msgBytesType := abi.ConvertType(unpacked[1], []byte{})
ackRelayFeeType := abi.ConvertType(unpacked[3], big.NewInt(0))

channelId, ok := channelIdType.(uint8)
if !ok {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode channelId error %d, message=%v, error: %v", i, message, err)
}

msgBytes, ok := msgBytesType.([]byte)
if !ok {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode msgBytes error %d, message=%v, error: %v", i, message, err)
}

ackRelayFee, ok := ackRelayFeeType.(*big.Int)
if !ok {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode ackRelayFee error %d, message=%v, error: %v", i, message, err)
}

crossChainApp := k.CrossChainKeeper.GetCrossChainApp(sdk.ChannelID(channelId))
if crossChainApp == nil {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrChannelNotRegistered, "channel %d not registered", channelId)
}

msgHeader := sdk.PackageHeader{
PackageType: packageHeader.PackageType,
Timestamp: packageHeader.Timestamp,
RelayerFee: big.NewInt(0),
AckRelayerFee: ackRelayFee,
}

payload := append(make([]byte, sdk.SynPackageHeaderLength), msgBytes...)

cacheCtx, write := ctx.CacheContext()
crashSingleMsg, resultSingleMsg := executeClaim(cacheCtx, crossChainApp, srcChainId, 0, payload, &msgHeader)
if resultSingleMsg.IsOk() {
write()
}

if crashSingleMsg {
crash = true
result.Err = resultSingleMsg.Err
break
}
}

// write ack package
var sendSequence int64 = -1
if packageHeader.PackageType == sdk.SynCrossChainPackageType {
if crash {
if len(pack.Payload) < sdk.SynPackageHeaderLength {
logger.Error("found payload without header",
"channelID", pack.ChannelId, "sequence", pack.Sequence, "payload", hex.EncodeToString(pack.Payload))
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidPackage, "payload without header")
}

sendSeq, ibcErr := k.CrossChainKeeper.CreateRawIBCPackageWithFee(ctx, sdk.ChainID(srcChainId), pack.ChannelId,
sdk.FailAckCrossChainPackageType, pack.Payload[sdk.SynPackageHeaderLength:], packageHeader.AckRelayerFee, sdk.NilAckRelayerFee)
if ibcErr != nil {
logger.Error("failed to write FailAckCrossChainPackage", "err", ibcErr)
return sdkmath.ZeroInt(), nil, ibcErr
}

sendSequence = int64(sendSeq)
} else if len(result.Payload) != 0 {
sendSeq, err := k.CrossChainKeeper.CreateRawIBCPackageWithFee(ctx, sdk.ChainID(srcChainId), pack.ChannelId,
sdk.AckCrossChainPackageType, result.Payload, packageHeader.AckRelayerFee, sdk.NilAckRelayerFee)
if err != nil {
logger.Error("failed to write AckCrossChainPackage", "err", err)
return sdkmath.ZeroInt(), nil, err
}
sendSequence = int64(sendSeq)
}
}

claimEvent := &types.EventPackageClaim{
SrcChainId: srcChainId,
DestChainId: destChainId,
ChannelId: uint32(pack.ChannelId),
PackageType: uint32(packageHeader.PackageType),
ReceiveSequence: pack.Sequence,
SendSequence: sendSequence,
RelayerFee: packageHeader.RelayerFee.String(),
AckRelayerFee: packageHeader.AckRelayerFee.String(),
Crash: crash,
ErrorMsg: result.ErrMsg(),
}

return sdkmath.NewIntFromBigInt(packageHeader.RelayerFee), claimEvent, nil
}

func (k Keeper) handlePackage(
ctx sdk.Context,
pack *types.Package,
srcChainId uint32,
destChainId uint32,
timestamp uint64,
) (sdkmath.Int, *types.EventPackageClaim, error) {
logger := k.Logger(ctx)

sequence := k.CrossChainKeeper.GetReceiveSequence(ctx, sdk.ChainID(srcChainId), pack.ChannelId)
if sequence != pack.Sequence {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidReceiveSequence,
Expand All @@ -201,6 +348,15 @@ func (k Keeper) handlePackage(
"package type %d is invalid", packageHeader.PackageType)
}

if pack.ChannelId == types.MultiMessageChannelId {
return k.handleMultiMessagePackage(ctx, pack, &packageHeader, srcChainId, destChainId)
}

crossChainApp := k.CrossChainKeeper.GetCrossChainApp(pack.ChannelId)
if crossChainApp == nil {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrChannelNotRegistered, "channel %d not registered", pack.ChannelId)
}

cacheCtx, write := ctx.CacheContext()
crash, result := executeClaim(cacheCtx, crossChainApp, srcChainId, sequence, pack.Payload, &packageHeader)
if result.IsOk() {
Expand Down
107 changes: 107 additions & 0 deletions x/oracle/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package keeper_test

import (
"encoding/hex"
"fmt"
"github.com/ethereum/go-ethereum/common"
"math/big"
"strings"
"time"

"github.com/golang/mock/gomock"
Expand All @@ -12,6 +16,8 @@ import (
"github.com/cosmos/cosmos-sdk/x/oracle/testutil"
"github.com/cosmos/cosmos-sdk/x/oracle/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"

"github.com/ethereum/go-ethereum/accounts/abi"
)

type DummyCrossChainApp struct{}
Expand Down Expand Up @@ -179,3 +185,104 @@ func (s *TestSuite) TestInvalidClaim() {
s.Require().NotNil(err, "process claim should return error")
s.Require().Contains(err.Error(), "is not the same in payload header")
}

type packUnpackTest struct {
def string
unpacked interface{}
packed string
}

var (
Uint8, _ = abi.NewType("uint8", "", nil)
Bytes, _ = abi.NewType("bytes", "", nil)
Uint256, _ = abi.NewType("uint256", "", nil)
Address, _ = abi.NewType("address", "", nil)

messageTypeArgs = abi.Arguments{
{Name: "ChannelId", Type: Uint8},
{Name: "MsgBytes", Type: Bytes},
{Name: "RelayFee", Type: Uint256},
{Name: "AckRelayFee", Type: Uint256},
{Name: "Sender", Type: Address},
}

/* messageType, _ = abi.NewType("tuple", "", []abi.ArgumentMarshaling{
{Name: "ChannelId", Type: "uint8"},
{Name: "MsgBytes", Type: "bytes"},
{Name: "RelayFee", Type: "uint256"},
{Name: "AckRelayFee", Type: "uint256"},
{Name: "Sender", Type: "address"},
})
messageTypeArgs = abi.Arguments{
{Type: messageType},
}*/
)

type CrossChainMessage struct {
ChannelId uint8
MsgBytes []byte
RelayFee *big.Int
AckRelayFee *big.Int
Sender common.Address
}

func (s *TestSuite) TestMultiMessage() {
test := packUnpackTest{
def: `[{"type": "bytes[]"}]`,
packed: "000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000022000000000000000000000000000000000000000000000000000000000000001c0000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000005746573743100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001c0000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737432000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
unpacked: [][]byte{{0xf0, 0xf0, 0xf0}, {0xf0, 0xf0, 0xf0}},
}

i := 0
//Unpack
fmt.Println(test.packed)

def := fmt.Sprintf(`[{ "name" : "method", "type": "function", "outputs": %s}]`, test.def)
messagesAbi, err := abi.JSON(strings.NewReader(def))
if err != nil {
fmt.Printf("invalid ABI definition %s: %v", def, err)
}
encb, err := hex.DecodeString(test.packed)
if err != nil {
fmt.Printf("invalid hex %s: %v", test.packed, err)
}

out, err := messagesAbi.Unpack("method", encb)
if err != nil {
fmt.Printf("test %d (%v) failed: %v", i, test.def, err)
return
}

type unpackedType [][]byte
unpacked := abi.ConvertType(out[0], unpackedType{})
messages, ok := unpacked.(unpackedType)
if !ok {
fmt.Printf("ConvertType failed: %v", unpacked)
}

fmt.Println("messages", messages)

for _, message := range messages {
fmt.Println("message", hex.EncodeToString(message))

unpacked, err := messageTypeArgs.Unpack(message)
s.Require().Nil(err, "unpack error")

fmt.Println("unpacked", unpacked)

channelIdType := abi.ConvertType(unpacked[0], uint8(0))
msgBytesType := abi.ConvertType(unpacked[1], []byte{})
ackRelayFeeType := abi.ConvertType(unpacked[3], big.NewInt(0))

channelId, ok := channelIdType.(uint8)
fmt.Println(ok, channelId)

msgBytes, ok := msgBytesType.([]byte)
fmt.Println(ok, msgBytes)

ackRelayFee, ok := ackRelayFeeType.(*big.Int)
fmt.Println(ok, ackRelayFee)
}

}
1 change: 1 addition & 0 deletions x/oracle/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ var (
ErrInvalidDestChainId = errors.Register(ModuleName, 14, "dest chain id is invalid")
ErrInvalidSrcChainId = errors.Register(ModuleName, 15, "src chain id is invalid")
ErrInvalidAddress = errors.Register(ModuleName, 16, "address is invalid")
ErrInvalidMultiMessage = errors.Register(ModuleName, 17, "multi message is invalid")
)
1 change: 1 addition & 0 deletions x/oracle/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ const (
// RelayPackagesChannelId is not a communication channel actually, we just use it to record sequence.
RelayPackagesChannelName = "relayPackages"
RelayPackagesChannelId sdk.ChannelID = 0x00
MultiMessageChannelId sdk.ChannelID = 0x08
)

0 comments on commit af840db

Please sign in to comment.