Skip to content

Commit

Permalink
update to ImplicitDifferentiation 0.5 (#13)
Browse files Browse the repository at this point in the history
* update to ImplicitDifferentiation 0.5

* update to ID 0.5

* bump version
  • Loading branch information
mohamed82008 committed Sep 13, 2023
1 parent ae304ae commit 3df0cf0
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 137 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiableFactorizations"
uuid = "f7876f94-e99c-4755-b0c6-59dc4ff4934d"
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
version = "0.2.0"
version = "0.2.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -12,7 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[compat]
ChainRulesCore = "1"
ComponentArrays = "0.11, 0.12, 0.13"
ImplicitDifferentiation = "0.4"
ImplicitDifferentiation = "0.5"
julia = "1"

[extras]
Expand Down
254 changes: 128 additions & 126 deletions src/DifferentiableFactorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,191 +6,193 @@ using LinearAlgebra, ImplicitDifferentiation, ComponentArrays, ChainRulesCore

# QR

function qr_conditions(A, x, _)
(; Q, R) = x
return vcat(
vec(UpperTriangular(Q' * Q) + LowerTriangular(R) - I - Diagonal(R)),
vec(Q * R - A),
)
function qr_conditions(A, x)
(; Q, R) = x
return vcat(
vec(UpperTriangular(Q' * Q) + LowerTriangular(R) - I - Diagonal(R)),
vec(Q * R - A),
)
end
function qr_forward(A)
qr_res = qr(A)
Q = copy(qr_res.Q[:, 1:size(A, 2)])
(; R) = qr_res
return ComponentVector(; Q, R), 0
qr_res = qr(A)
Q = copy(qr_res.Q[:, 1:size(A, 2)])
(; R) = qr_res
return ComponentVector(; Q, R)
end
const _diff_qr = ImplicitFunction(qr_forward, qr_conditions)
const _diff_qr = ImplicitFunction(qr_forward, qr_conditions, DirectLinearSolver(), nothing)
function diff_qr(A)
(; Q, R) = _diff_qr(A)[1]
return (; Q, R)
(; Q, R) = _diff_qr(A)
return (; Q, R)
end

# Cholesky

function cholesky_conditions(A, U, _)
return vec(UpperTriangular(U' * U) + LowerTriangular(U) - UpperTriangular(A) - Diagonal(U))
function cholesky_conditions(A, U)
return vec(
UpperTriangular(U' * U) + LowerTriangular(U) - UpperTriangular(A) - Diagonal(U),
)
end
function cholesky_forward(A)
ch_res = cholesky(A)
return ch_res.U, 0
ch_res = cholesky(A)
return ch_res.U
end
const _diff_cholesky = ImplicitFunction(cholesky_forward, cholesky_conditions)
const _diff_cholesky =
ImplicitFunction(cholesky_forward, cholesky_conditions, DirectLinearSolver(), nothing)
function diff_cholesky(A)
U = _diff_cholesky(A)[1]
return (; L = U', U)
U = _diff_cholesky(A)
return (; L = U', U)
end

# LU

function lu_conditions(A, LU, _)
(; L, U, p) = LU
pint = convert(Vector{Int}, p)
return vcat(
vec(UpperTriangular(L) + LowerTriangular(U) - Diagonal(U) - I),
vec(L * U - A[pint, :]),
p,
)
function lu_conditions(A, LU, p)
(; L, U) = LU
return vcat(
vec(UpperTriangular(L) + LowerTriangular(U) - Diagonal(U) - I),
vec(L * U - A[p, :]),
)
end
function lu_forward(A)
lu_res = lu(A)
(; L, U, p) = lu_res
return ComponentVector(; L, U, p), 0
lu_res = lu(A)
(; L, U, p) = lu_res
return ComponentVector(; L, U), p
end
const _diff_lu = ImplicitFunction(lu_forward, lu_conditions)
const _diff_lu = ImplicitFunction(lu_forward, lu_conditions, DirectLinearSolver(), nothing)
function diff_lu(A)
temp = _diff_lu(A)[1]
(; L, U, p) = temp
return (; L, U, p = convert(Vector{Int}, p))
temp, p = _diff_lu(A)
(; L, U) = temp
return (; L, U, p)
end

# Eigen

comp_vec(A) = ComponentVector((; A))
comp_vec(A, B) = ComponentVector((; A, B))
function ChainRulesCore.rrule(::typeof(comp_vec), A)
out = comp_vec(A)
T = typeof(out)
return out, Δ -> begin
= convert(T, Δ)
(NoTangent(), _Δ.A)
end
out = comp_vec(A)
T = typeof(out)
return out, Δ -> begin
= convert(T, Δ)
(NoTangent(), _Δ.A)
end
end
function ChainRulesCore.rrule(::typeof(comp_vec), A, B)
out = comp_vec(A, B)
T = typeof(out)
return out, Δ -> begin
= convert(T, Δ)
(NoTangent(), _Δ.A, _Δ.B)
end
end

function eigen_conditions(AB, sV, _)
(; s, V) = sV
(; A) = AB
if hasproperty(AB, :B)
(; B) = AB
else
B = I
end
return vcat(
vec(A * V - B * V * Diagonal(s)),
diag(V' * B * V) .- 1,
)
out = comp_vec(A, B)
T = typeof(out)
return out, Δ -> begin
= convert(T, Δ)
(NoTangent(), _Δ.A, _Δ.B)
end
end

function eigen_conditions(AB, sV)
(; s, V) = sV
(; A) = AB
if hasproperty(AB, :B)
(; B) = AB
else
B = I
end
return vcat(vec(A * V - B * V * Diagonal(s)), diag(V' * B * V) .- 1)
end
function eigen_forward(AB)
(; A) = AB
if hasproperty(AB, :B)
(; B) = AB
eig_res = eigen(A, B)
else
eig_res = eigen(A)
end
s = eig_res.values
V = eig_res.vectors
return ComponentVector(; s, V), 0
end

const _diff_eigen = ImplicitFunction(eigen_forward, eigen_conditions)
(; A) = AB
if hasproperty(AB, :B)
(; B) = AB
eig_res = eigen(A, B)
else
eig_res = eigen(A)
end
s = eig_res.values
V = eig_res.vectors
return ComponentVector(; s, V)
end

const _diff_eigen =
ImplicitFunction(eigen_forward, eigen_conditions, DirectLinearSolver(), nothing)
function diff_eigen(A)
(; s, V) = _diff_eigen(comp_vec(A))[1]
return (; s , V)
(; s, V) = _diff_eigen(comp_vec(A))
return (; s, V)
end
function diff_eigen(A, B)
(; s, V) = _diff_eigen(comp_vec(A, B))[1]
return (; s , V)
(; s, V) = _diff_eigen(comp_vec(A, B))
return (; s, V)
end

function schur_conditions(A, Z_T, _)
(; Z, T) = Z_T
return vcat(
vec(Z' * A * Z - T),
vec(Z' * Z - I + LowerTriangular(T) - Diagonal(T)),
)
function schur_conditions(A, Z_T)
(; Z, T) = Z_T
return vcat(vec(Z' * A * Z - T), vec(Z' * Z - I + LowerTriangular(T) - Diagonal(T)))
end
function schur_forward(A)
schur_res = schur(A)
(; Z, T) = schur_res
return ComponentVector(; Z, T), 0
schur_res = schur(A)
(; Z, T) = schur_res
return ComponentVector(; Z, T)
end
const _diff_schur = ImplicitFunction(schur_forward, schur_conditions)
const _diff_schur =
ImplicitFunction(schur_forward, schur_conditions, DirectLinearSolver(), nothing)

function bidiag(v1, v2)
return Bidiagonal(v1, v2, :L)
return Bidiagonal(v1, v2, :L)
end
function ChainRulesCore.rrule(::typeof(bidiag), v1, v2)
bidiag(v1, v2), Δ -> begin
NoTangent(), diag(Δ), diag(Δ, -1)
end
end

function gen_schur_conditions(AB, left_right_S_T, _)
(; left, right, S, T) = left_right_S_T
(; A, B) = AB
return vcat(
vec(left * S * right' - A),
vec(left * T * right' - B),
vec(UpperTriangular(left' * left) - I + LowerTriangular(S) - bidiag(diag(S), diag(S, -1) .+ (diag(S, -1) .* diag(T, 1)))),
vec(UpperTriangular(right' * right) - I + LowerTriangular(T) - Diagonal(T)),
)
bidiag(v1, v2), Δ -> begin
NoTangent(), diag(Δ), diag(Δ, -1)
end
end

function gen_schur_conditions(AB, left_right_S_T)
(; left, right, S, T) = left_right_S_T
(; A, B) = AB
return vcat(
vec(left * S * right' - A),
vec(left * T * right' - B),
vec(
UpperTriangular(left' * left) - I + LowerTriangular(S) -
bidiag(diag(S), diag(S, -1) .+ (diag(S, -1) .* diag(T, 1))),
),
vec(UpperTriangular(right' * right) - I + LowerTriangular(T) - Diagonal(T)),
)
end
function gen_schur_forward(AB)
(; A, B) = AB
schur_res = schur(A, B)
(; left, right, S, T) = schur_res
return ComponentVector(; left, right, S, T), 0
(; A, B) = AB
schur_res = schur(A, B)
(; left, right, S, T) = schur_res
return ComponentVector(; left, right, S, T)
end
const _diff_gen_schur = ImplicitFunction(gen_schur_forward, gen_schur_conditions)
const _diff_gen_schur =
ImplicitFunction(gen_schur_forward, gen_schur_conditions, DirectLinearSolver(), nothing)

function diff_schur(A, B)
(; left, right, S, T) = _diff_gen_schur(comp_vec(A, B))[1]
return (; left, right, S, T)
(; left, right, S, T) = _diff_gen_schur(comp_vec(A, B))
return (; left, right, S, T)
end
function diff_schur(A)
(; Z, T) = _diff_schur(A)[1]
return (; Z, T)
(; Z, T) = _diff_schur(A)
return (; Z, T)
end

# SVD

function svd_conditions(A, USV, _)
(; U, S, V) = USV
VtV = V' * V
return vcat(
vec(U * Diagonal(S) * V' - A),
vec(UpperTriangular(VtV) + LowerTriangular(U' * U) - 2I),
diag(VtV) .- 1,
)
function svd_conditions(A, USV)
(; U, S, V) = USV
VtV = V' * V
return vcat(
vec(U * Diagonal(S) * V' - A),
vec(UpperTriangular(VtV) + LowerTriangular(U' * U) - 2I),
diag(VtV) .- 1,
)
end
function svd_forward(A)
svd_res = svd(A)
(; U, S, V) = svd_res
return ComponentVector(; U, S, V), 0
svd_res = svd(A)
(; U, S, V) = svd_res
return ComponentVector(; U, S, V)
end

const _diff_svd = ImplicitFunction(svd_forward, svd_conditions)
const _diff_svd =
ImplicitFunction(svd_forward, svd_conditions, DirectLinearSolver(), nothing)
function diff_svd(A)
(; U, S, V) = _diff_svd(A)[1]
return (; U, S , V)
(; U, S, V) = _diff_svd(A)
return (; U, S, V)
end

end
Loading

2 comments on commit 3df0cf0

@mohamed82008
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/91389

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.1 -m "<description of version>" 3df0cf01c91e882e387c36c3a665a0ecbadebe00
git push origin v0.2.1

Please sign in to comment.