Skip to content

Commit

Permalink
Merge pull request #1465 from Pangoraw/unreachable_block
Browse files Browse the repository at this point in the history
Handle unreachable blocks in the adjoint CFG
  • Loading branch information
ToucheSir authored Oct 18, 2023
2 parents cf7f7d0 + bcf996a commit b152846
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
ForwardDiff = "0.10"
GPUArrays = "8.4.2, 9"
GPUArraysCore = "0.1.1"
IRTools = "0.4.4"
IRTools = "0.4.11"
LogExpFunctions = "0.3.1"
MacroTools = "0.5"
NaNMath = "0.3, 1"
Expand Down
5 changes: 3 additions & 2 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ Variable(a::Alpha) = Variable(a.id)
sig(b::IRTools.Block) = unique([arg for br in branches(b) for arg in br.args if arg isa Variable])
sig(pr::Primal) = Dict(b.id => sig(b) for b in blocks(pr.ir))

# TODO unreachables?
function adjointcfg(pr::Primal)
ir = empty(pr.ir)
return!(ir, nothing)
Expand All @@ -257,7 +256,9 @@ function adjointcfg(pr::Primal)
push!(rb, xcall(Base, :(!==), alpha(pr.branches[b.id]), BranchNumber(i)))
branch!(rb, preds[i].id, unless = cond)
end
if !isempty(branches(b)) && branches(b)[end] == IRTools.unreachable
if isempty(preds) || (!isempty(branches(b)) && branches(b)[end] == IRTools.unreachable)
# If `b` is unreachable, then no context produced by the primal should end up branching to `rb`
push!(rb, xcall(Core, :throw, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable`
branch!(rb, 0)
end
end
Expand Down
20 changes: 20 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,23 @@ end

# issue 897
@test gradient(x -> sum(norm, collect(eachcol(x))), ones(3, 400))[1] fill(0.5773502691896258, 3, 400)

# issue 1118 & 1380
function f_1380(x)
if rand(Bool)
return x
else
return 2x
end

# unreachable
return nothing
end

@testset "unreachable block" begin
y, back = Zygote.pullback(f_1380, 1.)
# There should not be a compiler error
local g
@test_nowarn g = back(1.)
@test only(g) (1., 2.)
end

0 comments on commit b152846

Please sign in to comment.