Skip to content

Commit

Permalink
stan pcens tests mostly complete
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Aug 30, 2024
1 parent c017676 commit e420fe3
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 91 deletions.
82 changes: 55 additions & 27 deletions R/pcd-stan-tools.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,61 @@
#'
#' @return A character string with the path to the Stan code
#' @export
#' @aliases pcd_stan_path
pcd_stan_path <- function() {
system.file("stan", package = "primarycensoreddist")
}

#' List available Stan functions
#' Extract function names or content from Stan code
#'
#' @param content Character vector containing Stan code
#'
#' @param extract_names Logical, if TRUE extract function names, otherwise
#' extract function content
#'
#' @param func_name Optional, function name to extract content for
#'
#' @return Character vector of function names or content
#' @keywords internal
.extract_stan_functions <- function(
content, extract_names = TRUE, func_name = NULL) {
func_pattern <- paste0(
"^(real|vector|matrix|void|int)\\s+",
"(\\w+)\\s*\\("
)
if (extract_names) {
func_lines <- grep(func_pattern, content, value = TRUE)
return(gsub(func_pattern, "\\2", func_lines))
} else {
start_line <- grep(paste0(func_pattern, ".*", func_name), content)
if (length(start_line) > 0) {
end_line <- which(
cumsum(grepl("^\\s*\\{", content[start_line:length(content)])) ==
cumsum(grepl("^\\s*\\}", content[start_line:length(content)]))
)[1] + start_line - 1
return(content[start_line:end_line])
}
return(character(0))
}
}

