Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Elide stack generation outside of looping control flow #1195

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 104 additions & 23 deletions src/compiler/emit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,50 +34,130 @@ xtuple(xs...) = xcall(:tuple, xs...)

concrete(T::DataType) = T
concrete(::Type{Type{T}}) where T = typeof(T)
concrete(T) = Any
concrete(@nospecialize _) = Any

runonce(b) = b.id in (1, length(b.ir.blocks))

# TODO use a more efficient algorithm such as Johnson (1975)
# https://epubs.siam.org/doi/abs/10.1137/0204007
self_reaching(cfg, bid, visited = BitSet()) = reaches(cfg, bid, bid, visited)
function reaches(cfg, from, to, visited)
for succ in cfg[from]
if succ === to
return true
elseif succ ∉ visited
push!(visited, succ)
if reaches(cfg, succ, to, visited)
return true
end
end
end
return false
end

function forward_stacks!(adj, F)
stks, recs = [], []
stks, recs = Tuple{Int, Alpha, Bool}[], Variable[]
pr = adj.primal
for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id))
if runonce(b)
push!(recs, Variable(α))
else
stk = pushfirst!(pr, xstack(Any))
push!(recs, stk)
push!(b, xcall(Zygote, :_push!, stk, Variable(α)))
blks = blocks(pr)
last_block = length(blks)
cfg = IRTools.CFG(pr)
cfgᵀ = cfg'
doms = IRTools.dominators(cfg)

reaching_visited = BitSet()
in_loop = map(1:last_block) do b
empty!(reaching_visited)
self_reaching(cfg, b, reaching_visited)
end
alphavars = Dict{Alpha, Variable}()
alpha_blocks = [α => b.id for b in blks for α in alphauses(block(adj.adjoint, b.id))]
for b in Iterators.reverse(blks)
filter!(alpha_blocks) do (α, bid)
if b.id in doms[bid]
# If a block dominates this block, α is guaranteed to be present here
αvar = Variable(α)
for br in branches(b)
map!(a -> a === α ? αvar : a, br.args, br.args)
end
push!(recs, b.id === last_block ? αvar : alphavars[α])
push!(stks, (bid, α, false))
elseif in_loop[bid]
# This block is in a loop, so we're forced to insert stacks
# Note: all alphas in loops will have stacks after the first iteration
stk = pushfirst!(pr, xstack(Any))
push!(recs, stk)
push!(block(pr, bid), xcall(Zygote, :_push!, stk, Variable(α)))
push!(stks, (bid, α, true))
else
# Fallback case, propagate alpha back through the CFG
argvar = nothing
if b.id > 1
# Need to make sure all predecessors have a branch to add arguments to
IRTools.explicitbranch!(b)
argvar = argument!(b, insert=false)
end
if b.id === last_block
# This alpha has been threaded all the way through to the exit block
alphavars[α] = argvar
end
for br in branches(b)
map!(a -> a === α ? argvar : a, br.args, br.args)
end
for pred in cfgᵀ[b.id]
pred >= b.id && continue # TODO is this needed?
pred_branches = branches(block(pr, pred))
idx = findfirst(br -> br.block === b.id, pred_branches)
if idx === nothing
throw(error("Predecessor $pred of block $(b.id) has no branch to $(b.id)"))
end
branch_here = pred_branches[idx]
push!(branch_here.args, α)
end
# We're not done with this alpha yet, revisit in predecessors
return true
end
return false
end
# Prune any alphas that don't exist on this path through the CFG
for br in branches(b)
map!(a -> a isa Alpha ? nothing : a, br.args, br.args)
end
push!(stks, (b.id, alpha(α)))
end
args = arguments(pr)[3:end]
@assert isempty(alpha_blocks)

rec = push!(pr, xtuple(recs...))
# Pullback{F,Any} reduces specialisation
P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any}
# P = Pullback{F,Any} # reduce specialisation
rec = push!(pr, Expr(:call, P, rec))
ret = xtuple(pr.blocks[end].branches[end].args[1], rec)
ret = push!(pr, ret)
pr.blocks[end].branches[end].args[1] = ret
return pr, stks
end

# Helps constrain pullback function type in the backwards pass
# If we had the type, we could make this a PiNode
notnothing(::Nothing) = error()
notnothing(x) = x

function reverse_stacks!(adj, stks)
ir = adj.adjoint
entry = blocks(ir)[end]
blcks = blocks(ir)
entry = blcks[end]
self = argument!(entry, at = 1)
t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t)))
repl = Dict()
runonce(b) = b.id in (1, length(ir.blocks))
for b in blocks(ir)
for (i, (b′, α)) in enumerate(stks)
t = pushfirst!(entry, xcall(:getfield, self, QuoteNode(:t)))
repl = Dict{Alpha,Variable}()
for b in blcks
for (i, (b′, α, use_stack)) in enumerate(stks)
b.id == b′ || continue
if runonce(b)
val = insertafter!(ir, t, xcall(:getindex, t, i))
else
stk = push!(entry, xcall(:getindex, t, i))
stk = push!(entry, xcall(Zygote, :Stack, stk))
# i.e. recs[i] from forward_stacks!
val = insertafter!(ir, t, xcall(:getindex, t, i))
if use_stack
stk = push!(entry, xcall(Zygote, :Stack, val))
val = pushfirst!(b, xcall(:pop!, stk))
elseif !runonce(b)
# The first and last blocks always run, so this check is redundant there
val = pushfirst!(b, xcall(Zygote, :notnothing, val))
end
repl[α] = val
end
Expand All @@ -87,6 +167,7 @@ end

