diff --git a/src/psd_mat.jl b/src/psd_mat.jl index 0996928..7726e45 100644 --- a/src/psd_mat.jl +++ b/src/psd_mat.jl @@ -23,11 +23,11 @@ struct PSDMat{T<:Real,S<:AbstractMatrix} <: AbstractPDMat{T} PSDMat{T,S}(m::AbstractMatrix{T},c::CholType{T,S}) where {T,S} = new{T, S}(m, c) end -function PSDMat(mat::AbstractMatrix, chol::CholType) +function PSDMat(mat::AbstractMatrix, chol::CholType{T,S}) where {T,S} d = LinearAlgebra.checksquare(mat) size(chol, 1) == d || throw(DimensionMismatch("Dimensions of mat and chol are inconsistent.")) - PSDMat{eltype(mat),typeof(mat)}(mat, chol) + PSDMat{T, S}(d, convert(S, mat), chol) end PSDMat(mat::Matrix) = PSDMat(mat, cholesky(mat, VERSION >= v"1.8.0-rc1" ? RowMaximum() : Val(true); check=false)) diff --git a/test/psd_mat.jl b/test/psd_mat.jl index 86faccf..d9a34e6 100644 --- a/test/psd_mat.jl +++ b/test/psd_mat.jl @@ -69,6 +69,20 @@ end @test whiten!( similar(x), pd, x) ≈ whiten!( similar(x), psd, x) @test unwhiten!(similar(x), pd, x) ≈ unwhiten!(similar(x), psd, x) end + + @testset "Constructing from Matrix{<:Integer} works" begin + + m = [ + 1 0 + 0 1 + ] + + for type in (UInt8, Int8, UInt16, Int16, UInt32, Int32, UInt, Int) + @test PSDMat{Float64, Matrix{Float64}} == typeof(PSDMat(type.(m))) + end + + end + end @testset "Degenerate MvNormal" begin