From e1f1383788153826d766e8dae9eb8022c8eb181a Mon Sep 17 00:00:00 2001 From: ncave <777696+ncave@users.noreply.github.com> Date: Wed, 12 Jun 2024 09:58:44 -0700 Subject: [PATCH] Inline union case testers --- src/Fable.Transforms/FSharp2Fable.Util.fs | 11 +++++ src/Fable.Transforms/FSharp2Fable.fs | 34 ++++++-------- src/Fable.Transforms/Rust/Fable2Rust.fs | 54 +++++++++++++++++------ 3 files changed, 65 insertions(+), 34 deletions(-) diff --git a/src/Fable.Transforms/FSharp2Fable.Util.fs b/src/Fable.Transforms/FSharp2Fable.Util.fs index e4cb8e414e..1ec73744d8 100644 --- a/src/Fable.Transforms/FSharp2Fable.Util.fs +++ b/src/Fable.Transforms/FSharp2Fable.Util.fs @@ -1064,6 +1064,17 @@ module Patterns = let (|MemberFullName|) (memb: FSharpMemberOrFunctionOrValue) = memb.FullName + let (|UnionCaseTesterFor|_|) (memb: FSharpMemberOrFunctionOrValue) = + match memb.DeclaringEntity with + | Some ent when ent.IsFSharpUnion -> + // if memb.IsUnionCaseTester then // TODO: this currently fails, use when fixed + if memb.IsPropertyGetterMethod && memb.LogicalName.StartsWith("get_Is") then + let unionCaseName = memb.LogicalName |> Naming.replacePrefix "get_Is" "" + ent.UnionCases |> Seq.tryFind (fun uc -> uc.Name = unionCaseName) + else + None + | _ -> None + let (|RefType|_|) = function | TypeDefinition tdef as t when tdef.TryFullName = Some Types.refCell -> Some t diff --git a/src/Fable.Transforms/FSharp2Fable.fs b/src/Fable.Transforms/FSharp2Fable.fs index 8f8043717e..d8c0e6aedf 100644 --- a/src/Fable.Transforms/FSharp2Fable.fs +++ b/src/Fable.Transforms/FSharp2Fable.fs @@ -955,8 +955,8 @@ let private transformExpr (com: IFableCompiler) (ctx: Context) appliedGenArgs fs else args - match callee with - | Some(CreateEvent(callee, event) as createEvent) -> + match callee, memb with + | Some(CreateEvent(callee, event) as createEvent), _ -> let! callee = transformExpr com ctx [] callee let eventType = makeType ctx.GenericArgs createEvent.Type @@ -965,7 +965,10 @@ let private transformExpr (com: IFableCompiler) (ctx: Context) appliedGenArgs fs return makeCallFrom com ctx (makeRangeFrom fsExpr) typ callGenArgs (Some callee) args memb - | callee -> + | Some unionExpr, UnionCaseTesterFor unionCase -> + return! transformUnionCaseTest com ctx (makeRangeFrom fsExpr) unionExpr unionExpr.Type unionCase + + | callee, _ -> let r = makeRangeFrom fsExpr let! callee = transformExprOpt com ctx callee @@ -1488,19 +1491,11 @@ let private isIgnoredNonAttachedMember (memb: FSharpMemberOrFunctionOrValue) = | None -> false ) -let private isErasedUnionCaseTester (memb: FSharpMemberOrFunctionOrValue) = - // if memb.IsUnionCaseTester then // TODO: this currently fails, use when fixed - if memb.IsPropertyGetterMethod && memb.LogicalName.StartsWith("get_Is") then - match memb.DeclaringEntity with - | Some ent when ent.IsFSharpUnion -> - // return true only when the tester's own union case is erased - let unionCaseName = memb.LogicalName |> Naming.replacePrefix "get_Is" "" - - ent.UnionCases - |> Seq.exists (fun unionCase -> unionCase.Name = unionCaseName && hasAttrib Atts.erase unionCase.Attributes) - | _ -> false - else - false +let private isUnionCaseTester (memb: FSharpMemberOrFunctionOrValue) = + // memb.IsUnionCaseTester // TODO: this currently fails, use when fixed + match memb with + | UnionCaseTesterFor _ -> true + | _ -> false let private isCompilerGenerated (memb: FSharpMemberOrFunctionOrValue) (args: FSharpMemberOrFunctionOrValue list list) = memb.IsCompilerGenerated @@ -1919,11 +1914,8 @@ let private transformMemberDecl [] elif memb.IsImplicitConstructor then transformPrimaryConstructor com ctx memb args body - // Ignore union case testers for erased union cases - elif isErasedUnionCaseTester memb then - $"Erased union case tester will be ignored: {memb.LogicalName}" - |> addWarning com [] None - + // ignore union case testers as they will be inlined + elif isUnionCaseTester memb then [] // Ignore members generated by the F# compiler (for comparison and equality) elif isCompilerGenerated memb args then diff --git a/src/Fable.Transforms/Rust/Fable2Rust.fs b/src/Fable.Transforms/Rust/Fable2Rust.fs index 814e52c632..f4166aa7d7 100644 --- a/src/Fable.Transforms/Rust/Fable2Rust.fs +++ b/src/Fable.Transforms/Rust/Fable2Rust.fs @@ -2732,6 +2732,7 @@ module Util = let guardExpr = match guard with | Fable.Test(expr, Fable.TypeTest typ, r) -> transformTypeTest com ctx r true typ expr + | Fable.Test(expr, Fable.UnionCaseTest tag, r) -> transformUnionCaseTest com ctx r tag expr | _ -> transformExpr com ctx guard let thenExpr = transformLeaveContext com ctx None thenBody @@ -2846,6 +2847,38 @@ module Util = mkLetExpr pat downcastExpr | _ -> makeLibCall com ctx genArgsOpt "Native" "type_test" [ expr ] + let transformUnionCaseTest (com: IRustCompiler) ctx range tag (fableExpr: Fable.Expr) : Rust.Expr = + match fableExpr.Type with + | Fable.DeclaredType(entRef, genArgs) -> + let ent = com.GetEntity(entRef) + assert (ent.IsFSharpUnion) + // let genArgsOpt = transformGenArgs com ctx genArgs // TODO: + let unionCase = ent.UnionCases |> List.item tag + + let fields = + match fableExpr with + | Fable.IdentExpr ident -> + unionCase.UnionCaseFields + |> List.mapi (fun i _field -> + let fieldName = $"{ident.Name}_{tag}_{i}" + makeFullNameIdentPat fieldName + ) + | _ -> + if List.isEmpty unionCase.UnionCaseFields then + [] + else + [ WILD_PAT ] + + let unionCaseName = getUnionCaseName com ctx entRef unionCase + let pat = makeUnionCasePat unionCaseName fields + + let expr = + fableExpr + |> prepareRefForPatternMatch com ctx fableExpr.Type (tryGetIdentName fableExpr) + + mkLetExpr pat expr + | _ -> failwith "Should not happen" + let transformTest (com: IRustCompiler) ctx range kind (fableExpr: Fable.Expr) : Rust.Expr = match kind with | Fable.TypeTest typ -> transformTypeTest com ctx range false typ fableExpr @@ -2874,18 +2907,10 @@ module Util = let unionCase = ent.UnionCases |> List.item tag let fields = - match fableExpr with - | Fable.IdentExpr ident -> - unionCase.UnionCaseFields - |> List.mapi (fun i _field -> - let fieldName = $"{ident.Name}_{tag}_{i}" - makeFullNameIdentPat fieldName - ) - | _ -> - if List.isEmpty unionCase.UnionCaseFields then - [] - else - [ WILD_PAT ] + if List.isEmpty unionCase.UnionCaseFields then + [] + else + [ WILD_PAT ] let unionCaseName = getUnionCaseName com ctx entRef unionCase let pat = makeUnionCasePat unionCaseName fields @@ -2894,7 +2919,10 @@ module Util = fableExpr |> prepareRefForPatternMatch com ctx fableExpr.Type (tryGetIdentName fableExpr) - mkLetExpr pat expr + let guardExpr = mkLetExpr pat expr + let thenExpr = mkBoolLitExpr true + let elseExpr = mkBoolLitExpr false + mkIfThenElseExpr guardExpr thenExpr elseExpr | _ -> failwith "Should not happen" let transformSwitch (com: IRustCompiler) ctx (evalExpr: Fable.Expr) cases defaultCase targets : Rust.Expr =