Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core.Box & self-reference (fastmath and silent error fixed) #114

Closed
mcabbott opened this issue Apr 3, 2024 · 14 comments
Closed

Core.Box & self-reference (fastmath and silent error fixed) #114

mcabbott opened this issue Apr 3, 2024 · 14 comments
Labels
enhancement New feature or request

Comments

@mcabbott
Copy link

mcabbott commented Apr 3, 2024

Not sure I'm using this right, but I get a zero gradient for softmax no matter what I do... is this a bug?

julia> using Tapir

julia> Tapir_grad(f, xs...) = Tapir.value_and_pullback!!(Tapir.build_rrule(f, xs...), 1.0, f, xs...);

julia> Tapir_grad(prod, [2.3, 4.5])
(10.35, (NoTangent(), [4.5, 2.3]))

julia> Tapir_grad(x -> prod(map(sum, x)), (x=[2.0], y=[3.0, 4.0], z=5))
(70.0, (NoTangent(), (x = [35.0], y = [10.0, 10.0], z = NoTangent())))

julia> using NNlib

julia> Tapir_grad(firstsoftmax, [1.0, 2.0, 0.0])  # this gives wrong answer
(0.24472847105479764, (NoTangent(), [0.0, 0.0, 0.0]))

julia> using DifferentiationInterface  # in case this understands how to call it

julia> DifferentiationInterface.gradient(firstsoftmax, DifferentiationInterface.AutoTapir(), [1.0, 2.0, 0.0])
3-element Vector{Float64}:
 0.0
 0.0
 0.0

Compare:

julia> using Zygote

julia> Zygote.withgradient(firstsoftmax, [1.0, 2.0, 0.0])
(val = 0.24472847105479764, grad = ([0.1848364465099787, -0.16280340198980442, -0.022033044520174294],))

julia> using ForwardDiff

julia> ForwardDiff.gradient(firstsoftmax, [1.0, 2.0, 0.0])
3-element Vector{Float64}:
  0.18483644650997869
 -0.16280340198980436
 -0.022033044520174298
@willtebbutt
Copy link
Member

This is interesting. Will dig into it.

@willtebbutt
Copy link
Member

willtebbutt commented Apr 3, 2024

