From cdc9d1d73e723ba302601fbc17fe71ecdc9410fb Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 14 May 2024 12:09:09 +0100 Subject: [PATCH] More safe mode options --- src/interpreter/s2s_reverse_mode_ad.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 560d4365..983eff5f 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -722,11 +722,13 @@ function rule_type(interp::TapirInterpreter{C}, ::Type{sig}) where {C, sig} end """ - build_rrule(args...) + build_rrule(args...; safety_on=false) Helper method. Only uses static information from `args`. """ -build_rrule(args...) = build_rrule(PInterp(), _typeof(TestUtils.__get_primals(args))) +function build_rrule(args...; safety_on=false) + return build_rrule(PInterp(), _typeof(TestUtils.__get_primals(args)); safety_on) +end """ build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C} @@ -736,7 +738,14 @@ for `rrule!!` for more info. If `safety_on` is `true`, then all calls to rules are replaced with calls to `SafeRRule`s. """ -function build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C} +function build_rrule( + interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false, silence_safety_messages=true +) where {C} + + # If we're compiling in safe mode, let the user know by default. + if !silence_safety_messages + @info "Compiling rule for $sig in safe mode. Disable for best performance." + end # Reset id count. This ensures that the IDs generated are the same each time this # function runs.