diff --git a/Project.toml b/Project.toml index 9eedc951..104f6397 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.9" +version = "0.2.10" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 560d4365..983eff5f 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -722,11 +722,13 @@ function rule_type(interp::TapirInterpreter{C}, ::Type{sig}) where {C, sig} end """ - build_rrule(args...) + build_rrule(args...; safety_on=false) Helper method. Only uses static information from `args`. """ -build_rrule(args...) = build_rrule(PInterp(), _typeof(TestUtils.__get_primals(args))) +function build_rrule(args...; safety_on=false) + return build_rrule(PInterp(), _typeof(TestUtils.__get_primals(args)); safety_on) +end """ build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C} @@ -736,7 +738,14 @@ for `rrule!!` for more info. If `safety_on` is `true`, then all calls to rules are replaced with calls to `SafeRRule`s. """ -function build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C} +function build_rrule( + interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false, silence_safety_messages=true +) where {C} + + # If we're compiling in safe mode, let the user know by default. + if !silence_safety_messages + @info "Compiling rule for $sig in safe mode. Disable for best performance." + end # Reset id count. This ensures that the IDs generated are the same each time this # function runs.