Skip to content

Commit

Permalink
Merge pull request #758 from JuliaDiff/ChrisRackauckas-patch-1
Browse files Browse the repository at this point in the history
Add Tridiagonal construction rule
  • Loading branch information
oxinabox authored Jun 10, 2024
2 parents be9c221 + ce288b9 commit 78f00cb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,16 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
end
return y, logdet_pullback
end

#####
##### Tridiagonal
#####

function rrule(::Type{Tridiagonal}, dl, d, du)
y = Tridiagonal(dl, d, du)
function Tridiagonal_pullback(ȳ)
∂y = unthunk(ȳ)
return (NoTangent(), diag(∂y, -1), diag(∂y), diag(∂y, 1))
end
return y, Tridiagonal_pullback
end
4 changes: 4 additions & 0 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,8 @@
end
end
end

@testset "Tridiagonal" begin
test_rrule(Tridiagonal, [1.0, 4.0], [2.0, 3.0, 4.0], [5.0, 3.0])
end
end

0 comments on commit 78f00cb

Please sign in to comment.