Skip to content

Commit

Permalink
wip: add arbitrary logpdf
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Oct 11, 2023
1 parent a8c02ed commit fb4d2f0
Show file tree
Hide file tree
Showing 9 changed files with 736 additions and 43 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.9'
- 'nightly'
os:
Expand Down
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "1.0.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -14,11 +15,12 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"

[compat]
Distributions = "0.25"
DomainSets = "0.7"
Random = "1.9"
SpecialFunctions = "2.3"
Statistics = "1.9"
StatsAPI = "1.7"
StatsBase = "0.34"
SpecialFunctions = "2.3"
TinyHugeNumbers = "1.0"
julia = "1.9"

Expand All @@ -27,8 +29,9 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "CpuId", "Test", "ReTestItems", "LinearAlgebra", "StableRNGs"]
test = ["Aqua", "CpuId", "Test", "ReTestItems", "LinearAlgebra", "StableRNGs", "HCubature"]
7 changes: 6 additions & 1 deletion src/BayesBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BayesBase

using TinyHugeNumbers

using StatsAPI, StatsBase, Statistics, Distributions, Random
using StatsAPI, StatsBase, DomainSets, Statistics, Distributions, Random

using StatsAPI: params

Expand Down Expand Up @@ -65,6 +65,10 @@ export failprob,
variate_form,
value_support

using DomainSets: dimension, Domain

export dimension, Domain

using Base: precision, prod, prod!

export precision, prod, prod!
Expand All @@ -78,5 +82,6 @@ include("promotion.jl")
include("prod.jl")

include("densities/factorizedjoint.jl")
include("densities/continouslogpdf.jl")

end
222 changes: 222 additions & 0 deletions src/densities/continouslogpdf.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
export ContinuousUnivariateLogPdf, ContinuousMultivariateLogPdf

# import DomainIntegrals
# import HCubature

import Base: isapprox, in

abstract type AbstractContinuousGenericLogPdf end

getdomain(dist::AbstractContinuousGenericLogPdf) = dist.domain
getlogpdf(dist::AbstractContinuousGenericLogPdf) = dist.logpdf

BayesBase.value_support(::Type{<:AbstractContinuousGenericLogPdf}) = Continuous
BayesBase.value_support(::AbstractContinuousGenericLogPdf) = Continuous

# We throw an error on purpose, since we do not want to use `AbstractContinuousGenericLogPdf` much without approximations
# We want to encourage a user to use approximate generic log-pdfs as much as possible instead
function __error_genericlogpdf_not_defined(dist::AbstractContinuousGenericLogPdf, f::Symbol)
return error(
"`$f` is not defined for `$(dist)`. Use functional form constraints to approximate the resulting generic log-pdf object and to use it in the inference procedure.",
)
end

function BayesBase.mean(dist::AbstractContinuousGenericLogPdf)
return __error_genericlogpdf_not_defined(dist, :mean)
end
function BayesBase.median(dist::AbstractContinuousGenericLogPdf)
return __error_genericlogpdf_not_defined(dist, :median)
end
function BayesBase.mode(dist::AbstractContinuousGenericLogPdf)
return __error_genericlogpdf_not_defined(dist, :mode)
end
function BayesBase.var(dist::AbstractContinuousGenericLogPdf)
return __error_genericlogpdf_not_defined(dist, :var)
end
function BayesBase.std(dist::AbstractContinuousGenericLogPdf)
return __error_genericlogpdf_not_defined(dist, :std)
end
function BayesBase.cov(dist::AbstractContinuousGenericLogPdf)
return __error_genericlogpdf_not_defined(dist, :cov)
end
function BayesBase.invcov(dist::AbstractContinuousGenericLogPdf)
return __error_genericlogpdf_not_defined(dist, :invcov)
end
function BayesBase.entropy(dist::AbstractContinuousGenericLogPdf)
return __error_genericlogpdf_not_defined(dist, :entropy)
end

function Base.precision(dist::AbstractContinuousGenericLogPdf)
return __error_genericlogpdf_not_defined(dist, :precision)
end

Base.eltype(dist::AbstractContinuousGenericLogPdf) = eltype(getdomain(dist))

