Skip to content

Commit

Permalink
Fix inferrence for svd
Browse files Browse the repository at this point in the history
  • Loading branch information
Kolaru committed Jul 19, 2021
1 parent f40b48a commit 7e3010d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
11 changes: 7 additions & 4 deletions src/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,14 @@ end

# Factorizations

function to_vec(x::SVD)
x_vec, back = to_vec([x.U, x.S, x.Vt])
function to_vec(x::F) where {F <: SVD}
# Convert the vector S to a matrix so we can work with a vector of matrices
# only and inferrence work
v = [x.U, reshape(x.S, length(x.S), 1), x.Vt]
x_vec, back = to_vec(v)
function SVD_from_vec(v)
U, S, Vt = back(v)
return SVD(U, S, Vt)
U, Smat, Vt = back(v)
return F(U, vec(Smat), Vt)
end
return x_vec, SVD_from_vec
end
Expand Down
2 changes: 1 addition & 1 deletion test/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ end
for dims in [(5, 5), (4, 6), (7, 3)]
M = randn(T, dims...)
P = M * M' + I # Positive definite matrix
test_to_vec(svd(M); check_inferred = false)
test_to_vec(svd(M); check_inferred = true)
test_to_vec(qr(M))
test_to_vec(cholesky(P))
end
Expand Down

0 comments on commit 7e3010d

Please sign in to comment.