From 1e5c4a1a5f7606f70cdad40e198585a6316818df Mon Sep 17 00:00:00 2001 From: Pavel Zverev <31499197+blitz-1306@users.noreply.github.com> Date: Wed, 31 May 2023 12:58:36 +0500 Subject: [PATCH] Function to check that `TypeNode` is ABI-encodable (#209) * Introduce InferType.isABIEncodable() method and related unit test. Tweak InferType.toABIEncodedType() to check if type is encodable. Other small fixes. * Address review remark: invert logic of InferType.isABIEncodable() to explicitl list allowed types instead of listing disallowed ones. * Fix test due to contract constructor signature change --- src/types/ast/import_ref_type.ts | 3 +- src/types/infer.ts | 84 +++++-- test/unit/types/abi_encodable.spec.ts | 343 ++++++++++++++++++++++++++ test/unit/types/castable.spec.ts | 2 +- 4 files changed, 410 insertions(+), 22 deletions(-) create mode 100644 test/unit/types/abi_encodable.spec.ts diff --git a/src/types/ast/import_ref_type.ts b/src/types/ast/import_ref_type.ts index f85c2900..5af73caf 100644 --- a/src/types/ast/import_ref_type.ts +++ b/src/types/ast/import_ref_type.ts @@ -14,7 +14,8 @@ export class ImportRefType extends TypeNode { assert( importStmt.vSymbolAliases.length === 0 && importStmt.unitAlias !== "", - `ImportRefTypes only applicable to unit alias imports, not ${importStmt.print()}` + "ImportRefTypes only applicable to unit alias imports, not {0}", + importStmt ); this.importStmt = importStmt; diff --git a/src/types/infer.ts b/src/types/infer.ts index 90f82272..b18f09b0 100644 --- a/src/types/infer.ts +++ b/src/types/infer.ts @@ -1,18 +1,16 @@ import { Decimal } from "decimal.js"; import { gte, lt } from "semver"; import { + ASTNode, AnyResolvable, ArrayTypeName, Assignment, - ASTNode, BinaryOperation, Conditional, ContractDefinition, ContractKind, ElementaryTypeName, ElementaryTypeNameExpression, - encodeEventSignature, - encodeFuncSignature, EnumDefinition, ErrorDefinition, EventDefinition, @@ -37,7 +35,6 @@ import { ModifierDefinition, NewExpression, ParameterList, - resolveAny, SourceUnit, StateVariableVisibility, StructDefinition, @@ -48,7 +45,10 @@ import { UserDefinedTypeName, UserDefinedValueTypeDefinition, VariableDeclaration, - VariableDeclarationStatement + VariableDeclarationStatement, + encodeEventSignature, + encodeFuncSignature, + resolveAny } from "../ast"; import { DataLocation } from "../ast/constants"; import { assert, eq, forAll, forAny, pp } from "../misc"; @@ -94,25 +94,25 @@ import { } from "./builtins"; import { evalConstantExpr } from "./eval_const"; import { SolTypeError } from "./misc"; -import { applySubstitution, buildSubstitutions, TypeSubstituion } from "./polymorphic"; +import { TypeSubstituion, applySubstitution, buildSubstitutions } from "./polymorphic"; import { types } from "./reserved"; import { BINARY_OPERATOR_GROUPS, + SUBDENOMINATION_MULTIPLIERS, castable, decimalToRational, enumToIntType, generalizeType, getABIEncoderVersion, - getFallbackRecvFuns, getFQDefName, + getFallbackRecvFuns, inferCommonVisiblity, isReferenceType, isVisiblityExternallyCallable, mergeFunTypes, smallestFittingType, specializeType, - stripSingletonParens, - SUBDENOMINATION_MULTIPLIERS + stripSingletonParens } from "./utils"; const unaryImpureOperators = ["++", "--"]; @@ -201,7 +201,7 @@ function isSupportedByEncoderV1(type: TypeNode): boolean { const [baseT] = generalizeType(type.elementT); return ( - isSupportedByEncoderV1(baseT) || + isSupportedByEncoderV1(baseT) && !(baseT instanceof ArrayType && baseT.size === undefined) ); } @@ -2341,6 +2341,54 @@ export class InferType { return false; } + isABIEncodable(type: TypeNode, encoderVersion: ABIEncoderVersion): boolean { + if ( + type instanceof AddressType || + type instanceof BoolType || + type instanceof BytesType || + type instanceof FixedBytesType || + (type instanceof FunctionType && + (type.visibility === FunctionVisibility.External || + type.visibility === FunctionVisibility.Public)) || + type instanceof IntType || + type instanceof IntLiteralType || + type instanceof StringLiteralType || + type instanceof StringType + ) { + return true; + } + + if (type instanceof PointerType) { + return this.isABIEncodable(type.to, encoderVersion); + } + + if (encoderVersion === ABIEncoderVersion.V1 && !isSupportedByEncoderV1(type)) { + return false; + } + + if (type instanceof ArrayType) { + return this.isABIEncodable(type.elementT, encoderVersion); + } + + if (type instanceof UserDefinedType) { + if ( + type.definition instanceof ContractDefinition || + type.definition instanceof EnumDefinition || + type.definition instanceof UserDefinedValueTypeDefinition + ) { + return true; + } + + if (type.definition instanceof StructDefinition) { + return type.definition.vMembers.every((field) => + this.isABIEncodable(this.variableDeclarationToTypeNode(field), encoderVersion) + ); + } + } + + return false; + } + /** * Convert an internal TypeNode to the external TypeNode that would correspond to it * after ABI-encoding with encoder version `encoderVersion`. Follows the following rules: @@ -2359,9 +2407,12 @@ export class InferType { encoderVersion: ABIEncoderVersion, normalizePointers = false ): TypeNode { - if (type instanceof MappingType) { - throw new Error("Cannot abi-encode mapping types"); - } + assert( + this.isABIEncodable(type, encoderVersion), + 'Can not ABI-encode type "{0}" with encoder "{1}"', + type, + encoderVersion + ); if (type instanceof ArrayType) { const elT = this.toABIEncodedType(type.elementT, encoderVersion); @@ -2401,13 +2452,6 @@ export class InferType { } if (type.definition instanceof StructDefinition) { - assert( - encoderVersion !== ABIEncoderVersion.V1 || isSupportedByEncoderV1(type), - "Type {0} is not supported by encoder {1}", - type, - encoderVersion - ); - const fieldTs = type.definition.vMembers.map((fieldT) => this.variableDeclarationToTypeNode(fieldT) ); diff --git a/test/unit/types/abi_encodable.spec.ts b/test/unit/types/abi_encodable.spec.ts new file mode 100644 index 00000000..c82248be --- /dev/null +++ b/test/unit/types/abi_encodable.spec.ts @@ -0,0 +1,343 @@ +import { expect } from "expect"; +import { + ABIEncoderVersion, + ASTNodeFactory, + ArrayType, + BuiltinFunctionType, + ContractKind, + DataLocation, + ErrorType, + EventType, + FunctionLikeSetType, + FunctionStateMutability, + FunctionType, + FunctionVisibility, + ImportRefType, + InferType, + IntLiteralType, + IntType, + LatestCompilerVersion, + MappingType, + Mutability, + PointerType, + RationalLiteralType, + StateVariableVisibility, + StringLiteralType, + SuperType, + TRest, + TVar, + TupleType, + TypeNameType, + TypeNode, + UserDefinedType, + types +} from "../../../src"; + +const cases: Array< + [ + TypeNode | ((factory: ASTNodeFactory, inference: InferType) => TypeNode), + string, + ABIEncoderVersion, + boolean + ] +> = [ + [new StringLiteralType("string"), LatestCompilerVersion, ABIEncoderVersion.V1, true], + [new IntLiteralType(1n), LatestCompilerVersion, ABIEncoderVersion.V1, true], + [ + new RationalLiteralType({ numerator: 1n, denominator: 2n }), + LatestCompilerVersion, + ABIEncoderVersion.V1, + false + ], + [types.stringMemory, LatestCompilerVersion, ABIEncoderVersion.V1, true], + [types.bytesMemory, LatestCompilerVersion, ABIEncoderVersion.V1, true], + [types.bytesCalldata, LatestCompilerVersion, ABIEncoderVersion.V1, true], + [types.bytes32, LatestCompilerVersion, ABIEncoderVersion.V1, true], + [types.address, LatestCompilerVersion, ABIEncoderVersion.V1, true], + [types.addressPayable, LatestCompilerVersion, ABIEncoderVersion.V1, true], + [types.uint256, LatestCompilerVersion, ABIEncoderVersion.V1, true], + [ + new MappingType(new IntType(8, false), new IntType(8, false)), + LatestCompilerVersion, + ABIEncoderVersion.V2, + false + ], + [new TupleType([]), LatestCompilerVersion, ABIEncoderVersion.V1, false], + [ + new FunctionType( + undefined, + [], + [], + FunctionVisibility.Private, + FunctionStateMutability.Pure + ), + LatestCompilerVersion, + ABIEncoderVersion.V1, + false + ], + [ + new FunctionType( + undefined, + [], + [], + FunctionVisibility.Internal, + FunctionStateMutability.Pure + ), + LatestCompilerVersion, + ABIEncoderVersion.V1, + false + ], + [ + new FunctionType( + undefined, + [], + [], + FunctionVisibility.Default, + FunctionStateMutability.Pure + ), + LatestCompilerVersion, + ABIEncoderVersion.V1, + false + ], + [ + new FunctionType( + undefined, + [], + [], + FunctionVisibility.External, + FunctionStateMutability.Pure + ), + LatestCompilerVersion, + ABIEncoderVersion.V1, + true + ], + [ + new FunctionType( + undefined, + [], + [], + FunctionVisibility.Public, + FunctionStateMutability.Pure + ), + LatestCompilerVersion, + ABIEncoderVersion.V1, + true + ], + [ + (factory) => { + const def = factory.makeContractDefinition( + "SomeContract", + 0, + ContractKind.Contract, + false, + true, + [], + [], + [] + ); + + def.linearizedBaseContracts.push(def.id); + + return new UserDefinedType(def.name, def); + }, + LatestCompilerVersion, + ABIEncoderVersion.V1, + true + ], + [ + (factory) => { + const def = factory.makeStructDefinition("SomeStruct", 0, "internal", []); + + return new UserDefinedType(def.name, def); + }, + LatestCompilerVersion, + ABIEncoderVersion.V1, + false + ], + [ + (factory) => { + const def = factory.makeStructDefinition("SomeStruct", 0, "internal", []); + + return new UserDefinedType(def.name, def); + }, + LatestCompilerVersion, + ABIEncoderVersion.V2, + true + ], + [ + (factory) => { + const def = factory.makeStructDefinition("SomeStruct", 0, "internal", [ + factory.makeVariableDeclaration( + false, + false, + "someVar", + 0, + false, + DataLocation.Storage, + StateVariableVisibility.Default, + Mutability.Mutable, + "", + "", + factory.makeMapping( + "", + factory.makeElementaryTypeName("uint8", "uint8"), + factory.makeElementaryTypeName("uint8", "uint8") + ) + ) + ]); + + return new UserDefinedType(def.name, def); + }, + LatestCompilerVersion, + ABIEncoderVersion.V2, + false + ], + [ + (factory) => { + const def = factory.makeStructDefinition("SomeStruct", 0, "internal", [ + factory.makeVariableDeclaration( + false, + false, + "someVar", + 0, + false, + DataLocation.Storage, + StateVariableVisibility.Default, + Mutability.Mutable, + "", + "", + factory.makeFunctionTypeName( + "", + FunctionVisibility.Internal, + FunctionStateMutability.View, + factory.makeParameterList([]), + factory.makeParameterList([]) + ) + ) + ]); + + return new UserDefinedType(def.name, def); + }, + LatestCompilerVersion, + ABIEncoderVersion.V2, + false + ], + [ + (factory) => { + const def = factory.makeEnumDefinition("SomeEnum", []); + + return new UserDefinedType(def.name, def); + }, + LatestCompilerVersion, + ABIEncoderVersion.V1, + true + ], + [ + new PointerType( + new ArrayType( + new PointerType(new ArrayType(new IntType(8, false)), DataLocation.Memory) + ), + DataLocation.Memory + ), + LatestCompilerVersion, + ABIEncoderVersion.V1, + false + ], + [ + new PointerType( + new ArrayType( + new PointerType(new ArrayType(new IntType(8, false)), DataLocation.Memory) + ), + DataLocation.Memory + ), + LatestCompilerVersion, + ABIEncoderVersion.V2, + true + ], + [ + new PointerType( + new ArrayType( + new PointerType(new ArrayType(new IntType(8, false), 3n), DataLocation.Memory), + 1n + ), + DataLocation.Memory + ), + LatestCompilerVersion, + ABIEncoderVersion.V1, + true + ], + [ + new EventType("SomeEvent", [new IntType(8, false)]), + LatestCompilerVersion, + ABIEncoderVersion.V2, + false + ], + [ + new ErrorType("SomeError", [new IntType(8, false)]), + LatestCompilerVersion, + ABIEncoderVersion.V2, + false + ], + [new TypeNameType(new IntType(8, false)), LatestCompilerVersion, ABIEncoderVersion.V2, false], + [ + new BuiltinFunctionType("keccak256", [types.bytesMemory], [types.bytes32]), + LatestCompilerVersion, + ABIEncoderVersion.V2, + false + ], + [ + (factory) => { + const def = factory.makeContractDefinition( + "SomeContract", + 0, + ContractKind.Contract, + false, + true, + [], + [], + [] + ); + + def.linearizedBaseContracts.push(def.id); + + return new SuperType(def); + }, + LatestCompilerVersion, + ABIEncoderVersion.V2, + false + ], + [new TVar("T"), LatestCompilerVersion, ABIEncoderVersion.V2, false], + [new TRest("..."), LatestCompilerVersion, ABIEncoderVersion.V2, false], + [new FunctionLikeSetType([]), LatestCompilerVersion, ABIEncoderVersion.V2, false], + [ + (factory) => { + const unit = factory.makeSourceUnit("some.sol", 0, "path/to/some.sol", new Map()); + const imp = factory.makeImportDirective( + "some.sol", + "path/to/some.sol", + "some", + [], + 0, + unit.id + ); + + return new ImportRefType(imp); + }, + LatestCompilerVersion, + ABIEncoderVersion.V2, + false + ] +]; + +describe("ABI encodability detection unit test (isABIEncodable())", () => { + const factory = new ASTNodeFactory(); + + for (const [original, compilerVersion, encoderVersion, expectation] of cases) { + const inference = new InferType(compilerVersion); + const originalT = original instanceof TypeNode ? original : original(factory, inference); + + it(`${originalT.pp()} -> ${expectation} (compiler ${compilerVersion}, encoder ${encoderVersion})`, () => { + expect(inference.isABIEncodable(originalT, encoderVersion)).toEqual(expectation); + }); + } +}); diff --git a/test/unit/types/castable.spec.ts b/test/unit/types/castable.spec.ts index 01e3c206..5437e583 100644 --- a/test/unit/types/castable.spec.ts +++ b/test/unit/types/castable.spec.ts @@ -265,7 +265,7 @@ describe("Type casting unit test (castable())", () => { const fromT = from instanceof TypeNode ? from : from(factory); const toT = to instanceof TypeNode ? to : to(factory); - it(`${fromT.pp()} -> ${toT.pp()}" expected to be ${expectation} (in ${compilerVersion})`, () => { + it(`"${fromT.pp()} -> ${toT.pp()}" expected to be ${expectation} (in ${compilerVersion})`, () => { expect(castable(fromT, toT, compilerVersion)).toEqual(expectation); }); }