Lol, AD is descending into the implementation of fast_exp (I think this is the right method, but it's definitely one of the methods in this file).

This function contains a lot of operations with zero derivative (like rounding), and then reinterprets the output as a float.

The fix is for me to add rules for fast math, but this code shouldn't really have run in the first place. I think whichever intrinsic reinterpret eventually hits is currently assuming that if you convert from a float to an int, the float isn't going to have been used in a differentiable manner -- plainly this is wrong. I'm pretty sure that I should get Tapir to throw an error if it tries to reinterpret a float as an int.

Thanks for opening this!

To reproduce:

julia> foo(x) = @fastmath exp(x)
foo (generic function with 1 method)

julia> Tapir_grad(foo, 1.0)
(2.718281828459045, (NoTangent(), 0.0))

I verified locally that if you turn off fast math, differentiating the softmax gives the same numbers as Zygote and ForwardDiff.

@mcabbott
Copy link
Author

mcabbott commented Apr 3, 2024

Sounds good.

I have another error when I try to simplify it:

julia> sm = let  # simplify NNlib's a little
         softmax(x::AbstractArray{T}; dims::Int = 1) where {T} = softmax!(similar(x, float(T)), x; dims)

       softmax!(x::AbstractArray; dims = 1) = softmax!(x, x; dims)

       function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
           max_ = maximum(x; dims)  # removed fastmath version
           if true # all(isfinite, max_)
               out .= exp.(x .- max_)  # removed @fastmath
           else
               @fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
           end
           tmp = sum!(max_, out)  # removed branching on dims isa Colon
           out ./= tmp
       end

       softmax
       end
(::var"#softmax#94"{var"#softmax#91#95"}) (generic function with 1 method)

julia> sm([1,2,3.])  softmax([1,2,3.])
true

julia> Tapir_grad(firstsm, [1.0, 2.0, 0.0])
ERROR: StackOverflowError:
Stacktrace:
      [1] macro expansion
        @ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:0 [inlined]
      [2] zero_tangent(x::Core.Box)
        @ Tapir ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:342
      [3] macro expansion
        @ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:0 [inlined]
      [4] zero_tangent
        @ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:342 [inlined]
      [5] macro expansion
        @ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:0 [inlined]
      [6] zero_tangent(x::var"#softmax!#96"{var"#softmax!#92#97"})
        @ Tapir ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:342
--- the last 2 lines are repeated 1 more time ---
--- the last 6 lines are repeated 19337 more times ---
 [116031] macro expansion
        @ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:0 [inlined]
 [116032] zero_tangent
        @ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:342 [inlined]
--- the last 2 lines are repeated 1 more time ---
 [116035] macro expansion
        @ ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:0 [inlined]
 [116036] zero_tangent(x::ComposedFunction{typeof(first), var"#softmax#94"{var"#softmax#91#95"}})
        @ Tapir ~/.julia/packages/Tapir/BqxEi/src/tangents.jl:342
 [116037] zero_codual(x::Function)
        @ Tapir ~/.julia/packages/Tapir/BqxEi/src/codual.jl:20
 [116038] map
        @ ./tuple.jl:292 [inlined]
 [116039] value_and_pullback!!(::Tapir.DerivedRule{…}, ::Float64, ::Function, ::Vector{…})
        @ Tapir ~/.julia/packages/Tapir/BqxEi/src/interface.jl:51
Some type information was truncated. Use `show(err)` to see complete types.

@willtebbutt
Copy link
Member

Another fun one! I think there is some kind of self-referential stuff going on with sm because it is boxed. In any case, if you do something like

__softmax(x::AbstractArray{T}; dims::Int = 1) where {T} = __softmax!(similar(x, float(T)), x; dims)

__softmax!(x::AbstractArray; dims = 1) = __softmax!(x, x; dims)

function __softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
    max_ = maximum(x; dims)  # removed fastmath version
    if true # all(isfinite, max_)
        out .= exp.(x .- max_)  # removed @fastmath
    else
        @fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
    end
    tmp = sum!(max_, out)  # removed branching on dims isa Colon
    out ./= tmp
end

Tapir_grad(first__softmax, [1.0, 2.0, 0.0])

it'll work.

I don't fully understand how the boxing mechanism works, so I'll have to look at that. If it is indeed a self-referential type issue, then this is a known limitation of the current zero_tangent implementation, that I've vaguely got on my list of things to sort out, but not super high up.

@mcabbott
Copy link
Author

mcabbott commented Apr 3, 2024

Indeed, I don't know why Box is circular. (const sm = let.... was the first things I tried, no change.) But perhaps throwing an informative error is easy even if handling it is nontrivial.

julia> dump(sm)
softmax (function of type var"#softmax#94"{var"#softmax#91#95"})
  #softmax#91: softmax#91 (function of type var"#softmax#91#95")
    softmax!: Core.Box
      contents: softmax! (function of type var"#softmax!#96"{var"#softmax!#92#97"})
        #softmax!#92: softmax!#92 (function of type var"#softmax!#92#97")
          softmax!: Core.Box#= circular reference @-3 =#
        #softmax!#93: Core.Box
          contents: softmax!#93 (function of type var"#softmax!#93#98")

@willtebbutt
Copy link
Member

Hmm yes, that would be ideal. I'm can't say that I'm entirely sure how to detect whether circular referencing is going on though. If you have thoughts, I'd love to hear them.

@yebai
Copy link
Contributor

yebai commented Apr 3, 2024

It seems such self-references will always cause StackOverflow errors. If true, can we catch StackOverflow errors and print a more informative message, e.g. suggest users check for self-references?

@mcabbott
Copy link
Author

mcabbott commented Apr 3, 2024

You could just do this:

julia> Tapir.zero_tangent(x::Core.Box) = error("nope")

julia> Tapir.zero_tangent(sm)
ERROR: nope

I don't know how many other self-referential types show up in the wild.

I believe Enzyme's make_zero handles some of them, and uses an IdDict cache to do so, but in fact it fails on this one.

What the cache also lets it do is preserve === relations, leading to e.g. this difference:

julia> sh = [1.0, 2.0];

julia> nt = (a=sh, b=sh, c=copy(sh));

julia> zed = Tapir.zero_tangent((a=sh, b=sh, c=copy(sh)))
(a = [0.0, 0.0], b = [0.0, 0.0], c = [0.0, 0.0])

julia> zed.a === zed.b
false

julia> Tapir_grad(x -> sum(map(sum, x)), nt)
(9.0, (NoTangent(), (a = [1.0, 1.0], b = [1.0, 1.0], c = [1.0, 1.0])))

julia> Base.zero(nt::NamedTuple) = Enzyme.make_zero(nt);

julia> z2 = zero(nt);

julia> z2.a === z2.b
true

julia> Enzyme.gradient(Reverse, x -> sum(map(sum, x)), nt)
(a = [2.0, 2.0], b = [2.0, 2.0], c = [1.0, 1.0])

@willtebbutt
Copy link
Member

Good points @yebai @mcabbott -- I'll add this to my todo list. It'll probably be a couple of weeks before I get round to it though (some larger refactoring work is currently at the top of my list)

@willtebbutt
Copy link
Member

Please do continue to open issues as you find problems though @mcabbott -- this has already been a very productive issue from my perspective.

@mcabbott mcabbott changed the title Zero gradient for softmax Zero gradient for softmax -- reinterpret & @fastmath, plus Core.Box & self-reference Apr 13, 2024
@willtebbutt willtebbutt added bug (numerical correctness) bug Something isn't working labels Apr 29, 2024
@willtebbutt willtebbutt changed the title Zero gradient for softmax -- reinterpret & @fastmath, plus Core.Box & self-reference Support @fastmath & Core.Box & self-reference May 1, 2024
@willtebbutt
Copy link
Member

willtebbutt commented May 1, 2024

Okay. The v0.2.6 of Tapir will throw an error if you try to bitcast anything into an IEEEFloat. So, for example, if you encounter Core.bitcast(Float64, x), in a programme that you're trying to differentiate, you'll get an (informative!) error.

I've not yet added support for stuff, but we at least won't get a silent failure anymore. I'll continue to work on fixing up the fastmath and Box stuff

edit: although my value_and_pullback!! interface is giving trouble now.

edit2: but if you run something like

foo(x) = first(softmax(x))
rule = Tapir.build_rrule(foo, randn(3))
Tapir.value_and_gradient!!(rule, foo, randn(3))

you should see a stack trace with this at the top:

ERROR: ArgumentError: It is not permissible to bitcast to a differentiable type during AD, as this risks dropping tangents, and therefore risks silently giving the wrong answer. If this call to bitcast appears as part of the implementation of a differentiable function, you should write a rule for this function, or modify its implementation to avoid the bitcast.
Stacktrace:
  [1] rrule!!(f::CoDual{typeof(Tapir.IntrinsicsWrappers.bitcast), NoFData}, t::CoDual{Type{Float64}, NoFData}, x::CoDual{UInt64, NoFData})
    @ Tapir.IntrinsicsWrappers ~/ml/ad_playground/Taped.jl/src/rrules/builtins.jl:98
  [2] RRuleZeroWrapper
    @ ~/ml/ad_playground/Taped.jl/src/interpreter/s2s_reverse_mode_ad.jl:232 [inlined]
  [3] exp_fast
    @ ./special/exp.jl:328 [inlined]

@willtebbutt willtebbutt added enhancement New feature or request and removed bug (numerical correctness) labels May 1, 2024
@willtebbutt
Copy link
Member

willtebbutt commented May 2, 2024

v0.2.7 of Tapir now has support for the various variants of exp in Base.FastMath, and I've added integration tests for all of the other functions in there. So your example (when run in the manner that I've shown in my previous comment) now should work fine. Most of them seem to be fine with Float64 inputs -- there are just a couple of that someone will need to write / adapt rules for at some point.

