Skip to content

Commit

Permalink
Merge pull request #1474 from Pangoraw/try_catch
Browse files Browse the repository at this point in the history
Support try/catch on the happy (nothrow) path
  • Loading branch information
ToucheSir authored Sep 24, 2024
2 parents fe393b0 + d56dd27 commit 406a6f5
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
ForwardDiff = "0.10"
GPUArrays = "8.4.2, 9, 10"
GPUArraysCore = "0.1.1"
IRTools = "0.4.11"
IRTools = "0.4.12"
LogExpFunctions = "0.3.1"
MacroTools = "0.5"
NaNMath = "0.3, 1"
Expand Down
35 changes: 19 additions & 16 deletions docs/src/limitations.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,30 @@ julia> gradient(rand(3)) do y

## Try-catch statements

Any expressions involving `try`/`catch` statements is not supported.
```julia
function tryme(x)
try
2 * x
catch e
throw(e)
end
end
Code containting try-catch blocks can be differentiated as long as no exception is actually thrown.

julia> gradient(rand(3)) do x
sum(tryme(x))
```julia
julia> function safe_sqrt(x)
try
sqrt(x)
catch
0.
end
end
ERROR: Compiling Tuple{typeof(tryme), Vector{Float64}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
safe_sqrt (generic function with 1 method)

julia> gradient(safe_sqrt, 4.)
(0.25,)

julia> val, pull = pullback(safe_sqrt, -1.)
(0.0, Zygote.var"#76#77"{Zygote.Pullback{Tuple{typeof(safe_sqrt), Float64}, Any}}((safe_sqrt)))

julia> pull(1.)
ERROR: Can't differentiate function execution in catch block at #= REPL[2]:3 =#.
Stacktrace:
...
```
Here `tryme` uses a `try`/`catch` statement, and Zygote throws an error when trying to differentiate it as expected. `try`/`catch` expressions are used for error handling, but they are less common in Julia compared to some other languages.

Here, the `safe_sqrt` function catches DomainError from the sqrt call when the input is out of domain and safely returns 0. Zygote is able to differentiate the function when no error is thrown by the sqrt call, but fails to differentiate when the control flow goes through the catch block.

## Foreign call expressions

Expand Down
14 changes: 7 additions & 7 deletions src/compiler/emit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,23 @@ concrete(T::DataType) = T
concrete(::Type{Type{T}}) where T = typeof(T)
concrete(T) = Any

runonce(b) = b.id in (1, length(b.ir.blocks))
runonce(b) = b.id in (1, length(b.ir.blocks)) &&
!any(((_,stmt),) -> isexpr(stmt.expr, :catch), b)

function forward_stacks!(adj, F)
stks, recs = [], []
pr = adj.primal
for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id))
if runonce(b)
not_stack = runonce(b)
if not_stack
push!(recs, Variable(α))
else
stk = pushfirst!(pr, xstack(Any))
push!(recs, stk)
push!(b, xcall(Zygote, :_push!, stk, Variable(α)))
end
push!(stks, (b.id, alpha(α)))
push!(stks, (b.id, alpha(α), not_stack))
end
args = arguments(pr)[3:end]
rec = push!(pr, xtuple(recs...))
P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any}
# P = Pullback{F,Any} # reduce specialisation
Expand All @@ -68,11 +69,10 @@ function reverse_stacks!(adj, stks)
self = argument!(entry, at = 1)
t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t)))
repl = Dict()
runonce(b) = b.id in (1, length(ir.blocks))
for b in blocks(ir)
for (i, (b′, α)) in enumerate(stks)
for (i, (b′, α, not_stack)) in enumerate(stks)
b.id == b′ || continue
if runonce(b)
if not_stack
val = insertafter!(ir, t, xcall(:getindex, t, i))
else
stk = push!(entry, xcall(:getindex, t, i))
Expand Down
25 changes: 18 additions & 7 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,6 @@ function instrument(ir::IR)
ex = st.expr
if isexpr(ex, :foreigncall, :isdefined)
continue
elseif isexpr(ex, :enter, :leave)
error("""try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
""")
elseif isexpr(ex, :(=))
@assert ex.args[1] isa GlobalRef
pr[v] = xcall(Zygote, :global_set, QuoteNode(ex.args[1]), ex.args[2])
Expand Down Expand Up @@ -262,7 +257,7 @@ function adjointcfg(pr::Primal)
end
if isempty(preds) || (!isempty(branches(b)) && branches(b)[end] == IRTools.unreachable)
# If `b` is unreachable, then no context produced by the primal should end up branching to `rb`
push!(rb, xcall(Core, :throw, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable`
push!(rb, xcall(Base, :error, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable`
branch!(rb, 0)
end
end
Expand All @@ -283,7 +278,7 @@ xaccum(ir, xs...) = push!(ir, xcall(Zygote, :accum, xs...))

function passthrough_expr(ex::Expr)
# Metadata we want to preserve
isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo) && return true
isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo, :enter, :leave, :catch) && return true
# ccalls and more that are safe to preserve/required for proper operation:
# - jl_set_task_threadpoolid: added in 1.9 for @spawn
isexpr(ex, :foreigncall) && unwrapquote(ex.args[1]) in (:jl_set_task_threadpoolid,) && return true
Expand All @@ -301,9 +296,14 @@ function adjoint(pr::Primal)
for i = 1:length(sigs[b.id])
grad(sigs[b.id][i], arguments(rb)[i])
end

has_leave = false

# Backprop through statements
for v in reverse(keys(b))
ex = b[v].expr
has_leave |= isexpr(ex, :leave)

if haskey(pr.pullbacks, v)
g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)),
line = b[v].line))
Expand All @@ -325,6 +325,17 @@ function adjoint(pr::Primal)
continue
end
end

# This is corresponds to a catch blocks which technically
# has predecessors but they are not modelled in the IRTools CFG.
# We put an error message at the beginning of said block.
if has_leave && isempty(predecessors(b)) && b.id != 1
_, f_stmt = first(b)
li = pr.ir.lines[f_stmt.line]
pushfirst!(rb, stmt(xcall(Base, :error,
"Can't differentiate function execution in catch block at $(li.file):$(li.line).")))
end

if b.id > 1 # Backprop through (predecessor) branch arguments
gs = grad.(arguments(b))
for br in branches(rb)
Expand Down
97 changes: 97 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,100 @@ end
@test_nowarn g = back(1.)
@test only(g) (1., 2.)
end

function throws_and_catches_if_x_negative(x,y)
z = x + y
try
if x < 0.
throw(DomainError("x is negative"))
end
z = 2z + x + y
catch err
@error "something went wrong" exception=(err,catch_backtrace())
end
return 3z
end

function try_catch_finally(cond, x)

try
x = 2x
cond && throw(DomainError())
catch
x = 2x
finally
x = 3x
end

x
end

if VERSION >= v"1.8"
# try/catch/else is invalid syntax prior to v1.8
eval(Meta.parse("""
function try_catch_else(cond, x)
x = 2x
try
x = 2x
cond && throw(nothing)
catch
x = 3x
else
x = 2x
end
x
end
"""))
end

@testset "try/catch" begin
@testset "happy path (nothrow)" begin
res, (dx,dy) = withgradient(throws_and_catches_if_x_negative, 1., 2.)
@test res == 3 * (2 * (1. + 2.) + 1. + 2.)
@test dx == 3. * (2. + 1.)
@test dy == 3. * (2. + 1.)
end

@testset "try/catch/finally" begin
res, (_, dx,) = withgradient(try_catch_finally, false, 1.)
@test res == 6.
@test dx == 6.

res, pull = pullback(try_catch_finally, true, 1.)
@test res == 12.
@test_throws ErrorException pull(1.)
err = try pull(1.) catch ex; ex end
@test occursin("Can't differentiate function execution in catch block",
string(err))
end

if VERSION >= v"1.8"
@testset "try/catch/else" begin
@test Zygote.gradient(try_catch_else, false, 1.0) == (nothing, 8.0)
@test_throws "Can't differentiate function execution in catch block" Zygote.gradient(try_catch_else, true, 1.0)
end
end

function foo_try(f)
y = 1
try
y = f()
catch
y
end
y
end

g, = gradient(x -> foo_try(() -> x), 1) # 1
@test g == 1.

vy, pull = pullback(foo_try, () -> 0//0) # bypass because of expr
@test vy === 1
@test_throws ErrorException pull(1.)

err = try pull(1.) catch ex; ex end
@test occursin("Can't differentiate function execution in catch block",
string(err))
end
3 changes: 1 addition & 2 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,7 @@ function pow_try(x)
end
end

@test_broken gradient(pow_try, 1) == (2,)
@test_throws Zygote.CompileError gradient(pow_try, 1)
@test gradient(pow_try, 1) == (2,)

function pow_simd(x, n)
r = 1
Expand Down

0 comments on commit 406a6f5

Please sign in to comment.