From 3d1e0c6e8416378863aabd246d34c423816fc7cc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 19 May 2024 17:55:05 -0400 Subject: [PATCH] Restore the rrule for merge --- src/chainrules.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/chainrules.jl b/src/chainrules.jl index 642302a8c..5e602fcd4 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -13,6 +13,19 @@ CRC.@non_differentiable Base.printstyled(::Any...) CRC.@non_differentiable fieldcount(::Any) # Utilities +## DON'T REMOVE THIS CAUSES DOWNSTREAM FAILURES +function CRC.rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1, F2} + y = merge(nt1, nt2) + function ∇merge(dy) + dnt1 = NamedTuple((f1 => (f1 in F2 ? NoTangent() : getproperty(dy, f1)) + for f1 in F1)) + dnt2 = NamedTuple((f2 => getproperty(dy, f2) for f2 in F2)) + return (NoTangent(), dnt1, dnt2) + end + ∇merge(::Union{NoTangent, ZeroTangent}) = (NoTangent(), NoTangent(), NoTangent()) + return y, ∇merge +end + function CRC.rrule(::typeof(_eachslice), x, d::Val) return _eachslice(x, d), @closure(Δ->(NoTangent(), ∇_eachslice(Δ, x, d), NoTangent())) end