BayesBase.paramfloattype(dist::AbstractContinuousGenericLogPdf) = deep_eltype(eltype(dist))
BayesBase.samplefloattype(dist::AbstractContinuousGenericLogPdf) = paramfloattype(dist)

(dist::AbstractContinuousGenericLogPdf)(x::Real) = logpdf(dist, x)
(dist::AbstractContinuousGenericLogPdf)(x::AbstractVector{<:Real}) = logpdf(dist, x)

function BayesBase.logpdf(dist::AbstractContinuousGenericLogPdf, x)
@assert x getdomain(dist) "x = $(x) does not belong to the domain ($(getdomain(dist))) of $dist"
lpdf = getlogpdf(dist)
return lpdf(x)
end

# We don't expect neither `pdf` nor `logpdf` to be normalised
BayesBase.pdf(dist::AbstractContinuousGenericLogPdf, x) = exp(logpdf(dist, x))

"""
ContinuousUnivariateLogPdf{ D <: DomainSets.Domain, F } <: AbstractContinuousGenericLogPdf
Generic continuous univariate distribution in a form of domain specification and logpdf function. Can be used in cases where no
known analytical distribution available.
# Arguments
- `domain`: domain specificatiom from `DomainSets.jl` package, by default the `domain` is set to `DomainSets.FullSpace()`. Use `BayesBase.UnspecifiedDomain()` to bypass domain checks.
- `logpdf`: callable object that represents the logdensity. Can be un-normalised.
"""
struct ContinuousUnivariateLogPdf{D<:DomainSets.Domain,F} <: AbstractContinuousGenericLogPdf
domain::D
logpdf::F

function ContinuousUnivariateLogPdf(domain::D, logpdf::F) where {D,F}
@assert dimension(domain) == 1 "Cannot create ContinuousUnivariateLogPdf. Dimension of domain = $(domain) is not equal to 1."
return new{D,F}(domain, logpdf)
end
end

function ContinuousUnivariateLogPdf(f::Function)
return ContinuousUnivariateLogPdf(DomainSets.FullSpace(), f)
end

BayesBase.variate_form(::Type{<:ContinuousUnivariateLogPdf}) = Univariate
BayesBase.variate_form(::ContinuousUnivariateLogPdf) = Univariate

function BayesBase.promote_variate_type(
::Type{Univariate}, ::Type{AbstractContinuousGenericLogPdf}
)
return ContinuousUnivariateLogPdf
end

function Base.show(io::IO, dist::ContinuousUnivariateLogPdf)
return print(io, "ContinuousUnivariateLogPdf(", getdomain(dist), ")")
end
function Base.show(io::IO, ::Type{<:ContinuousUnivariateLogPdf{D}}) where {D}
return print(io, "ContinuousUnivariateLogPdf{", D, "}")
end

function BayesBase.support(dist::ContinuousUnivariateLogPdf)
return getdomain(dist)
end

BayesBase.insupport(dist::ContinuousUnivariateLogPdf, x) = x getdomain(dist)

# Fallback for various optimisation packages which may pass arguments as vectors
function BayesBase.logpdf(dist::ContinuousUnivariateLogPdf, x::AbstractVector{<:Real})
@assert length(x) === 1 "`ContinuousUnivariateLogPdf` expects either float or a vector of a single float as an input for the `logpdf` function."
return logpdf(dist, first(x))
end

function Base.convert(
::Type{<:ContinuousUnivariateLogPdf}, domain::D, logpdf::F
) where {D<:DomainSets.Domain,F}
return ContinuousUnivariateLogPdf(domain, logpdf)
end

function BayesBase.convert_paramfloattype(
::Type{T}, dist::ContinuousUnivariateLogPdf
) where {T<:Real}
return convert(
ContinuousUnivariateLogPdf,
dist.domain,
(x) -> dist.logpdf(convert_paramfloattype(T, x)),
)
end

function BayesBase.vague(::Type{<:ContinuousUnivariateLogPdf})
return ContinuousUnivariateLogPdf(DomainSets.FullSpace(), (x) -> 1)
end

# We do not check typeof of a different functions because in most of the cases lambdas have different types, but they can still be the same
function BayesBase.isequal_typeof(
::ContinuousUnivariateLogPdf{D,F1}, ::ContinuousUnivariateLogPdf{D,F2}
) where {D,F1<:Function,F2<:Function}
return true
end

