Skip to content

Commit

Permalink
Merge pull request #992 from DhairyaLGandhi/dg/941
Browse files Browse the repository at this point in the history
Differentiate `push!` with implicit Params
  • Loading branch information
CarloLucibello authored Jun 24, 2021
2 parents 18a6f2a + ce8eb91 commit 87e2f12
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,15 @@ Base.adjoint(f::Function) = x -> gradient(f, x)[1]

# TODO store ids only
struct Params
order::Buffer{Any, Vector{Any}}
order::Buffer # {Any, Vector{Any}}
params::IdSet{Any}
Params() = new(Buffer([], false), IdSet())
end

Params() = Params(Buffer([], false), IdSet())
Params(xs) = Params(Buffer(xs, false), IdSet(xs))
Params(ps::Params) = ps
Params(xs::Tuple) = Params(collect(xs))

@forward Params.order Base.iterate, Base.length, Base.getindex
@forward Params.params Base.in

Expand Down Expand Up @@ -103,6 +107,20 @@ function Base.push!(ps::Params, x)
return ps
end

@adjoint! function Base.push!(xs::IdSet, x...)
l = length(x)
push!(xs, x...), Δ -> begin
(Δ, ntuple(_ -> nothing, l)...)
end
end

@adjoint! function Base.push!(xs::Params, x::AbstractArray{T}...) where T
sz_x = size.(x)
push!(xs, x...), Δ -> begin
(Δ, map(x -> Ones{T}(x...), sz_x)...)
end
end

Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)

function Base.delete!(ps::Params, x)
Expand All @@ -114,8 +132,6 @@ function Base.delete!(ps::Params, x)
return ps
end

Params(xs) = push!(Params(), xs...)

Base.Broadcast.broadcasted(f, ps::Params) = broadcasted(f, ps.order)

Base.:(==)(x::Params, y::Params) = x.order.data == y.order.data
Expand Down
40 changes: 40 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,44 @@ end
@test all(abs.(gs[w]) .<= 1e-5)
@test all(abs.(gs[b]) .<= 1e-5)
end

@testset "Params nesting" begin
struct Dense{F,T,S}
W::T
b::S
σ::F
end

(d::Dense)(x) = d.σ.(d.W * x .+ d.b)
d = Dense(ones(Float32, 3,3), zeros(Float32, 3), identity)
ps = Zygote.Params([d.W, d.b])
r = ones(Float32, 3,3)

gs = gradient(ps) do
p, pb = pullback(ps) do
sum(d(r))
end
g = pb(p)
sum(g[d.W]) # + sum(g[d.b])
end

@test gs[d.W] fill(81f0, (3,3))

# Test L2
l2g = gradient(ps) do
sum(sum(x .^ 2) for x in ps)
end
@test l2g[d.W] fill(2.f0, size(d.W))
@test l2g[d.b] fill(0.f0, size(d.b))

# Can be safely removed - creating Params within
# gradient calls may break between releases.
sgs = gradient(ps) do
sum(sum(x) for x in Zygote.Params([d.W, d.b]))
end
@test sgs[d.W] fill(1.f0, size(d.W))
@test sgs[d.b] fill(1.f0, size(d.b))
end


end

0 comments on commit 87e2f12

Please sign in to comment.