-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Core.Box & self-reference (fastmath and silent error fixed) #114
Comments
This is interesting. Will dig into it. |
Lol, AD is descending into the implementation of fast_exp (I think this is the right method, but it's definitely one of the methods in this file). This function contains a lot of operations with zero derivative (like rounding), and then reinterprets the output as a float. The fix is for me to add rules for fast math, but this code shouldn't really have run in the first place. I think whichever intrinsic Thanks for opening this! To reproduce: julia> foo(x) = @fastmath exp(x)
foo (generic function with 1 method)
julia> Tapir_grad(foo, 1.0)
(2.718281828459045, (NoTangent(), 0.0)) I verified locally that if you turn off fast math, differentiating the softmax gives the same numbers as Zygote and ForwardDiff. |
Sounds good. I have another error when I try to simplify it: julia> sm = let # simplify NNlib's a little
softmax(x::AbstractArray{T}; dims::Int = 1) where {T} = softmax!(similar(x, float(T)), x; dims)
softmax!(x::AbstractArray; dims = 1) = softmax!(x, x; dims)
function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
max_ = maximum(x; dims) # removed fastmath version
if true # all(isfinite, max_)
out .= exp.(x .- max_) # removed @fastmath
else
@fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
end
tmp = sum!(max_, out) # removed branching on dims isa Colon
out ./= tmp
end
softmax
end
(::var"#softmax#94"{var"#softmax#91#95"}) (generic function with 1 method)
julia> sm([1,2,3.]) ≈ softmax([1,2,3.])
true
julia> Tapir_grad(first∘sm, [1.0, 2.0, 0.0])
ERROR: StackOverflowError:
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:0 [inlined]
[2] zero_tangent(x::Core.Box)
@ Tapir ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:342
[3] macro expansion
@ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:0 [inlined]
[4] zero_tangent
@ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:342 [inlined]
[5] macro expansion
@ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:0 [inlined]
[6] zero_tangent(x::var"#softmax!#96"{var"#softmax!#92#97"})
@ Tapir ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:342
--- the last 2 lines are repeated 1 more time ---
--- the last 6 lines are repeated 19337 more times ---
[116031] macro expansion
@ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:0 [inlined]
[116032] zero_tangent
@ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:342 [inlined]
--- the last 2 lines are repeated 1 more time ---
[116035] macro expansion
@ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:0 [inlined]
[116036] zero_tangent(x::ComposedFunction{typeof(first), var"#softmax#94"{var"#softmax#91#95"}})
@ Tapir ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:342
[116037] zero_codual(x::Function)
@ Tapir ~/.julia/packages/Tapir/BqxEi/src/codual.jl:20
[116038] map
@ ./tuple.jl:292 [inlined]
[116039] value_and_pullback!!(::Tapir.DerivedRule{…}, ::Float64, ::Function, ::Vector{…})
@ Tapir ~/.julia/packages/Tapir/BqxEi/src/interface.jl:51
Some type information was truncated. Use `show(err)` to see complete types. |
Another fun one! I think there is some kind of self-referential stuff going on with __softmax(x::AbstractArray{T}; dims::Int = 1) where {T} = __softmax!(similar(x, float(T)), x; dims)
__softmax!(x::AbstractArray; dims = 1) = __softmax!(x, x; dims)
function __softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
max_ = maximum(x; dims) # removed fastmath version
if true # all(isfinite, max_)
out .= exp.(x .- max_) # removed @fastmath
else
@fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
end
tmp = sum!(max_, out) # removed branching on dims isa Colon
out ./= tmp
end
Tapir_grad(first∘__softmax, [1.0, 2.0, 0.0]) it'll work. I don't fully understand how the boxing mechanism works, so I'll have to look at that. If it is indeed a self-referential type issue, then this is a known limitation of the current |
Indeed, I don't know why Box is circular. ( julia> dump(sm)
softmax (function of type var"#softmax#94"{var"#softmax#91#95"})
#softmax#91: softmax#91 (function of type var"#softmax#91#95")
softmax!: Core.Box
contents: softmax! (function of type var"#softmax!#96"{var"#softmax!#92#97"})
#softmax!#92: softmax!#92 (function of type var"#softmax!#92#97")
softmax!: Core.Box#= circular reference @-3 =#
#softmax!#93: Core.Box
contents: softmax!#93 (function of type var"#softmax!#93#98") |
Hmm yes, that would be ideal. I'm can't say that I'm entirely sure how to detect whether circular referencing is going on though. If you have thoughts, I'd love to hear them. |
It seems such self-references will always cause StackOverflow errors. If true, can we catch StackOverflow errors and print a more informative message, e.g. suggest users check for self-references? |
You could just do this: julia> Tapir.zero_tangent(x::Core.Box) = error("nope")
julia> Tapir.zero_tangent(sm)
ERROR: nope I don't know how many other self-referential types show up in the wild. I believe Enzyme's make_zero handles some of them, and uses an IdDict cache to do so, but in fact it fails on this one. What the cache also lets it do is preserve julia> sh = [1.0, 2.0];
julia> nt = (a=sh, b=sh, c=copy(sh));
julia> zed = Tapir.zero_tangent((a=sh, b=sh, c=copy(sh)))
(a = [0.0, 0.0], b = [0.0, 0.0], c = [0.0, 0.0])
julia> zed.a === zed.b
false
julia> Tapir_grad(x -> sum(map(sum, x)), nt)
(9.0, (NoTangent(), (a = [1.0, 1.0], b = [1.0, 1.0], c = [1.0, 1.0])))
julia> Base.zero(nt::NamedTuple) = Enzyme.make_zero(nt);
julia> z2 = zero(nt);
julia> z2.a === z2.b
true
julia> Enzyme.gradient(Reverse, x -> sum(map(sum, x)), nt)
(a = [2.0, 2.0], b = [2.0, 2.0], c = [1.0, 1.0]) |
Please do continue to open issues as you find problems though @mcabbott -- this has already been a very productive issue from my perspective. |
reinterpret
& @fastmath
, plus Core.Box & self-reference
reinterpret
& @fastmath
, plus Core.Box & self-reference@fastmath
& Core.Box & self-reference
Okay. The v0.2.6 of Tapir will throw an error if you try to bitcast anything into an I've not yet added support for stuff, but we at least won't get a silent failure anymore. I'll continue to work on fixing up the fastmath and Box stuff edit: although my edit2: but if you run something like foo(x) = first(softmax(x))
rule = Tapir.build_rrule(foo, randn(3))
Tapir.value_and_gradient!!(rule, foo, randn(3)) you should see a stack trace with this at the top: ERROR: ArgumentError: It is not permissible to bitcast to a differentiable type during AD, as this risks dropping tangents, and therefore risks silently giving the wrong answer. If this call to bitcast appears as part of the implementation of a differentiable function, you should write a rule for this function, or modify its implementation to avoid the bitcast.
Stacktrace:
[1] rrule!!(f::CoDual{typeof(Tapir.IntrinsicsWrappers.bitcast), NoFData}, t::CoDual{Type{Float64}, NoFData}, x::CoDual{UInt64, NoFData})
@ Tapir.IntrinsicsWrappers ~/ml/ad_playground/Taped.jl/src/rrules/builtins.jl:98
[2] RRuleZeroWrapper
@ ~/ml/ad_playground/Taped.jl/src/interpreter/s2s_reverse_mode_ad.jl:232 [inlined]
[3] exp_fast
@ ./special/exp.jl:328 [inlined] |
v0.2.7 of Tapir now has support for the various variants of |
@fastmath
& Core.Box & self-reference@fastmath
&~~ Core.Box & self-reference
@fastmath
&~~ Core.Box & self-reference
I have now gotten around to looking at this. I quite like @yebai 's suggestion of providing an informative error message when we encounter a stack overflow. With #144 , the error that you saw @mcabbott becomes julia> Tapir_grad(first∘sm, [1.0, 2.0, 0.0])
ERROR: Found a StackOverFlow error when trying to wrap inputs. This often means that Tapir.jl has encountered a self-referential type. Tapir.jl is not presently able to handle self-referential types, so if you are indeed using a self-referential type somewhere, you will need to refactor to avoid it if you wish to use Tapir.jl.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] __create_coduals(args::Tuple{ComposedFunction{typeof(first), var"#softmax#4"{var"#softmax#1#5"}}, Vector{Float64}})
@ Tapir ~/ml/ad_playground/Taped.jl/src/interface.jl:82
[3] value_and_pullback!!(::Tapir.DerivedRule{…}, ::Float64, ::Function, ::Vector{…})
@ Tapir ~/ml/ad_playground/Taped.jl/src/interface.jl:65
[4] Tapir_grad(f::Function, xs::Vector{Float64})
@ Main ./REPL[35]:1
[5] top-level scope
@ REPL[36]:1
caused by: StackOverflowError:
Stacktrace:
[1] zero_tangent(x::Core.Box)
@ Tapir ~/ml/ad_playground/Taped.jl/src/tangents.jl:293
[2] macro expansion
@ ~/ml/ad_playground/Taped.jl/src/tangents.jl:0 [inlined]
[3] zero_tangent
@ ~/ml/ad_playground/Taped.jl/src/tangents.jl:293 [inlined]
[4] macro expansion
@ ~/ml/ad_playground/Taped.jl/src/tangents.jl:0 [inlined]
[5] zero_tangent(x::var"#softmax!#6"{var"#softmax!#2#7"})
@ Tapir ~/ml/ad_playground/Taped.jl/src/tangents.jl:293
--- the last 2 lines are repeated 1 more time ---
--- the last 6 lines are repeated 18684 more times ---
[112112] macro expansion
@ ~/ml/ad_playground/Taped.jl/src/tangents.jl:0 [inlined]
[112113] zero_tangent
@ ~/ml/ad_playground/Taped.jl/src/tangents.jl:293 [inlined]
--- the last 2 lines are repeated 2 more times ---
[112118] zero_codual
@ ~/ml/ad_playground/Taped.jl/src/codual.jl:24 [inlined]
[112119] macro expansion
@ ./none:0 [inlined]
[112120] tuple_map
@ ./none:0 [inlined]
[112121] __create_coduals(args::Tuple{ComposedFunction{typeof(first), var"#softmax#4"{var"#softmax#1#5"}}, Vector{Float64}})
@ Tapir ~/ml/ad_playground/Taped.jl/src/interface.jl:79
[112122] value_and_pullback!!(::Tapir.DerivedRule{…}, ::Float64, ::Function, ::Vector{…})
@ Tapir ~/ml/ad_playground/Taped.jl/src/interface.jl:65
Some type information was truncated. Use `show(err)` to see complete types. I've added code to catch this at the interface layer (we can't do this the whole way down without killing performance), so this will catch things when arguments are self-referential. When someone finds an example where this is a problem elsewhere in the pipeline, I'll figure out how to cover it. I'll close this when #144 is merged. |
Closing via #144 |
Not sure I'm using this right, but I get a zero gradient for
softmax
no matter what I do... is this a bug?Compare:
The text was updated successfully, but these errors were encountered: