diff --git a/Project.toml b/Project.toml index 5105e6a06..e4426b59d 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1" ForwardDiff = "0.10" GPUArrays = "8.4.2, 9" GPUArraysCore = "0.1.1" -IRTools = "0.4.4" +IRTools = "0.4.11" LogExpFunctions = "0.3.1" MacroTools = "0.5" NaNMath = "0.3, 1" diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 333323e83..0583b3da6 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -244,7 +244,6 @@ Variable(a::Alpha) = Variable(a.id) sig(b::IRTools.Block) = unique([arg for br in branches(b) for arg in br.args if arg isa Variable]) sig(pr::Primal) = Dict(b.id => sig(b) for b in blocks(pr.ir)) -# TODO unreachables? function adjointcfg(pr::Primal) ir = empty(pr.ir) return!(ir, nothing) @@ -257,7 +256,9 @@ function adjointcfg(pr::Primal) push!(rb, xcall(Base, :(!==), alpha(pr.branches[b.id]), BranchNumber(i))) branch!(rb, preds[i].id, unless = cond) end - if !isempty(branches(b)) && branches(b)[end] == IRTools.unreachable + if isempty(preds) || (!isempty(branches(b)) && branches(b)[end] == IRTools.unreachable) + # If `b` is unreachable, then no context produced by the primal should end up branching to `rb` + push!(rb, xcall(Core, :throw, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable` branch!(rb, 0) end end diff --git a/test/compiler.jl b/test/compiler.jl index c9b091f78..4f8776c90 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -225,3 +225,23 @@ end # issue 897 @test gradient(x -> sum(norm, collect(eachcol(x))), ones(3, 400))[1] ≈ fill(0.5773502691896258, 3, 400) + +# issue 1118 & 1380 +function f_1380(x) + if rand(Bool) + return x + else + return 2x + end + + # unreachable + return nothing +end + +@testset "unreachable block" begin + y, back = Zygote.pullback(f_1380, 1.) + # There should not be a compiler error + local g + @test_nowarn g = back(1.) + @test only(g) ∈ (1., 2.) +end