Skip to content

Commit

Permalink
Merge pull request #1001 from mcabbott/broadget
Browse files Browse the repository at this point in the history
Faster generic broadcasting
  • Loading branch information
mcabbott authored Jun 25, 2021
2 parents b170521 + 8424c3e commit 8d5efcb
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 32 deletions.
23 changes: 16 additions & 7 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ end

struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
(::StaticGetter{i})(::Nothing) where {i} = nothing
@generated function _unzip(tuples, ::Val{N}) where {N}
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i 1:N)...)
end
Expand All @@ -214,19 +215,27 @@ _tryreverse(m, x) = x
_tryreverse(m::typeof(map), x::Union{AbstractVector, Tuple}) = reverse(x)

for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
@eval function $∇mapfunc(cx, f, args...)
@eval function $∇mapfunc(cx, f::F, args...) where {F}
ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> nothing
else
ys, backs = unzip(ys_and_backs)
ys = map(first, ys_and_backs)
ys, function (Δ)
isnothing(Δ) && return nothing
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
Δf_and_args_zipped = $mapfunc((f, δ) -> f(δ), _tryreverse($mapfunc, backs, Δ)...)
Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped))
Δf = reduce(accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
if Base.issingletontype(F) && length(args) == 1
Δarg = $mapfunc(((_,pb), δ) -> last(pb(δ)), ys_and_backs, Δ) # No unzip needed
(nothing, Δarg)
elseif Base.issingletontype(F) # Ensures `f` is pure: nothing captured & no state
Δargs = unzip($mapfunc(((_,pb), δ) -> Base.tail(pb(δ)), ys_and_backs, Δ))
(nothing, Δargs...)
else
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
Δf_and_args_zipped = $mapfunc(((_,pb), δ) -> pb(δ), _tryreverse($mapfunc, ys_and_backs, Δ)...)
Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped))
Δf = reduce(accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
end
end
end
Expand Down
42 changes: 28 additions & 14 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,31 +164,48 @@ end
# Avoid hitting special cases for `Adjoint` etc.
_broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))

_get(x::Tuple, i) = x[i]
_get(::Nothing, i) = nothing
collapse_nothings(xs::AbstractArray{Nothing}) = nothing
collapse_nothings(xs) = xs

@adjoint function broadcasted(::AbstractArrayStyle, f, args...)
_dual_purefun(::Type{F}) where {F<:Function} = Base.issingletontype(F)
_dual_purefun(::Type) = false
_dual_purefun(::Type{typeof(^)}) = false # avoid DomainError from negative powers

_dual_safearg(x::Numeric{<:Real}) = true
_dual_safearg(x::Ref{<:Numeric{<:Real}}) = true
_dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types
_dual_safearg(x) = false

@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
if T == Bool
return f.(args...), _->nothing
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args)
y, back = broadcast_forward(f, args...)
return y, ȳ -> (nothing, nothing, back(ȳ)...)
end
len = inclen(args)
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y = map(x -> x[1], y∂b)
∂b = map(x -> x[2], y∂b)
y, function (ȳ)
dxs_zip = map((∂b, ȳ) -> ∂b(ȳ), ∂b, ȳ)
dxs = collapse_nothings.(ntuple(i -> map(x -> _get(x, i), dxs_zip), len))
y = map(first, y∂b)
function ∇broadcasted(ȳ)
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
dxs = ntuple(len) do i
collapse_nothings(map(StaticGetter{i}(), dxs_zip))
end
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
end
y, ∇broadcasted
end

@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...)
len = inclen(args)
y, ∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y, function (ȳ)
function ∇broadcasted0(ȳ)
dxs = ∂b(ȳ)
dxs === nothing && return nothing
(nothing, dxs...)
end
y, ∇broadcasted0
end

# Use the `map` adjoint in this special case, which is the same but applies
Expand All @@ -202,17 +219,14 @@ end

@adjoint! (b::typeof(broadcast))(f, args...) = _pullback(__context__, broadcasted, f, args...)

# Forward Mode (mainly necessary for CUDA)
# Forward Mode -- necessary for CUDA, also used as a fast path above

import ForwardDiff
using ForwardDiff: Dual

dual(x, p) = x
dual(x::Real, p) = Dual(x, p)

