Skip to content

Commit

Permalink
More safe mode options
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed May 14, 2024
1 parent 4c4be36 commit cdc9d1d
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -722,11 +722,13 @@ function rule_type(interp::TapirInterpreter{C}, ::Type{sig}) where {C, sig}
end

"""
build_rrule(args...)
build_rrule(args...; safety_on=false)
Helper method. Only uses static information from `args`.
"""
build_rrule(args...) = build_rrule(PInterp(), _typeof(TestUtils.__get_primals(args)))
function build_rrule(args...; safety_on=false)
return build_rrule(PInterp(), _typeof(TestUtils.__get_primals(args)); safety_on)
end

"""
build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C}
Expand All @@ -736,7 +738,14 @@ for `rrule!!` for more info.
If `safety_on` is `true`, then all calls to rules are replaced with calls to `SafeRRule`s.
"""
function build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C}
function build_rrule(
interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false, silence_safety_messages=true
) where {C}

# If we're compiling in safe mode, let the user know by default.
if !silence_safety_messages
@info "Compiling rule for $sig in safe mode. Disable for best performance."
end

# Reset id count. This ensures that the IDs generated are the same each time this
# function runs.
Expand Down

0 comments on commit cdc9d1d

Please sign in to comment.