Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ProjectTo convert Tangent back to Diagonal, etc, when safe #446

Closed
wants to merge 21 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 22, 2021

The example from #441 (comment) is x -> sqrt(Diagonal(x)), whose implementation is sqrt.(x.diag). At present, Zygote returns a "structural" tangent for this, i.e. a NamedTuple. When the .diag is the very first operation being performed, this is returned, but if it occurs after other operations, then their gradients will tend not to understand this.

This PR proposes that there should be a method ProjectTo{Diagonal}(::Tangent) which converts this back to the "natural" form, i.e. to another Diagonal. To try it out:

# ] dev ChainRulesCore
using Zygote, ChainRulesCore, LinearAlgebra

gradient(x -> sum(sqrt.(x.diag)), Diagonal([1,2,3]))  # returns a NamedTuple containing a vector.
gradient(x -> sum(sqrt.((cbrt.(x)).diag)), Diagonal([1,2,3]))  # gradient of cbrt.(x) fails 

# Instead of hacking the gradient for getproperty, for now insert this:
function ChainRulesCore.rrule(::typeof(identity), x::AbstractArray)
  x, dx -> (NoTangent(), ProjectTo(x)(dx))
end

gradient(x -> sum(sqrt.(identity(x).diag)), Diagonal([1,2,3]))[1]  # returns a Diagonal
gradient(x -> sum(sqrt.((identity(cbrt.(x))).diag)), Diagonal([1,2,3]))[1]  # works!
# Note these print T = Any, from (::Tangent{T}) where T = (@show T; ...)

There are were two immediate hurdles here. One is that, to work with Base's sqrt(::Diagonal) method, you would have to insert a projection step into Zygote's litereal_getproperty adjoint definition. It's not immediately obvious to me how to do that. Done in FluxML/Zygote.jl#1104

The second is that Zygote makes a Tangent{Any} here. I thought that dx::Tangent{T} was supposed to always have typeof(x) == T exactly. If that's not true, then must we worry about getting a Tangent which doesn't come from a Diagonal at all? See below, I guess this wants FluxML/Zygote.jl#1057

I added a similar line for UpperTriangular. At present x::UpperTriangular accepts dx::Diagonal as a "natural" gradient. If it must accept dx::Tangent{Diagonal} too, well we could write a method for that. But how generally that can work I don't know. I'm not sure it's worth trying.

Beyond the immediate, this is still an easy case of the problems discussed in #441 (not 411, sorry!), in that the Tangent contains a Vector and can trivially be re-wrapped to form a Diagonal. (Before finding this Any, I was trying to make dispatch restrict it to this case.) It doesn't address what should happen if the content of the Tangent{Diagonal} is some other weird structural thing, which cannot itself be ProjectTo'd to an AbstractVector -- that's what we don't have concrete examples of yet.

Edit -- this particular example is now handled explicitly by JuliaDiff/ChainRules.jl#509.

Edit' -- the Tangent{Any} isn't an inference failure, it's explicitly constructed that way because the input NamedTuple doesn't have the type, here:
https://github.com/FluxML/Zygote.jl/blob/05d0c2ae04f334a2ec61e42decfe1172d0f2e6e8/src/compiler/chainrules.jl#L126-L129

@codecov-commenter
Copy link

codecov-commenter commented Aug 22, 2021

Codecov Report

Merging #446 (2cd6f89) into main (a9840b4) will decrease coverage by 2.87%.
The diff coverage is 82.69%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #446      +/-   ##
==========================================
- Coverage   93.03%   90.16%   -2.88%     
==========================================
  Files          15       15              
  Lines         862      925      +63     
==========================================
+ Hits          802      834      +32     
- Misses         60       91      +31     
Impacted Files Coverage Δ
src/tangent_types/abstract_zero.jl 84.61% <79.41%> (-10.84%) ⬇️
src/projection.jl 89.41% <84.28%> (-7.89%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a9840b4...2cd6f89. Read the comment docs.

@mcabbott mcabbott added ProjectTo related to the projection functionality Structural Tangent Related to the `Tangent` type for structured (composite) values labels Aug 22, 2021
@mcabbott mcabbott marked this pull request as ready for review October 15, 2021 19:33
src/projection.jl Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
@mcabbott mcabbott changed the title Make ProjectTo convert Tangent back to Diagonal Make ProjectTo convert Tangent back to Diagonal, etc, when safe Oct 15, 2021
@mcabbott
Copy link
Member Author

mcabbott commented Oct 16, 2021

[Edited!]

The simplest effect is like so, although right now this works only if it accepts Tangent{Any}:

julia> Zygote.gradient(x -> parent(x)[1], Diagonal([1,2,3]))[1]
3×3 Diagonal{Float64, Vector{Float64}}:
 1.0        
     0.0    
         0.0

Zygote is applying the projection at the last backward step above. But the point of this is really to apply it at the first step. I think this can be done as follows: This is done by FluxML/Zygote.jl#1104 now.

With that, the gradient can propagate through for instance further broadcasting steps:

julia> gradient(x -> sum(sqrt.((cbrt.(x)).diag)), Diagonal([1,2,3]))[1]
3×3 Diagonal{Float64, Vector{Float64}}:
 0.166667              
          0.0935385    
                    0.0667187

Without this PR, Zygote would like to return a NamedTuple. Which cannot be handled by the gradient of broadcasting.

In fact the generic (::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx is not wide enough right now -- projection gives an error on Tangent{Any}:

julia> pullback(x -> parent(x)[1], Diagonal([1,2,3]))[2](1.0)  # no projection
((diag = [1.0, 0.0, 0.0],),)

julia> gradient(x -> parent(x)[1], Diagonal([1,2,3]))[1]  # with projection
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}})(::ChainRulesCore.Tangent{Any, NamedTuple{(:diag,), Tuple{Vector{Float64}}}})

