diff --git a/R/pcd-stan-tools.R b/R/pcd-stan-tools.R index 1928ff7..ceccddb 100644 --- a/R/pcd-stan-tools.R +++ b/R/pcd-stan-tools.R @@ -2,17 +2,61 @@ #' #' @return A character string with the path to the Stan code #' @export -#' @aliases pcd_stan_path pcd_stan_path <- function() { system.file("stan", package = "primarycensoreddist") } -#' List available Stan functions +#' Extract function names or content from Stan code +#' +#' @param content Character vector containing Stan code +#' +#' @param extract_names Logical, if TRUE extract function names, otherwise +#' extract function content +#' +#' @param func_name Optional, function name to extract content for +#' +#' @return Character vector of function names or content +#' @keywords internal +.extract_stan_functions <- function( + content, extract_names = TRUE, func_name = NULL) { + func_pattern <- paste0( + "^(real|vector|matrix|void|int)\\s+", + "(\\w+)\\s*\\(" + ) + if (extract_names) { + func_lines <- grep(func_pattern, content, value = TRUE) + return(gsub(func_pattern, "\\2", func_lines)) + } else { + start_line <- grep(paste0(func_pattern, ".*", func_name), content) + if (length(start_line) > 0) { + end_line <- which( + cumsum(grepl("^\\s*\\{", content[start_line:length(content)])) == + cumsum(grepl("^\\s*\\}", content[start_line:length(content)])) + )[1] + start_line - 1 + return(content[start_line:end_line]) + } + return(character(0)) + } +} + +#' Get Stan function names from Stan files +#' +#' This function reads all Stan files in the specified directory and extracts +#' the names of all functions defined in those files. +#' +#' @param stan_path Character string specifying the path to the directory +#' containing Stan files. Defaults to the Stan path of the primarycensoreddist +#' package. +#' +#' @return A character vector containing unique names of all functions found in +#' the Stan files. #' -#' @inheritParams pcd_load_stan_functions -#' @return A character vector of available Stan function names #' @export -#' @aliases pcd_stan_functions +#' @examples +#' \dontrun{ +#' stan_functions <- pcd_stan_functions() +#' print(stan_functions) +#' } pcd_stan_functions <- function( stan_path = primarycensoreddist::pcd_stan_path()) { stan_files <- list.files( @@ -23,13 +67,7 @@ pcd_stan_functions <- function( functions <- character(0) for (file in stan_files) { content <- readLines(file) - func_lines <- grep( - "^(real|vector|matrix|void)\\s+\\w+\\s*\\(", content, - value = TRUE - ) - functions <- c( - functions, gsub("^.*?\\s+(\\w+)\\s*\\(.*$", "\\1", func_lines) - ) + functions <- c(functions, .extract_stan_functions(content)) } unique(functions) } @@ -53,7 +91,6 @@ pcd_stan_functions <- function( #' #' @return A character string containing the requested Stan functions #' @export -#' @aliases pcd_load_stan_functions pcd_load_stan_functions <- function( functions = NULL, stan_path = primarycensoreddist::pcd_stan_path(), wrap_in_block = FALSE, write_to_file = FALSE, @@ -71,20 +108,11 @@ pcd_load_stan_functions <- function( all_content <- c(all_content, content) } else { for (func in functions) { - start_line <- grep( - paste0("^(real|vector|matrix|void)\\s+", func, "\\s*\\("), content + func_content <- .extract_stan_functions( + content, + extract_names = FALSE, func_name = func ) - if (length(start_line) > 0) { - end_line <- which( - cumsum( - grepl("^\\s*\\{", content[start_line:length(content)]) - ) == - cumsum( - grepl("^\\s*\\}", content[start_line:length(content)]) - ) - )[1] + start_line - 1 - all_content <- c(all_content, content[start_line:end_line]) - } + all_content <- c(all_content, func_content) } } } @@ -104,7 +132,7 @@ pcd_load_stan_functions <- function( if (write_to_file) { writeLines(result, output_file) - message("Stan functions written to:", output_file, "\n") + message("Stan functions written to: ", output_file, "\n") } return(result) diff --git a/inst/stan/functions/primary_censored_dist.stan b/inst/stan/functions/primary_censored_dist.stan index 9396166..90ee158 100644 --- a/inst/stan/functions/primary_censored_dist.stan +++ b/inst/stan/functions/primary_censored_dist.stan @@ -111,10 +111,9 @@ real primary_dist_lpdf(real x, int primary_dist_id, array[] real params, real mi int primary_params_len = x_i[4]; // Extract distribution parameters - int n_params = size(theta) - 1; array[dist_params_len] real params; if (dist_params_len) { - params = theta[3:(2 + dist_params_len)]; + params = theta[2:(1 + dist_params_len)]; } array[primary_params_len] real primary_params; if (primary_params_len) { @@ -316,6 +315,8 @@ real primary_censored_dist_pmf(int d, int dist_id, array[] real params, * @param pwindow Primary event window * @param primary_dist_id Primary distribution identifier * @param primary_params Primary distribution parameters + * @param approx_truncation Binary; if 1, use approximate truncation method + * and if 0, use exact truncation method. * * @return Vector of primary event censored log PMFs for delays \[0, 1\] to * \[max_delay, max_delay + 1\]. @@ -326,6 +327,7 @@ real primary_censored_dist_pmf(int d, int dist_id, array[] real params, * 2. Assumes integer delays (swindow = 1) * 3. Is more computationally efficient for multiple delay calculation as it * reduces the number of integration calls. + * 4. Allows for approximate or exact truncation handling * * @code * // Example: Weibull delay distribution with uniform primary distribution @@ -336,21 +338,24 @@ real primary_censored_dist_pmf(int d, int dist_id, array[] real params, * real pwindow = 7.0; * int primary_dist_id = 1; // Uniform * array[0] real primary_params = {}; + * int approx_truncation = 1; // Use approximate truncation * vector[max_delay] log_pmf = * primary_censored_sone_lpmf_vectorized( - * max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params + * max_delay, D, dist_id, params, pwindow, primary_dist_id, + * primary_params, approx_truncation * ); */ vector primary_censored_sone_lpmf_vectorized( int max_delay, data real D, int dist_id, array[] real params, data real pwindow, - int primary_dist_id, array[] real primary_params + int primary_dist_id, array[] real primary_params, + int approx_truncation ) { - vector[max_delay] log_pmfs; - real log_normalizer; int upper_interval = max_delay + 1; - vector[max_delay] log_cdfs; + vector[upper_interval] log_pmfs; + vector[upper_interval] log_cdfs; + real log_normalizer; // Check if D is at least max_delay + 1 if (D < upper_interval) { @@ -358,25 +363,43 @@ vector primary_censored_sone_lpmf_vectorized( } // Compute log CDFs - for (d in 1:upper_interval) { - log_cdfs[d] = primary_censored_dist_lcdf( - d | dist_id, params, pwindow, D, primary_dist_id, primary_params - ); + if (approx_truncation) { + for (d in 1:upper_interval) { + log_cdfs[d] = primary_censored_dist_lcdf( + d | dist_id, params, pwindow, positive_infinity(), primary_dist_id, + primary_params + ); + } + } else { + for (d in 1:upper_interval) { + log_cdfs[d] = primary_censored_dist_lcdf( + d | dist_id, params, pwindow, D, primary_dist_id, primary_params + ); + } } // Compute log normalizer using upper_interval - if (D > upper_interval) { - log_normalizer = primary_censored_dist_lcdf( - upper_interval | dist_id, params, pwindow, D, primary_dist_id, primary_params - ); + if (approx_truncation) { + if (D > upper_interval) { + if (is_inf(D)) { + log_normalizer = 0; // No normalization needed for infinite D + } else { + log_normalizer = primary_censored_dist_lcdf( + upper_interval | dist_id, params, pwindow, positive_infinity(), + primary_dist_id, primary_params + ); + } + } else { + log_normalizer = log_cdfs[upper_interval]; + } } else { - log_normalizer = log_cdfs[upper_interval]; + log_normalizer = 0; // No external normalization for exact truncation } // Compute log PMFs log_pmfs[1] = log_cdfs[1] - log_normalizer; - for (d in 1:max_delay) { - log_pmfs[d] = log_diff_exp(log_cdfs[d+1], log_cdfs[d]) - log_normalizer; + for (d in 2:upper_interval) { + log_pmfs[d] = log_diff_exp(log_cdfs[d], log_cdfs[d-1]) - log_normalizer; } return log_pmfs; @@ -392,6 +415,7 @@ vector primary_censored_sone_lpmf_vectorized( * @param pwindow Primary event window * @param primary_dist_id Primary distribution identifier * @param primary_params Primary distribution parameters + * @param approx_truncation Logical; if TRUE, use approximate truncation method * * @return Vector of primary event censored PMFs for integer delays 1 to max_delay * @@ -400,6 +424,7 @@ vector primary_censored_sone_lpmf_vectorized( * max_delay + 1\] in one call. * 2. Assumes integer delays (swindow = 1) * 3. Is more computationally efficient for multiple delay calculations + * 4. Allows for approximate or exact truncation handling * * @code * // Example: Weibull delay distribution with uniform primary distribution @@ -410,20 +435,22 @@ vector primary_censored_sone_lpmf_vectorized( * real pwindow = 7.0; * int primary_dist_id = 1; // Uniform * array[0] real primary_params = {}; + * int approx_truncation = 1; // Use approximate truncation * vector[max_delay] pmf = * primary_censored_sone_pmf_vectorized( - * max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params + * max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, approx_truncation * ); */ vector primary_censored_sone_pmf_vectorized( int max_delay, data real D, int dist_id, array[] real params, data real pwindow, int primary_dist_id, - array[] real primary_params + array[] real primary_params, + int approx_truncation ) { return exp( primary_censored_sone_lpmf_vectorized( - max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params + max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, approx_truncation ) ); } diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 38e4573..46823e7 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -5,14 +5,15 @@ options(datatable.print.keys = FALSE) if ( !on_ci() || (on_ci() && Sys.info()["sysname"] == "Linux" && not_on_cran()) ) { - library(cmdstanr) + library(cmdstanr) # nolint stan_functions <- pcd_load_stan_functions( + stan_path = "inst/stan/functions", wrap_in_block = TRUE, write_to_file = TRUE, - output_file = file.path(tempdir(), "pcd_stan_functions.stan") + output_file = file.path("pcd_stan_functions.stan") ) model <- suppressMessages(suppressWarnings(cmdstanr::cmdstan_model( - file.path(tempdir(), "pcd_stan_functions.stan") + file.path("pcd_stan_functions.stan") ))) model$expose_functions(global = TRUE) } diff --git a/tests/testthat/test-stan-rpd-primarycensoreddist.R b/tests/testthat/test-stan-rpd-primarycensoreddist.R index aaf7fb1..82c120e 100644 --- a/tests/testthat/test-stan-rpd-primarycensoreddist.R +++ b/tests/testthat/test-stan-rpd-primarycensoreddist.R @@ -9,7 +9,7 @@ test_that("Stan primary_censored_dist_cdf matches R pprimarycensoreddist", { pwindow <- 1 D <- Inf primary_dist_id <- 1 # Uniform - primary_params <- numeric(0) + primary_params <- array(numeric(0)) stan_cdf <- sapply( d, primary_censored_dist_cdf, dist_id, params, pwindow, D, @@ -61,54 +61,162 @@ test_that("Stan primary_censored_dist_pmf matches R dprimarycensoreddist", { primary_dist_id <- 1 # Uniform primary_params <- numeric(0) - stan_pmf <- sapply(d, primary_censored_dist_pmf, dist_id, params, pwindow, swindow, D, primary_dist_id, primary_params) - r_pmf <- dprimarycensoreddist(d, plnorm, pwindow = pwindow, swindow = swindow, D = D, meanlog = params[1], sdlog = params[2]) + stan_pmf <- sapply( + d, primary_censored_dist_pmf, dist_id, params, pwindow, swindow, D, + primary_dist_id, primary_params + ) + r_pmf <- dprimarycensoreddist( + d, plnorm, + pwindow = pwindow, swindow = swindow, D = D, + meanlog = params[1], sdlog = params[2] + ) expect_equal(stan_pmf, r_pmf, tolerance = 1e-6) }) -test_that("Stan primary_censored_dist_lpmf matches R dprimarycensoreddist with log = TRUE", { - d <- 0:10 - dist_id <- 1 # Lognormal - params <- c(0, 1) # meanlog, sdlog - pwindow <- 1 - swindow <- 1 - D <- Inf - primary_dist_id <- 1 # Uniform - primary_params <- numeric(0) +test_that( + "Stan primary_censored_dist_lpmf matches R dprimarycensoreddist with + log = TRUE", + { + d <- 0:10 + dist_id <- 1 # Lognormal + params <- c(0, 1) # meanlog, sdlog + pwindow <- 1 + swindow <- 1 + D <- Inf + primary_dist_id <- 1 # Uniform + primary_params <- numeric(0) - stan_lpmf <- sapply(d, primary_censored_dist_lpmf, dist_id, params, pwindow, swindow, D, primary_dist_id, primary_params) - r_lpmf <- dprimarycensoreddist(d, plnorm, pwindow = pwindow, swindow = swindow, D = D, meanlog = params[1], sdlog = params[2], log = TRUE) + stan_lpmf <- sapply( + d, primary_censored_dist_lpmf, dist_id, params, pwindow, swindow, D, + primary_dist_id, primary_params + ) + r_lpmf <- log( + dprimarycensoreddist( + d, plnorm, + pwindow = pwindow, swindow = swindow, D = D, + meanlog = params[1], sdlog = params[2] + ) + ) - expect_equal(stan_lpmf, r_lpmf, tolerance = 1e-6) -}) + expect_equal(stan_lpmf, r_lpmf, tolerance = 1e-6) + } +) -test_that("Stan primary_censored_sone_pmf_vectorized matches R dprimarycensoreddist", { - max_delay <- 10 - dist_id <- 1 # Lognormal - params <- c(0, 1) # meanlog, sdlog - pwindow <- 1 - D <- Inf - primary_dist_id <- 1 # Uniform - primary_params <- numeric(0) +test_that( + "Stan primary_censored_sone_pmf_vectorized matches R dprimarycensoreddist", + { + max_delay <- 10 + dist_id <- 1 # Lognormal + params <- c(0, 1) # meanlog, sdlog + pwindow <- 1 + D <- Inf + primary_dist_id <- 1 # Uniform + primary_params <- numeric(0) - stan_pmf <- primary_censored_sone_pmf_vectorized(max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params) - r_pmf <- dprimarycensoreddist(1:max_delay, plnorm, pwindow = pwindow, swindow = 1, D = D, meanlog = params[1], sdlog = params[2]) + stan_pmf_approx <- primary_censored_sone_pmf_vectorized( + max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, 1 + ) + stan_pmf_exact <- primary_censored_sone_pmf_vectorized( + max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, 0 + ) + r_pmf <- dprimarycensoreddist( + 0:max_delay, plnorm, + pwindow = pwindow, swindow = 1, D = D, + meanlog = params[1], sdlog = params[2] + ) - expect_equal(stan_pmf, r_pmf, tolerance = 1e-6) -}) + expect_equal(stan_pmf_approx, r_pmf, tolerance = 1e-6) + expect_equal(stan_pmf_exact, r_pmf, tolerance = 1e-6) + } +) -test_that("Stan primary_censored_sone_lpmf_vectorized matches R dprimarycensoreddist with log = TRUE", { - max_delay <- 10 - dist_id <- 1 # Lognormal - params <- c(0, 1) # meanlog, sdlog - pwindow <- 1 - D <- Inf - primary_dist_id <- 1 # Uniform - primary_params <- numeric(0) +test_that( + "Stan primary_censored_sone_pmf_vectorized matches R dprimarycensoreddist + with finite D", + { + max_delay <- 10 + dist_id <- 1 # Lognormal + params <- c(0, 1) # meanlog, sdlog + pwindow <- 1 + D <- 15 + primary_dist_id <- 1 # Uniform + primary_params <- numeric(0) - stan_lpmf <- primary_censored_sone_lpmf_vectorized(max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params) - r_lpmf <- dprimarycensoreddist(1:max_delay, plnorm, pwindow = pwindow, swindow = 1, D = D, meanlog = params[1], sdlog = params[2], log = TRUE) + stan_pmf_approx <- primary_censored_sone_pmf_vectorized( + max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, 1 + ) + stan_pmf_exact <- primary_censored_sone_pmf_vectorized( + max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, 0 + ) + r_pmf <- dprimarycensoreddist( + 0:max_delay, plnorm, + pwindow = pwindow, swindow = 1, D = D, + meanlog = params[1], sdlog = params[2] + ) - expect_equal(stan_lpmf, r_lpmf, tolerance = 1e-6) -}) + expect_equal(stan_pmf_approx, r_pmf, tolerance = 1e-2) + expect_equal(stan_pmf_exact, r_pmf, tolerance = 1e-6) + expect_true(all(abs(stan_pmf_exact - r_pmf) < abs(stan_pmf_approx - r_pmf))) + } +) + +test_that( + "Stan primary_censored_sone_pmf_vectorized matches R dprimarycensoreddist + with D equal to max_delay + 1", + { + max_delay <- 10 + dist_id <- 1 # Lognormal + params <- c(0, 1) # meanlog, sdlog + pwindow <- 1 + D <- max_delay + 1 + primary_dist_id <- 1 # Uniform + primary_params <- numeric(0) + + stan_pmf_approx <- primary_censored_sone_pmf_vectorized( + max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, 1 + ) + stan_pmf_exact <- primary_censored_sone_pmf_vectorized( + max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, 0 + ) + r_pmf <- dprimarycensoreddist( + 0:max_delay, plnorm, + pwindow = pwindow, swindow = 1, D = D, + meanlog = params[1], sdlog = params[2] + ) + + expect_equal(stan_pmf_approx, r_pmf, tolerance = 1e-3) + expect_equal(stan_pmf_exact, r_pmf, tolerance = 1e-6) + expect_true(all(abs(stan_pmf_exact - r_pmf) < abs(stan_pmf_approx - r_pmf))) + } +) + +test_that( + "Stan primary_censored_sone_lpmf_vectorized matches R dprimarycensoreddist + with log = TRUE", + { + max_delay <- 10 + dist_id <- 1 # Lognormal + params <- c(0, 1) # meanlog, sdlog + pwindow <- 1 + D <- Inf + primary_dist_id <- 1 # Uniform + primary_params <- numeric(0) + + stan_lpmf_approx <- primary_censored_sone_lpmf_vectorized( + max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, 1 + ) + stan_lpmf_exact <- primary_censored_sone_lpmf_vectorized( + max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, 0 + ) + r_lpmf <- dprimarycensoreddist( + 0:max_delay, plnorm, + pwindow = pwindow, swindow = 1, D = D, + meanlog = params[1], sdlog = params[2], log = TRUE + ) + + expect_equal(stan_lpmf_approx, r_lpmf, tolerance = 1e-6) + expect_equal(stan_lpmf_exact, r_lpmf, tolerance = 1e-8) + expect_true(all(abs(stan_lpmf_exact - r_lpmf) < abs(stan_lpmf_approx - r_lpmf))) + } +)