Skip to content

Commit

Permalink
Test single-stream output
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Dec 11, 2024
1 parent 350d24a commit 6bb6999
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
22 changes: 18 additions & 4 deletions R/distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
## because it's quite gross. Mostly it is trying to tidy up some of
## the ways that we might draw from multivariate normal distributions.
## This is complicated by wanting to cache the results of the vcv
## decomposition where possible.
## decomposition where possible. We don't expose this anywhere to the
## user, and doing so is difficult because we'd need a thread-safe way
## of doing matrix multiplication (and possibly Cholesky
## factorisation) but this involves LAPACK (and therefore linking to
## libfortran) and is not guaranteed to be thread-safe.

## It's also tangled up with the distribution support, being different
## to most distributions in having vector output and *matrix*
## input. The most effficient way of using this really requires that
## we have the decomposition cached wherever possible, so it does not
## neatly fit into our usual approach at all.

## This is the form of the Cholesky factorisation of a matrix we use
## in the multivariate normal sampling.
Expand Down Expand Up @@ -48,9 +58,13 @@ make_rmvnorm <- function(vcv) {
}
function(rng) {
n_streams <- length(rng)
stopifnot(any(n_streams == c(1, m)))
len <- if (n_streams == 1) n * m else n
rand <- monty_random_n_normal(len, 0, 1, rng)
if (n_streams == 1) {
rand <- matrix(monty_random_n_normal(n * m, 0, 1, rng), n, m)
} else if (n_streams == m) {
rand <- monty_random_n_normal(n, 0, 1, rng)
} else {
stop("Invalid input in rng")
}
if (n == 1) {
ret <- drop(rand) * r
} else {
Expand Down
36 changes: 32 additions & 4 deletions tests/testthat/test-distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ test_that("Can draw samples from many centred single-variable MVNs", {
r1 <- monty_rng_create(seed = 42, n_streams = 4)
r2 <- monty_rng_create(seed = 42, n_streams = 4)
vcv <- array(1, c(1, 1, 4))
x <- runif(4)
expect_equal(make_rmvnorm(vcv)(r2),
monty_random_normal(0, 1, r1))

expect_equal(make_rmvnorm(vcv)(r2),
monty_random_normal(0, 1, r1))
Expand All @@ -78,7 +75,6 @@ test_that("Can draw samples from many centred single-variable MVNs", {
test_that("Can draw samples from many bivariate MVNs", {
r1 <- monty_rng_create(seed = 42, n_streams = 5)
r2 <- monty_rng_create(seed = 42, n_streams = 5)
r3 <- monty_rng_create(seed = 42, n_streams = 5)
vcv <- array(0, c(2, 2, 5))
set.seed(1)
vcv[1, 1, ] <- 1:5
Expand All @@ -96,3 +92,35 @@ test_that("Can draw samples from many bivariate MVNs", {
}, numeric(2))
expect_identical(y, z)
})


test_that("can draw from single-variable mvns with a single stream", {
r1 <- monty_rng_create(seed = 42)
r2 <- monty_rng_create(seed = 42)
vcv <- array(1, c(1, 1, 4))

expect_equal(make_rmvnorm(vcv)(r2),
monty_random_n_normal(4, 0, 1, r1))
expect_equal(make_rmvnorm(0.1 * vcv)(r2),
monty_random_n_normal(4, 0, 1, r1) * sqrt(0.1))
expect_equal(make_rmvnorm(0.1 * 1:4 * vcv)(r2),
monty_random_n_normal(4, 0, 1, r1) * sqrt(0.1 * 1:4))
})


test_that("Can draw samples from many bivariate MVNs with a single stream", {
r1 <- monty_rng_create(seed = 42)
r2 <- monty_rng_create(seed = 42)
vcv <- array(0, c(2, 2, 5))
set.seed(1)
vcv[1, 1, ] <- 1:5
vcv[2, 2, ] <- 1
vcv[1, 2, ] <- vcv[2, 1, ] <- rnorm(5, 0, 0.1)

y <- make_rmvnorm(vcv)(r2)
expect_equal(dim(y), c(2, 5))

## A bit of work do do these separately:
z <- vapply(1:5, function(i) make_rmvnorm(vcv[, , i])(r1), numeric(2))
expect_identical(y, z)
})

0 comments on commit 6bb6999

Please sign in to comment.