Skip to content

Commit

Permalink
feat: add multi message package
Browse files Browse the repository at this point in the history
  • Loading branch information
cosinlink committed Mar 25, 2024
1 parent a972393 commit bb3e21a
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 10 deletions.
166 changes: 156 additions & 10 deletions x/oracle/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package keeper

import (
"context"

Check failure on line 4 in x/oracle/keeper/msg_server.go

View workflow job for this annotation

GitHub Actions / golangci-lint

File is not `gofumpt`-ed (gofumpt)
"encoding/hex"
"fmt"
"runtime/debug"

sdkerrors "cosmossdk.io/errors"
sdkmath "cosmossdk.io/math"
"encoding/hex"

Check failure on line 7 in x/oracle/keeper/msg_server.go

View workflow job for this annotation

GitHub Actions / golangci-lint

File is not `gofumpt`-ed (gofumpt)
"fmt"
proto "github.com/cosmos/gogoproto/proto"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"math/big"

Check failure on line 12 in x/oracle/keeper/msg_server.go

View workflow job for this annotation

GitHub Actions / golangci-lint

File is not `gofumpt`-ed (gofumpt)
"runtime/debug"
"strings"

govtypes "github.com/cosmos/cosmos-sdk/x/gov/types"

Expand All @@ -22,6 +25,34 @@ type msgServer struct {
Keeper
}

type CrossChainMessage struct {
ChannelId uint8
MsgBytes []byte
RelayFee *big.Int
AckRelayFee *big.Int
Sender common.Address
}

type MessagesType [][]byte

var (
Uint8, _ = abi.NewType("uint8", "", nil)
Bytes, _ = abi.NewType("bytes", "", nil)
Uint256, _ = abi.NewType("uint256", "", nil)
Address, _ = abi.NewType("address", "", nil)

MessageTypeArgs = abi.Arguments{
{Name: "ChannelId", Type: Uint8},
{Name: "MsgBytes", Type: Bytes},
{Name: "RelayFee", Type: Uint256},
{Name: "AckRelayFee", Type: Uint256},
{Name: "Sender", Type: Address},
}

MessagesAbiDefinition = fmt.Sprintf(`[{ "name" : "method", "type": "function", "outputs": %s}]`, `[{"type": "bytes[]"}]`)
MessagesAbi, _ = abi.JSON(strings.NewReader(MessagesAbiDefinition))
)

// NewMsgServerImpl returns an implementation of the oracle MsgServer interface
// for the provided Keeper.
func NewMsgServerImpl(k Keeper) types.MsgServer {
Expand Down Expand Up @@ -166,6 +197,106 @@ func (k Keeper) distributeReward(ctx sdk.Context, relayer sdk.AccAddress, signed
return nil
}

func (k Keeper) handleMultiMessagePackage(
ctx sdk.Context,
pack *types.Package,
packageHeader *sdk.PackageHeader,
srcChainId uint32,
) (crash bool, result sdk.ExecuteResult) {
out, err := MessagesAbi.Unpack("method", pack.Payload)
if err != nil {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "messages unpack failed, payload=%v", pack.Payload),
}
}

unpacked := abi.ConvertType(out[0], MessagesType{})
messages, ok := unpacked.(MessagesType)
if !ok {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "messages ConvertType failed, payload=%v", pack.Payload),
}
}

if len(messages) == 0 {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "empty messages, payload=%v", pack.Payload),
}
}

crash = false
result = sdk.ExecuteResult{}
payloads := make([][]byte, len(messages))

for i, message := range messages {
unpacked, err := MessageTypeArgs.Unpack(message)
if err != nil || len(unpacked) != 5 {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode message error %d, message=%v, error: %s", i, message, err),
}
}

channelIdType := abi.ConvertType(unpacked[0], uint8(0))
msgBytesType := abi.ConvertType(unpacked[1], []byte{})
ackRelayFeeType := abi.ConvertType(unpacked[3], big.NewInt(0))

