diff --git a/src/lib/base.jl b/src/lib/base.jl index e259a999d..88738afb1 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -231,3 +231,47 @@ end fallback_Fix2(y) = f(y, x) return _pullback(__context__, fallback_Fix2, y) end + +# function ChainRulesCore.rrule(::typeof(Dict), xs::Pair...) +# function Dict_pullback(Δ) +# return (NoTangent(), ((first=ZeroTangent(), second=get(Δ, x[1], ZeroTangent())) for x in xs)...) +# end +# return Dict(xs...), Dict_pullback +# end + +# function ChainRulesCore.rrule(::typeof(Dict), xs::AbstractVector{<:Pair}) +# function Dict_pullback(Δ) +# x̄s = [(first=ZeroTangent(), second=get(Δ, x[1], ZeroTangent())) for x in xs] +# return (NoTangent(), x̄s) +# end +# return Dict(xs), Dict_pullback +# end + + +function Zygote._pullback(::AContext, ::typeof(Dict), xs::Pair...) + function Dict_pullback(Δ) + return (nothing, ((first=nothing, second=get(Δ, x[1], nothing)) for x in xs)...) + end + return Dict(xs...), Dict_pullback +end + +function Zygote._pullback(::AContext, ::typeof(Dict), xs::AbstractVector{<:Pair}) + function Dict_pullback(Δ) + x̄s = [(first=nothing, second=get(Δ, x[1], nothing)) for x in xs] + return (nothing, x̄s) + end + return Dict(xs), Dict_pullback +end + +# iterable of pairs / generator +function _pullback(cx::AContext, ::typeof(Dict), xs) + a, pba = _pullback(cx, collect, xs) + y, pby = _pullback(cx, Dict, a) + function Dict_pullback(Δ) + Δa = pby(Δ)[2] + @show a Δa Δ + Δxs = pba(Δa) + return (nothing, Δxs) + end + return y, Dict_pullback +end diff --git a/test/features.jl b/test/features.jl index e4fe61140..85317103d 100644 --- a/test/features.jl +++ b/test/features.jl @@ -835,3 +835,48 @@ end end @test gradient(f760, 3)[1] ≈ 123.93054835019153 end + +@testset "Dict constructors" begin + # pair + g = gradient(1 => 2) do x + d = Dict(x) + d[1] + end[1] + @test g == (first = nothing, second = 1) + + # pairs + g = gradient(1 => 2, 2 => 3, 4=>10) do x1, x2, x3 + d = Dict(x1, x2, x3) + d[1] + 2*d[4] + end + @test g == ((first = nothing, second = 1), nothing, (first = nothing, second = 2.0)) + + # array of pairs + g = gradient(2) do c + d = Dict([i => i*c for i in 1:3]) + d[1] + 2*d[2] + end[1] + @test g == 5 + + # generator of pairs + @test_broken gradient(2) do c + d = Dict(i => i*c for i in 1:3) + d[1] + 2*d[2] + end[1] +end + +# pullback(Dict, 1 => 2) + +# Zygote.refresh() +# y, pb = Zygote._pullback(Zygote.Context(), Dict, 1 => 2) +# pb(Dict(1 => 5)) + +# gradient(2) do c +# d = Dict(i => i*c for i in 1:3) +# d[1] + 2*d[2] +# end[1] + +# gradient(2) do c +# d = collect(i => i*c for i in 1:3) +# d[1][2] + 2*d[2][2] +# end[1] \ No newline at end of file