From 3ddf945f1bef392c89dd0dffef038fee14c49826 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Wed, 29 Nov 2023 11:05:16 +0100 Subject: [PATCH 1/9] enable try/catch support on the happy (nothrow) path --- docs/src/limitations.md | 35 ++++++++++++++++++--------------- src/compiler/reverse.jl | 26 ++++++++++++++++++------- test/compiler.jl | 43 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 23 deletions(-) diff --git a/docs/src/limitations.md b/docs/src/limitations.md index f27f74305..b455e1850 100644 --- a/docs/src/limitations.md +++ b/docs/src/limitations.md @@ -82,27 +82,30 @@ julia> gradient(rand(3)) do y ## Try-catch statements -Any expressions involving `try`/`catch` statements is not supported. -```julia -function tryme(x) - try - 2 * x - catch e - throw(e) - end -end +Exceptions containing try-catch statements can be differentiated if the catch block is not reached (no error are thrown). -julia> gradient(rand(3)) do x - sum(tryme(x)) +```julia +julia> function safe_sqrt(x) + try + sqrt(x) + catch + 0. + end end -ERROR: Compiling Tuple{typeof(tryme), Vector{Float64}}: try/catch is not supported. -Refer to the Zygote documentation for fixes. -https://fluxml.ai/Zygote.jl/latest/limitations +safe_sqrt (generic function with 1 method) + +julia> gradient(safe_sqrt, 4.) +(0.25,) + +julia> val, pull = pullback(safe_sqrt, -1.) +(0.0, Zygote.var"#76#77"{Zygote.Pullback{Tuple{typeof(safe_sqrt), Float64}, Any}}(∂(safe_sqrt))) +julia> pull(1.) +ERROR: Can't differentiate function execution in catch block at #= REPL[2]:3 =#. Stacktrace: - ... ``` -Here `tryme` uses a `try`/`catch` statement, and Zygote throws an error when trying to differentiate it as expected. `try`/`catch` expressions are used for error handling, but they are less common in Julia compared to some other languages. + +Here, the `safe_sqrt` function catches DomainError from the sqrt call when the input is out of domain and safely returns 0. Zygote is able to differentiate the function when no error is thrown by the sqrt call, but fails to differentiate when the control flow goes through the catch block. ## Foreign call expressions diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 0583b3da6..f47560541 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -124,11 +124,6 @@ function instrument(ir::IR) ex = st.expr if isexpr(ex, :foreigncall, :isdefined) continue - elseif isexpr(ex, :enter, :leave) - error("""try/catch is not supported. - Refer to the Zygote documentation for fixes. - https://fluxml.ai/Zygote.jl/latest/limitations - """) elseif isexpr(ex, :(=)) @assert ex.args[1] isa GlobalRef pr[v] = xcall(Zygote, :global_set, QuoteNode(ex.args[1]), ex.args[2]) @@ -258,7 +253,7 @@ function adjointcfg(pr::Primal) end if isempty(preds) || (!isempty(branches(b)) && branches(b)[end] == IRTools.unreachable) # If `b` is unreachable, then no context produced by the primal should end up branching to `rb` - push!(rb, xcall(Core, :throw, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable` + push!(rb, xcall(Base, :error, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable` branch!(rb, 0) end end @@ -279,7 +274,7 @@ xaccum(ir, xs...) = push!(ir, xcall(Zygote, :accum, xs...)) function passthrough_expr(ex::Expr) # Metadata we want to preserve - isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo) && return true + isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo, :enter, :leave, :catch) && return true # ccalls and more that are safe to preserve/required for proper operation: # - jl_set_task_threadpoolid: added in 1.9 for @spawn isexpr(ex, :foreigncall) && unwrapquote(ex.args[1]) in (:jl_set_task_threadpoolid,) && return true @@ -297,9 +292,14 @@ function adjoint(pr::Primal) for i = 1:length(sigs[b.id]) grad(sigs[b.id][i], arguments(rb)[i]) end + + has_leave = false + # Backprop through statements for v in reverse(keys(b)) ex = b[v].expr + has_leave |= isexpr(ex, :leave) + if haskey(pr.pullbacks, v) g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)), line = b[v].line)) @@ -321,6 +321,18 @@ function adjoint(pr::Primal) continue end end + + # This is corresponds to a catch blocks which technically + # has predecessors but they are not modelled in the IRTools CFG. + # We put an error message at the beginning of said block. + if has_leave && isempty(predecessors(b)) && b.id != 1 + _, f_stmt = first(b) + li = pr.ir.lines[f_stmt.line] + li = LineNumberNode(Int(li.line), li.file) + pushfirst!(rb, stmt(xcall(Base, :error, + "Can't differentiate function execution in catch block at $(li)."))) + end + if b.id > 1 # Backprop through (predecessor) branch arguments gs = grad.(arguments(b)) for br in branches(rb) diff --git a/test/compiler.jl b/test/compiler.jl index 4f8776c90..e1ace79bf 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -245,3 +245,46 @@ end @test_nowarn g = back(1.) @test only(g) ∈ (1., 2.) end + +function f_try_catch(x,y) + z = x + y + try + if x < 0. + throw(DomainError("x is negative")) + end + z = 2z + x + y + catch err + @error "something went wrong" exception=(err,catch_backtrace()) + end + return 3z +end + +@testset "try/catch" begin + @testset "happy path (nothrow)" begin + res, (dx,dy) = withgradient(f_try_catch, 1., 2.) + @test res == 3 * (2 * (1. + 2.) + 1. + 2.) + @test dx == 3. * (2. + 1.) + @test dy == 3. * (2. + 1.) + end + + function foo_try(f) + y = 1 + try + y = f() + catch + y + end + y + end + + g, = gradient(x -> foo_try(() -> x), 1) # 1 + @test g == 1. + + vy, pull = pullback(foo_try, () -> 0//0) # bypass because of expr + @test vy === 1 + @test_throws ErrorException pull(1.) + + err = try pull(1.) catch ex; ex end + @test occursin("Can't differentiate function execution in catch block", + string(err)) +end From 8f416aa8e1ce8d940c24c644e3f926350571185d Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 4 Dec 2023 22:49:45 +0100 Subject: [PATCH 2/9] use stacks for possibly undefined pullbacks another possibility would be to thread the pullback values as block parameters through the control flow with arguments coming from the catch block being `nothing`. --- src/compiler/emit.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/compiler/emit.jl b/src/compiler/emit.jl index ca79f11ce..7b1c82ed6 100644 --- a/src/compiler/emit.jl +++ b/src/compiler/emit.jl @@ -36,22 +36,23 @@ concrete(T::DataType) = T concrete(::Type{Type{T}}) where T = typeof(T) concrete(T) = Any -runonce(b) = b.id in (1, length(b.ir.blocks)) +runonce(b) = b.id in (1, length(b.ir.blocks)) && + !any(((_,stmt),) -> isexpr(stmt.expr, :catch), b) function forward_stacks!(adj, F) stks, recs = [], [] pr = adj.primal for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id)) - if runonce(b) + is_stack = runonce(b) + if is_stack push!(recs, Variable(α)) else stk = pushfirst!(pr, xstack(Any)) push!(recs, stk) push!(b, xcall(Zygote, :_push!, stk, Variable(α))) end - push!(stks, (b.id, alpha(α))) + push!(stks, (b.id, alpha(α), is_stack)) end - args = arguments(pr)[3:end] rec = push!(pr, xtuple(recs...)) P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any} # P = Pullback{F,Any} # reduce specialisation @@ -68,11 +69,10 @@ function reverse_stacks!(adj, stks) self = argument!(entry, at = 1) t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t))) repl = Dict() - runonce(b) = b.id in (1, length(ir.blocks)) for b in blocks(ir) - for (i, (b′, α)) in enumerate(stks) + for (i, (b′, α, is_stack)) in enumerate(stks) b.id == b′ || continue - if runonce(b) + if is_stack val = insertafter!(ir, t, xcall(:getindex, t, i)) else stk = push!(entry, xcall(:getindex, t, i)) From 0836d94b4b209b824ceb0ba70ddc5f32df364560 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Wed, 6 Dec 2023 09:39:56 +0100 Subject: [PATCH 3/9] Update features.jl --- test/features.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/features.jl b/test/features.jl index 908ae5815..78dba0484 100644 --- a/test/features.jl +++ b/test/features.jl @@ -416,8 +416,7 @@ function pow_try(x) end end -@test_broken gradient(pow_try, 1) == (2,) -@test_throws Zygote.CompileError gradient(pow_try, 1) +@test gradient(pow_try, 1) == (2,) function pow_simd(x, n) r = 1 From 07ec290affd637c39cd1a52b9e83bd91551d1449 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Wed, 6 Dec 2023 11:10:19 +0100 Subject: [PATCH 4/9] Update emit.jl --- src/compiler/emit.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/compiler/emit.jl b/src/compiler/emit.jl index 7b1c82ed6..75bfc0e58 100644 --- a/src/compiler/emit.jl +++ b/src/compiler/emit.jl @@ -43,15 +43,15 @@ function forward_stacks!(adj, F) stks, recs = [], [] pr = adj.primal for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id)) - is_stack = runonce(b) - if is_stack + not_stack = runonce(b) + if not_stack push!(recs, Variable(α)) else stk = pushfirst!(pr, xstack(Any)) push!(recs, stk) push!(b, xcall(Zygote, :_push!, stk, Variable(α))) end - push!(stks, (b.id, alpha(α), is_stack)) + push!(stks, (b.id, alpha(α), not_stack)) end rec = push!(pr, xtuple(recs...)) P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any} @@ -70,9 +70,9 @@ function reverse_stacks!(adj, stks) t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t))) repl = Dict() for b in blocks(ir) - for (i, (b′, α, is_stack)) in enumerate(stks) + for (i, (b′, α, not_stack)) in enumerate(stks) b.id == b′ || continue - if is_stack + if not_stack val = insertafter!(ir, t, xcall(:getindex, t, i)) else stk = push!(entry, xcall(:getindex, t, i)) From c0e5ba1d7fbe759eebc5bc375350bf5c15d13509 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 2 Jan 2024 18:28:27 +0100 Subject: [PATCH 5/9] Update docs/src/limitations.md Co-authored-by: Frames White --- docs/src/limitations.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/limitations.md b/docs/src/limitations.md index b455e1850..4e9012ced 100644 --- a/docs/src/limitations.md +++ b/docs/src/limitations.md @@ -82,7 +82,7 @@ julia> gradient(rand(3)) do y ## Try-catch statements -Exceptions containing try-catch statements can be differentiated if the catch block is not reached (no error are thrown). +Code containting try-catch blocks can be differentiated as long as no exception is actually thrown. ```julia julia> function safe_sqrt(x) From 9e6e63b50970adaf80889e38d335906ff4e40f25 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 21 Jan 2024 12:41:34 +0100 Subject: [PATCH 6/9] add try/catch/finally test This requires updating the IRTools version to include the better ssa conversion for try catch branches. --- test/compiler.jl | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index e1ace79bf..3fdc7a912 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -246,7 +246,7 @@ end @test only(g) ∈ (1., 2.) end -function f_try_catch(x,y) +function throws_and_catches_if_x_negative(x,y) z = x + y try if x < 0. @@ -259,14 +259,49 @@ function f_try_catch(x,y) return 3z end +function try_catch_finally(cond, x) + + try + x = 2x + cond && throw(DomainError()) + catch + x = 2x + finally + x = 3x + end + + x +end + +if VERSION >= v"1.8" + # try/catch/else is invalid syntax prior to v1.8 + eval(Meta.parse(""" + function try_catch_else(cond, x) + end + """)) +end + @testset "try/catch" begin @testset "happy path (nothrow)" begin - res, (dx,dy) = withgradient(f_try_catch, 1., 2.) + res, (dx,dy) = withgradient(throws_and_catches_if_x_negative, 1., 2.) @test res == 3 * (2 * (1. + 2.) + 1. + 2.) @test dx == 3. * (2. + 1.) @test dy == 3. * (2. + 1.) end + @testset "try/catch/finally" begin + res, (_, dx,) = withgradient(try_catch_finally, false, 1.) + @test res == 6. + @test dx == 6. + + res, pull = pullback(try_catch_finally, true, 1.) + @test res == 12. + @test_throws ErrorException pull(1.) + err = try pull(1.) catch ex; ex end + @test occursin("Can't differentiate function execution in catch block", + string(err)) + end + function foo_try(f) y = 1 try From e00a28cd6881e4a7632187c5b18792c366e2b2cf Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 21 Jan 2024 12:47:14 +0100 Subject: [PATCH 7/9] improve line info display for error message --- src/compiler/reverse.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index f47560541..5ed79ea81 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -328,9 +328,8 @@ function adjoint(pr::Primal) if has_leave && isempty(predecessors(b)) && b.id != 1 _, f_stmt = first(b) li = pr.ir.lines[f_stmt.line] - li = LineNumberNode(Int(li.line), li.file) pushfirst!(rb, stmt(xcall(Base, :error, - "Can't differentiate function execution in catch block at $(li)."))) + "Can't differentiate function execution in catch block at $(li.file):$(li.line)."))) end if b.id > 1 # Backprop through (predecessor) branch arguments From 4d52ed47bc16e38ec384fb5ddc50094235576b5b Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Sun, 21 Jan 2024 19:26:58 +0100 Subject: [PATCH 8/9] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9101c516c..9760925c5 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1" ForwardDiff = "0.10" GPUArrays = "8.4.2, 9" GPUArraysCore = "0.1.1" -IRTools = "0.4.11" +IRTools = "0.4.12" LogExpFunctions = "0.3.1" MacroTools = "0.5" NaNMath = "0.3, 1" From d56dd270988e450834b906bf3d1d113db19b12de Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 29 Jan 2024 22:07:45 +0100 Subject: [PATCH 9/9] complete try/catch/else test --- test/compiler.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/compiler.jl b/test/compiler.jl index 3fdc7a912..af93ae4f3 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -277,6 +277,18 @@ if VERSION >= v"1.8" # try/catch/else is invalid syntax prior to v1.8 eval(Meta.parse(""" function try_catch_else(cond, x) + x = 2x + + try + x = 2x + cond && throw(nothing) + catch + x = 3x + else + x = 2x + end + + x end """)) end @@ -302,6 +314,13 @@ end string(err)) end + if VERSION >= v"1.8" + @testset "try/catch/else" begin + @test Zygote.gradient(try_catch_else, false, 1.0) == (nothing, 8.0) + @test_throws "Can't differentiate function execution in catch block" Zygote.gradient(try_catch_else, true, 1.0) + end + end + function foo_try(f) y = 1 try