Skip to content

Commit

Permalink
allow 3d arrays for Diagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 12, 2022
1 parent c20360a commit 487d216
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,10 @@ end
ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag))
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx)))
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))
function (project::ProjectTo{Diagonal})(dx::AbstractArray)
ind = diagind(size(dx,1), size(dx,2), 0)
return Diagonal(project.diag(dx[ind]))
end
function (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal}) # structural => natural
return dx.diag isa ArrayOrZero ? Diagonal(project.diag(dx.diag)) : dx
end
Expand Down
1 change: 1 addition & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ struct NoSuperType end
@testset "LinearAlgebra: sparse structured matrices" begin
pdiag = ProjectTo(Diagonal(1:3))
@test pdiag(reshape(1:9, 3, 3)) == Diagonal([1, 5, 9])
@test pdiag(reshape(1:9, 3, 3, 1)) == Diagonal([1, 5, 9])
@test pdiag(pdiag(reshape(1:9, 3, 3))) == pdiag(reshape(1:9, 3, 3))
@test pdiag(rand(ComplexF32, 3, 3)) isa Diagonal{Float64}
@test pdiag(Diagonal(1.0:3.0)) === Diagonal(1.0:3.0)
Expand Down

0 comments on commit 487d216

Please sign in to comment.