@willtebbutt willtebbutt changed the title Support @fastmath & Core.Box & self-reference ~~Support @fastmath &~~ Core.Box & self-reference May 3, 2024
@willtebbutt willtebbutt changed the title ~~Support @fastmath &~~ Core.Box & self-reference Core.Box & self-reference (fastmath and silent error fixed) May 3, 2024
@willtebbutt willtebbutt removed the bug Something isn't working label May 3, 2024
@willtebbutt willtebbutt added this to the A Milestone milestone May 13, 2024
@willtebbutt
Copy link
Member

willtebbutt commented May 13, 2024

I have now gotten around to looking at this. I quite like @yebai 's suggestion of providing an informative error message when we encounter a stack overflow. With #144 , the error that you saw @mcabbott becomes

julia> Tapir_grad(firstsm, [1.0, 2.0, 0.0])
ERROR: Found a StackOverFlow error when trying to wrap inputs. This often means that Tapir.jl has encountered a self-referential type. Tapir.jl is not presently able to handle self-referential types, so if you are indeed using a self-referential type somewhere, you will need to refactor to avoid it if you wish to use Tapir.jl.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] __create_coduals(args::Tuple{ComposedFunction{typeof(first), var"#softmax#4"{var"#softmax#1#5"}}, Vector{Float64}})
   @ Tapir ~/ml/ad_playground/Taped.jl/src/interface.jl:82
 [3] value_and_pullback!!(::Tapir.DerivedRule{…}, ::Float64, ::Function, ::Vector{…})
   @ Tapir ~/ml/ad_playground/Taped.jl/src/interface.jl:65
 [4] Tapir_grad(f::Function, xs::Vector{Float64})
   @ Main ./REPL[35]:1
 [5] top-level scope
   @ REPL[36]:1