channelId, ok := channelIdType.(uint8)
if !ok {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode channelId error %d, message=%v, error: %v", i, message, err),
}
}

msgBytes, ok := msgBytesType.([]byte)
if !ok {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode msgBytes error %d, message=%v, error: %v", i, message, err),
}
}

ackRelayFee, ok := ackRelayFeeType.(*big.Int)
if !ok {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode ackRelayFee error %d, message=%v, error: %v", i, message, err),
}
}

crossChainApp := k.CrossChainKeeper.GetCrossChainApp(sdk.ChannelID(channelId))
if crossChainApp == nil {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrChannelNotRegistered, "message %d, channel %d not registered", i, channelId),
}
}

msgHeader := sdk.PackageHeader{
PackageType: packageHeader.PackageType,
Timestamp: packageHeader.Timestamp,
RelayerFee: big.NewInt(0),
AckRelayerFee: ackRelayFee,
}

payload := append(make([]byte, sdk.SynPackageHeaderLength), msgBytes...)
crashSingleMsg, resultSingleMsg := executeClaim(ctx, crossChainApp, srcChainId, 0, payload, &msgHeader)
if crashSingleMsg {
return true, resultSingleMsg
}

payloads[i] = []byte{channelId}
ackRelayFeeBytes := bigIntToBytes32(ackRelayFee)
payloads[i] = append(payloads[i], ackRelayFeeBytes[:]...)
payloads[i] = append(payloads[i], resultSingleMsg.Payload...)
}

result.Payload, err = MessagesAbi.Pack("method", payloads)
if err != nil {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrInvalidMessagesResult, "messages result pack failed, payloads=%v, error=%s", payloads, err),
}
}

return crash, result
}

func (k Keeper) handlePackage(
ctx sdk.Context,
pack *types.Package,
Expand All @@ -175,11 +306,6 @@ func (k Keeper) handlePackage(
) (sdkmath.Int, *types.EventPackageClaim, error) {
logger := k.Logger(ctx)

crossChainApp := k.CrossChainKeeper.GetCrossChainApp(pack.ChannelId)
if crossChainApp == nil {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrChannelNotRegistered, "channel %d not registered", pack.ChannelId)
}

sequence := k.CrossChainKeeper.GetReceiveSequence(ctx, sdk.ChainID(srcChainId), pack.ChannelId)
if sequence != pack.Sequence {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidReceiveSequence,
Expand All @@ -201,8 +327,20 @@ func (k Keeper) handlePackage(
"package type %d is invalid", packageHeader.PackageType)
}

crash := false
var result sdk.ExecuteResult
cacheCtx, write := ctx.CacheContext()
crash, result := executeClaim(cacheCtx, crossChainApp, srcChainId, sequence, pack.Payload, &packageHeader)

if pack.ChannelId == types.MultiMessageChannelId {
crash, result = k.handleMultiMessagePackage(cacheCtx, pack, &packageHeader, srcChainId)
} else {
crossChainApp := k.CrossChainKeeper.GetCrossChainApp(pack.ChannelId)
if crossChainApp == nil {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrChannelNotRegistered, "channel %d not registered", pack.ChannelId)
}
crash, result = executeClaim(cacheCtx, crossChainApp, srcChainId, sequence, pack.Payload, &packageHeader)
}

if result.IsOk() {
write()
}
Expand Down Expand Up @@ -295,3 +433,11 @@ func executeClaim(
}
return crash, result
}

func bigIntToBytes32(x *big.Int) [32]byte {
var b [32]byte
xBytes := x.Bytes()
numPadding := 32 - len(xBytes)
copy(b[numPadding:], xBytes)
return b
}
61 changes: 61 additions & 0 deletions x/oracle/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package keeper_test

