From fb4d2f058eb8d9fc0b82645df701c461678d821e Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Wed, 11 Oct 2023 11:51:49 +0200 Subject: [PATCH] wip: add arbitrary logpdf --- .github/workflows/CI.yml | 1 - Project.toml | 7 +- src/BayesBase.jl | 7 +- src/densities/continouslogpdf.jl | 222 +++++++++++++ src/densities/factorizedjoint.jl | 4 +- src/prod.jl | 61 ++-- src/statsfuns.jl | 41 +++ test/densities/continouslogpdf_tests.jl | 409 ++++++++++++++++++++++++ test/statsfuns_tests.jl | 27 ++ 9 files changed, 736 insertions(+), 43 deletions(-) create mode 100644 src/densities/continouslogpdf.jl create mode 100644 test/densities/continouslogpdf_tests.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 64f849e..6d11abc 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,7 +19,6 @@ jobs: fail-fast: false matrix: version: - - '1.0' - '1.9' - 'nightly' os: diff --git a/Project.toml b/Project.toml index d9b01c6..0811c9b 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.0.0" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -14,11 +15,12 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" [compat] Distributions = "0.25" +DomainSets = "0.7" Random = "1.9" +SpecialFunctions = "2.3" Statistics = "1.9" StatsAPI = "1.7" StatsBase = "0.34" -SpecialFunctions = "2.3" TinyHugeNumbers = "1.0" julia = "1.9" @@ -27,8 +29,9 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CpuId = "adafc99b-e345-5852-983c-f28acb93d879" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "CpuId", "Test", "ReTestItems", "LinearAlgebra", "StableRNGs"] +test = ["Aqua", "CpuId", "Test", "ReTestItems", "LinearAlgebra", "StableRNGs", "HCubature"] diff --git a/src/BayesBase.jl b/src/BayesBase.jl index d9a66ac..281927d 100644 --- a/src/BayesBase.jl +++ b/src/BayesBase.jl @@ -2,7 +2,7 @@ module BayesBase using TinyHugeNumbers -using StatsAPI, StatsBase, Statistics, Distributions, Random +using StatsAPI, StatsBase, DomainSets, Statistics, Distributions, Random using StatsAPI: params @@ -65,6 +65,10 @@ export failprob, variate_form, value_support +using DomainSets: dimension, Domain + +export dimension, Domain + using Base: precision, prod, prod! export precision, prod, prod! @@ -78,5 +82,6 @@ include("promotion.jl") include("prod.jl") include("densities/factorizedjoint.jl") +include("densities/continouslogpdf.jl") end diff --git a/src/densities/continouslogpdf.jl b/src/densities/continouslogpdf.jl new file mode 100644 index 0000000..f27d3ac --- /dev/null +++ b/src/densities/continouslogpdf.jl @@ -0,0 +1,222 @@ +export ContinuousUnivariateLogPdf, ContinuousMultivariateLogPdf + +# import DomainIntegrals +# import HCubature + +import Base: isapprox, in + +abstract type AbstractContinuousGenericLogPdf end + +getdomain(dist::AbstractContinuousGenericLogPdf) = dist.domain +getlogpdf(dist::AbstractContinuousGenericLogPdf) = dist.logpdf + +BayesBase.value_support(::Type{<:AbstractContinuousGenericLogPdf}) = Continuous +BayesBase.value_support(::AbstractContinuousGenericLogPdf) = Continuous + +# We throw an error on purpose, since we do not want to use `AbstractContinuousGenericLogPdf` much without approximations +# We want to encourage a user to use approximate generic log-pdfs as much as possible instead +function __error_genericlogpdf_not_defined(dist::AbstractContinuousGenericLogPdf, f::Symbol) + return error( + "`$f` is not defined for `$(dist)`. Use functional form constraints to approximate the resulting generic log-pdf object and to use it in the inference procedure.", + ) +end + +function BayesBase.mean(dist::AbstractContinuousGenericLogPdf) + return __error_genericlogpdf_not_defined(dist, :mean) +end +function BayesBase.median(dist::AbstractContinuousGenericLogPdf) + return __error_genericlogpdf_not_defined(dist, :median) +end +function BayesBase.mode(dist::AbstractContinuousGenericLogPdf) + return __error_genericlogpdf_not_defined(dist, :mode) +end +function BayesBase.var(dist::AbstractContinuousGenericLogPdf) + return __error_genericlogpdf_not_defined(dist, :var) +end +function BayesBase.std(dist::AbstractContinuousGenericLogPdf) + return __error_genericlogpdf_not_defined(dist, :std) +end +function BayesBase.cov(dist::AbstractContinuousGenericLogPdf) + return __error_genericlogpdf_not_defined(dist, :cov) +end +function BayesBase.invcov(dist::AbstractContinuousGenericLogPdf) + return __error_genericlogpdf_not_defined(dist, :invcov) +end +function BayesBase.entropy(dist::AbstractContinuousGenericLogPdf) + return __error_genericlogpdf_not_defined(dist, :entropy) +end + +function Base.precision(dist::AbstractContinuousGenericLogPdf) + return __error_genericlogpdf_not_defined(dist, :precision) +end + +Base.eltype(dist::AbstractContinuousGenericLogPdf) = eltype(getdomain(dist)) + +BayesBase.paramfloattype(dist::AbstractContinuousGenericLogPdf) = deep_eltype(eltype(dist)) +BayesBase.samplefloattype(dist::AbstractContinuousGenericLogPdf) = paramfloattype(dist) + +(dist::AbstractContinuousGenericLogPdf)(x::Real) = logpdf(dist, x) +(dist::AbstractContinuousGenericLogPdf)(x::AbstractVector{<:Real}) = logpdf(dist, x) + +function BayesBase.logpdf(dist::AbstractContinuousGenericLogPdf, x) + @assert x ∈ getdomain(dist) "x = $(x) does not belong to the domain ($(getdomain(dist))) of $dist" + lpdf = getlogpdf(dist) + return lpdf(x) +end + +# We don't expect neither `pdf` nor `logpdf` to be normalised +BayesBase.pdf(dist::AbstractContinuousGenericLogPdf, x) = exp(logpdf(dist, x)) + +""" + ContinuousUnivariateLogPdf{ D <: DomainSets.Domain, F } <: AbstractContinuousGenericLogPdf + +Generic continuous univariate distribution in a form of domain specification and logpdf function. Can be used in cases where no +known analytical distribution available. + +# Arguments +- `domain`: domain specificatiom from `DomainSets.jl` package, by default the `domain` is set to `DomainSets.FullSpace()`. Use `BayesBase.UnspecifiedDomain()` to bypass domain checks. +- `logpdf`: callable object that represents the logdensity. Can be un-normalised. +""" +struct ContinuousUnivariateLogPdf{D<:DomainSets.Domain,F} <: AbstractContinuousGenericLogPdf + domain::D + logpdf::F + + function ContinuousUnivariateLogPdf(domain::D, logpdf::F) where {D,F} + @assert dimension(domain) == 1 "Cannot create ContinuousUnivariateLogPdf. Dimension of domain = $(domain) is not equal to 1." + return new{D,F}(domain, logpdf) + end +end + +function ContinuousUnivariateLogPdf(f::Function) + return ContinuousUnivariateLogPdf(DomainSets.FullSpace(), f) +end + +BayesBase.variate_form(::Type{<:ContinuousUnivariateLogPdf}) = Univariate +BayesBase.variate_form(::ContinuousUnivariateLogPdf) = Univariate + +function BayesBase.promote_variate_type( + ::Type{Univariate}, ::Type{AbstractContinuousGenericLogPdf} +) + return ContinuousUnivariateLogPdf +end + +function Base.show(io::IO, dist::ContinuousUnivariateLogPdf) + return print(io, "ContinuousUnivariateLogPdf(", getdomain(dist), ")") +end +function Base.show(io::IO, ::Type{<:ContinuousUnivariateLogPdf{D}}) where {D} + return print(io, "ContinuousUnivariateLogPdf{", D, "}") +end + +function BayesBase.support(dist::ContinuousUnivariateLogPdf) + return getdomain(dist) +end + +BayesBase.insupport(dist::ContinuousUnivariateLogPdf, x) = x ∈ getdomain(dist) + +# Fallback for various optimisation packages which may pass arguments as vectors +function BayesBase.logpdf(dist::ContinuousUnivariateLogPdf, x::AbstractVector{<:Real}) + @assert length(x) === 1 "`ContinuousUnivariateLogPdf` expects either float or a vector of a single float as an input for the `logpdf` function." + return logpdf(dist, first(x)) +end + +function Base.convert( + ::Type{<:ContinuousUnivariateLogPdf}, domain::D, logpdf::F +) where {D<:DomainSets.Domain,F} + return ContinuousUnivariateLogPdf(domain, logpdf) +end + +function BayesBase.convert_paramfloattype( + ::Type{T}, dist::ContinuousUnivariateLogPdf +) where {T<:Real} + return convert( + ContinuousUnivariateLogPdf, + dist.domain, + (x) -> dist.logpdf(convert_paramfloattype(T, x)), + ) +end + +function BayesBase.vague(::Type{<:ContinuousUnivariateLogPdf}) + return ContinuousUnivariateLogPdf(DomainSets.FullSpace(), (x) -> 1) +end + +# We do not check typeof of a different functions because in most of the cases lambdas have different types, but they can still be the same +function BayesBase.isequal_typeof( + ::ContinuousUnivariateLogPdf{D,F1}, ::ContinuousUnivariateLogPdf{D,F2} +) where {D,F1<:Function,F2<:Function} + return true +end + +## + +""" + ContinuousMultivariateLogPdf{ D <: DomainSets.Domain, F } <: AbstractContinuousGenericLogPdf + +Generic continuous multivariate distribution in a form of domain specification and logpdf function. Can be used in cases where no +known analytical distribution available. + +# Arguments +- `domain`: multidimensional domain specification from `DomainSets.jl` package. Use `BayesBase.UnspecifiedDomain()` to bypass domain checks. +- `logpdf`: callable object that accepts an `AbstractVector` as an input and represents the logdensity. Can be un-normalised. +""" +struct ContinuousMultivariateLogPdf{D<:DomainSets.Domain,F} <: + AbstractContinuousGenericLogPdf + domain::D + logpdf::F + + function ContinuousMultivariateLogPdf( + domain::D, logpdf::F + ) where {D<:DomainSets.Domain,F} + @assert DomainSets.dimension(domain) !== 1 "Cannot create ContinuousMultivariateLogPdf. Dimension of domain = $(domain) should not be equal to 1. Use, for example, `DomainSets.FullSpace() ^ 2` to create 2-dimensional full space domain." + return new{D,F}(domain, logpdf) + end +end + +BayesBase.variate_form(::Type{<:ContinuousMultivariateLogPdf}) = Multivariate +BayesBase.variate_form(::ContinuousMultivariateLogPdf) = Multivariate + +function BayesBase.promote_variate_type( + ::Type{Multivariate}, ::Type{AbstractContinuousGenericLogPdf} +) + return ContinuousMultivariateLogPdf +end + +function ContinuousMultivariateLogPdf(dims::Int, f::Function) + return ContinuousMultivariateLogPdf(DomainSets.FullSpace()^dims, f) +end + +function Base.show(io::IO, dist::ContinuousMultivariateLogPdf) + return print(io, "ContinuousMultivariateLogPdf(", getdomain(dist), ")") +end +function Base.show(io::IO, ::Type{<:ContinuousMultivariateLogPdf{D}}) where {D} + return print(io, "ContinuousMultivariateLogPdf{", D, "}") +end + +BayesBase.support(dist::ContinuousMultivariateLogPdf) = getdomain(dist) +BayesBase.insupport(dist::ContinuousMultivariateLogPdf, x) = x ∈ getdomain(dist) + +function Base.convert( + ::Type{<:ContinuousMultivariateLogPdf}, domain::D, logpdf::F +) where {D<:DomainSets.Domain,F} + return ContinuousMultivariateLogPdf(domain, logpdf) +end + +function BayesBase.convert_paramfloattype( + ::Type{T}, dist::ContinuousMultivariateLogPdf +) where {T<:Real} + return convert( + ContinuousMultivariateLogPdf, + dist.domain, + (x) -> dist.logpdf(convert_paramfloattype(T, x)), + ) +end + +function BayesBase.vague(::Type{<:ContinuousMultivariateLogPdf}, dims::Int) + return ContinuousMultivariateLogPdf(DomainSets.FullSpace()^dims, (x) -> 1) +end + +# We do not check typeof of a different functions because in most of the cases lambdas have different types, but they can still be the same +function BayesBase.isequal_typeof( + ::ContinuousMultivariateLogPdf{D,F1}, ::ContinuousMultivariateLogPdf{D,F2} +) where {D,F1<:Function,F2<:Function} + return true +end diff --git a/src/densities/factorizedjoint.jl b/src/densities/factorizedjoint.jl index d5ab56a..addcb1c 100644 --- a/src/densities/factorizedjoint.jl +++ b/src/densities/factorizedjoint.jl @@ -20,9 +20,9 @@ Base.@propagate_inbounds function Base.getindex(joint::FactorizedJoint, i::Int) return getindex(components(joint), i) end -BayesBase.length(joint::FactorizedJoint) = length(joint.multipliers) +Base.length(joint::FactorizedJoint) = length(joint.multipliers) -function BayesBase.isapprox(x::FactorizedJoint, y::FactorizedJoint; kwargs...) +function Base.isapprox(x::FactorizedJoint, y::FactorizedJoint; kwargs...) return length(x) === length(y) && all( tuple -> isapprox(tuple[1], tuple[2]; kwargs...), zip(components(x), components(y)), diff --git a/src/prod.jl b/src/prod.jl index b17aee4..5b21bbc 100644 --- a/src/prod.jl +++ b/src/prod.jl @@ -1,10 +1,8 @@ -import Distributions: VariateForm, ValueSupport, variate_form, value_support, support import Base: prod, prod!, show, showerror export prod, default_prod_rule, - fuse_supports, ClosedProd, PreserveTypeProd, PreserveTypeLeftProd, @@ -171,7 +169,9 @@ Base.prod(::ClosedProd, left, ::Missing) = left Base.prod(::ClosedProd, ::Missing, ::Missing) = missing # We assume that we want to preserve the `Distribution` when working with two `Distribution`s -Base.prod(::ClosedProd, left::Distribution, right::Distribution) = prod(PreserveTypeProd(Distribution), left, right) +function Base.prod(::ClosedProd, left::Distribution, right::Distribution) + return prod(PreserveTypeProd(Distribution), left, right) +end # This is a hidden prod strategy to ensure symmetricity in the `default_prod_rule`. # Most of the automatic prod rule resolution relies on the `symmetric_default_prod_rule` instead of just `default_prod_rule` @@ -197,22 +197,6 @@ function symmetric_default_prod_rule(::UnspecifiedProd, ::UnspecifiedProd, left, return UnspecifiedProd() end -""" - fuse_supports(left, right) - -Fuse supports of two distributions of `left` and `right`. -By default, checks that the supports are identical and throws an error otherwise. -Can implement specific fusions for specific distributions. - -See also: [`prod`](@ref), [`ProductOf`](@ref) -""" -function fuse_supports(left, right) - if !isequal(support(left), support(right)) - error("Cannot form a `ProductOf` $(left) & `$(right)`. Support is incompatible.") - end - return support(left) -end - """ ProductOf @@ -242,20 +226,20 @@ function Base.show(io::IO, product::ProductOf) return print(io, "ProductOf(", getleft(product), ",", getright(product), ")") end -function Distributions.support(product::ProductOf) - return fuse_supports(getleft(product), getright(product)) +function BayesBase.support(product::ProductOf) + return fuse_supports(support(getleft(product)), support(getright(product))) end -Distributions.pdf(product::ProductOf, x) = exp(logpdf(product, x)) +BayesBase.pdf(product::ProductOf, x) = exp(logpdf(product, x)) -function Distributions.logpdf(product::ProductOf, x) - return Distributions.logpdf(getleft(product), x) + - Distributions.logpdf(getright(product), x) +function BayesBase.logpdf(product::ProductOf, x) + @assert x ∈ support(product) "The `$(x)` does not belong to the support of the product `$(product)`" + return logpdf(getleft(product), x) + logpdf(getright(product), x) end -Distributions.variate_form(::P) where {P<:ProductOf} = variate_form(P) +BayesBase.variate_form(::P) where {P<:ProductOf} = variate_form(P) -function Distributions.variate_form(::Type{ProductOf{L,R}}) where {L,R} +function BayesBase.variate_form(::Type{ProductOf{L,R}}) where {L,R} return _check_product_variate_form(variate_form(L), variate_form(R)) end @@ -269,9 +253,9 @@ function _check_product_variate_form( ) end -Distributions.value_support(::P) where {P<:ProductOf} = value_support(P) +BayesBase.value_support(::P) where {P<:ProductOf} = value_support(P) -function Distributions.value_support(::Type{ProductOf{L,R}}) where {L,R} +function BayesBase.value_support(::Type{ProductOf{L,R}}) where {L,R} return _check_product_value_support(value_support(L), value_support(R)) end @@ -399,7 +383,7 @@ function Base.push!(product::LinearizedProductOf{F}, item::F) where {F} return LinearizedProductOf(push!(vector, item), vlength + 1) end -Distributions.support(dist::LinearizedProductOf) = support(first(dist.vector)) +BayesBase.support(dist::LinearizedProductOf) = support(first(dist.vector)) Base.length(product::LinearizedProductOf) = product.length Base.eltype(product::LinearizedProductOf) = eltype(first(product.vector)) @@ -412,23 +396,26 @@ function BayesBase.samplefloattype(product::LinearizedProductOf) return samplefloattype(first(product.vector)) end -Distributions.variate_form(::Type{<:LinearizedProductOf{F}}) where {F} = variate_form(F) -Distributions.variate_form(::LinearizedProductOf{F}) where {F} = variate_form(F) +BayesBase.variate_form(::Type{<:LinearizedProductOf{F}}) where {F} = variate_form(F) +BayesBase.variate_form(::LinearizedProductOf{F}) where {F} = variate_form(F) -Distributions.value_support(::Type{<:LinearizedProductOf{F}}) where {F} = value_support(F) -Distributions.value_support(::LinearizedProductOf{F}) where {F} = value_support(F) +BayesBase.value_support(::Type{<:LinearizedProductOf{F}}) where {F} = value_support(F) +BayesBase.value_support(::LinearizedProductOf{F}) where {F} = value_support(F) function Base.show(io::IO, product::LinearizedProductOf{F}) where {F} return print(io, "LinearizedProductOf(", F, ", length = ", product.length, ")") end -function Distributions.logpdf(dist::LinearizedProductOf, x) +function BayesBase.logpdf(product::LinearizedProductOf, x) + @assert x ∈ support(product) "The `$(x)` does not belong to the support of the product `$(product)`" return mapreduce( - (d) -> logpdf(d, x), +, view(dist.vector, 1:min(dist.length, length(dist.vector))) + (d) -> logpdf(d, x), + +, + view(product.vector, 1:min(product.length, length(product.vector))), ) end -Distributions.pdf(dist::LinearizedProductOf, x) = exp(logpdf(dist, x)) +BayesBase.pdf(dist::LinearizedProductOf, x) = exp(logpdf(dist, x)) # We assume that it is better (really) to preserve the type of the `LinearizedProductOf`, it is just faster for the compiler function BayesBase.default_prod_rule(::Type{F}, ::Type{LinearizedProductOf{F}}) where {F} diff --git a/src/statsfuns.jl b/src/statsfuns.jl index 9a27682..b7203a3 100644 --- a/src/statsfuns.jl +++ b/src/statsfuns.jl @@ -24,6 +24,10 @@ export mirrorlog, sampling_optimized, components, component, + UnspecifiedDomain, + UnspecifiedDimension, + fuse_supports, + isequal_typeof, distribution_typewrapper """ @@ -173,11 +177,48 @@ Returns `k`-th component of a distribution `d` (joint or a mixture). """ function component end +"""Unknown domain that is used as a placeholder when exact domain knowledge is unavailable""" +struct UnspecifiedDomain <: Domain{Any} end + +"""Unknown dimension is equal and not equal to any number""" +struct UnspecifiedDimension end + +DomainSets.dimension(::UnspecifiedDomain) = UnspecifiedDimension() + +Base.in(::Any, ::UnspecifiedDomain) = true + +Base.:(!=)(::UnspecifiedDimension, ::Int) = true +Base.:(!==)(::UnspecifiedDimension, ::Int) = true +Base.:(==)(::UnspecifiedDimension, ::Int) = true + +""" + fuse_supports(left, right) + +Fuses supports `left` and `right`. +By default, checks that the inputs are identical and throws an error otherwise. +Can implement specific fusions for specific supports. +""" +function fuse_supports(left, right) + @assert isequal(left, right) "Cannot automatically fuse supports of $(left) & `$(right)`." + return left +end + +fuse_supports(left::UnspecifiedDomain, right) = right +fuse_supports(left, right::UnspecifiedDomain) = left +fuse_supports(::UnspecifiedDomain, ::UnspecifiedDomain) = UnspecifiedDomain() + """ Strips type parameters from the type of the `distribution`. """ distribution_typewrapper(distribution) = generated_distribution_typewrapper(distribution) +""" + isequal_typeof(left, right) + +Alias for `typeof(left) === typeof(right)`, but can be specialized. +""" +isequal_typeof(left, right) = typeof(left) === typeof(right) + # Returns a wrapper distribution for a `<:Distribution` type, this function uses internals of Julia # It is not ideal, but is fine for now, if Julia changes it internals such that does not work # We will need to write the `distribution_typewrapper` method for each support member of exponential family diff --git a/test/densities/continouslogpdf_tests.jl b/test/densities/continouslogpdf_tests.jl new file mode 100644 index 0000000..b7941cc --- /dev/null +++ b/test/densities/continouslogpdf_tests.jl @@ -0,0 +1,409 @@ + +@testitem "ContinuousUnivariateLogPdf: Constructor" begin + import DomainSets: FullSpace + + f = (x) -> -x^2 + d1 = ContinuousUnivariateLogPdf(f) + d2 = ContinuousUnivariateLogPdf(FullSpace(), f) + + @test typeof(d1) === typeof(d2) + @test eltype(d1) === Float64 + @test eltype(d2) === Float64 + @test paramfloattype(d1) === Float64 + @test samplefloattype(d1) === Float64 + @test paramfloattype(d2) === Float64 + @test samplefloattype(d2) === Float64 + + @test_throws AssertionError ContinuousUnivariateLogPdf(FullSpace()^2, f) +end + +@testitem "ContinuousUnivariateLogPdf: Intentional errors" begin + dist = ContinuousUnivariateLogPdf((x) -> x) + @test_throws ErrorException mean(dist) + @test_throws ErrorException median(dist) + @test_throws ErrorException mode(dist) + @test_throws ErrorException var(dist) + @test_throws ErrorException std(dist) + @test_throws ErrorException cov(dist) + @test_throws ErrorException invcov(dist) + @test_throws ErrorException entropy(dist) + @test_throws ErrorException precision(dist) +end + +@testitem "ContinuousUnivariateLogPdf: pdf/logpdf" begin + import DomainSets: FullSpace, HalfLine + + d1 = ContinuousUnivariateLogPdf(FullSpace(), (x) -> -x^2) + + f32_points1 = range(Float32(-10.0), Float32(10.0); length=50) + f64_points1 = range(-10.0, 10.0; length=50) + bf_points1 = range(BigFloat(-10.0), BigFloat(10.0); length=50) + points1 = vcat(f32_points1, f64_points1, bf_points1) + + @test all(map(p -> -p^2 == d1(p), points1)) + @test all(map(p -> -p^2 == logpdf(d1, p), points1)) + @test all(map(p -> exp(-p^2) == pdf(d1, p), points1)) + @test all(map(p -> -p^2 == d1([p]), points1)) + @test all(map(p -> -p^2 == logpdf(d1, [p]), points1)) + @test all(map(p -> exp(-p^2) == pdf(d1, [p]), points1)) + + d2 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> -x^4) + + f32_points2 = range(Float32(0.0), Float32(10.0); length=50) + f64_points2 = range(0.0, 10.0; length=50) + bf_points2 = range(BigFloat(0.0), BigFloat(10.0); length=50) + points2 = vcat(f32_points2, f64_points2, bf_points2) + + @test all(map(p -> -p^4 == d2(p), points2)) + @test all(map(p -> -p^4 == logpdf(d2, p), points2)) + @test all(map(p -> exp(-p^4) == pdf(d2, p), points2)) + @test all(map(p -> -p^4 == d2([p]), points2)) + @test all(map(p -> -p^4 == logpdf(d2, [p]), points2)) + @test all(map(p -> exp(-p^4) == pdf(d2, [p]), points2)) + + @test_throws AssertionError d2(-1.0) + @test_throws AssertionError logpdf(d2, -1.0) + @test_throws AssertionError pdf(d2, -1.0) + @test_throws AssertionError d2([-1.0]) + @test_throws AssertionError logpdf(d2, [-1.0]) + @test_throws AssertionError pdf(d2, [-1.0]) + + @test_throws AssertionError d2(Float32(-1.0)) + @test_throws AssertionError logpdf(d2, Float32(-1.0)) + @test_throws AssertionError pdf(d2, Float32(-1.0)) + @test_throws AssertionError d2([Float32(-1.0)]) + @test_throws AssertionError logpdf(d2, [Float32(-1.0)]) + @test_throws AssertionError pdf(d2, [Float32(-1.0)]) + + @test_throws AssertionError d2(BigFloat(-1.0)) + @test_throws AssertionError logpdf(d2, BigFloat(-1.0)) + @test_throws AssertionError pdf(d2, BigFloat(-1.0)) + @test_throws AssertionError d2([BigFloat(-1.0)]) + @test_throws AssertionError logpdf(d2, [BigFloat(-1.0)]) + @test_throws AssertionError pdf(d2, [BigFloat(-1.0)]) + + d3 = ContinuousUnivariateLogPdf(FullSpace(Float32), (x) -> -x^2) + + @test all(map(p -> -p^2 == d3(p), points1)) + @test all(map(p -> -p^2 == logpdf(d3, p), points1)) + @test all(map(p -> exp(-p^2) == pdf(d3, p), points1)) + @test all(map(p -> -p^2 == d3([p]), points1)) + @test all(map(p -> -p^2 == logpdf(d3, [p]), points1)) + @test all(map(p -> exp(-p^2) == pdf(d3, [p]), points1)) + + d4 = ContinuousUnivariateLogPdf(FullSpace(BigFloat), (x) -> -x^2) + + @test all(map(p -> -p^2 == d4(p), points1)) + @test all(map(p -> -p^2 == logpdf(d4, p), points1)) + @test all(map(p -> exp(-p^2) == pdf(d4, p), points1)) + @test all(map(p -> -p^2 == d4([p]), points1)) + @test all(map(p -> -p^2 == logpdf(d4, [p]), points1)) + @test all(map(p -> exp(-p^2) == pdf(d4, [p]), points1)) + + d5 = ContinuousUnivariateLogPdf(HalfLine{Float32}(), (x) -> -x^2) + + @test all(map(p -> -p^2 == d5(p), points2)) + @test all(map(p -> -p^2 == logpdf(d5, p), points2)) + @test all(map(p -> exp(-p^2) == pdf(d5, p), points2)) + @test all(map(p -> -p^2 == d5([p]), points2)) + @test all(map(p -> -p^2 == logpdf(d5, [p]), points2)) + @test all(map(p -> exp(-p^2) == pdf(d5, [p]), points2)) + + d6 = ContinuousUnivariateLogPdf(HalfLine{BigFloat}(), (x) -> -x^2) + + @test all(map(p -> -p^2 == d6(p), points2)) + @test all(map(p -> -p^2 == logpdf(d6, p), points2)) + @test all(map(p -> exp(-p^2) == pdf(d6, p), points2)) + @test all(map(p -> -p^2 == d6([p]), points2)) + @test all(map(p -> -p^2 == logpdf(d6, [p]), points2)) + @test all(map(p -> exp(-p^2) == pdf(d6, [p]), points2)) +end + +@testitem "ContinuousUnivariateLogPdf: test domain in logpdf" begin + import DomainSets: FullSpace, HalfLine + + d1 = ContinuousUnivariateLogPdf(FullSpace(), (x) -> -x^2) + d2 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> -x^4) + + # This also throws a warning in stdout + @test_throws AssertionError logpdf(d1, [1.0, 1.0]) + @test_throws AssertionError logpdf(d2, [1.0, 1.0]) +end + +@testitem "ContinuousUnivariateLogPdf: support" begin + import DomainSets: FullSpace, HalfLine + + d1 = ContinuousUnivariateLogPdf(FullSpace(), (x) -> 1.0) + @test 1.0 ∈ support(d1) + @test -1.0 ∈ support(d1) + + d2 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> 1.0) + @test 1.0 ∈ support(d2) + @test -1.0 ∉ support(d2) +end + +@testitem "ContinuousUnivariateLogPdf: vague" begin + d = vague(ContinuousUnivariateLogPdf) + + @test typeof(d) <: ContinuousUnivariateLogPdf + @test d(rand()) ≈ 1 +end + +@testitem "ContinuousUnivariateLogPdf: prod" begin + import DomainSets: FullSpace, HalfLine + + dist = ContinuousUnivariateLogPdf(FullSpace(), (x) -> 2.0 * -x^2) + d2 = ContinuousUnivariateLogPdf(FullSpace(), (x) -> 3.0 * -x^2) + + product = prod(GenericProd(), dist, d2) + pt1 = ContinuousUnivariateLogPdf(FullSpace(), (x) -> logpdf(dist, x) + logpdf(d2, x)) + + @test variate_form(typeof(product)) === variate_form(typeof(dist)) + @test variate_form(typeof(product)) === variate_form(typeof(d2)) + @test value_support(typeof(product)) === value_support(typeof(dist)) + @test value_support(typeof(product)) === value_support(typeof(d2)) + @test support(product) === support(dist) + @test support(product) === support(d2) + + for x in rand(10) + @test isapprox(logpdf(product, x), logpdf(pt1, x)) + @test isapprox(pdf(product, x), pdf(pt1, x)) + end + + result = ContinuousUnivariateLogPdf(HalfLine(), (x) -> 2.0 * -x^2) + d4 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> 3.0 * -x^2) + + pr2 = prod(GenericProd(), result, d4) + pt2 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(result, x) + logpdf(d4, x)) + + @test variate_form(typeof(pr2)) === variate_form(typeof(result)) + @test variate_form(typeof(pr2)) === variate_form(typeof(d4)) + @test value_support(typeof(pr2)) === value_support(typeof(result)) + @test value_support(typeof(pr2)) === value_support(typeof(d4)) + @test support(pr2) === support(result) + @test support(pr2) === support(d4) + + for x in rand(10) + @test isapprox(logpdf(pr2, x), logpdf(pt2, x)) + @test isapprox(pdf(pr2, x), pdf(pt2, x)) + end + + d5 = ContinuousUnivariateLogPdf(FullSpace(), (x) -> 2.0 * -x^2) + d6 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> 2.0 * -x^2) + + @test_throws AssertionError logpdf(prod(GenericProd(), d5, d6), 1.0) # domains are different + @test_throws AssertionError logpdf(prod(GenericProd(), d5, d6), -1.0) # domains are different +end + +@testitem "ContinuousUnivariateLogPdf: vectorised-prod" begin + import DomainSets: FullSpace + + f = (x) -> 2.0 * -x^2 + dist = ContinuousUnivariateLogPdf(FullSpace(), f) + result = ContinuousUnivariateLogPdf(FullSpace(), (x) -> 3 * f(x)) + product = prod(GenericProd(), prod(GenericProd(), dist, dist), dist) + + @test product isa LinearizedProductOf + + @test variate_form(typeof(product)) === variate_form(typeof(dist)) + @test variate_form(typeof(product)) === variate_form(typeof(result)) + @test value_support(typeof(product)) === value_support(typeof(dist)) + @test value_support(typeof(product)) === value_support(typeof(result)) + @test support(product) === support(dist) + @test support(product) === support(result) + + for x in rand(10) + @test logpdf(product, x) ≈ logpdf(result, x) + @test pdf(product, x) ≈ pdf(result, x) + end + + # Test internal side-effects + another_product = prod(GenericProd(), product, dist) + + for x in rand(10) + @test logpdf(product, x) ≈ logpdf(result, x) + @test pdf(product, x) ≈ pdf(result, x) + + @test logpdf(another_product, x) ≈ (logpdf(product, x) + logpdf(dist, x)) + @test pdf(another_product, x) ≈ (pdf(product, x) * pdf(dist, x)) + end +end + +@testitem "ContinuousUnivariateLogPdf: convert" begin + d = DomainSets.FullSpace() + l = (x) -> 1.0 + + c = convert(ContinuousUnivariateLogPdf, d, l) + @test typeof(c) <: ContinuousUnivariateLogPdf + @test isapprox(c, ContinuousUnivariateLogPdf(d, l), atol=1e-12) + + c2 = convert(ContinuousUnivariateLogPdf, c) + @test typeof(c2) <: ContinuousUnivariateLogPdf + @test isapprox(c2, ContinuousUnivariateLogPdf(d, l), atol=1e-12) +end + +@testitem "ContinuousMultivariateLogPdf: Constructor" begin + f = (x) -> -x'x + dist = ContinuousMultivariateLogPdf(2, f) + d2 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, f) + + @test typeof(dist) === typeof(d2) + @test dist ≈ d2 + @test paramfloattype(dist) === Float64 + @test samplefloattype(dist) === Float64 + @test paramfloattype(d2) === Float64 + @test samplefloattype(d2) === Float64 + + @test_throws AssertionError ContinuousMultivariateLogPdf(DomainSets.FullSpace(), f) + @test_throws MethodError ContinuousMultivariateLogPdf(f) +end + +@testitem "ContinuousMultivariateLogPdf: Intentional errors" begin + dist = ContinuousMultivariateLogPdf(2, (x) -> -x'x) + @test_throws ErrorException mean(dist) + @test_throws ErrorException median(dist) + @test_throws ErrorException mode(dist) + @test_throws ErrorException var(dist) + @test_throws ErrorException std(dist) + @test_throws ErrorException cov(dist) + @test_throws ErrorException invcov(dist) + @test_throws ErrorException entropy(dist) + @test_throws ErrorException precision(dist) +end + +@testitem "ContinuousMultivariateLogPdf: pdf/logpdf" begin + dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> -x'x) + + f32_points1 = range(Float32(-10.0), Float32(10.0); length=5) + f64_points1 = range(-10.0, 10.0; length=5) + bf_points1 = range(BigFloat(-10.0), BigFloat(10.0); length=5) + + points1 = vcat( + vec(map(collect, Iterators.product(f32_points1, f32_points1))), + vec(map(collect, Iterators.product(f64_points1, f64_points1))), + vec(map(collect, Iterators.product(bf_points1, bf_points1))), + ) + + @test all(map(p -> -p'p == dist(p), points1)) + @test all(map(p -> -p'p == logpdf(dist, p), points1)) + @test all(map(p -> exp(-p'p) == pdf(dist, p), points1)) + + d2 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> -x'x / 4) + + f32_points2 = range(Float32(0.0), Float32(10.0); length=5) + f64_points2 = range(0.0, 10.0; length=5) + bf_points2 = range(BigFloat(0.0), BigFloat(10.0); length=5) + + points2 = vcat( + vec(map(collect, Iterators.product(f32_points2, f32_points2))), + vec(map(collect, Iterators.product(f64_points2, f64_points2))), + vec(map(collect, Iterators.product(bf_points2, bf_points2))), + ) + + @test all(map(p -> -p'p / 4 == d2(p), points2)) + @test all(map(p -> -p'p / 4 == logpdf(d2, p), points2)) + @test all(map(p -> exp(-p'p / 4) == pdf(d2, p), points2)) +end + +@testitem "ContinuousMultivariateLogPdf: test domain in logpdf" begin + for dim in (2, 3, 4) + dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^dim, (x) -> -x'x) + d2 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^dim, (x) -> -x'x) + + # This also throws a warning in stdout + @test_logs (:warn, r".*incompatible combination.*") @test_throws AssertionError logpdf( + dist, ones(dim + 1) + ) + @test_logs (:warn, r".*incompatible combination.*") @test_throws AssertionError logpdf( + d2, ones(dim + 1) + ) + end +end + +@testitem "ContinuousMultivariateLogPdf: vague" begin + d = vague(ContinuousMultivariateLogPdf, 2) + + @test typeof(d) <: ContinuousMultivariateLogPdf + @test d ≈ ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 2.0) +end + +@testitem "ContinuousMultivariateLogPdf: prod" begin + dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 2.0 * -x'x) + d2 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 3.0 * -x'x) + + product = prod(ProdAnalytical(), dist, d2) + pt1 = ContinuousMultivariateLogPdf( + DomainSets.FullSpace()^2, (x) -> logpdf(dist, x) + logpdf(d2, x) + ) + + @test getdomain(product) === getdomain(dist) + @test getdomain(product) === getdomain(d2) + @test variate_form(typeof(product)) === variate_form(typeof(dist)) + @test variate_form(typeof(product)) === variate_form(typeof(d2)) + @test value_support(typeof(product)) === value_support(typeof(dist)) + @test value_support(typeof(product)) === value_support(typeof(d2)) + @test support(product) === support(dist) + @test support(product) === support(d2) + @test isapprox(product, pt1, atol=1e-12) + + result = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> 2.0 * -x'x) + d4 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> 3.0 * -x'x) + + pr2 = prod(ProdAnalytical(), result, d4) + pt2 = ContinuousMultivariateLogPdf( + DomainSets.HalfLine()^2, (x) -> logpdf(result, x) + logpdf(d4, x) + ) + + @test getdomain(pr2) === getdomain(result) + @test getdomain(pr2) === getdomain(d4) + @test variate_form(typeof(pr2)) === variate_form(typeof(result)) + @test variate_form(typeof(pr2)) === variate_form(typeof(d4)) + @test value_support(typeof(pr2)) === value_support(typeof(result)) + @test value_support(typeof(pr2)) === value_support(typeof(d4)) + @test support(pr2) === support(result) + @test support(pr2) === support(d4) + @test isapprox(pr2, pt2, atol=1e-12) + + @test !isapprox(product, pr2; atol=1e-12) + + d5 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 2.0 * -x'x) + d6 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> 2.0 * -x'x) + @test_throws AssertionError prod(ProdAnalytical(), d5, d6) +end + +@testitem "ContinuousMultivariateLogPdf: vectorised-prod" begin + f = (x) -> 2.0 * -x'x + dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, f) + d2 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, f) + result = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> f(x) + f(x)) + + product = prod(ProdAnalytical(), dist, d2) + + @test product isa GenericLogPdfVectorisedProduct + @test getdomain(product) === getdomain(dist) + @test getdomain(product) === getdomain(d2) + @test variate_form(typeof(product)) === variate_form(typeof(dist)) + @test variate_form(typeof(product)) === variate_form(typeof(d2)) + @test value_support(typeof(product)) === value_support(typeof(dist)) + @test value_support(typeof(product)) === value_support(typeof(d2)) + @test support(product) === support(dist) + @test support(product) === support(d2) + + for point in [rand(Float64, 2) for _ in 1:10] + @test pdf(product, point) ≈ pdf(result, point) + @test logpdf(product, point) ≈ logpdf(result, point) + end +end + +@testitem "ContinuousMultivariateLogPdf: convert" begin + d = DomainSets.FullSpace()^2 + l = (x) -> 1.0 + + c = convert(ContinuousMultivariateLogPdf, d, l) + @test typeof(c) <: ContinuousMultivariateLogPdf + @test isapprox(c, ContinuousMultivariateLogPdf(d, l), atol=1e-12) + + c2 = convert(ContinuousMultivariateLogPdf, c) + @test typeof(c2) <: ContinuousMultivariateLogPdf + @test isapprox(c2, ContinuousMultivariateLogPdf(d, l), atol=1e-12) +end \ No newline at end of file diff --git a/test/statsfuns_tests.jl b/test/statsfuns_tests.jl index a5a7b01..f9bdda1 100644 --- a/test/statsfuns_tests.jl +++ b/test/statsfuns_tests.jl @@ -26,4 +26,31 @@ end end end +@testitem "UnspecifiedDomain" begin + using DomainSets + @test 1 ∈ UnspecifiedDomain() + @test (1, 1) ∈ UnspecifiedDomain() + @test [ 0, 1 ] ∈ UnspecifiedDomain() + + @test fuse_supports(UnspecifiedDomain(), UnspecifiedDomain()) === UnspecifiedDomain() + @test fuse_supports(RealLine(), UnspecifiedDomain()) === RealLine() + @test fuse_supports(UnspecifiedDomain(), RealLine()) === RealLine() +end + +@testitem "UnspecifiedDimension" begin + using DomainSets + + @test UnspecifiedDimension() == 1 + @test UnspecifiedDimension() == 2 + @test UnspecifiedDimension() != 1 + @test UnspecifiedDimension() != 2 +end + +@testitem "isequal_typeof" begin + @test !isequal_typeof(1, 1.0) + @test isequal_typeof(1.0, 1.0) + @test !isequal_typeof([ 1.0 ], 1.0) + @test !isequal_typeof([ 1.0 ], [ 1 ]) + @test isequal_typeof([ 1.0 ], [ 1.0 ]) +end