dualtype(::Type{Dual{G,T,P}}) where {G,T,P} = T
dualtype(T) = T

function dual_function(f::F) where F
function (args::Vararg{Any,N}) where N
ds = map(args, ntuple(identity,Val(N))) do x, i
Expand Down
51 changes: 41 additions & 10 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -500,14 +500,45 @@ end
@test 150_000_000 > @allocated gradient(loss, ones(1000,1000))
end

@testset "tuples & broadcasting" begin
@test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
@test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
@test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),)

# https://github.com/FluxML/Zygote.jl/issues/975
gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2])
@test gt[1] == gv[1]
@test collect(gt[2]) gv[2]
@testset "tricky broadcasting" begin
@test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
@test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
@test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),)

# https://github.com/FluxML/Zygote.jl/issues/975
gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2])
@test gt[1] == gv[1]
@test collect(gt[2]) gv[2]

# closure captures y -- can't use ForwardDiff
@test gradient((x,y) -> sum((z->z^2+y[1]).(x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
@test gradient((x,y) -> sum((z->z^2+y[1]), x), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
@test gradient((x,y) -> sum(map((z->z^2+y[1]), x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
@test gradient((x,y) -> mapreduce((z->z^2+y[1]), +, x), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])

# type unstable
@test gradient(xs -> sum((x -> x<2 ? false : x^2).(xs)), [1,2,3])[1][2:3] == [4, 6]
@test gradient(xs -> sum((x -> x<2 ? false : x^2), xs), [1,2,3])[1][2:3] == [4, 6]
@test gradient(xs -> sum(map((x -> x<2 ? false : x^2), xs)), [1,2,3])[1][2:3] == [4, 6]
@test gradient(xs -> mapreduce((x -> x<2 ? false : x^2), +, xs), [1,2,3])[1][2:3] == [4, 6]

# with Ref, Val, Symbol
@test gradient(x -> sum(x .+ Ref(x[1])), [1,2,3]) == ([4,1,1],)
@test gradient(x -> sum(x .+ (x[1],)), [1,2,3]) == ([4,1,1],)
@test gradient(x -> sum((firsttuple).(x, :ignore)), [1,2,3]) == ([1,1,1],)
@test gradient(x -> sum((firsttuple).(x, Symbol)), [1,2,3]) == ([1,1,1],)
_f(x,::Val{y}=Val(2)) where {y} = x/y
@test gradient(x -> sum(_f.(x, Val(2))), [1,2,3]) == ([0.5, 0.5, 0.5],)
@test gradient(x -> sum(_f.(x)), [1,2,3]) == ([0.5, 0.5, 0.5],)
@test gradient(x -> sum(map(_f, x)), [1,2,3]) == ([0.5, 0.5, 0.5],)

@test gradient(x -> sum(x ./ [1,2,4]), [1,2,pi]) == ([1.0, 0.5, 0.25],)
@test gradient(x -> sum(map(/, x, [1,2,4])), [1,2,pi]) == ([1.0, 0.5, 0.25],)

# negative powers
@test gradient((x,p) -> sum(x .^ p), [1.0,2.0,4.0], [1,-1,2])[1] [1.0, -0.25, 8.0]
@test gradient((x,p) -> sum(x .^ p), [1.0,2.0,4.0], -1)[1] [-1.0, -0.25, -0.0625]
@test gradient((x,p) -> sum(z -> z^p, x), [1.0,2.0,4.0], -1)[1] [-1.0, -0.25, -0.0625]
@test gradient((x,p) -> mapreduce(z -> z^p, +, x), [1.0,2.0,4.0], -1)[1] [-1.0, -0.25, -0.0625]
end
3 changes: 2 additions & 1 deletion test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,8 @@ end
end

@testset "broadcast" begin
@test gradient(x -> sum(sin.(x)), Diagonal(randn(3)))[1][2] == 1
# Before https://github.com/FluxML/Zygote.jl/pull/1001 this gave [1 1 1; 1 0 1; 1 1 -1]
@test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] [1 0 0; 0 0 0; 0 0 -1]

a = rand(3)
b = rand(2,2)
Expand Down

0 comments on commit 8d5efcb

Please sign in to comment.