import (
"encoding/hex"
"fmt"
"github.com/ethereum/go-ethereum/common/hexutil"
"math/big"
"time"

Expand All @@ -9,13 +12,22 @@ import (

"github.com/cosmos/cosmos-sdk/bsc/rlp"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/oracle/keeper"
"github.com/cosmos/cosmos-sdk/x/oracle/testutil"
"github.com/cosmos/cosmos-sdk/x/oracle/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"

"github.com/ethereum/go-ethereum/accounts/abi"
)

type DummyCrossChainApp struct{}

type packUnpackTest struct {
def string
unpacked interface{}
packed string
}

func (ta *DummyCrossChainApp) ExecuteSynPackage(ctx sdk.Context, header *sdk.CrossChainAppContext, payload []byte) sdk.ExecuteResult {
return sdk.ExecuteResult{}
}
Expand Down Expand Up @@ -179,3 +191,52 @@ func (s *TestSuite) TestInvalidClaim() {
s.Require().NotNil(err, "process claim should return error")
s.Require().Contains(err.Error(), "is not the same in payload header")
}

func (s *TestSuite) TestMultiMessageDecode() {
msg1, _ := hexutil.Decode("000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737431000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
msg2, _ := hexutil.Decode("000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737432000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")

tests := []packUnpackTest{
packUnpackTest{
def: `[{"type": "bytes[]"}]`,
packed: "000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000022000000000000000000000000000000000000000000000000000000000000001c0000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000005746573743100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001c0000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737432000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
unpacked: [][]byte{msg1, msg2},
},
}

for i, test := range tests {
encb, err := hex.DecodeString(test.packed)
s.Require().Nilf(err, "invalid hex %s: %v", test.packed, err)

out, err := keeper.MessagesAbi.Unpack("method", encb)
s.Require().Nilf(err, "test %d (%v) failed: %v", i, test.def, err)

unpacked := abi.ConvertType(out[0], keeper.MessagesType{})
messages, ok := unpacked.(keeper.MessagesType)
s.Require().Truef(ok, "ConvertType failed: %v", unpacked)

for _, message := range messages {
fmt.Println("message", hex.EncodeToString(message))

unpacked, err := keeper.MessageTypeArgs.Unpack(message)
s.Require().Nil(err, "unpack error")

fmt.Println("unpacked", unpacked)

channelIdType := abi.ConvertType(unpacked[0], uint8(0))
msgBytesType := abi.ConvertType(unpacked[1], []byte{})
ackRelayFeeType := abi.ConvertType(unpacked[3], big.NewInt(0))

channelId, ok := channelIdType.(uint8)
s.Require().Truef(ok, "channelId unpacked failed")

msgBytes, ok := msgBytesType.([]byte)
s.Require().Truef(ok, "msgBytes unpacked failed")

ackRelayFee, ok := ackRelayFeeType.(*big.Int)
s.Require().Truef(ok, "ackRelayFee unpacked failed")

fmt.Println(channelId, msgBytes, ackRelayFee)
}
}
}
2 changes: 2 additions & 0 deletions x/oracle/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ var (
ErrInvalidDestChainId = errors.Register(ModuleName, 14, "dest chain id is invalid")
ErrInvalidSrcChainId = errors.Register(ModuleName, 15, "src chain id is invalid")
ErrInvalidAddress = errors.Register(ModuleName, 16, "address is invalid")
ErrInvalidMultiMessage = errors.Register(ModuleName, 17, "multi message is invalid")
ErrInvalidMessagesResult = errors.Register(ModuleName, 18, "multi message result is invalid")
)
1 change: 1 addition & 0 deletions x/oracle/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ const (
// RelayPackagesChannelId is not a communication channel actually, we just use it to record sequence.
RelayPackagesChannelName = "relayPackages"
RelayPackagesChannelId sdk.ChannelID = 0x00
MultiMessageChannelId sdk.ChannelID = 0x08
)

0 comments on commit bb3e21a

Please sign in to comment.