##

"""
ContinuousMultivariateLogPdf{ D <: DomainSets.Domain, F } <: AbstractContinuousGenericLogPdf
Generic continuous multivariate distribution in a form of domain specification and logpdf function. Can be used in cases where no
known analytical distribution available.
# Arguments
- `domain`: multidimensional domain specification from `DomainSets.jl` package. Use `BayesBase.UnspecifiedDomain()` to bypass domain checks.
- `logpdf`: callable object that accepts an `AbstractVector` as an input and represents the logdensity. Can be un-normalised.
"""
struct ContinuousMultivariateLogPdf{D<:DomainSets.Domain,F} <:
AbstractContinuousGenericLogPdf
domain::D
logpdf::F

function ContinuousMultivariateLogPdf(
domain::D, logpdf::F
) where {D<:DomainSets.Domain,F}
@assert DomainSets.dimension(domain) !== 1 "Cannot create ContinuousMultivariateLogPdf. Dimension of domain = $(domain) should not be equal to 1. Use, for example, `DomainSets.FullSpace() ^ 2` to create 2-dimensional full space domain."
return new{D,F}(domain, logpdf)
end
end

BayesBase.variate_form(::Type{<:ContinuousMultivariateLogPdf}) = Multivariate
BayesBase.variate_form(::ContinuousMultivariateLogPdf) = Multivariate

function BayesBase.promote_variate_type(
::Type{Multivariate}, ::Type{AbstractContinuousGenericLogPdf}
)
return ContinuousMultivariateLogPdf
end

function ContinuousMultivariateLogPdf(dims::Int, f::Function)
return ContinuousMultivariateLogPdf(DomainSets.FullSpace()^dims, f)
end

function Base.show(io::IO, dist::ContinuousMultivariateLogPdf)
return print(io, "ContinuousMultivariateLogPdf(", getdomain(dist), ")")
end
function Base.show(io::IO, ::Type{<:ContinuousMultivariateLogPdf{D}}) where {D}
return print(io, "ContinuousMultivariateLogPdf{", D, "}")
end

BayesBase.support(dist::ContinuousMultivariateLogPdf) = getdomain(dist)
BayesBase.insupport(dist::ContinuousMultivariateLogPdf, x) = x getdomain(dist)

function Base.convert(
::Type{<:ContinuousMultivariateLogPdf}, domain::D, logpdf::F
) where {D<:DomainSets.Domain,F}
return ContinuousMultivariateLogPdf(domain, logpdf)
end

function BayesBase.convert_paramfloattype(
::Type{T}, dist::ContinuousMultivariateLogPdf
) where {T<:Real}
return convert(
ContinuousMultivariateLogPdf,
dist.domain,
(x) -> dist.logpdf(convert_paramfloattype(T, x)),
)
end

function BayesBase.vague(::Type{<:ContinuousMultivariateLogPdf}, dims::Int)
return ContinuousMultivariateLogPdf(DomainSets.FullSpace()^dims, (x) -> 1)
end

# We do not check typeof of a different functions because in most of the cases lambdas have different types, but they can still be the same
function BayesBase.isequal_typeof(
::ContinuousMultivariateLogPdf{D,F1}, ::ContinuousMultivariateLogPdf{D,F2}
) where {D,F1<:Function,F2<:Function}
return true
end
4 changes: 2 additions & 2 deletions src/densities/factorizedjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ Base.@propagate_inbounds function Base.getindex(joint::FactorizedJoint, i::Int)
return getindex(components(joint), i)
end

BayesBase.length(joint::FactorizedJoint) = length(joint.multipliers)
Base.length(joint::FactorizedJoint) = length(joint.multipliers)

function BayesBase.isapprox(x::FactorizedJoint, y::FactorizedJoint; kwargs...)
function Base.isapprox(x::FactorizedJoint, y::FactorizedJoint; kwargs...)
return length(x) === length(y) && all(
tuple -> isapprox(tuple[1], tuple[2]; kwargs...),
zip(components(x), components(y)),
Expand Down
Loading

0 comments on commit fb4d2f0

Please sign in to comment.