caused by: StackOverflowError:
Stacktrace:
      [1] zero_tangent(x::Core.Box)
        @ Tapir ~/ml/ad_playground/Taped.jl/src/tangents.jl:293
      [2] macro expansion
        @ ~/ml/ad_playground/Taped.jl/src/tangents.jl:0 [inlined]
      [3] zero_tangent
        @ ~/ml/ad_playground/Taped.jl/src/tangents.jl:293 [inlined]
      [4] macro expansion
        @ ~/ml/ad_playground/Taped.jl/src/tangents.jl:0 [inlined]
      [5] zero_tangent(x::var"#softmax!#6"{var"#softmax!#2#7"})
        @ Tapir ~/ml/ad_playground/Taped.jl/src/tangents.jl:293
--- the last 2 lines are repeated 1 more time ---
--- the last 6 lines are repeated 18684 more times ---
 [112112] macro expansion
        @ ~/ml/ad_playground/Taped.jl/src/tangents.jl:0 [inlined]
 [112113] zero_tangent
        @ ~/ml/ad_playground/Taped.jl/src/tangents.jl:293 [inlined]
--- the last 2 lines are repeated 2 more times ---
 [112118] zero_codual
        @ ~/ml/ad_playground/Taped.jl/src/codual.jl:24 [inlined]
 [112119] macro expansion
        @ ./none:0 [inlined]
 [112120] tuple_map
        @ ./none:0 [inlined]
 [112121] __create_coduals(args::Tuple{ComposedFunction{typeof(first), var"#softmax#4"{var"#softmax#1#5"}}, Vector{Float64}})
        @ Tapir ~/ml/ad_playground/Taped.jl/src/interface.jl:79
 [112122] value_and_pullback!!(::Tapir.DerivedRule{…}, ::Float64, ::Function, ::Vector{…})
        @ Tapir ~/ml/ad_playground/Taped.jl/src/interface.jl:65
Some type information was truncated. Use `show(err)` to see complete types.

I've added code to catch this at the interface layer (we can't do this the whole way down without killing performance), so this will catch things when arguments are self-referential. When someone finds an example where this is a problem elsewhere in the pipeline, I'll figure out how to cover it.

I'll close this when #144 is merged.

@willtebbutt
Copy link
Member

Closing via #144

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants