Skip to content

Commit

Permalink
get local stan tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Sep 2, 2024
1 parent f5eb2a4 commit a6d4e80
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 51 deletions.
11 changes: 9 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# primarycensoreddist 0.0.0.1000
# primarycensoreddist 0.1.0.1000

# primarycensoreddist 0.1.0

* Added package skeleton.
* Added prototype stan functions for censored distributions.
* Added checking input functions.
* Added stan functions for primary censored and truncated distributions.
* Added R functions for primary censored and truncated distributions.
* Added tests for primary censored and truncated distributions.
* Added tests to compare R and Stan implementations.
* Resolved R CMD check errors, warnings and notes.
8 changes: 8 additions & 0 deletions R/dprimarycensoreddist.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ dprimarycensoreddist <- function(
check_pdist(pdist, D, ...)
check_dprimary(dprimary, pwindow, dprimary_args)

if (max(x + swindow) > D) {
stop(
"Upper truncation point is greater than D. It is ", max(x + swindow),
" and D is ", D, ". Resolve this by increasing D to be the maximum",
" of x + swindow."
)
}

result <- vapply(x, function(d) {
if (d < 0) {
return(0) # Return log(0) for non-positive delays
Expand Down
84 changes: 45 additions & 39 deletions inst/stan/functions/primary_censored_dist.stan
Original file line number Diff line number Diff line change
Expand Up @@ -99,47 +99,47 @@ real primary_dist_lpdf(real x, int primary_dist_id, array[] real params, real mi
* p, xc, theta, x_r, x_i
* );
*/
real primary_censored_integrand(real x, real xc, array[] real theta,
array[] real x_r, array[] int x_i) {
// Unpack parameters
real d = theta[1];
int dist_id = x_i[1];
int primary_dist_id = x_i[2];
real pwindow = x_r[1];
real D = x_r[2];
int dist_params_len = x_i[3];
int primary_params_len = x_i[4];
real primary_censored_integrand(real x, real xc, array[] real theta,
array[] real x_r, array[] int x_i) {
// Unpack parameters
real d = theta[1];
int dist_id = x_i[1];
int primary_dist_id = x_i[2];
real pwindow = x_r[1];
real D = x_r[2];
int dist_params_len = x_i[3];
int primary_params_len = x_i[4];

// Extract distribution parameters
array[dist_params_len] real params;
if (dist_params_len) {
params = theta[2:(1 + dist_params_len)];
}
array[primary_params_len] real primary_params;
if (primary_params_len) {
int theta_len = size(theta);
primary_params = theta[(theta_len - dist_params_len + 1):theta_len];
}
// Extract distribution parameters
array[dist_params_len] real params;
if (dist_params_len) {
params = theta[2:(1 + dist_params_len)];
}
array[primary_params_len] real primary_params;
if (primary_params_len) {
int theta_len = size(theta);
primary_params = theta[(theta_len - dist_params_len + 1):theta_len];
}

// Compute adjusted delay
real d_adj = d - x;
// Compute adjusted delay
real d_adj = d - x;

// Compute log probabilities
real log_cdf = dist_lcdf(d_adj | params, dist_id);
real log_primary_pdf = primary_dist_lpdf(
x | primary_dist_id, primary_params, 0, pwindow
);
// Compute log probabilities
real log_cdf = dist_lcdf(d_adj | params, dist_id);
real log_primary_pdf = primary_dist_lpdf(
x | primary_dist_id, primary_params, 0, pwindow
);

if (is_inf(D)) {
// No truncation
return exp(log_cdf + log_primary_pdf);
}else{
// Truncate at D
real D_adj = D - x;
real log_cdf_D = dist_lcdf(D_adj | params, dist_id);
return exp(log_cdf - log_cdf_D + log_primary_pdf);
}
if (is_inf(D)) {
// No truncation
return exp(log_cdf + log_primary_pdf);
} else {
// Truncate at D
real D_adj = D - x;
real log_cdf_D = dist_lcdf(D_adj | params, dist_id);
return exp(log_cdf - log_cdf_D + log_primary_pdf);
}
}

/**
* Compute the primary event censored CDF for a single delay
Expand Down Expand Up @@ -172,7 +172,7 @@ real primary_censored_dist_cdf(real d, int dist_id, array[] real params,
int primary_dist_id,
array[] real primary_params) {
real result;
if (d <= 0 || d >= D) {
if (d <= 0 || d > D) {
return 0;
}

Expand Down Expand Up @@ -219,7 +219,7 @@ real primary_censored_dist_lcdf(real d, int dist_id, array[] real params,
data real pwindow, data real D,
int primary_dist_id,
array[] real primary_params) {
if (d <= 0 || d >= D) {
if (d <= 0 || d > D) {
return negative_infinity();
}
return log(
Expand Down Expand Up @@ -260,8 +260,14 @@ real primary_censored_dist_lcdf(real d, int dist_id, array[] real params,
real primary_censored_dist_lpmf(int d, int dist_id, array[] real params,
data real pwindow, real swindow, data real D,
int primary_dist_id, array[] real primary_params) {

real d_upper = d + swindow;
if (d_upper > D) {
reject("Upper truncation point is greater than D. It is ", d_upper,
" and D is ", D, ". Resolve this by increasing D to be greater or equal to d + swindow.");
}
real log_cdf_upper = primary_censored_dist_lcdf(
d + swindow | dist_id, params, pwindow, D, primary_dist_id, primary_params
d_upper | dist_id, params, pwindow, D, primary_dist_id, primary_params
);
real log_cdf_lower = primary_censored_dist_lcdf(
d | dist_id, params, pwindow, D, primary_dist_id, primary_params
Expand Down
23 changes: 23 additions & 0 deletions man/dot-extract_stan_functions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 14 additions & 5 deletions man/pcd_stan_functions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ if (
!on_ci() || (on_ci() && Sys.info()["sysname"] == "Linux" && not_on_cran())
) {
library(cmdstanr) # nolint
temp_path <- file.path(tempdir(), "pcd_stan_functions.stan")
stan_functions <- pcd_load_stan_functions(
stan_path = "inst/stan/functions",
wrap_in_block = TRUE,
write_to_file = TRUE,
output_file = file.path("pcd_stan_functions.stan")
output_file = temp_path
)
model <- suppressMessages(suppressWarnings(cmdstanr::cmdstan_model(
file.path("pcd_stan_functions.stan")
model <- suppressMessages(suppressWarnings(cmdstanr::cmdstan_model( # nolint
temp_path
)))
model$expose_functions(global = TRUE)
}
15 changes: 15 additions & 0 deletions tests/testthat/test-dprimarycensoreddist.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,18 @@ test_that("dprimarycensoreddist matches difference of pprimarycensoreddist", {

expect_equal(pmf, cdf_diff, tolerance = 1e-6)
})

test_that("dprimarycensoreddist throws an error for invalid upper truncation point", {
d <- 10
pwindow <- 1
swindow <- 1
D <- 10

expect_error(
dpcens(
d, plnorm, pwindow, swindow, D,
meanlog = 0, sdlog = 1
),
"Upper truncation point is greater than D"
)
})
14 changes: 14 additions & 0 deletions tests/testthat/test-rprimarycensoreddist.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,17 @@ test_that("rprimarycensoreddist generates samples within the correct range", {

expect_true(all(samples >= 0 & samples < D))
})

test_that("rprimarycensoreddist handles different primary distributions", {
n <- 1000
pwindow <- 5
D <- 10
r <- 0.5
samples <- rpcens(
n, rlnorm, pwindow,
D = D, rprimary = rexpgrowth, rprimary_args = list(r = r),
meanlog = 0, sdlog = 1
)

expect_true(all(samples >= 0 & samples < D))
})
53 changes: 52 additions & 1 deletion tests/testthat/test-stan-rpd-primarycensoreddist.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,55 @@ test_that(
}
)

test_that(
"Stan primary_censored_dist_lpmf throws an error for invalid upper truncation
point",
{
d <- 10
dist_id <- 1 # Lognormal
params <- c(0, 1) # meanlog, sdlog
pwindow <- 1
swindow <- 1
D <- 10
primary_dist_id <- 1 # Uniform
primary_params <- numeric(0)

expect_error(
primary_censored_dist_lpmf(
d, dist_id, params, pwindow, swindow, D, primary_dist_id, primary_params
),
"Upper truncation point is greater than D"
)
}
)



test_that(
"Stan primary_censored_dist matches R primarycensoreddist when d is the same
as D - 1",
{
d <- 10
dist_id <- 1 # Lognormal
params <- c(0, 1) # meanlog, sdlog
pwindow <- 1
primary_dist_id <- 1 # Uniform
primary_params <- numeric(0)

stan_pmf <- primary_censored_dist_pmf(
d, dist_id, params, pwindow, 1, d + 1, primary_dist_id, primary_params
)
r_pmf <- dprimarycensoreddist(
d, plnorm,
pwindow = pwindow, swindow = 1, D = d + 1,
meanlog = params[1], sdlog = params[2]
)

expect_equal(stan_pmf, r_pmf, tolerance = 1e-6)
}
)


test_that("Stan primary_censored_dist_pmf matches R dprimarycensoreddist", {
d <- 0:10
dist_id <- 1 # Lognormal
Expand Down Expand Up @@ -217,6 +266,8 @@ test_that(

expect_equal(stan_lpmf_approx, r_lpmf, tolerance = 1e-6)
expect_equal(stan_lpmf_exact, r_lpmf, tolerance = 1e-8)
expect_true(all(abs(stan_lpmf_exact - r_lpmf) < abs(stan_lpmf_approx - r_lpmf)))
expect_true(
all(abs(stan_lpmf_exact - r_lpmf) <= abs(stan_lpmf_approx - r_lpmf))
)
}
)

0 comments on commit a6d4e80

Please sign in to comment.