Skip to content

Commit

Permalink
Fix signature computation
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed May 16, 2024
1 parent eaa8f65 commit 6c7b345
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
_type(x) = x
_type(x::CC.Const) = _typeof(x.val)
_type(x::CC.PartialStruct) = x.typ
_type(x::CC.Conditional) = Union{x.thentype, x.elsetype}
_type(x::CC.Conditional) = Union{_type(x.thentype), _type(x.elsetype)}

function CC.inlining_policy(
interp::TapirInterpreter{C},
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ function DynamicDerivedRule(interp::TapirInterpreter, safety_on::Bool)
end

function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N}
sig = Tuple{tuple_map(typeof, tuple_map(primal, args))...}
sig = signature_from_values(tuple_map(primal, args))
is_primitive(context_type(dynamic_rule.interp), sig) && return rrule!!(args...)
rule = get(dynamic_rule.cache, sig, nothing)
if rule === nothing
Expand Down
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ _typeof(x) = Base._stable_typeof(x)
_typeof(x::Tuple) = Tuple{map(_typeof, x)...}
_typeof(x::NamedTuple{names}) where {names} = NamedTuple{names, _typeof(Tuple(x))}

"""
signature_from_values(x::Tuple)
"""
signature_from_values(x::Tuple) = Tuple{map(Base._stable_typeof, x)...}

"""
tuple_map(f::F, x::Tuple) where {F}
Expand Down

0 comments on commit 6c7b345

Please sign in to comment.