From 39bd2dc346b1103057c788f785bf5479fa7cb3df Mon Sep 17 00:00:00 2001 From: Scott Fairclough Date: Thu, 31 Oct 2024 12:07:39 +0000 Subject: [PATCH] logic change to mod exp revert rules --- core/vm/contracts_zkevm.go | 44 +++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/core/vm/contracts_zkevm.go b/core/vm/contracts_zkevm.go index 73afa136536..a01e3025b7a 100644 --- a/core/vm/contracts_zkevm.go +++ b/core/vm/contracts_zkevm.go @@ -302,15 +302,20 @@ func (c *bigModExp_zkevm) RequiredGas(input []byte) uint64 { // Retrieve the operands and execute the exponentiation var ( - base = new(big.Int).SetBytes(getData(input, 0, baseLen.Uint64())) - exp = new(big.Int).SetBytes(getData(input, baseLen.Uint64(), expLen.Uint64())) - mod = new(big.Int).SetBytes(getData(input, baseLen.Uint64()+expLen.Uint64(), modLen.Uint64())) + base = new(big.Int).SetBytes(getData(input, 0, baseLen.Uint64())) + exp = new(big.Int).SetBytes(getData(input, baseLen.Uint64(), expLen.Uint64())) + mod = new(big.Int).SetBytes(getData(input, baseLen.Uint64()+expLen.Uint64(), modLen.Uint64())) + baseBitLen = base.BitLen() + expBitLen = exp.BitLen() + modBitLen = mod.BitLen() ) + // special revert case for ZK + if baseBitLen == 0 && modBitLen > 8192 { + return 0 + } + // limit to 8192 bits for base, exp, and mod in ZK - revert if we go over - baseBitLen := base.BitLen() - expBitLen := exp.BitLen() - modBitLen := mod.BitLen() if baseBitLen > 8192 || expBitLen > 8192 || modBitLen > 8192 { return 0 } @@ -389,22 +394,31 @@ func (c *bigModExp_zkevm) Run(input []byte) ([]byte, error) { } else { input = input[:0] } - // Handle a special case when both the base and mod length is zero - if baseLen == 0 && modLen == 0 { + + if modLen == 0 { return []byte{}, nil } + // Retrieve the operands and execute the exponentiation var ( - base = new(big.Int).SetBytes(getData(input, 0, baseLen)) - exp = new(big.Int).SetBytes(getData(input, baseLen, expLen)) - mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) - v []byte + base = new(big.Int).SetBytes(getData(input, 0, baseLen)) + exp = new(big.Int).SetBytes(getData(input, baseLen, expLen)) + mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) + v []byte + baseBitLen = base.BitLen() + expBitLen = exp.BitLen() + modBitLen = mod.BitLen() ) + if baseBitLen == 0 { + if modBitLen > 8192 { + return nil, ErrExecutionReverted + } else { + return []byte{}, nil + } + } + // limit to 8192 bits for base, exp, and mod in ZK - baseBitLen := base.BitLen() - expBitLen := exp.BitLen() - modBitLen := mod.BitLen() if baseBitLen > 8192 || expBitLen > 8192 || modBitLen > 8192 { return nil, ErrExecutionReverted }