function stacks!(adj, T)
forw, stks = forward_stacks!(adj, T)
IRTools.domorder!(forw)
back = reverse_stacks!(adj, stks)
permute!(back, length(back.blocks):-1:1)
IRTools.domorder!(back)
Expand Down
6 changes: 3 additions & 3 deletions src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ end
meta, forw, _ = g
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
forw = varargs!(meta, forw, 3)
# IRTools.verify(forw)
# verify(forw)
forw = slots!(pis!(inlineable!(forw)))
# be ready to swap to using chainrule if one is declared
cr_edge != nothing && edge!(meta, cr_edge)
cr_edge !== nothing && edge!(meta, cr_edge)
return update!(meta.code, forw)
end

Expand All @@ -53,7 +53,7 @@ end
end
meta, _, back = g
argnames!(meta, Symbol("#self#"), :Δ)
# IRTools.verify(back)
# verify(back)
back = slots!(inlineable!(back))
return update!(meta.code, back)
end
6 changes: 6 additions & 0 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ function accum_global(cx::Context, ref, x̄)
return
end

# Needed for nested AD
function _pullback(::typeof(accum_global), cx::Context, ref, x̄)
accum_global_pullback(_) = nothing
return accum_global(cx, ref, x̄), accum_global_pullback
end

unwrap(x) = x

@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)
Expand Down
62 changes: 49 additions & 13 deletions test/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Zygote, Test
using Zygote, IRTools, Test
using Zygote: pullback, @adjoint, Context

macro test_inferred(ex)
Expand All @@ -18,24 +18,22 @@ end

bad(x) = x
@adjoint bad(x) = x, Δ -> error("bad")
bad_adjoint_line = @__LINE__() - 1 # source location of above

function badly(x)
x = x + 1
x = bad(x)
return x
end
bad_pullback_line = @__LINE__() - 3 # should match source location of Pullback

y, back = pullback(badly, 2)
@test y == 3
@test_throws Exception back(1)
bt = try back(1) catch e stacktrace(catch_backtrace()) end

@test trace_contains(bt, nothing, "compiler.jl", 20)
if VERSION >= v"1.6-"
@test_broken trace_contains(bt, :badly, "compiler.jl", 24)
else
@test trace_contains(bt, :badly, "compiler.jl", 24)
end
bt = try back(1) catch e stacktrace(catch_backtrace()) end
@test trace_contains(bt, nothing, "compiler.jl", bad_adjoint_line)
@test trace_contains(bt, nothing, "compiler.jl", bad_pullback_line)

# Type inference checks

Expand All @@ -58,10 +56,9 @@ y, back = @test_inferred pullback(f, 5)
y, back = @test_inferred pullback(Core._apply, +, (1, 2, 3))
@test_inferred back(1)

# TODO fix bcast inference
# bcast(x) = x .* 5
# y, back = @test_inferred pullback(bcast, [1,2,3])
# @test_inferred back([1,1,1])
bcast(x) = x .* 5
y, back = @test_inferred pullback(bcast, [1,2,3])
@test_inferred back([1,1,1])

foo = let a = 4
x -> x*a
Expand Down Expand Up @@ -91,6 +88,45 @@ struct Funky
y
end

@testset "stack elision" begin
function isstackfree(T)
_, forw, back = Zygote._generate_pullback_via_decomposition(T)
for (_, stmt) in forw
expr = stmt.expr
expr.head == :call && first(expr.args) == GlobalRef(Zygote, :_push!) && return false
end
for (_, stmt) in back
expr = stmt.expr
expr.head == :call && first(expr.args) == GlobalRef(Zygote, :Stack) && return false
end
return true
end

function knockoff_pow(x, n)
n == 0 && return 1
n == 1 && return x
n == 2 && return x * x
n == 3 && return x * x * x
return x ^ n
end

function roundabout_trig(x, fancy_sin, fancy_cos, fancy_tan)
if fancy_tan
s = fancy_sin ? inv(csc(x)) : sin(x)
c = fancy_cos ? inv(sec(x)) : cos(x)
s += 0
c *= 1
return s / c
else
return tan(x)
end
end

@test !isstackfree(Tuple{typeof(pow), Int, Int})
@test isstackfree(Tuple{typeof(knockoff_pow), Int, Int})
@test isstackfree(Tuple{typeof(roundabout_trig), Float64, Bool, Bool, Bool})
end

@testset "issue #851" begin
f = Funky(1, 1);
function Base.getproperty(f::Funky, i::Symbol)
Expand Down Expand Up @@ -128,7 +164,7 @@ end
d_two = Zygote.pullback(two_svds, X)[2](Δoutput)
d_one = Zygote.pullback(one_svd, X)[2](Δoutput)
@test d_one == d_two
end
end

# this test fails if adjoint for literal_getproperty is added
# https://github.com/FluxML/Zygote.jl/issues/922#issuecomment-804128905
Expand Down
2 changes: 1 addition & 1 deletion test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ end == (2,)
global_param = 3

@testset "Global Params" begin
cx = Zygote.Context()
cx = Zygote.Context{true}(nothing) # only makes sense with implicit params
y, back = Zygote._pullback(cx, x -> x*global_param, 2)
@test y == 6
@test back(1) == (nothing, 3)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Zygote, Test
using Zygote: gradient, ZygoteRuleConfig
using CUDA
using CUDA: has_cuda
using LinearAlgebra

@testset "all" begin # Overall testset ensures it keeps running after failure

Expand Down
1 change: 0 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using LinearAlgebra
using ForwardDiff
using Zygote: hessian_dual, hessian_reverse

Expand Down