That error is a problem with or without this PR. Maybe it should be widened? This should also fixed by FluxML/Zygote.jl#1104, which uses zygote2differential to produce Tangent{T} instead of Tangent{Any}.


# Diagonal
ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag))
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx)))
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))
function (project::ProjectTo{Diagonal})(dx::Tangent) # structural => natural
return dx.diag isa Tangent ? dx.diag : Diagonal(project.diag(dx.diag))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain to me why the right thing to do is to return the diag field of dx if it's a Tangent? I would have thought the correct thing to do would be to throw an error or something, since the conversion can't be performed.

Copy link
Member Author

@mcabbott mcabbott Oct 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that if you manage to produce something like this:

dx = Tangent{Diagonal}(; diag=Tanget{SVector}(; data = ??))

then we can't wrap dx.diag. I'm not really sure that can happen, though.

Oh I see. In that case I meant to return dx untouched. But I have a typo.

I meant also to think about dx.diag isa Thunk, can that happen?

dx.diag isa AbstractZero is handled by the Diagonal constructor now, this PR. Although perhaps a cleaner pattern would be to handle that here. What's nice about that is that project.diag(dx.diag) could probably produce a Zero?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed. I reversed the comparison to

return dx.diag isa ArrayOrZero ? Diagonal(project.diag(dx.diag)) : dx

@oxinabox
Copy link
Member

I am going to leave this to @willtebbutt to review.
and unsubscribe.
Ping me if i am needed

@mcabbott
Copy link
Member Author

mcabbott commented Nov 3, 2021

Spotted today, this is an example in the wild of what this PR wants to do:

https://github.com/SciML/DiffEqFlux.jl/blob/61cc51e1b63709d55655c53d50244cc3932bd60e/src/DiffEqFlux.jl#L71-L73

That's for Tridiagonal. Notice that it always has to make two zero Vectors on each call, which seems a bit unfortunate. And is probably an argument for allowing these to make Diagonal, Bidiagonal, subspaces.

Comment on lines -396 to -463
for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg
for UL in (:UpperTriangular, :LowerTriangular, :UpperHessenberg)
@eval begin
ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x)))
(project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx))
function (project::ProjectTo{$UL})(dx::Diagonal)
sub = project.parent
sub_one = ProjectTo{project_type(sub)}(;
element=sub.element, axes=(sub.axes[1],)
)
return Diagonal(sub_one(dx.diag))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To explain what's going on here:

  • First, this used to include UnitUpperTriangular for which it was wrong. The gradient of that has to be zero on the diagonal, not one. So that has moved to its own case, where it does UnitUpperTriangular(dx) .- I instead, which makes in fact an UpperTriangular, not a subtype.
  • Second, the handling of (project::ProjectTo{$UL})(dx::Diagonal) is much simplified. Instead of inventing the projector needed and handling the diagonal by hand, it exploits the fact that map(ProjectTo{Float32}, ::Diagonal) already knows what to do.

The second idea greatly simplifies many more exotic examples below, such as (project::ProjectTo{Tridiagonal})(dx::Bidiagonal) = project.full(dx).

Comment on lines 57 to 68
LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractVector, du::AbstractZero) = Bidiagonal(d, dl, :L)
function LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractZero, du::AbstractVector)
d = fill!(similar(dl, length(dl) + 1), 0)
Tridiagonal(convert(typeof(d), dl), d, convert(typeof(d), du))
end
# two Zeros:
LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractZero) = Diagonal(d)
LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractZero, du::AbstractVector) = Bidiagonal(d, du, :U)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, accepting Diagonal and Bidiagonal as the gradients of Tridiagonal avoids many cases of making an empty array to pad out the type. These types only accept same-type vectors, so you cannot pad it with a Fill.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 8, 2021

OK I think this is done. Plays well with Zygote 6.30:

julia> gradient(x -> sqrt(sum((x .^ 2).ev)), Bidiagonal([1,2,3], [4,5], :U))[1]
3×3 Bidiagonal{Float64, Vector{Float64}}:
 0.0  0.624695    
     0.0       0.780869
              0.0

This was wrong before:

julia> gradient(x -> sum(abs, x), UnitUpperTriangular(rand(3,3)))[1]
3×3 UpperTriangular{Float64, Matrix{Float64}}:
 0.0  1.0  1.0
     0.0  1.0
         0.0

Two bugs I give up on for now:

julia> UpperTriangular(Fill(3,3,3)) - I  # not my fault. Fixed in Julia 1.8
ERROR: ArgumentError: Cannot setindex! to 2 for an AbstractFill with value 3.

julia> pullback(x -> sqrt(sum((x .^ 2).dv)), Bidiagonal([1,2,3], [4,5], :U))[2](1)[1]  # My code makes a Diagonal, I have no idea where the Tri comes from
3×3 Tridiagonal{Float64, Vector{Float64}}:
 0.267261  0.0         
 0.0       0.534522  0.0
          0.0       0.801784

@oxinabox
Copy link
Member

oxinabox commented Mar 2, 2022

Sorry this is taking so long to review, It is just slightly too big (conceptually) for me to review during the time i normally have aside to review things, and so I keep starting to review it and not completing it.

@mcabbott
Copy link
Member Author

mcabbott commented Jun 7, 2022

Bump?

This is one of the last pieces of the ProjectTo story worked out last summer, but somehow hasn't made it. We'd have heard if its absence was holding anyone back, so it can't be that important, but it does clarify what the design is, a bit.

@oxinabox
Copy link
Member

sorry this has been on my to do for a long time.
It's not forgotten its just big and needs some thinking

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ProjectTo related to the projection functionality Structural Tangent Related to the `Tangent` type for structured (composite) values
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants