From 6d3a43865b2524c7dd1a008e9802c6d0cd3a1831 Mon Sep 17 00:00:00 2001 From: Pavel Zverev <31499197+blitz-1306@users.noreply.github.com> Date: Wed, 25 Oct 2023 02:03:15 +0500 Subject: [PATCH] Handle additional edge cases for `evalConstantExpr()` (#230) * Handle additional edge cases for evalConstantExpr(): support Buffer values (for hex literals) in unary math ops. * Introduce castToType(). Tweak evalUnary() and evalBinary() to cover more edge-cases. Export utility functions for downstream. Add more edge-cases to tests. * review nit --------- Co-authored-by: Dimitar --- src/types/eval_const.ts | 195 +++++++++++++------- src/types/utils.ts | 4 + test/unit/types/eval_const.spec.ts | 276 +++++++++++++++++++++++++++++ 3 files changed, 409 insertions(+), 66 deletions(-) diff --git a/src/types/eval_const.ts b/src/types/eval_const.ts index 2f2fcef1..1b9cb95c 100644 --- a/src/types/eval_const.ts +++ b/src/types/eval_const.ts @@ -18,9 +18,21 @@ import { VariableDeclaration } from "../ast"; import { pp } from "../misc"; -import { BytesType, FixedBytesType, IntType, NumericLiteralType, StringType } from "./ast"; +import { + BytesType, + FixedBytesType, + IntType, + NumericLiteralType, + StringType, + TypeNode +} from "./ast"; import { InferType } from "./infer"; -import { BINARY_OPERATOR_GROUPS, SUBDENOMINATION_MULTIPLIERS, clampIntToType } from "./utils"; +import { + BINARY_OPERATOR_GROUPS, + SUBDENOMINATION_MULTIPLIERS, + clampIntToType, + fixedBytesTypeToIntType +} from "./utils"; /** * Tune up precision of decimal values to follow Solidity behavior. * Be careful with precision - setting it to large values causes NodeJS to crash. @@ -51,7 +63,7 @@ function str(value: Value): string { return value instanceof Decimal ? value.toString() : pp(value); } -function promoteToDec(v: Value): Decimal { +export function toDec(v: Value): Decimal { if (v instanceof Decimal) { return v; } @@ -71,10 +83,86 @@ function promoteToDec(v: Value): Decimal { throw new Error(`Expected number not ${v}`); } +export function toInt(v: Value): bigint { + if (typeof v === "bigint") { + return v; + } + + if (v instanceof Decimal && v.isInt()) { + return BigInt(v.toHex()); + } + + if (typeof v === "string") { + return v === "" ? 0n : BigInt("0x" + Buffer.from(v, "utf-8").toString("hex")); + } + + if (v instanceof Buffer) { + return v.length === 0 ? 0n : BigInt("0x" + v.toString("hex")); + } + + throw new Error(`Expected integer not ${v}`); +} + function demoteFromDec(d: Decimal): Decimal | bigint { return d.isInt() ? BigInt(d.toFixed()) : d; } +export function castToType(v: Value, fromT: TypeNode | undefined, toT: TypeNode): Value { + if (typeof v === "bigint") { + if (toT instanceof IntType) { + return clampIntToType(v, toT); + } + + if (toT instanceof FixedBytesType) { + if (fromT instanceof FixedBytesType && fromT.size < toT.size) { + return BigInt("0x" + v.toString(16).padEnd(toT.size * 2, "0")); + } + + return clampIntToType(v, fixedBytesTypeToIntType(toT)); + } + } + + if (typeof v === "string") { + if (toT instanceof BytesType) { + return Buffer.from(v, "utf-8"); + } + + if (toT instanceof FixedBytesType) { + if (v.length === 0) { + return 0n; + } + + const buf = Buffer.from(v, "utf-8"); + + if (buf.length < toT.size) { + return BigInt("0x" + buf.toString("hex").padEnd(toT.size * 2, "0")); + } + + return BigInt("0x" + buf.slice(0, toT.size).toString("hex")); + } + } + + if (v instanceof Buffer) { + if (toT instanceof StringType) { + return v.toString("utf-8"); + } + + if (toT instanceof FixedBytesType) { + if (v.length === 0) { + return 0n; + } + + if (v.length < toT.size) { + return BigInt("0x" + v.toString("hex").padEnd(toT.size * 2, "0")); + } + + return BigInt("0x" + v.slice(0, toT.size).toString("hex")); + } + } + + return v; +} + export function isConstant(expr: Expression | VariableDeclaration): boolean { if (expr instanceof Literal) { return true; @@ -259,8 +347,8 @@ export function evalBinaryImpl(operator: string, left: Value, right: Value): Val if (typeof left === "boolean" || typeof right === "boolean") { isEqual = left === right; } else { - const leftDec = promoteToDec(left); - const rightDec = promoteToDec(right); + const leftDec = toDec(left); + const rightDec = toDec(right); isEqual = leftDec.equals(rightDec); } @@ -283,8 +371,8 @@ export function evalBinaryImpl(operator: string, left: Value, right: Value): Val ); } - const leftDec = promoteToDec(left); - const rightDec = promoteToDec(right); + const leftDec = toDec(left); + const rightDec = toDec(right); if (operator === "<") { return leftDec.lessThan(rightDec); @@ -306,8 +394,8 @@ export function evalBinaryImpl(operator: string, left: Value, right: Value): Val } if (BINARY_OPERATOR_GROUPS.Arithmetic.includes(operator)) { - const leftDec = promoteToDec(left); - const rightDec = promoteToDec(right); + const leftDec = toDec(left); + const rightDec = toDec(right); let res: Decimal; @@ -331,28 +419,27 @@ export function evalBinaryImpl(operator: string, left: Value, right: Value): Val } if (BINARY_OPERATOR_GROUPS.Bitwise.includes(operator)) { - if (!(typeof left === "bigint" && typeof right === "bigint")) { - throw new EvalError(`${operator} expects integers not ${str(left)} and ${str(right)}`); - } + const leftInt = toInt(left); + const rightInt = toInt(right); if (operator === "<<") { - return left << right; + return leftInt << rightInt; } if (operator === ">>") { - return left >> right; + return leftInt >> rightInt; } if (operator === "|") { - return left | right; + return leftInt | rightInt; } if (operator === "&") { - return left & right; + return leftInt & rightInt; } if (operator === "^") { - return left ^ right; + return leftInt ^ rightInt; } throw new EvalError(`Unknown bitwise operator ${operator}`); @@ -396,13 +483,16 @@ export function evalLiteral(node: Literal): Value { export function evalUnary(node: UnaryOperation, inference: InferType): Value { try { const subT = inference.typeOf(node.vSubExpression); - const res = evalUnaryImpl(node.operator, evalConstantExpr(node.vSubExpression, inference)); + const sub = evalConstantExpr(node.vSubExpression, inference); - if (subT instanceof IntType && typeof res === "bigint") { - return clampIntToType(res, subT); + if (subT instanceof NumericLiteralType) { + return evalUnaryImpl(node.operator, sub); } - return res; + const resT = inference.typeOfUnaryOperation(node); + const res = evalUnaryImpl(node.operator, sub); + + return castToType(res, undefined, resT); } catch (e: unknown) { if (e instanceof EvalError && e.expr === undefined) { e.expr = node; @@ -417,21 +507,25 @@ export function evalBinary(node: BinaryOperation, inference: InferType): Value { const leftT = inference.typeOf(node.vLeftExpression); const rightT = inference.typeOf(node.vRightExpression); - const res = evalBinaryImpl( - node.operator, - evalConstantExpr(node.vLeftExpression, inference), - evalConstantExpr(node.vRightExpression, inference) - ); + let left = evalConstantExpr(node.vLeftExpression, inference); + let right = evalConstantExpr(node.vRightExpression, inference); - if (!(leftT instanceof NumericLiteralType && rightT instanceof NumericLiteralType)) { - const resT = inference.typeOfBinaryOperation(node); + if (leftT instanceof NumericLiteralType && rightT instanceof NumericLiteralType) { + return evalBinaryImpl(node.operator, left, right); + } - if (resT instanceof IntType && typeof res === "bigint") { - return clampIntToType(res, resT); - } + if (node.operator !== "**" && node.operator !== ">>" && node.operator !== "<<") { + const commonT = inference.inferCommonType(leftT, rightT); + + left = castToType(left, leftT, commonT); + right = castToType(right, rightT, commonT); } - return res; + const res = evalBinaryImpl(node.operator, left, right); + + const resT = inference.typeOfBinaryOperation(node); + + return castToType(res, undefined, resT); } catch (e: unknown) { if (e instanceof EvalError && e.expr === undefined) { e.expr = node; @@ -503,41 +597,10 @@ export function evalFunctionCall(node: FunctionCall, inference: InferType): Valu } const val = evalConstantExpr(node.vArguments[0], inference); - const castT = inference.typeOfElementaryTypeNameExpression(node.vExpression).type; - - if (typeof val === "bigint") { - if (castT instanceof IntType) { - return clampIntToType(val, castT); - } - - if (castT instanceof FixedBytesType) { - return clampIntToType(val, new IntType(castT.size * 8, false)); - } - } - - if (typeof val === "string") { - if (castT instanceof BytesType) { - return Buffer.from(val, "utf-8"); - } - - if (castT instanceof FixedBytesType) { - const buf = Buffer.from(val, "utf-8"); - - return BigInt("0x" + buf.slice(0, castT.size).toString("hex")); - } - } - - if (val instanceof Buffer) { - if (castT instanceof StringType) { - return val.toString("utf-8"); - } - - if (castT instanceof FixedBytesType) { - return BigInt("0x" + val.slice(0, castT.size).toString("hex")); - } - } + const fromT = inference.typeOf(node.vArguments[0]); + const toT = inference.typeOfElementaryTypeNameExpression(node.vExpression).type; - return val; + return castToType(val, fromT, toT); } /** diff --git a/src/types/utils.ts b/src/types/utils.ts index 448cdba1..73da214a 100644 --- a/src/types/utils.ts +++ b/src/types/utils.ts @@ -303,6 +303,10 @@ export function enumToIntType(decl: EnumDefinition): IntType { return new IntType(size, false); } +export function fixedBytesTypeToIntType(type: FixedBytesType): IntType { + return new IntType(type.size * 8, false, type.src); +} + export function getABIEncoderVersion(unit: SourceUnit, compilerVersion: string): ABIEncoderVersion { const predefined = unit.abiEncoderVersion; diff --git a/test/unit/types/eval_const.spec.ts b/test/unit/types/eval_const.spec.ts index 841fcc86..c70abbb7 100644 --- a/test/unit/types/eval_const.spec.ts +++ b/test/unit/types/eval_const.spec.ts @@ -1320,6 +1320,282 @@ const cases: Array<[string, (factory: ASTNodeFactory) => Expression, boolean, Va true, BigInt("0xe8") ], + [ + "Edge-case: bytes2(0xff00) == byte(0xff)", + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "==", + factory.makeFunctionCall( + "bytes2", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes2)", "bytes2"), + [factory.makeLiteral("", LiteralKind.Number, "ff00", "0xff00")] + ), + factory.makeFunctionCall( + "byte", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(byte)", "byte"), + [factory.makeLiteral("", LiteralKind.Number, "ff", "0xff")] + ) + ), + true, + true + ], + [ + "Edge-case: bytes2(0xcc00) < byte(0xee)", + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "<", + factory.makeFunctionCall( + "bytes2", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes2)", "bytes2"), + [factory.makeLiteral("", LiteralKind.Number, "cc00", "0xcc00")] + ), + factory.makeFunctionCall( + "byte", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(byte)", "byte"), + [factory.makeLiteral("", LiteralKind.Number, "ee", "0xee")] + ) + ), + true, + true + ], + [ + "Edge-case: (~bytes4(0xF0FF000F)) == 0x0f00fff0", + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "==", + factory.makeTupleExpression("", false, [ + factory.makeUnaryOperation( + "", + true, + "~", + factory.makeFunctionCall( + "bytes4", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes4)", "bytes4"), + [ + factory.makeLiteral( + "", + LiteralKind.Number, + "F0FF000F", + "0xF0FF000F" + ) + ] + ) + ) + ]), + factory.makeLiteral("", LiteralKind.Number, "0f00fff0", "0x0f00fff0") + ), + true, + true + ], + [ + "Edge-case: (~bytes4(0xFFFFFFFF)) == 0", + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "==", + factory.makeTupleExpression("", false, [ + factory.makeUnaryOperation( + "", + true, + "~", + factory.makeFunctionCall( + "bytes4", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes4)", "bytes4"), + [ + factory.makeLiteral( + "", + LiteralKind.Number, + "FFFFFFFF", + "0xFFFFFFFF" + ) + ] + ) + ) + ]), + factory.makeLiteral("", LiteralKind.Number, "00", "0") + ), + true, + true + ], + [ + "Edge-case: (~bytes4(0x00000000)) == 0xFFFFFFFF", + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "==", + factory.makeTupleExpression("", false, [ + factory.makeUnaryOperation( + "", + true, + "~", + factory.makeFunctionCall( + "bytes4", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes4)", "bytes4"), + [ + factory.makeLiteral( + "", + LiteralKind.Number, + "00000000", + "0x00000000" + ) + ] + ) + ) + ]), + factory.makeLiteral("", LiteralKind.Number, "FFFFFFFF", "0xFFFFFFFF") + ), + true, + true + ], + [ + 'Edge-case: bytes1(0x01) | hex"02" == 0x03', + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "==", + factory.makeBinaryOperation( + "", + "|", + factory.makeFunctionCall( + "bytes1", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes1)", "bytes1"), + [factory.makeLiteral("", LiteralKind.Number, "01", "0x01")] + ), + factory.makeLiteral("", LiteralKind.HexString, "02", "0x02") + ), + factory.makeLiteral("", LiteralKind.Number, "03", "0x03") + ), + true, + true + ], + [ + 'Edge-case: ~bytes1(hex"01") == 0xfe', + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "==", + factory.makeUnaryOperation( + "", + true, + "~", + factory.makeFunctionCall( + "bytes1", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes1)", "bytes1"), + [factory.makeLiteral("", LiteralKind.Number, "01", "0x01")] + ) + ), + factory.makeLiteral("", LiteralKind.Number, "fe", "0xfe") + ), + true, + true + ], + [ + 'Edge-case: bytes2(hex"ff") == bytes1(0xff)', + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "==", + factory.makeFunctionCall( + "bytes2", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes2)", "bytes2"), + [factory.makeLiteral("", LiteralKind.HexString, "ff", "ff")] + ), + factory.makeFunctionCall( + "bytes1", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes1)", "bytes1"), + [factory.makeLiteral("", LiteralKind.Number, "ff", "0xff")] + ) + ), + true, + true + ], + [ + 'Edge-case: bytes2(hex"cc") < bytes1(0xee)', + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "<", + factory.makeFunctionCall( + "bytes2", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes2)", "bytes2"), + [factory.makeLiteral("", LiteralKind.HexString, "cc", "cc")] + ), + factory.makeFunctionCall( + "bytes1", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes1)", "bytes1"), + [factory.makeLiteral("", LiteralKind.Number, "ee", "0xee")] + ) + ), + true, + true + ], + [ + 'Edge-case: bytes2(hex"") == 0', + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "==", + factory.makeFunctionCall( + "bytes2", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes2)", "bytes2"), + [factory.makeLiteral("", LiteralKind.HexString, "", "")] + ), + factory.makeLiteral("", LiteralKind.Number, "00", "0") + ), + true, + true + ], + [ + 'Edge-case: bytes2("") == 0', + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "==", + factory.makeFunctionCall( + "bytes2", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes2)", "bytes2"), + [factory.makeLiteral("", LiteralKind.String, "", "")] + ), + factory.makeLiteral("", LiteralKind.Number, "00", "0") + ), + true, + true + ], + [ + 'Edge-case: bytes2("ab") == 0x6162', + (factory: ASTNodeFactory) => + factory.makeBinaryOperation( + "", + "==", + factory.makeFunctionCall( + "bytes2", + FunctionCallKind.TypeConversion, + factory.makeElementaryTypeNameExpression("type(bytes2)", "bytes2"), + [factory.makeLiteral("", LiteralKind.String, "6162", "ab")] + ), + factory.makeLiteral("", LiteralKind.Number, "6162", "0x6162") + ), + true, + true + ], [ 'Identifier & IndexAccess (const A = "abcdef" (string), a[2])', (factory: ASTNodeFactory) => {