From a189437df37b697521029e07fb310a85c7ebf283 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 5 Aug 2024 17:37:03 +0100 Subject: [PATCH] Attempt to get better errors (#214) * Change how rules work * Fix typo * Attempt to get better errors * Print the right thing * Improve printing further * Simplify rule_type * Tidy up * Fix type stability issue * Bump patch --- Project.toml | 2 +- src/interpreter/s2s_reverse_mode_ad.jl | 41 ++++++++++++++------------ src/rrules/blas.jl | 2 +- src/utils.jl | 2 +- 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index bce33e2b..587621e3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.29" +version = "0.2.30" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 50f05358..0f615974 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -718,23 +718,13 @@ function rule_type(interp::TapirInterpreter{C}, sig_or_mi) where {C} arg_types = map(_type, ir.argtypes) arg_fwds_types = Tuple{map(fcodual_type, arg_types)...} arg_rvs_types = Tuple{map(rdata_type ∘ tangent_type, arg_types)...} - fwds_return_codual = fcodual_type(Treturn) rvs_return_type = rdata_type(tangent_type(Treturn)) - if isconcretetype(fwds_return_codual) - return DerivedRule{ - MistyClosure{OpaqueClosure{arg_fwds_types, fwds_return_codual}}, - MistyClosure{OpaqueClosure{Tuple{rvs_return_type}, arg_rvs_types}}, - Val{isva}, - Val{length(ir.argtypes)}, - } - else - return DerivedRule{ - MistyClosure{OpaqueClosure{arg_fwds_types, P}} where {P<:fwds_return_codual}, - MistyClosure{OpaqueClosure{Tuple{rvs_return_type}, arg_rvs_types}}, - Val{isva}, - Val{length(ir.argtypes)}, - } - end + return DerivedRule{ + MistyClosure{OpaqueClosure{arg_fwds_types, fcodual_type(Treturn)}}, + MistyClosure{OpaqueClosure{Tuple{rvs_return_type}, arg_rvs_types}}, + Val{isva}, + Val{length(ir.argtypes)}, + } end """ @@ -827,7 +817,7 @@ function build_rrule( interp.oc_cache[(sig_or_mi, safety_on)] = (fwds_oc, pb_oc) end - raw_rule = rule_type(interp, sig_or_mi)(fwds_oc, pb_oc, Val(isva), Val(num_args(info))) + raw_rule = DerivedRule(fwds_oc, pb_oc, Val(isva), Val(num_args(info))) return safety_on ? SafeRRule(raw_rule) : raw_rule end @@ -1257,9 +1247,22 @@ mutable struct LazyDerivedRule{Tinterp<:TapirInterpreter, Trule} end end -function (rule::LazyDerivedRule)(args::Vararg{Any, N}) where {N} +function (rule::LazyDerivedRule{T, Trule})(args::Vararg{Any, N}) where {N, T, Trule} if !isdefined(rule, :rule) - rule.rule = build_rrule(rule.interp, rule.mi; safety_on=rule.safety_on) + derived_rule = build_rrule(rule.interp, rule.mi; safety_on=rule.safety_on) + if derived_rule isa Trule + rule.rule = derived_rule + else + @warn "Unable to put rule in rule field. Rule should error." + println("derived_rule is of type") + display(typeof(derived_rule)) + println() + println("Expected type is") + display(Trule) + println() + derived_rule(args...) + error("Rule with bad type ran without error.") + end end return rule.rule(args...) end diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 4014b5d0..066db692 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -111,7 +111,7 @@ end # LEVEL 2 # -for (gemv, elty) in ((:dgemv_, :Float64), (:sgemm_, :Float32)) +for (gemv, elty) in ((:dgemv_, :Float64), (:sgemv_, :Float32)) @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(gemv))}}, diff --git a/src/utils.jl b/src/utils.jl index 8c2f722f..f928d969 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,7 +4,7 @@ Central definition of typeof, which is specific to the use-required in this package. """ _typeof(x) = Base._stable_typeof(x) -_typeof(x::Tuple) = Tuple{map(_typeof, x)...} +_typeof(x::Tuple) = Tuple{tuple_map(_typeof, x)...} _typeof(x::NamedTuple{names}) where {names} = NamedTuple{names, _typeof(Tuple(x))} """