diff --git a/src/types/infer.ts b/src/types/infer.ts index af18130a..3a35c6ae 100644 --- a/src/types/infer.ts +++ b/src/types/infer.ts @@ -803,10 +803,18 @@ export class InferType { let rets: TypeNode[]; - if ( - resolvedCalleeT instanceof FunctionType || - resolvedCalleeT instanceof BuiltinFunctionType - ) { + if (resolvedCalleeT instanceof FunctionType) { + rets = resolvedCalleeT.returns; + + // Convert any calldata pointers back to memory for external calls + if (node.vExpression instanceof MemberAccess) { + rets = rets.map((retT) => + retT instanceof PointerType && retT.location === DataLocation.CallData + ? specializeType(generalizeType(retT)[0], DataLocation.Memory) + : retT + ); + } + } else if (resolvedCalleeT instanceof BuiltinFunctionType) { rets = resolvedCalleeT.returns; } else if (resolvedCalleeT instanceof EventType || resolvedCalleeT instanceof ErrorType) { rets = []; @@ -1298,43 +1306,11 @@ export class InferType { return normalT; } - private changeLocToMemory( - typ: FunctionType | FunctionLikeSetType - ): FunctionType | FunctionLikeSetType { - if (typ instanceof FunctionLikeSetType) { - const funTs = typ.defs.map((funT) => this.changeLocToMemory(funT) as FunctionType); - - return new FunctionLikeSetType(funTs); - } - - const params = typ.parameters.map((paramT) => - paramT instanceof PointerType && paramT.location === DataLocation.CallData - ? specializeType(generalizeType(paramT)[0], DataLocation.Memory) - : paramT - ); - - const rets = typ.returns.map((retT) => - retT instanceof PointerType && retT.location === DataLocation.CallData - ? specializeType(generalizeType(retT)[0], DataLocation.Memory) - : retT - ); - - return new FunctionType( - typ.name, - params, - rets, - typ.visibility, - typ.mutability, - typ.implicitFirstArg, - typ.src - ); - } - private typeOfMemberAccessImpl(node: MemberAccess, baseT: TypeNode): TypeNode | undefined { if (baseT instanceof UserDefinedType && baseT.definition instanceof ContractDefinition) { const contract = baseT.definition; - let fieldT = this.typeOfResolved(node.memberName, contract, true); + const fieldT = this.typeOfResolved(node.memberName, contract, true); assert( fieldT === undefined || @@ -1353,8 +1329,6 @@ export class InferType { } if (fieldT) { - fieldT = this.changeLocToMemory(fieldT); - if (builtinT instanceof BuiltinFunctionType) { return mergeFunTypes( fieldT as FunctionType | FunctionLikeSetType, diff --git a/test/integration/types/infer.spec.ts b/test/integration/types/infer.spec.ts index 90a0fa3c..119e6fd9 100644 --- a/test/integration/types/infer.spec.ts +++ b/test/integration/types/infer.spec.ts @@ -542,7 +542,8 @@ function compareTypeNodes( if ( inferredT instanceof FunctionType && parsedT instanceof FunctionType && - inferredT.visibility === FunctionVisibility.External && + (inferredT.visibility === FunctionVisibility.External || + parsedT.visibility === FunctionVisibility.External) && inferredT.parameters.length === parsedT.parameters.length && inferredT.parameters.length === parsedT.parameters.length ) {