Skip to content

Commit

Permalink
limit modexp call to 8192 bit inputs (#1391)
Browse files Browse the repository at this point in the history
* limit modexp call to 8192 bit inputs

* logic change to mod exp revert rules

* mod len 0 logic in modExp

* remove comment from modexp

* more mod exp zk tweaks
  • Loading branch information
hexoscott authored Nov 5, 2024
1 parent b6cac45 commit 642df8f
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 11 deletions.
65 changes: 54 additions & 11 deletions core/vm/contracts_zkevm.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,34 @@ func (c *bigModExp_zkevm) RequiredGas(input []byte) uint64 {
} else {
input = input[:0]
}

// Retrieve the operands and execute the exponentiation
var (
base = new(big.Int).SetBytes(getData(input, 0, baseLen.Uint64()))
exp = new(big.Int).SetBytes(getData(input, baseLen.Uint64(), expLen.Uint64()))
mod = new(big.Int).SetBytes(getData(input, baseLen.Uint64()+expLen.Uint64(), modLen.Uint64()))
baseBitLen = base.BitLen()
expBitLen = exp.BitLen()
modBitLen = mod.BitLen()
)

// zk special cases
// - if mod = 0 we consume gas as normal
// - if base is 0 and mod < 8192 we consume gas as normal
// - if neither of the above are true we check for reverts and return 0 gas fee

if modBitLen == 0 {
// consume as normal - will return 0
} else if baseBitLen == 0 {
if modBitLen > 8192 {
return 0
} else {
// consume as normal - will return 0
}
} else if baseBitLen > 8192 || expBitLen > 8192 || modBitLen > 8192 {
return 0
}

// Retrieve the head 32 bytes of exp for the adjusted exponent length
var expHead *big.Int
if big.NewInt(int64(len(input))).Cmp(baseLen) <= 0 {
Expand Down Expand Up @@ -373,21 +401,36 @@ func (c *bigModExp_zkevm) Run(input []byte) ([]byte, error) {
} else {
input = input[:0]
}
// Handle a special case when both the base and mod length is zero
if baseLen == 0 && modLen == 0 {
return []byte{}, nil
}

// Retrieve the operands and execute the exponentiation
var (
base = new(big.Int).SetBytes(getData(input, 0, baseLen))
exp = new(big.Int).SetBytes(getData(input, baseLen, expLen))
mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen))
v []byte
base = new(big.Int).SetBytes(getData(input, 0, baseLen))
exp = new(big.Int).SetBytes(getData(input, baseLen, expLen))
mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen))
v []byte
baseBitLen = base.BitLen()
expBitLen = exp.BitLen()
modBitLen = mod.BitLen()
)

if modBitLen == 0 {
return []byte{}, nil
}

if baseBitLen == 0 {
if modBitLen > 8192 {
return nil, ErrExecutionReverted
} else {
return common.LeftPadBytes([]byte{}, int(modLen)), nil
}
}

// limit to 8192 bits for base, exp, and mod in ZK
if baseBitLen > 8192 || expBitLen > 8192 || modBitLen > 8192 {
return nil, ErrExecutionReverted
}

switch {
case mod.BitLen() == 0:
// Modulo 0 is undefined, return zero
return common.LeftPadBytes([]byte{}, int(modLen)), nil
case base.Cmp(libcommon.Big1) == 0:
//If base == 1, then we can just return base % mod (if mod >= 1, which it is)
v = base.Mod(base, mod).Bytes()
Expand Down
69 changes: 69 additions & 0 deletions core/vm/contracts_zkevm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package vm

import (
"testing"
"math/big"
)

var (
big0 = big.NewInt(0)
big10 = big.NewInt(10)
big8194 = big.NewInt(0).Lsh(big.NewInt(1), 8194)
)

func Test_ModExpZkevm_Gas(t *testing.T) {
modExp := bigModExp_zkevm{enabled: true, eip2565: true}

cases := map[string]struct {
base *big.Int
exp *big.Int
mod *big.Int
expected uint64
}{
"simple test": {big10, big10, big10, 200},
"0 mod - normal gas": {big10, big10, big0, 200},
"base 0 - mod < 8192 - normal gas": {big0, big10, big10, 200},
"base 0 - mod > 8192 - 0 gas": {big0, big10, big8194, 0},
"base over 8192 - 0 gas": {big8194, big10, big10, 0},
"exp over 8192 - 0 gas": {big10, big8194, big10, 0},
"mod over 8192 - 0 gas": {big10, big10, big8194, 0},
}

for name, test := range cases {
t.Run(name, func(t *testing.T) {
input := make([]byte, 0)

base := len(test.base.Bytes())
exp := len(test.exp.Bytes())
mod := len(test.mod.Bytes())

input = append(input, uint64To32Bytes(base)...)
input = append(input, uint64To32Bytes(exp)...)
input = append(input, uint64To32Bytes(mod)...)
input = append(input, uint64ToDeterminedBytes(test.base, base)...)
input = append(input, uint64ToDeterminedBytes(test.exp, exp)...)
input = append(input, uint64ToDeterminedBytes(test.mod, mod)...)

gas := modExp.RequiredGas(input)

if gas != test.expected {
t.Errorf("Expected %d, got %d", test.expected, gas)
}
})
}
}

func uint64To32Bytes(input int) []byte {
bigInt := new(big.Int).SetUint64(uint64(input))
bytes := bigInt.Bytes()
result := make([]byte, 32)
copy(result[32-len(bytes):], bytes)
return result
}

func uint64ToDeterminedBytes(input *big.Int, length int) []byte {
bytes := input.Bytes()
result := make([]byte, length)
copy(result[length-len(bytes):], bytes)
return result
}

0 comments on commit 642df8f

Please sign in to comment.