Skip to content

Commit

Permalink
CompatHelper: bump compat for DimensionalData to 0.29, (keep existing…
Browse files Browse the repository at this point in the history
… compat) (#87)

* CompatHelper: bump compat for DimensionalData to 0.29, (keep existing compat)

* Use DimensionalData.maplayers if available

* Update test/mcmcdiagnostictools.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Increment patch number

---------

Co-authored-by: CompatHelper Julia <compathelper_noreply@julialang.org>
Co-authored-by: Seth Axen <seth@sethaxen.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 9, 2024
1 parent dd5f879 commit 3abecb2
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InferenceObjects"
uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00"
authors = ["Seth Axen <seth.axen@gmail.com> and contributors"]
version = "0.4.6"
version = "0.4.7"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand All @@ -23,7 +23,7 @@ InferenceObjectsPosteriorStatsExt = ["PosteriorStats", "StatsBase"]
[compat]
ArviZExampleData = "0.1.10"
Dates = "1.9"
DimensionalData = "0.27, 0.28"
DimensionalData = "0.27, 0.28, 0.29"
EvoTrees = "0.16"
MCMCDiagnosticTools = "0.3.4"
MLJBase = "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using DimensionalData: DimensionalData, Dimensions, LookupArrays
using InferenceObjects: InferenceObjects, Random
using MCMCDiagnosticTools: MCMCDiagnosticTools

maplayers = isdefined(DimensionalData, :maplayers) ? DimensionalData.maplayers : map

include("utils.jl")
include("bfmi.jl")
include("ess_rhat.jl")
Expand Down
2 changes: 1 addition & 1 deletion ext/InferenceObjectsMCMCDiagnosticToolsExt/ess_rhat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ end
for f in (:ess, :rhat)
@eval begin
function MCMCDiagnosticTools.$f(data::InferenceObjects.Dataset; kwargs...)
ds = map(data) do var
ds = maplayers(data) do var
return _as_dimarray(MCMCDiagnosticTools.$f(_params_array(var); kwargs...), var)
end
return DimensionalData.rebuild(ds; metadata=DimensionalData.NoMetadata())
Expand Down
2 changes: 1 addition & 1 deletion ext/InferenceObjectsMCMCDiagnosticToolsExt/mcse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function MCMCDiagnosticTools.mcse(data::InferenceObjects.InferenceData; kwargs..
return MCMCDiagnosticTools.mcse(data.posterior; kwargs...)
end
function MCMCDiagnosticTools.mcse(data::InferenceObjects.Dataset; kwargs...)
ds = map(data) do var
ds = maplayers(data) do var
return _as_dimarray(MCMCDiagnosticTools.mcse(_params_array(var); kwargs...), var)
end
return DimensionalData.rebuild(ds; metadata=DimensionalData.NoMetadata())
Expand Down
2 changes: 1 addition & 1 deletion ext/InferenceObjectsMCMCDiagnosticToolsExt/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ end
function MCMCDiagnosticTools.rstar(
rng::Random.AbstractRNG, clf, data::InferenceObjects.Dataset; kwargs...
)
data_array = cat(map(_as_3d_array _params_array, data)...; dims=3)
data_array = cat(maplayers(_as_3d_array _params_array, data)...; dims=3)
return MCMCDiagnosticTools.rstar(rng, clf, data_array; kwargs...)
end
function MCMCDiagnosticTools.rstar(
Expand Down
2 changes: 2 additions & 0 deletions src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ for f in [:data, :dims, :refdims, :metadata, :layerdims, :layermetadata]
end
end

DimensionalData.modify(f, s::Dataset) = Dataset(DimensionalData.modify(f, parent(s)))

# Warning: this is not an API function and probably should be implemented abstractly upstream
DimensionalData.show_after(io, mime, ::Dataset) = nothing

Expand Down
14 changes: 10 additions & 4 deletions test/mcmcdiagnostictools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ using Random
using Statistics
using Test

if !isdefined(DimensionalData, :maplayers)
maplayers = map
end

@testset "MCMCDiagnosticTools integration" begin
nchains, ndraws = 4, 10
sizes = (x=(), y=(2,), z=(3, 5))
Expand All @@ -16,12 +20,12 @@ using Test
dict1 = Dict(Symbol(k) => randn(ndraws, nchains, sz...) for (k, sz) in pairs(sizes))
idata1 = from_dict(dict1; dims, coords, sample_stats=Dict(:energy => energy))
# permute dimensions to test that diagnostics are invariant to dimension order
post2 = map(idata1.posterior) do var
post2 = maplayers(idata1.posterior) do var
n = ndims(var)
permdims = ((3:n)..., 2, 1)
return permutedims(var, permdims)
end
sample_stats2 = map(permutedims, idata1.sample_stats)
sample_stats2 = maplayers(permutedims, idata1.sample_stats)
idata2 = InferenceData(; posterior=post2, sample_stats=sample_stats2)

@testset for f in (ess, rhat, ess_rhat, mcse)
Expand All @@ -35,7 +39,7 @@ using Test
@test issetequal(keys(metric), keys(idata1.posterior))
@test metric == f(idata1.posterior; kind)
@test metric2 == f(idata2.posterior; kind)
@test all(map(, metric2, metric))
@test all(maplayers(, metric2, metric))
for k in keys(sizes)
@test all(
hasdim(
Expand Down Expand Up @@ -81,7 +85,9 @@ using Test
r4 = rstar(rng, classifier(rng), idata2.posterior; subset)
rng = Random.seed!(123)
post_mat = cat(
map(var -> reshape(parent(var), ndraws, nchains, :), idata1.posterior)...;
maplayers(
var -> reshape(parent(var), ndraws, nchains, :), idata1.posterior
)...;
dims=3,
)
r5 = rstar(rng, classifier(rng), post_mat; subset)
Expand Down

2 comments on commit 3abecb2

@sethaxen
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119074

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.7 -m "<description of version>" 3abecb22b3451f8f946461fb88e6fcbc95deb79b
git push origin v0.4.7

Please sign in to comment.