From a6d4e8039db23954fff0f799a9b6ff0c3a9578f8 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Sep 2024 11:28:58 +0100 Subject: [PATCH] get local stan tests passing --- NEWS.md | 11 ++- R/dprimarycensoreddist.R | 8 ++ .../stan/functions/primary_censored_dist.stan | 84 ++++++++++--------- man/dot-extract_stan_functions.Rd | 23 +++++ man/pcd_stan_functions.Rd | 19 +++-- tests/testthat/setup.R | 8 +- tests/testthat/test-dprimarycensoreddist.R | 15 ++++ tests/testthat/test-rprimarycensoreddist.R | 14 ++++ .../test-stan-rpd-primarycensoreddist.R | 53 +++++++++++- 9 files changed, 184 insertions(+), 51 deletions(-) create mode 100644 man/dot-extract_stan_functions.Rd diff --git a/NEWS.md b/NEWS.md index 3301802..5c11238 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,4 +1,11 @@ -# primarycensoreddist 0.0.0.1000 +# primarycensoreddist 0.1.0.1000 + +# primarycensoreddist 0.1.0 * Added package skeleton. -* Added prototype stan functions for censored distributions. +* Added checking input functions. +* Added stan functions for primary censored and truncated distributions. +* Added R functions for primary censored and truncated distributions. +* Added tests for primary censored and truncated distributions. +* Added tests to compare R and Stan implementations. +* Resolved R CMD check errors, warnings and notes. diff --git a/R/dprimarycensoreddist.R b/R/dprimarycensoreddist.R index 3ac3a18..aec9e79 100644 --- a/R/dprimarycensoreddist.R +++ b/R/dprimarycensoreddist.R @@ -36,6 +36,14 @@ dprimarycensoreddist <- function( check_pdist(pdist, D, ...) check_dprimary(dprimary, pwindow, dprimary_args) + if (max(x + swindow) > D) { + stop( + "Upper truncation point is greater than D. It is ", max(x + swindow), + " and D is ", D, ". Resolve this by increasing D to be the maximum", + " of x + swindow." + ) + } + result <- vapply(x, function(d) { if (d < 0) { return(0) # Return log(0) for non-positive delays diff --git a/inst/stan/functions/primary_censored_dist.stan b/inst/stan/functions/primary_censored_dist.stan index 90ee158..3f31b69 100644 --- a/inst/stan/functions/primary_censored_dist.stan +++ b/inst/stan/functions/primary_censored_dist.stan @@ -99,47 +99,47 @@ real primary_dist_lpdf(real x, int primary_dist_id, array[] real params, real mi * p, xc, theta, x_r, x_i * ); */ - real primary_censored_integrand(real x, real xc, array[] real theta, - array[] real x_r, array[] int x_i) { - // Unpack parameters - real d = theta[1]; - int dist_id = x_i[1]; - int primary_dist_id = x_i[2]; - real pwindow = x_r[1]; - real D = x_r[2]; - int dist_params_len = x_i[3]; - int primary_params_len = x_i[4]; +real primary_censored_integrand(real x, real xc, array[] real theta, + array[] real x_r, array[] int x_i) { + // Unpack parameters + real d = theta[1]; + int dist_id = x_i[1]; + int primary_dist_id = x_i[2]; + real pwindow = x_r[1]; + real D = x_r[2]; + int dist_params_len = x_i[3]; + int primary_params_len = x_i[4]; - // Extract distribution parameters - array[dist_params_len] real params; - if (dist_params_len) { - params = theta[2:(1 + dist_params_len)]; - } - array[primary_params_len] real primary_params; - if (primary_params_len) { - int theta_len = size(theta); - primary_params = theta[(theta_len - dist_params_len + 1):theta_len]; - } + // Extract distribution parameters + array[dist_params_len] real params; + if (dist_params_len) { + params = theta[2:(1 + dist_params_len)]; + } + array[primary_params_len] real primary_params; + if (primary_params_len) { + int theta_len = size(theta); + primary_params = theta[(theta_len - dist_params_len + 1):theta_len]; + } - // Compute adjusted delay - real d_adj = d - x; + // Compute adjusted delay + real d_adj = d - x; - // Compute log probabilities - real log_cdf = dist_lcdf(d_adj | params, dist_id); - real log_primary_pdf = primary_dist_lpdf( - x | primary_dist_id, primary_params, 0, pwindow - ); + // Compute log probabilities + real log_cdf = dist_lcdf(d_adj | params, dist_id); + real log_primary_pdf = primary_dist_lpdf( + x | primary_dist_id, primary_params, 0, pwindow + ); - if (is_inf(D)) { - // No truncation - return exp(log_cdf + log_primary_pdf); - }else{ - // Truncate at D - real D_adj = D - x; - real log_cdf_D = dist_lcdf(D_adj | params, dist_id); - return exp(log_cdf - log_cdf_D + log_primary_pdf); - } + if (is_inf(D)) { + // No truncation + return exp(log_cdf + log_primary_pdf); + } else { + // Truncate at D + real D_adj = D - x; + real log_cdf_D = dist_lcdf(D_adj | params, dist_id); + return exp(log_cdf - log_cdf_D + log_primary_pdf); } +} /** * Compute the primary event censored CDF for a single delay @@ -172,7 +172,7 @@ real primary_censored_dist_cdf(real d, int dist_id, array[] real params, int primary_dist_id, array[] real primary_params) { real result; - if (d <= 0 || d >= D) { + if (d <= 0 || d > D) { return 0; } @@ -219,7 +219,7 @@ real primary_censored_dist_lcdf(real d, int dist_id, array[] real params, data real pwindow, data real D, int primary_dist_id, array[] real primary_params) { - if (d <= 0 || d >= D) { + if (d <= 0 || d > D) { return negative_infinity(); } return log( @@ -260,8 +260,14 @@ real primary_censored_dist_lcdf(real d, int dist_id, array[] real params, real primary_censored_dist_lpmf(int d, int dist_id, array[] real params, data real pwindow, real swindow, data real D, int primary_dist_id, array[] real primary_params) { + + real d_upper = d + swindow; + if (d_upper > D) { + reject("Upper truncation point is greater than D. It is ", d_upper, + " and D is ", D, ". Resolve this by increasing D to be greater or equal to d + swindow."); + } real log_cdf_upper = primary_censored_dist_lcdf( - d + swindow | dist_id, params, pwindow, D, primary_dist_id, primary_params + d_upper | dist_id, params, pwindow, D, primary_dist_id, primary_params ); real log_cdf_lower = primary_censored_dist_lcdf( d | dist_id, params, pwindow, D, primary_dist_id, primary_params diff --git a/man/dot-extract_stan_functions.Rd b/man/dot-extract_stan_functions.Rd new file mode 100644 index 0000000..a0d01ae --- /dev/null +++ b/man/dot-extract_stan_functions.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/pcd-stan-tools.R +\name{.extract_stan_functions} +\alias{.extract_stan_functions} +\title{Extract function names or content from Stan code} +\usage{ +.extract_stan_functions(content, extract_names = TRUE, func_name = NULL) +} +\arguments{ +\item{content}{Character vector containing Stan code} + +\item{extract_names}{Logical, if TRUE extract function names, otherwise +extract function content} + +\item{func_name}{Optional, function name to extract content for} +} +\value{ +Character vector of function names or content +} +\description{ +Extract function names or content from Stan code +} +\keyword{internal} diff --git a/man/pcd_stan_functions.Rd b/man/pcd_stan_functions.Rd index 6cc0b77..795d68f 100644 --- a/man/pcd_stan_functions.Rd +++ b/man/pcd_stan_functions.Rd @@ -2,17 +2,26 @@ % Please edit documentation in R/pcd-stan-tools.R \name{pcd_stan_functions} \alias{pcd_stan_functions} -\title{List available Stan functions} +\title{Get Stan function names from Stan files} \usage{ pcd_stan_functions(stan_path = primarycensoreddist::pcd_stan_path()) } \arguments{ -\item{stan_path}{Character string, the path to the Stan code. Defaults to the -path to the Stan code in the primarycensoreddist package.} +\item{stan_path}{Character string specifying the path to the directory +containing Stan files. Defaults to the Stan path of the primarycensoreddist +package.} } \value{ -A character vector of available Stan function names +A character vector containing unique names of all functions found in +the Stan files. } \description{ -List available Stan functions +This function reads all Stan files in the specified directory and extracts +the names of all functions defined in those files. +} +\examples{ +\dontrun{ +stan_functions <- pcd_stan_functions() +print(stan_functions) +} } diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 46823e7..61e6857 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -6,14 +6,14 @@ if ( !on_ci() || (on_ci() && Sys.info()["sysname"] == "Linux" && not_on_cran()) ) { library(cmdstanr) # nolint + temp_path <- file.path(tempdir(), "pcd_stan_functions.stan") stan_functions <- pcd_load_stan_functions( - stan_path = "inst/stan/functions", wrap_in_block = TRUE, write_to_file = TRUE, - output_file = file.path("pcd_stan_functions.stan") + output_file = temp_path ) - model <- suppressMessages(suppressWarnings(cmdstanr::cmdstan_model( - file.path("pcd_stan_functions.stan") + model <- suppressMessages(suppressWarnings(cmdstanr::cmdstan_model( # nolint + temp_path ))) model$expose_functions(global = TRUE) } diff --git a/tests/testthat/test-dprimarycensoreddist.R b/tests/testthat/test-dprimarycensoreddist.R index ec7ff0e..cb59509 100644 --- a/tests/testthat/test-dprimarycensoreddist.R +++ b/tests/testthat/test-dprimarycensoreddist.R @@ -56,3 +56,18 @@ test_that("dprimarycensoreddist matches difference of pprimarycensoreddist", { expect_equal(pmf, cdf_diff, tolerance = 1e-6) }) + +test_that("dprimarycensoreddist throws an error for invalid upper truncation point", { + d <- 10 + pwindow <- 1 + swindow <- 1 + D <- 10 + + expect_error( + dpcens( + d, plnorm, pwindow, swindow, D, + meanlog = 0, sdlog = 1 + ), + "Upper truncation point is greater than D" + ) +}) diff --git a/tests/testthat/test-rprimarycensoreddist.R b/tests/testthat/test-rprimarycensoreddist.R index 6359e22..eb4fd03 100644 --- a/tests/testthat/test-rprimarycensoreddist.R +++ b/tests/testthat/test-rprimarycensoreddist.R @@ -9,3 +9,17 @@ test_that("rprimarycensoreddist generates samples within the correct range", { expect_true(all(samples >= 0 & samples < D)) }) + +test_that("rprimarycensoreddist handles different primary distributions", { + n <- 1000 + pwindow <- 5 + D <- 10 + r <- 0.5 + samples <- rpcens( + n, rlnorm, pwindow, + D = D, rprimary = rexpgrowth, rprimary_args = list(r = r), + meanlog = 0, sdlog = 1 + ) + + expect_true(all(samples >= 0 & samples < D)) +}) diff --git a/tests/testthat/test-stan-rpd-primarycensoreddist.R b/tests/testthat/test-stan-rpd-primarycensoreddist.R index 82c120e..ebb9ca5 100644 --- a/tests/testthat/test-stan-rpd-primarycensoreddist.R +++ b/tests/testthat/test-stan-rpd-primarycensoreddist.R @@ -51,6 +51,55 @@ test_that( } ) +test_that( + "Stan primary_censored_dist_lpmf throws an error for invalid upper truncation + point", + { + d <- 10 + dist_id <- 1 # Lognormal + params <- c(0, 1) # meanlog, sdlog + pwindow <- 1 + swindow <- 1 + D <- 10 + primary_dist_id <- 1 # Uniform + primary_params <- numeric(0) + + expect_error( + primary_censored_dist_lpmf( + d, dist_id, params, pwindow, swindow, D, primary_dist_id, primary_params + ), + "Upper truncation point is greater than D" + ) + } +) + + + +test_that( + "Stan primary_censored_dist matches R primarycensoreddist when d is the same + as D - 1", + { + d <- 10 + dist_id <- 1 # Lognormal + params <- c(0, 1) # meanlog, sdlog + pwindow <- 1 + primary_dist_id <- 1 # Uniform + primary_params <- numeric(0) + + stan_pmf <- primary_censored_dist_pmf( + d, dist_id, params, pwindow, 1, d + 1, primary_dist_id, primary_params + ) + r_pmf <- dprimarycensoreddist( + d, plnorm, + pwindow = pwindow, swindow = 1, D = d + 1, + meanlog = params[1], sdlog = params[2] + ) + + expect_equal(stan_pmf, r_pmf, tolerance = 1e-6) + } +) + + test_that("Stan primary_censored_dist_pmf matches R dprimarycensoreddist", { d <- 0:10 dist_id <- 1 # Lognormal @@ -217,6 +266,8 @@ test_that( expect_equal(stan_lpmf_approx, r_lpmf, tolerance = 1e-6) expect_equal(stan_lpmf_exact, r_lpmf, tolerance = 1e-8) - expect_true(all(abs(stan_lpmf_exact - r_lpmf) < abs(stan_lpmf_approx - r_lpmf))) + expect_true( + all(abs(stan_lpmf_exact - r_lpmf) <= abs(stan_lpmf_approx - r_lpmf)) + ) } )