Skip to content

Commit

Permalink
work on stan code
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Aug 28, 2024
1 parent 3ec6d57 commit c614596
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 18 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ BugReports: https://github.com/epinowcast/primarycensoreddist/issues/
Depends:
R (>= 4.0.0)
Suggests:
cmdstanr,
knitr,
spelling,
testthat (>= 3.1.9),
usethis
Additional_repositories:
https://stan-dev.r-universe.dev
VignetteBuilder:
knitr
Config/Needs/hexsticker: hexSticker, sysfonts
Config/Needs/website: r-lib/pkgdown, epinowcast/enwtheme
Config/Needs/cmdstanr: stan-dev/cmdstanr
Config/testthat/edition: 3
Encoding: UTF-8
Language: en-GB
Expand Down
2 changes: 1 addition & 1 deletion R/pcd-stan-tools.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pcd_load_stan_functions <- function(
wrap_in_block = FALSE, write_to_file = FALSE,
output_file = "pcd_stan_functions.stan") {
stan_files <- list.files(
stan_paths,
stan_path,
pattern = "\\.stan$", full.names = TRUE,
recursive = TRUE
)
Expand Down
10 changes: 5 additions & 5 deletions inst/stan/functions/expgrowth.stan
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ real expgrowth_pdf(real x, real min, real max, real r) {
if (x < min || x > max) {
return 0;
}
if (fabs(r) < 1e-10) {
if (abs(r) < 1e-10) {
return 1 / (max - min);
}
return r * exp(r * (x - min)) / (exp(r * max) - exp(r * min));
Expand All @@ -30,7 +30,7 @@ real expgrowth_lpdf(real x, real min, real max, real r) {
if (x < min || x > max) {
return negative_infinity();
}
if (fabs(r) < 1e-10) {
if (abs(r) < 1e-10) {
return -log(max - min);
}
return log(r) + r * (x - min) - log(exp(r * max) - exp(r * min));
Expand All @@ -52,7 +52,7 @@ real expgrowth_cdf(real x, real min, real max, real r) {
if (x > max) {
return 1;
}
if (fabs(r) < 1e-10) {
if (abs(r) < 1e-10) {
return (x - min) / (max - min);
}
return (exp(r * (x - min)) - exp(r * min)) / (exp(r * max) - exp(r * min));
Expand All @@ -74,7 +74,7 @@ real expgrowth_lcdf(real x, real min, real max, real r) {
if (x > max) {
return 0;
}
return log(expgrowth_cdf(x, min, max, r));
return log(expgrowth_cdf(x | min, max, r));
}

/**
Expand All @@ -87,7 +87,7 @@ real expgrowth_lcdf(real x, real min, real max, real r) {
*/
real expgrowth_rng(real min, real max, real r) {
real u = uniform_rng(0, 1);
if (fabs(r) < 1e-10) {
if (abs(r) < 1e-10) {
return min + u * (max - min);
}
return min + log(u * (exp(r * max) - exp(r * min)) + exp(r * min)) / r;
Expand Down
16 changes: 8 additions & 8 deletions inst/stan/functions/primary_censored_dist.stan
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,17 @@ real primary_dist_lpdf(real x, int primary_dist_id, array[] real params, real mi
* array[2] int x_i = {1, 1}; // dist_id = 1 (Lognormal), primary_dist_id = 1 (Uniform)
* real integrand_value = primary_censored_integrand(p, xc, theta, x_r, x_i);
*/
real primary_censored_integrand(real p, real xc, array[] real theta, array[] real x_r,
array[] int x_i) {
real primary_censored_integrand(real p, real xc, array[] real theta,
data array[] real x_r, data array[] int x_i) {
real d = xc;
real pwindow = x_r[1];
int dist_id = x_i[1];
int primary_dist_id = x_i[2];
real d_adj = d - p;

real log_cdf = dist_lcdf(d_adj, theta, dist_id);
real log_cdf = dist_lcdf(d_adj | theta, dist_id);
real log_primary_pdf = primary_dist_lpdf(
p, primary_dist_id, theta, 0, pwindow
p | primary_dist_id, theta, 0, pwindow
);

return exp(log_cdf + log_primary_pdf);
Expand All @@ -132,7 +132,7 @@ real primary_censored_integrand(real p, real xc, array[] real theta, array[] rea
* real integrand_value = primary_censored_integrand_truncated(p, xc, theta, x_r, x_i);
*/
real primary_censored_integrand_truncated(real p, real xc, array[] real theta,
array[] real x_r, array[] int x_i) {
data array[] real x_r, data array[] int x_i) {
real d = xc;
real pwindow = x_r[1];
int dist_id = x_i[1];
Expand All @@ -141,10 +141,10 @@ real primary_censored_integrand_truncated(real p, real xc, array[] real theta,
real D = theta[size(theta)];
real D_adj = D - p;

real log_cdf = dist_lcdf(d_adj, theta[1:(size(theta)-1)], dist_id);
real log_cdf_D = dist_lcdf(D_adj, theta[1:(size(theta)-1)], dist_id);
real log_cdf = dist_lcdf(d_adj | theta[1:(size(theta)-1)], dist_id);
real log_cdf_D = dist_lcdf(D_adj | theta[1:(size(theta)-1)], dist_id);
real log_primary_pdf = primary_dist_lpdf(
p, primary_dist_id, theta[1:(size(theta)-1)], 0, pwindow
p | primary_dist_id, theta[1:(size(theta)-1)], 0, pwindow
);

return exp(log_cdf - log_cdf_D + log_primary_pdf);
Expand Down
13 changes: 10 additions & 3 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@ options(datatable.print.class = FALSE)
options(datatable.print.keys = FALSE)

if (on_ci() && Sys.info()["sysname"] == "Linux" && not_on_cran()) {
# we only expose stan functions on linux CI
# because we only test these functions on linux
suppressMessages(suppressWarnings())
library(cmdstanr)
stan_functions <- pcd_load_stan_functions(
stan_path = file.path("inst", "stan"),
wrap_in_block = TRUE,
write_to_file = TRUE,
output_file = file.path(tempdir(), "pcd_stan_functions.stan")
)
model <- suppressMessages(suppressWarnings(cmdstanr::cmdstan_model(
file.path(tempdir(), "pcd_stan_functions.stan")
)))
}

0 comments on commit c614596

Please sign in to comment.