diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index 9517807e..cfef407b 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -69,12 +69,13 @@ end with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x) function truncated_inv_logabsdetjac(y, a, b) + y, a, b = promote(y, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded abs_y = abs(y) return log(b - a) - abs_y + 2 * LogExpFunctions.log1pexp(-abs_y) elseif lowerbounded || upperbounded - return convert(promote_type(typeof(y), typeof(a), typeof(b)), y) + return y else return zero(y) end @@ -82,10 +83,12 @@ end function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y) a, b = ib.orig.lb, ib.orig.ub - return truncated_inv_logabsdetjac.(y, a, b) + return sum(truncated_inv_logabsdetjac.(y, a, b)) end -with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y) = transform(ib, y), logabsdetjac(ib, y) +function with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y) + return transform(ib, y), logabsdetjac(ib, y) +end # It's only monotonically decreasing if it's only upper-bounded. # In the multivariate case, we can only say something reasonable if entries are monotonic.