#' Get Stan function names from Stan files
#'
#' This function reads all Stan files in the specified directory and extracts
#' the names of all functions defined in those files.
#'
#' @param stan_path Character string specifying the path to the directory
#' containing Stan files. Defaults to the Stan path of the primarycensoreddist
#' package.
#'
#' @return A character vector containing unique names of all functions found in
#' the Stan files.
#'
#' @inheritParams pcd_load_stan_functions
#' @return A character vector of available Stan function names
#' @export
#' @aliases pcd_stan_functions
#' @examples
#' \dontrun{
#' stan_functions <- pcd_stan_functions()
#' print(stan_functions)
#' }
pcd_stan_functions <- function(
stan_path = primarycensoreddist::pcd_stan_path()) {
stan_files <- list.files(
Expand All @@ -23,13 +67,7 @@ pcd_stan_functions <- function(
functions <- character(0)
for (file in stan_files) {
content <- readLines(file)
func_lines <- grep(
"^(real|vector|matrix|void)\\s+\\w+\\s*\\(", content,
value = TRUE
)
functions <- c(
functions, gsub("^.*?\\s+(\\w+)\\s*\\(.*$", "\\1", func_lines)
)
functions <- c(functions, .extract_stan_functions(content))
}
unique(functions)
}
Expand All @@ -53,7 +91,6 @@ pcd_stan_functions <- function(
#'
#' @return A character string containing the requested Stan functions
#' @export
#' @aliases pcd_load_stan_functions
pcd_load_stan_functions <- function(
functions = NULL, stan_path = primarycensoreddist::pcd_stan_path(),
wrap_in_block = FALSE, write_to_file = FALSE,
Expand All @@ -71,20 +108,11 @@ pcd_load_stan_functions <- function(
all_content <- c(all_content, content)
} else {
for (func in functions) {
start_line <- grep(
paste0("^(real|vector|matrix|void)\\s+", func, "\\s*\\("), content
func_content <- .extract_stan_functions(
content,
extract_names = FALSE, func_name = func
)
if (length(start_line) > 0) {
end_line <- which(
cumsum(
grepl("^\\s*\\{", content[start_line:length(content)])
) ==
cumsum(
grepl("^\\s*\\}", content[start_line:length(content)])
)
)[1] + start_line - 1
all_content <- c(all_content, content[start_line:end_line])
}
all_content <- c(all_content, func_content)
}
}
}
Expand All @@ -104,7 +132,7 @@ pcd_load_stan_functions <- function(

if (write_to_file) {
writeLines(result, output_file)
message("Stan functions written to:", output_file, "\n")
message("Stan functions written to: ", output_file, "\n")
}

return(result)
Expand Down
69 changes: 48 additions & 21 deletions inst/stan/functions/primary_censored_dist.stan
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,9 @@ real primary_dist_lpdf(real x, int primary_dist_id, array[] real params, real mi
int primary_params_len = x_i[4];

// Extract distribution parameters
int n_params = size(theta) - 1;
array[dist_params_len] real params;
if (dist_params_len) {
params = theta[3:(2 + dist_params_len)];
params = theta[2:(1 + dist_params_len)];
}
array[primary_params_len] real primary_params;
if (primary_params_len) {
Expand Down Expand Up @@ -316,6 +315,8 @@ real primary_censored_dist_pmf(int d, int dist_id, array[] real params,
* @param pwindow Primary event window
* @param primary_dist_id Primary distribution identifier
* @param primary_params Primary distribution parameters
* @param approx_truncation Binary; if 1, use approximate truncation method
* and if 0, use exact truncation method.
*
* @return Vector of primary event censored log PMFs for delays \[0, 1\] to
* \[max_delay, max_delay + 1\].
Expand All @@ -326,6 +327,7 @@ real primary_censored_dist_pmf(int d, int dist_id, array[] real params,
* 2. Assumes integer delays (swindow = 1)
* 3. Is more computationally efficient for multiple delay calculation as it
* reduces the number of integration calls.
* 4. Allows for approximate or exact truncation handling
*
* @code
* // Example: Weibull delay distribution with uniform primary distribution
Expand All @@ -336,47 +338,68 @@ real primary_censored_dist_pmf(int d, int dist_id, array[] real params,
* real pwindow = 7.0;
* int primary_dist_id = 1; // Uniform
* array[0] real primary_params = {};
* int approx_truncation = 1; // Use approximate truncation
* vector[max_delay] log_pmf =
* primary_censored_sone_lpmf_vectorized(
* max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params
* max_delay, D, dist_id, params, pwindow, primary_dist_id,
* primary_params, approx_truncation
* );
*/
vector primary_censored_sone_lpmf_vectorized(
int max_delay, data real D, int dist_id,
array[] real params, data real pwindow,
int primary_dist_id, array[] real primary_params
int primary_dist_id, array[] real primary_params,
int approx_truncation
) {

vector[max_delay] log_pmfs;
real log_normalizer;
int upper_interval = max_delay + 1;
vector[max_delay] log_cdfs;
vector[upper_interval] log_pmfs;
vector[upper_interval] log_cdfs;
real log_normalizer;

// Check if D is at least max_delay + 1
if (D < upper_interval) {
reject("D must be at least max_delay + 1");
}

// Compute log CDFs
for (d in 1:upper_interval) {
log_cdfs[d] = primary_censored_dist_lcdf(
d | dist_id, params, pwindow, D, primary_dist_id, primary_params
);
if (approx_truncation) {
for (d in 1:upper_interval) {
log_cdfs[d] = primary_censored_dist_lcdf(
d | dist_id, params, pwindow, positive_infinity(), primary_dist_id,
primary_params
);
}
} else {
for (d in 1:upper_interval) {
log_cdfs[d] = primary_censored_dist_lcdf(
d | dist_id, params, pwindow, D, primary_dist_id, primary_params
);
}
}

// Compute log normalizer using upper_interval
if (D > upper_interval) {
log_normalizer = primary_censored_dist_lcdf(
upper_interval | dist_id, params, pwindow, D, primary_dist_id, primary_params
);
if (approx_truncation) {
if (D > upper_interval) {
if (is_inf(D)) {
log_normalizer = 0; // No normalization needed for infinite D
} else {
log_normalizer = primary_censored_dist_lcdf(
upper_interval | dist_id, params, pwindow, positive_infinity(),
primary_dist_id, primary_params
);
}
} else {
log_normalizer = log_cdfs[upper_interval];
}
} else {
log_normalizer = log_cdfs[upper_interval];
log_normalizer = 0; // No external normalization for exact truncation
}

// Compute log PMFs
log_pmfs[1] = log_cdfs[1] - log_normalizer;
for (d in 1:max_delay) {
log_pmfs[d] = log_diff_exp(log_cdfs[d+1], log_cdfs[d]) - log_normalizer;
for (d in 2:upper_interval) {
log_pmfs[d] = log_diff_exp(log_cdfs[d], log_cdfs[d-1]) - log_normalizer;
}

return log_pmfs;
Expand All @@ -392,6 +415,7 @@ vector primary_censored_sone_lpmf_vectorized(
* @param pwindow Primary event window
* @param primary_dist_id Primary distribution identifier
* @param primary_params Primary distribution parameters
* @param approx_truncation Logical; if TRUE, use approximate truncation method
*
* @return Vector of primary event censored PMFs for integer delays 1 to max_delay
*
Expand All @@ -400,6 +424,7 @@ vector primary_censored_sone_lpmf_vectorized(
* max_delay + 1\] in one call.
* 2. Assumes integer delays (swindow = 1)
* 3. Is more computationally efficient for multiple delay calculations
* 4. Allows for approximate or exact truncation handling
*
* @code
* // Example: Weibull delay distribution with uniform primary distribution
Expand All @@ -410,20 +435,22 @@ vector primary_censored_sone_lpmf_vectorized(
* real pwindow = 7.0;
* int primary_dist_id = 1; // Uniform
* array[0] real primary_params = {};
* int approx_truncation = 1; // Use approximate truncation
* vector[max_delay] pmf =
* primary_censored_sone_pmf_vectorized(
* max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params
* max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, approx_truncation
* );
*/
vector primary_censored_sone_pmf_vectorized(
int max_delay, data real D, int dist_id,
array[] real params, data real pwindow,
int primary_dist_id,
array[] real primary_params
array[] real primary_params,
int approx_truncation
) {
return exp(
primary_censored_sone_lpmf_vectorized(
max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params
max_delay, D, dist_id, params, pwindow, primary_dist_id, primary_params, approx_truncation
)
);
}
7 changes: 4 additions & 3 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ options(datatable.print.keys = FALSE)
if (
!on_ci() || (on_ci() && Sys.info()["sysname"] == "Linux" && not_on_cran())
) {
library(cmdstanr)
library(cmdstanr) # nolint
stan_functions <- pcd_load_stan_functions(
stan_path = "inst/stan/functions",
wrap_in_block = TRUE,
write_to_file = TRUE,
output_file = file.path(tempdir(), "pcd_stan_functions.stan")
output_file = file.path("pcd_stan_functions.stan")
)
model <- suppressMessages(suppressWarnings(cmdstanr::cmdstan_model(
file.path(tempdir(), "pcd_stan_functions.stan")
file.path("pcd_stan_functions.stan")
)))
model$expose_functions(global = TRUE)
}
Loading

0 comments on commit e420fe3

Please sign in to comment.