diff --git a/src/lib/base.jl b/src/lib/base.jl index c0efa83cf..4ee5947a5 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -17,6 +17,19 @@ end end end +# IdSet (needed for nested AD with implicit params) + +grad_mut(::IdSet) = IdSet() + +function _pullback(cx::AContext, ::typeof(push!), s::IdSet, @nospecialize(x)) + res = push!(s, x) + function idset_push!_pullback(_) + Δ = pop!(grad_mut(cx, s), x, nothing) + (nothing, Δ, nothing) + end + return res, idset_push!_pullback +end + # Dictionaries grad_mut(d::AbstractDict) = Dict()