From 636d79d392a0338350a065bb64459a7f30e20813 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 30 Apr 2024 02:57:30 -0500 Subject: [PATCH] Migrate core BCF computation to C++ --- R/bcf.R | 490 +++++++++++++++++------------------------ R/cpp11.R | 52 +++++ src/Makevars | 1 + src/bcf.cpp | 303 +++++++++++++++++++++++++ src/cpp11.cpp | 180 ++++++++++++--- src/sampler.cpp | 40 ++-- src/stochtree-cpp | 2 +- src/stochtree_types.h | 8 +- tools/debug/bcf_demo.R | 44 ++++ 9 files changed, 771 insertions(+), 349 deletions(-) create mode 100644 src/bcf.cpp create mode 100644 tools/debug/bcf_demo.R diff --git a/R/bcf.R b/R/bcf.R index ee55268..80a5028 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -96,6 +96,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, X_test = NULL, Z_tes num_burnin = 0, num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = T, propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5, random_seed = -1) { + # TODO: Add optional vector of case / variance weights + # Convert all input data to matrices if not already converted if ((is.null(dim(X_train))) && (!is.null(X_train))) { X_train <- as.matrix(X_train) @@ -142,6 +144,10 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, X_test = NULL, Z_tes # Determine whether a test set is provided has_test = !is.null(X_test) + # Data dimensions + n_train = nrow(X_train) + if (has_test) n_test = nrow(X_test) + # Convert y_train to numeric vector if not already converted if (!is.null(dim(y_train))) { y_train <- as.matrix(y_train) @@ -155,6 +161,14 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, X_test = NULL, Z_tes adaptive_coding <- F } + # Check whether treatment is univariate + univariate_treatment <- ncol(Z_train) == 1 + + # Adaptive coding will be ignored for multivariate treatments + if ((!univariate_treatment) && (adaptive_coding)) { + adaptive_coding <- F + } + # Estimate if pre-estimated propensity score is not provided if ((is.null(pi_train)) && (propensity_covariate != "none")) { # Estimate using xgboost with some elementary hyperparameter tuning @@ -262,6 +276,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, X_test = NULL, Z_tes if (is.null(b_leaf_mu)) b_leaf_mu <- var(resid_train)/(num_trees_mu) if (is.null(b_leaf_tau)) b_leaf_tau <- var(resid_train)/(2*num_trees_tau) if (is.null(sigma_leaf_mu)) sigma_leaf_mu <- var(resid_train)/(num_trees_mu) + # TODO handle this in the case of multivariate Z if (is.null(sigma_leaf_tau)) sigma_leaf_tau <- var(resid_train)/(2*num_trees_tau) current_sigma2 <- sigma2 current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) @@ -269,206 +284,101 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, X_test = NULL, Z_tes # Container of variance parameter samples num_samples <- num_gfr + num_burnin + num_mcmc - if (sample_sigma_global) global_var_samples <- rep(0, num_samples) - if (sample_sigma_leaf_mu) leaf_scale_mu_samples <- rep(0, num_samples) - if (sample_sigma_leaf_tau) leaf_scale_tau_samples <- rep(0, num_samples) + if (sample_sigma_global) global_var_samples <- rep(0., num_samples) + if (sample_sigma_leaf_mu) leaf_scale_mu_samples <- rep(0., num_samples) + if (sample_sigma_leaf_tau) leaf_scale_tau_samples <- rep(0., num_samples) + + # Container of adaptive coding samples + if (adaptive_coding) { + b_0_samples <- rep(0., num_samples) + b_1_samples <- rep(0., num_samples) + } + + # Container of prediction samples + mu_hat_train = matrix(0., nrow = n_train, ncol = num_samples) + if (!univariate_treatment) tau_hat_train = array(0., dim = c(n_train, ncol(Z_train), num_samples)) + else tau_hat_train = matrix(0., nrow = n_train, ncol = num_samples) + y_hat_train = matrix(0., nrow = n_train, ncol = num_samples) + if (has_test) { + mu_hat_test = matrix(0., nrow = n_test, ncol = num_samples) + if (!univariate_treatment) tau_hat_test = array(0., dim = c(n_test, ncol(Z_test), num_samples)) + else tau_hat_test = matrix(0., nrow = n_test, ncol = num_samples) + y_hat_test = matrix(0., nrow = n_test, ncol = num_samples) + } # Prepare adaptive coding structure if ((!is.numeric(b_0)) || (!is.numeric(b_1)) || (length(b_0) > 1) || (length(b_1) > 1)) { stop("b_0 and b_1 must be single numeric values") } - if (adaptive_coding) { - b_0_samples <- rep(0, num_samples) - b_1_samples <- rep(0, num_samples) - current_b_0 <- b_0 - current_b_1 <- b_1 - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 - if (has_test) tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 - } else { - tau_basis_train <- Z_train - if (has_test) tau_basis_test <- Z_test - } - - # Data - forest_dataset_mu_train <- createForestDataset(X_train_mu) - forest_dataset_tau_train <- createForestDataset(X_train_tau, tau_basis_train) - if (has_test) forest_dataset_mu_test <- createForestDataset(X_test_mu) - if (has_test) forest_dataset_tau_test <- createForestDataset(X_test_tau, tau_basis_test) - outcome_train <- createOutcome(resid_train) # Random number generator (std::mt19937) if (is.null(random_seed)) random_seed = sample(1:10000,1,F) rng <- createRNG(random_seed) - # Sampling data structures - forest_model_mu <- createForestModel(forest_dataset_mu_train, feature_types_mu, num_trees_mu, nrow(X_train_mu), alpha_mu, beta_mu, min_samples_leaf_mu) - forest_model_tau <- createForestModel(forest_dataset_tau_train, feature_types_tau, num_trees_tau, nrow(X_train_tau), alpha_tau, beta_tau, min_samples_leaf_tau) - # Container of forest samples forest_samples_mu <- createForestContainer(num_trees_mu, 1, T) - forest_samples_tau <- createForestContainer(num_trees_tau, 1, F) - - # Initialize the leaves of each tree in the prognostic forest - forest_samples_mu$set_root_leaves(0, mean(resid_train) / num_trees_mu) - update_residual_forest_container_cpp(forest_dataset_mu_train$data_ptr, outcome_train$data_ptr, - forest_samples_mu$forest_container_ptr, forest_model_mu$tracker_ptr, - F, 0, F) + forest_samples_tau <- createForestContainer(num_trees_tau, ncol(Z_train), F) - # Initialize the leaves of each tree in the treatment effect forest - forest_samples_tau$set_root_leaves(0, 0.) - update_residual_forest_container_cpp(forest_dataset_tau_train$data_ptr, outcome_train$data_ptr, - forest_samples_tau$forest_container_ptr, forest_model_tau$tracker_ptr, - T, 0, F) - - # Run GFR (warm start) if specified - if (num_gfr > 0){ - for (i in 1:num_gfr) { - # Sample the prognostic forest - forest_model_mu$sample_one_iteration( - forest_dataset_mu_train, outcome_train, forest_samples_mu, rng, feature_types_mu, - 0, current_leaf_scale_mu, variable_weights_mu, - current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T - ) - - # Sample variance parameters (if requested) - if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda) - current_sigma2 <- global_var_samples[i] - } - if (sample_sigma_leaf_mu) { - leaf_scale_mu_samples[i] <- sample_tau_one_iteration(forest_samples_mu, rng, a_leaf_mu, b_leaf_mu, i-1) - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_samples[i]) - } - - # Sample the treatment forest - forest_model_tau$sample_one_iteration( - forest_dataset_tau_train, outcome_train, forest_samples_tau, rng, feature_types_tau, - 1, current_leaf_scale_tau, variable_weights_tau, - current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T - ) - - # Sample coding parameters (if requested) - if (adaptive_coding) { - mu_x_raw_train <- forest_samples_mu$predict_raw_single_forest(forest_dataset_mu_train, i-1) - tau_x_raw_train <- forest_samples_tau$predict_raw_single_forest(forest_dataset_tau_train, i-1) - s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0)) - s_tt1 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==1)) - partial_resid_mu_train <- resid_train - mu_x_raw_train - s_ty0 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==0)) - s_ty1 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==1)) - current_b_0 <- rnorm(1, (s_ty0/(s_tt0 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt0 + 2*current_sigma2))) - current_b_1 <- rnorm(1, (s_ty1/(s_tt1 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt1 + 2*current_sigma2))) - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 - forest_dataset_tau_train$update_basis(tau_basis_train) - b_0_samples[i] <- current_b_0 - b_1_samples[i] <- current_b_1 - if (has_test) { - tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 - forest_dataset_tau_test$update_basis(tau_basis_test) - } - } - - # Sample variance parameters (if requested) - if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda) - current_sigma2 <- global_var_samples[i] - } - if (sample_sigma_leaf_tau) { - leaf_scale_tau_samples[i] <- sample_tau_one_iteration(forest_samples_tau, rng, a_leaf_tau, b_leaf_tau, i-1) - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_samples[i]) - } - } - } + # Initialize the leaves of each tree in the prognostic / treatment forests + init_mu_leaf <- mean(resid_train) / num_trees_mu + init_tau_leaf <- mean(resid_train) / num_trees_tau + forest_samples_mu$set_root_leaves(0, init_mu_leaf) + forest_samples_tau$set_root_leaves(0, init_tau_leaf) - # Run MCMC - if (num_burnin + num_mcmc > 0) { - for (i in (num_gfr+1):num_samples) { - # Sample the prognostic forest - forest_model_mu$sample_one_iteration( - forest_dataset_mu_train, outcome_train, forest_samples_mu, rng, feature_types_mu, - 0, current_leaf_scale_mu, variable_weights_mu, - current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T - ) - - # Sample variance parameters (if requested) - if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda) - current_sigma2 <- global_var_samples[i] - } - if (sample_sigma_leaf_mu) { - leaf_scale_mu_samples[i] <- sample_tau_one_iteration(forest_samples_mu, rng, a_leaf_mu, b_leaf_mu, i-1) - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_samples[i]) - } - - # Sample the treatment forest - forest_model_tau$sample_one_iteration( - forest_dataset_tau_train, outcome_train, forest_samples_tau, rng, feature_types_tau, - 1, current_leaf_scale_tau, variable_weights_tau, - current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T - ) - - # Sample coding parameters (if requested) - if (adaptive_coding) { - mu_x_raw_train <- forest_samples_mu$predict_raw_single_forest(forest_dataset_mu_train, i-1) - tau_x_raw_train <- forest_samples_tau$predict_raw_single_forest(forest_dataset_tau_train, i-1) - s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0)) - s_tt1 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==1)) - partial_resid_mu_train <- resid_train - mu_x_raw_train - s_ty0 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==0)) - s_ty1 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==1)) - current_b_0 <- rnorm(1, (s_ty0/(s_tt0 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt0 + 2*current_sigma2))) - current_b_1 <- rnorm(1, (s_ty1/(s_tt1 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt1 + 2*current_sigma2))) - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 - forest_dataset_tau_train$update_basis(tau_basis_train) - b_0_samples[i] <- current_b_0 - b_1_samples[i] <- current_b_1 - if (has_test) { - tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 - forest_dataset_tau_test$update_basis(tau_basis_test) - } - } - - # Sample variance parameters (if requested) - if (sample_sigma_global) { - global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda) - current_sigma2 <- global_var_samples[i] - } - if (sample_sigma_leaf_tau) { - leaf_scale_tau_samples[i] <- sample_tau_one_iteration(forest_samples_tau, rng, a_leaf_tau, b_leaf_tau, i-1) - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_samples[i]) - } - } + # Prepare the BCF sampler to run + bcf_ptr <- bcf_init_cpp(univariate_treatment) + bcf_add_train_no_weights_cpp(bcf_ptr, X_train_mu, X_train_tau, Z_train, resid_train, binary_treatment) + if (has_test) bcf_add_test_cpp(bcf_ptr, X_test_mu, X_test_tau, Z_test) + if (sample_sigma_global) bcf_reset_global_var_samples_cpp(bcf_ptr, global_var_samples) + if (sample_sigma_leaf_mu) bcf_reset_prognostic_leaf_var_samples_cpp(bcf_ptr, leaf_scale_mu_samples) + if (sample_sigma_leaf_tau) bcf_reset_treatment_leaf_var_samples_cpp(bcf_ptr, leaf_scale_tau_samples) + if (adaptive_coding) { + bcf_reset_treatment_coding_samples_cpp(bcf_ptr, b_1_samples) + bcf_reset_control_coding_samples_cpp(bcf_ptr, b_0_samples) } + bcf_reset_train_prediction_samples_cpp(bcf_ptr, mu_hat_train, tau_hat_train, y_hat_train, n_train, num_samples, ncol(Z_train)) + if (has_test) bcf_reset_test_prediction_samples_cpp(bcf_ptr, mu_hat_test, tau_hat_test, y_hat_test, n_test, num_samples, ncol(Z_test)) - # Forest predictions - mu_hat_train <- forest_samples_mu$predict(forest_dataset_mu_train)*y_std_train + y_bar_train - # tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_tau_train)*y_std_train - if (adaptive_coding) { - tau_hat_train_raw <- forest_samples_tau$predict_raw(forest_dataset_tau_train) - tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples))*y_std_train + # Run the BCF sampler + if (univariate_treatment) { + sample_bcf_univariate_cpp( + bcf_ptr, forest_samples_mu$forest_container_ptr, forest_samples_tau$forest_container_ptr, rng$rng_ptr, cutpoint_grid_size, + sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau, beta_mu, beta_tau, + min_samples_leaf_mu, min_samples_leaf_tau, nu, lambda, a_leaf_mu, a_leaf_tau, + b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b_1, b_0, + feature_types_mu, feature_types_tau, num_gfr, num_burnin, num_mcmc, + init_mu_leaf, init_tau_leaf + ) } else { - tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_tau_train)*y_std_train + if (is.null(dim(sigma_leaf_tau))) { + if (ncol(Z_train) == 1) sigma_leaf_tau <- as.matrix(sigma_leaf_tau) + else sigma_leaf_tau <- diag(rep(sigma_leaf_tau, ncol(Z_train))) + } + sample_bcf_multivariate_cpp( + bcf_ptr, forest_samples_mu$forest_container_ptr, forest_samples_tau$forest_container_ptr, rng$rng_ptr, cutpoint_grid_size, + sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau, beta_mu, beta_tau, + min_samples_leaf_mu, min_samples_leaf_tau, nu, lambda, a_leaf_mu, a_leaf_tau, + b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b_1, b_0, + feature_types_mu, feature_types_tau, num_gfr, num_burnin, num_mcmc, + init_mu_leaf, init_tau_leaf + ) } - y_hat_train <- mu_hat_train + tau_hat_train * as.numeric(Z_train) + + # Rescale predictions and parameters + mu_hat_train <- mu_hat_train*y_std_train + y_bar_train + tau_hat_train <- tau_hat_train*y_std_train + y_hat_train <- y_hat_train*y_std_train + y_bar_train if (has_test) { - mu_hat_test <- forest_samples_mu$predict(forest_dataset_mu_test)*y_std_train + y_bar_train - # tau_hat_test <- forest_samples_tau$predict(forest_dataset_tau_test)*y_std_train - if (adaptive_coding) { - tau_hat_test_raw <- forest_samples_tau$predict_raw(forest_dataset_tau_test) - tau_hat_test <- t(t(tau_hat_test_raw) * (b_1_samples - b_0_samples))*y_std_train - } else { - tau_hat_test <- forest_samples_tau$predict_raw(forest_dataset_tau_test)*y_std_train - } - y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test) + mu_hat_test <- mu_hat_test*y_std_train + y_bar_train + tau_hat_test <- tau_hat_test*y_std_train + y_hat_test <- y_hat_test*y_std_train + y_bar_train + } + if (adaptive_coding) { + b_0_samples <- b_0_samples*y_std_train + b_1_samples <- b_1_samples*y_std_train } - - # Global error variance if (sample_sigma_global) sigma2_samples <- global_var_samples*(y_std_train^2) - - # Leaf parameter variance for prognostic forest if (sample_sigma_leaf_mu) sigma_leaf_mu_samples <- leaf_scale_mu_samples - - # Leaf parameter variance for treatment effect forest if (sample_sigma_leaf_tau) sigma_leaf_tau_samples <- leaf_scale_tau_samples # Return results as a list @@ -518,115 +428,115 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, X_test = NULL, Z_tes return(result) } -#' Predict from a sampled BCF model on new data -#' -#' @param bcf Object of type `bcf` containing draws of a Bayesian causal forest model and associated sampling outputs. -#' @param X_test Covariates used to determine tree leaf predictions for each observation. -#' @param Z_test Treatments used for prediction. -#' @param pi_test (Optional) Propensities used for prediction. Default: `NULL`. -#' -#' @return List of three `nrow(X_test)` by `bcf$num_samples` matrices: prognostic function estimates, treatment effect estimates and outcome predictions. -#' @export -#' -#' @examples -#' n <- 500 -#' x1 <- rnorm(n) -#' x2 <- rnorm(n) -#' x3 <- rnorm(n) -#' x4 <- as.numeric(rbinom(n,1,0.5)) -#' x5 <- as.numeric(sample(1:3,n,replace=T)) -#' X <- cbind(x1,x2,x3,x4,x5) -#' p <- ncol(X) -#' g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,4))} -#' mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} -#' mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} -#' tau1 <- function(x) {rep(3,nrow(x))} -#' tau2 <- function(x) {1+2*x[,2]*x[,4]} -#' mu_x <- mu1(X) -#' tau_x <- tau2(X) -#' pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 -#' Z <- rbinom(n,1,pi_x) -#' E_XZ <- mu_x + Z*tau_x -#' snr <- 4 -#' y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr) -#' test_set_pct <- 0.2 -#' n_test <- round(test_set_pct*n) -#' n_train <- n - n_test -#' test_inds <- sort(sample(1:n, n_test, replace = F)) -#' train_inds <- (1:n)[!((1:n) %in% test_inds)] -#' X_test <- X[test_inds,] -#' X_train <- X[train_inds,] -#' pi_test <- pi_x[test_inds] -#' pi_train <- pi_x[train_inds] -#' Z_test <- Z[test_inds] -#' Z_train <- Z[train_inds] -#' y_test <- y[test_inds] -#' y_train <- y[train_inds] -#' mu_test <- mu_x[test_inds] -#' mu_train <- mu_x[train_inds] -#' tau_test <- tau_x[test_inds] -#' tau_train <- tau_x[train_inds] -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train) -#' preds <- predict(bcf_model, X_test, Z_test, pi_test) -#' # plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") -#' # abline(0,1,col="red",lty=3,lwd=3) -#' # plot(rowMeans(preds$tau_hat), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") -#' # abline(0,1,col="red",lty=3,lwd=3) -predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL){ - # Convert all input data to matrices if not already converted - if ((is.null(dim(X_test))) && (!is.null(X_test))) { - X_test <- as.matrix(X_test) - } - if ((is.null(dim(Z_test))) && (!is.null(Z_test))) { - Z_test <- as.matrix(as.numeric(Z_test)) - } - if ((is.null(dim(pi_test))) && (!is.null(pi_test))) { - pi_test <- as.matrix(pi_test) - } - - # Data checks - if ((bcf$model_params$propensity_covariate != "none") && (is.null(pi_test))) { - stop("pi_test must be provided for this model") - } - if (nrow(X_test) != nrow(Z_test)) { - stop("X_test and Z_test must have the same number of rows") - } - if (bcf$model_params$num_covariates != ncol(X_test)) { - stop("X_test and must have the same number of columns as the covariates used to train the model") - } - - # Add propensities to any covariate set - if (bcf$model_params$propensity_covariate == "both") { - X_test_mu <- cbind(X_test, pi_test) - X_test_tau <- cbind(X_test, pi_test) - } else if (bcf$model_params$propensity_covariate == "mu") { - X_test_mu <- cbind(X_test, pi_test) - X_test_tau <- X_test - } else if (bcf$model_params$propensity_covariate == "tau") { - X_test_mu <- X_test - X_test_tau <- cbind(X_test, pi_test) - } - - # Create prediction datasets - prediction_dataset_mu <- createForestDataset(X_test_mu) - prediction_dataset_tau <- createForestDataset(X_test_tau, Z_test) - - # Compute and return predictions - y_std <- bcf$model_params$outcome_scale - y_bar <- bcf$model_params$outcome_mean - mu_hat_test <- bcf$forests_mu$predict(prediction_dataset_mu)*y_std + y_bar - if (bcf$model_params$adaptive_coding) { - tau_hat_test_raw <- bcf$forests_tau$predict_raw(prediction_dataset_tau) - tau_hat_test <- t(t(tau_hat_test_raw) * (bcf$b_1_samples - bcf$b_0_samples))*y_std - } else { - tau_hat_test <- bcf$forests_tau$predict_raw(forest_dataset_tau_test)*y_std - } - y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test) - - result <- list( - "mu_hat" = mu_hat_test, - "tau_hat" = tau_hat_test, - "y_hat" = y_hat_test - ) - return(result) -} +#' #' Predict from a sampled BCF model on new data +#' #' +#' #' @param bcf Object of type `bcf` containing draws of a Bayesian causal forest model and associated sampling outputs. +#' #' @param X_test Covariates used to determine tree leaf predictions for each observation. +#' #' @param Z_test Treatments used for prediction. +#' #' @param pi_test (Optional) Propensities used for prediction. Default: `NULL`. +#' #' +#' #' @return List of three `nrow(X_test)` by `bcf$num_samples` matrices: prognostic function estimates, treatment effect estimates and outcome predictions. +#' #' @export +#' #' +#' #' @examples +#' #' n <- 500 +#' #' x1 <- rnorm(n) +#' #' x2 <- rnorm(n) +#' #' x3 <- rnorm(n) +#' #' x4 <- as.numeric(rbinom(n,1,0.5)) +#' #' x5 <- as.numeric(sample(1:3,n,replace=T)) +#' #' X <- cbind(x1,x2,x3,x4,x5) +#' #' p <- ncol(X) +#' #' g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,4))} +#' #' mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} +#' #' mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} +#' #' tau1 <- function(x) {rep(3,nrow(x))} +#' #' tau2 <- function(x) {1+2*x[,2]*x[,4]} +#' #' mu_x <- mu1(X) +#' #' tau_x <- tau2(X) +#' #' pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +#' #' Z <- rbinom(n,1,pi_x) +#' #' E_XZ <- mu_x + Z*tau_x +#' #' snr <- 4 +#' #' y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +#' #' test_set_pct <- 0.2 +#' #' n_test <- round(test_set_pct*n) +#' #' n_train <- n - n_test +#' #' test_inds <- sort(sample(1:n, n_test, replace = F)) +#' #' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' #' X_test <- X[test_inds,] +#' #' X_train <- X[train_inds,] +#' #' pi_test <- pi_x[test_inds] +#' #' pi_train <- pi_x[train_inds] +#' #' Z_test <- Z[test_inds] +#' #' Z_train <- Z[train_inds] +#' #' y_test <- y[test_inds] +#' #' y_train <- y[train_inds] +#' #' mu_test <- mu_x[test_inds] +#' #' mu_train <- mu_x[train_inds] +#' #' tau_test <- tau_x[test_inds] +#' #' tau_train <- tau_x[train_inds] +#' #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train) +#' #' preds <- predict(bcf_model, X_test, Z_test, pi_test) +#' #' # plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") +#' #' # abline(0,1,col="red",lty=3,lwd=3) +#' #' # plot(rowMeans(preds$tau_hat), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") +#' #' # abline(0,1,col="red",lty=3,lwd=3) +#' predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL){ +#' # Convert all input data to matrices if not already converted +#' if ((is.null(dim(X_test))) && (!is.null(X_test))) { +#' X_test <- as.matrix(X_test) +#' } +#' if ((is.null(dim(Z_test))) && (!is.null(Z_test))) { +#' Z_test <- as.matrix(as.numeric(Z_test)) +#' } +#' if ((is.null(dim(pi_test))) && (!is.null(pi_test))) { +#' pi_test <- as.matrix(pi_test) +#' } +#' +#' # Data checks +#' if ((bcf$model_params$propensity_covariate != "none") && (is.null(pi_test))) { +#' stop("pi_test must be provided for this model") +#' } +#' if (nrow(X_test) != nrow(Z_test)) { +#' stop("X_test and Z_test must have the same number of rows") +#' } +#' if (bcf$model_params$num_covariates != ncol(X_test)) { +#' stop("X_test and must have the same number of columns as the covariates used to train the model") +#' } +#' +#' # Add propensities to any covariate set +#' if (bcf$model_params$propensity_covariate == "both") { +#' X_test_mu <- cbind(X_test, pi_test) +#' X_test_tau <- cbind(X_test, pi_test) +#' } else if (bcf$model_params$propensity_covariate == "mu") { +#' X_test_mu <- cbind(X_test, pi_test) +#' X_test_tau <- X_test +#' } else if (bcf$model_params$propensity_covariate == "tau") { +#' X_test_mu <- X_test +#' X_test_tau <- cbind(X_test, pi_test) +#' } +#' +#' # Create prediction datasets +#' prediction_dataset_mu <- createForestDataset(X_test_mu) +#' prediction_dataset_tau <- createForestDataset(X_test_tau, Z_test) +#' +#' # Compute and return predictions +#' y_std <- bcf$model_params$outcome_scale +#' y_bar <- bcf$model_params$outcome_mean +#' mu_hat_test <- bcf$forests_mu$predict(prediction_dataset_mu)*y_std + y_bar +#' if (bcf$model_params$adaptive_coding) { +#' tau_hat_test_raw <- bcf$forests_tau$predict_raw(prediction_dataset_tau) +#' tau_hat_test <- t(t(tau_hat_test_raw) * (bcf$b_1_samples - bcf$b_0_samples))*y_std +#' } else { +#' tau_hat_test <- bcf$forests_tau$predict_raw(forest_dataset_tau_test)*y_std +#' } +#' y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test) +#' +#' result <- list( +#' "mu_hat" = mu_hat_test, +#' "tau_hat" = tau_hat_test, +#' "y_hat" = y_hat_test +#' ) +#' return(result) +#' } diff --git a/R/cpp11.R b/R/cpp11.R index ba12da8..61fc272 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -1,5 +1,57 @@ # Generated by cpp11: do not edit by hand +bcf_init_cpp <- function(univariate_treatment) { + .Call(`_stochtree_bcf_init_cpp`, univariate_treatment) +} + +bcf_add_train_with_weights_cpp <- function(bcf_wrapper, X_train_mu, X_train_tau, Z_train, y_train, weights_train, treatment_binary) { + invisible(.Call(`_stochtree_bcf_add_train_with_weights_cpp`, bcf_wrapper, X_train_mu, X_train_tau, Z_train, y_train, weights_train, treatment_binary)) +} + +bcf_add_train_no_weights_cpp <- function(bcf_wrapper, X_train_mu, X_train_tau, Z_train, y_train, treatment_binary) { + invisible(.Call(`_stochtree_bcf_add_train_no_weights_cpp`, bcf_wrapper, X_train_mu, X_train_tau, Z_train, y_train, treatment_binary)) +} + +bcf_add_test_cpp <- function(bcf_wrapper, X_test_mu, X_test_tau, Z_test) { + invisible(.Call(`_stochtree_bcf_add_test_cpp`, bcf_wrapper, X_test_mu, X_test_tau, Z_test)) +} + +bcf_reset_global_var_samples_cpp <- function(bcf_wrapper, data_vector) { + invisible(.Call(`_stochtree_bcf_reset_global_var_samples_cpp`, bcf_wrapper, data_vector)) +} + +bcf_reset_prognostic_leaf_var_samples_cpp <- function(bcf_wrapper, data_vector) { + invisible(.Call(`_stochtree_bcf_reset_prognostic_leaf_var_samples_cpp`, bcf_wrapper, data_vector)) +} + +bcf_reset_treatment_leaf_var_samples_cpp <- function(bcf_wrapper, data_vector) { + invisible(.Call(`_stochtree_bcf_reset_treatment_leaf_var_samples_cpp`, bcf_wrapper, data_vector)) +} + +bcf_reset_treatment_coding_samples_cpp <- function(bcf_wrapper, data_vector) { + invisible(.Call(`_stochtree_bcf_reset_treatment_coding_samples_cpp`, bcf_wrapper, data_vector)) +} + +bcf_reset_control_coding_samples_cpp <- function(bcf_wrapper, data_vector) { + invisible(.Call(`_stochtree_bcf_reset_control_coding_samples_cpp`, bcf_wrapper, data_vector)) +} + +bcf_reset_train_prediction_samples_cpp <- function(bcf_wrapper, muhat, tauhat, yhat, num_obs, num_samples, treatment_dim) { + invisible(.Call(`_stochtree_bcf_reset_train_prediction_samples_cpp`, bcf_wrapper, muhat, tauhat, yhat, num_obs, num_samples, treatment_dim)) +} + +bcf_reset_test_prediction_samples_cpp <- function(bcf_wrapper, muhat, tauhat, yhat, num_obs, num_samples, treatment_dim) { + invisible(.Call(`_stochtree_bcf_reset_test_prediction_samples_cpp`, bcf_wrapper, muhat, tauhat, yhat, num_obs, num_samples, treatment_dim)) +} + +sample_bcf_univariate_cpp <- function(bcf_wrapper, forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu_int, feature_types_tau_int, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau) { + invisible(.Call(`_stochtree_sample_bcf_univariate_cpp`, bcf_wrapper, forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu_int, feature_types_tau_int, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau)) +} + +sample_bcf_multivariate_cpp <- function(bcf_wrapper, forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau_r, alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu_int, feature_types_tau_int, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau) { + invisible(.Call(`_stochtree_sample_bcf_multivariate_cpp`, bcf_wrapper, forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau_r, alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu_int, feature_types_tau_int, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau)) +} + create_forest_dataset_cpp <- function() { .Call(`_stochtree_create_forest_dataset_cpp`) } diff --git a/src/Makevars b/src/Makevars index 9bbe204..f9a0433 100644 --- a/src/Makevars +++ b/src/Makevars @@ -4,6 +4,7 @@ CPP_PKGROOT=stochtree-cpp PKG_CPPFLAGS= -I$(CPP_PKGROOT)/include -I$(CPP_PKGROOT)/dependencies/boost_math/include -I$(CPP_PKGROOT)/dependencies/eigen OBJECTS = \ + bcf.o \ data.o \ predictor.o \ sampler.o \ diff --git a/src/bcf.cpp b/src/bcf.cpp new file mode 100644 index 0000000..b618cc5 --- /dev/null +++ b/src/bcf.cpp @@ -0,0 +1,303 @@ +#include +#include "stochtree_types.h" +#include +#include +#include +#include +#include +#include + +[[cpp11::register]] +cpp11::external_pointer bcf_init_cpp(bool univariate_treatment = true) { + std::unique_ptr bcf_ptr = std::make_unique(univariate_treatment); + return cpp11::external_pointer(bcf_ptr.release()); +} + +[[cpp11::register]] +void bcf_add_train_with_weights_cpp( + cpp11::external_pointer bcf_wrapper, cpp11::doubles_matrix<> X_train_mu, + cpp11::doubles_matrix<> X_train_tau, cpp11::doubles_matrix<> Z_train, + cpp11::doubles y_train, cpp11::doubles weights_train, bool treatment_binary +) { + // Data dimensions + int n = X_train_mu.nrow(); + int X_train_mu_cols = X_train_mu.ncol(); + int X_train_tau_cols = X_train_tau.ncol(); + int Z_train_cols = Z_train.ncol(); + + // Pointers to R data + double* X_train_mu_data_ptr = REAL(PROTECT(X_train_mu)); + double* X_train_tau_data_ptr = REAL(PROTECT(X_train_tau)); + double* Z_train_data_ptr = REAL(PROTECT(Z_train)); + double* y_train_data_ptr = REAL(PROTECT(y_train)); + double* weights_train_data_ptr = REAL(PROTECT(weights_train)); + + // Load training data into BCF model + bcf_wrapper->LoadTrain( + y_train_data_ptr, n, X_train_mu_data_ptr, X_train_mu_cols, + X_train_tau_data_ptr, X_train_tau_cols, Z_train_data_ptr, + Z_train_cols, treatment_binary, weights_train_data_ptr + ); + + // UNPROTECT the SEXPs created to point to the R data + UNPROTECT(5); +} + +[[cpp11::register]] +void bcf_add_train_no_weights_cpp( + cpp11::external_pointer bcf_wrapper, cpp11::doubles_matrix<> X_train_mu, + cpp11::doubles_matrix<> X_train_tau, cpp11::doubles_matrix<> Z_train, + cpp11::doubles y_train, bool treatment_binary +) { + // Data dimensions + int n = X_train_mu.nrow(); + int X_train_mu_cols = X_train_mu.ncol(); + int X_train_tau_cols = X_train_tau.ncol(); + int Z_train_cols = Z_train.ncol(); + + // Pointers to R data + double* X_train_mu_data_ptr = REAL(PROTECT(X_train_mu)); + double* X_train_tau_data_ptr = REAL(PROTECT(X_train_tau)); + double* Z_train_data_ptr = REAL(PROTECT(Z_train)); + double* y_train_data_ptr = REAL(PROTECT(y_train)); + + // Load training data into BCF model + bcf_wrapper->LoadTrain( + y_train_data_ptr, n, X_train_mu_data_ptr, X_train_mu_cols, + X_train_tau_data_ptr, X_train_tau_cols, Z_train_data_ptr, + Z_train_cols, treatment_binary + ); + + // UNPROTECT the SEXPs created to point to the R data + UNPROTECT(4); +} + +[[cpp11::register]] +void bcf_add_test_cpp( + cpp11::external_pointer bcf_wrapper, cpp11::doubles_matrix<> X_test_mu, + cpp11::doubles_matrix<> X_test_tau, cpp11::doubles_matrix<> Z_test +) { + // Data dimensions + int n = X_test_mu.nrow(); + int X_test_mu_cols = X_test_mu.ncol(); + int X_test_tau_cols = X_test_tau.ncol(); + int Z_test_cols = Z_test.ncol(); + + // Pointers to R data + double* X_test_mu_data_ptr = REAL(PROTECT(X_test_mu)); + double* X_test_tau_data_ptr = REAL(PROTECT(X_test_tau)); + double* Z_test_data_ptr = REAL(PROTECT(Z_test)); + + // Load test data into BCF model + bcf_wrapper->LoadTest( + X_test_mu_data_ptr, n, X_test_mu_cols, + X_test_tau_data_ptr, X_test_tau_cols, + Z_test_data_ptr, Z_test_cols + ); + + // UNPROTECT the SEXPs created to point to the R data + UNPROTECT(3); +} + +[[cpp11::register]] +void bcf_reset_global_var_samples_cpp( + cpp11::external_pointer bcf_wrapper, + cpp11::doubles data_vector +) { + // Data dimensions + int n = data_vector.size(); + + // Pointer to R data + double* data_ptr = REAL(PROTECT(data_vector)); + + // Map Eigen array to data in the R container + bcf_wrapper->ResetGlobalVarSamples(data_ptr, n); + + // UNPROTECT the SEXP created to point to the R data + UNPROTECT(1); +} + +[[cpp11::register]] +void bcf_reset_prognostic_leaf_var_samples_cpp( + cpp11::external_pointer bcf_wrapper, + cpp11::doubles data_vector +) { + // Data dimensions + int n = data_vector.size(); + + // Pointer to R data + double* data_ptr = REAL(PROTECT(data_vector)); + + // Map Eigen array to data in the R container + bcf_wrapper->ResetPrognosticLeafVarSamples(data_ptr, n); + + // UNPROTECT the SEXP created to point to the R data + UNPROTECT(1); +} + +[[cpp11::register]] +void bcf_reset_treatment_leaf_var_samples_cpp( + cpp11::external_pointer bcf_wrapper, + cpp11::doubles data_vector +) { + // Data dimensions + int n = data_vector.size(); + + // Pointer to R data + double* data_ptr = REAL(PROTECT(data_vector)); + + // Map Eigen array to data in the R container + bcf_wrapper->ResetTreatmentLeafVarSamples(data_ptr, n); + + // UNPROTECT the SEXP created to point to the R data + UNPROTECT(1); +} + +[[cpp11::register]] +void bcf_reset_treatment_coding_samples_cpp( + cpp11::external_pointer bcf_wrapper, + cpp11::doubles data_vector +) { + // Data dimensions + int n = data_vector.size(); + + // Pointer to R data + double* data_ptr = REAL(PROTECT(data_vector)); + + // Map Eigen array to data in the R container + bcf_wrapper->ResetTreatedCodingSamples(data_ptr, n); + + // UNPROTECT the SEXP created to point to the R data + UNPROTECT(1); +} + +[[cpp11::register]] +void bcf_reset_control_coding_samples_cpp( + cpp11::external_pointer bcf_wrapper, + cpp11::doubles data_vector +) { + // Data dimensions + int n = data_vector.size(); + + // Pointer to R data + double* data_ptr = REAL(PROTECT(data_vector)); + + // Map Eigen array to data in the R container + bcf_wrapper->ResetControlCodingSamples(data_ptr, n); + + // UNPROTECT the SEXP created to point to the R data + UNPROTECT(1); +} + +[[cpp11::register]] +void bcf_reset_train_prediction_samples_cpp( + cpp11::external_pointer bcf_wrapper, + cpp11::doubles_matrix<> muhat, cpp11::doubles tauhat, cpp11::doubles_matrix<> yhat, + int num_obs, int num_samples, int treatment_dim +) { + // Pointers to R data + double* muhat_data_ptr = REAL(PROTECT(muhat)); + double* tauhat_data_ptr = REAL(PROTECT(tauhat)); + double* yhat_data_ptr = REAL(PROTECT(yhat)); + + // Map Eigen array to data in the R container + bcf_wrapper->ResetTrainPredictionSamples(muhat_data_ptr, tauhat_data_ptr, yhat_data_ptr, num_obs, num_samples, treatment_dim); + + // UNPROTECT the SEXP created to point to the R data + UNPROTECT(3); +} + +[[cpp11::register]] +void bcf_reset_test_prediction_samples_cpp( + cpp11::external_pointer bcf_wrapper, + cpp11::doubles_matrix<> muhat, cpp11::doubles tauhat, cpp11::doubles_matrix<> yhat, + int num_obs, int num_samples, int treatment_dim +) { + // Pointers to R data + double* muhat_data_ptr = REAL(PROTECT(muhat)); + double* tauhat_data_ptr = REAL(PROTECT(tauhat)); + double* yhat_data_ptr = REAL(PROTECT(yhat)); + + // Map Eigen array to data in the R container + bcf_wrapper->ResetTestPredictionSamples(muhat_data_ptr, tauhat_data_ptr, yhat_data_ptr, num_obs, num_samples, treatment_dim); + + // UNPROTECT the SEXP created to point to the R data + UNPROTECT(3); +} + +[[cpp11::register]] +void sample_bcf_univariate_cpp( + cpp11::external_pointer bcf_wrapper, + cpp11::external_pointer forest_samples_mu, + cpp11::external_pointer forest_samples_tau, + cpp11::external_pointer rng, + int cutpoint_grid_size, double sigma_leaf_mu, double sigma_leaf_tau, + double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, + int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, + double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, + double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, + cpp11::integers feature_types_mu_int, cpp11::integers feature_types_tau_int, + int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau +) { + // Convert feature_types + std::vector feature_types_mu(feature_types_mu_int.size()); + for (int i = 0; i < feature_types_mu_int.size(); i++) { + feature_types_mu.at(i) = static_cast(feature_types_mu_int.at(i)); + } + std::vector feature_types_tau(feature_types_tau_int.size()); + for (int i = 0; i < feature_types_tau_int.size(); i++) { + feature_types_tau.at(i) = static_cast(feature_types_tau_int.at(i)); + } + + // Run the sampler + bcf_wrapper->SampleBCF(forest_samples_mu.get(), forest_samples_tau.get(), rng.get(), + cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau, + beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, + a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, + b1, b0, feature_types_mu, feature_types_tau, num_gfr, num_burnin, num_mcmc, + leaf_init_mu, leaf_init_tau); +} + +[[cpp11::register]] +void sample_bcf_multivariate_cpp( + cpp11::external_pointer bcf_wrapper, + cpp11::external_pointer forest_samples_mu, + cpp11::external_pointer forest_samples_tau, + cpp11::external_pointer rng, + int cutpoint_grid_size, double sigma_leaf_mu, cpp11::doubles_matrix<> sigma_leaf_tau_r, + double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, + int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, + double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, + double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, + cpp11::integers feature_types_mu_int, cpp11::integers feature_types_tau_int, + int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau +) { + // Convert feature_types + std::vector feature_types_mu(feature_types_mu_int.size()); + for (int i = 0; i < feature_types_mu_int.size(); i++) { + feature_types_mu.at(i) = static_cast(feature_types_mu_int.at(i)); + } + std::vector feature_types_tau(feature_types_tau_int.size()); + for (int i = 0; i < feature_types_tau_int.size(); i++) { + feature_types_tau.at(i) = static_cast(feature_types_tau_int.at(i)); + } + + // Convert sigma_leaf_tau + Eigen::MatrixXd sigma_leaf_tau; + int num_row = sigma_leaf_tau_r.nrow(); + int num_col = sigma_leaf_tau_r.ncol(); + sigma_leaf_tau.resize(num_row, num_col); + for (int i = 0; i < num_row; i++) { + for (int j = 0; j < num_col; j++) { + sigma_leaf_tau(i,j) = sigma_leaf_tau_r(i,j); + } + } + + // Run the sampler + bcf_wrapper->SampleBCF(forest_samples_mu.get(), forest_samples_tau.get(), rng.get(), + cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau, + beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, + a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, + b1, b0, feature_types_mu, feature_types_tau, num_gfr, num_burnin, num_mcmc, + leaf_init_mu, leaf_init_tau); +} diff --git a/src/cpp11.cpp b/src/cpp11.cpp index df1faaf..0fb7469 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -5,6 +5,109 @@ #include "cpp11/declarations.hpp" #include +// bcf.cpp +cpp11::external_pointer bcf_init_cpp(bool univariate_treatment); +extern "C" SEXP _stochtree_bcf_init_cpp(SEXP univariate_treatment) { + BEGIN_CPP11 + return cpp11::as_sexp(bcf_init_cpp(cpp11::as_cpp>(univariate_treatment))); + END_CPP11 +} +// bcf.cpp +void bcf_add_train_with_weights_cpp(cpp11::external_pointer bcf_wrapper, cpp11::doubles_matrix<> X_train_mu, cpp11::doubles_matrix<> X_train_tau, cpp11::doubles_matrix<> Z_train, cpp11::doubles y_train, cpp11::doubles weights_train, bool treatment_binary); +extern "C" SEXP _stochtree_bcf_add_train_with_weights_cpp(SEXP bcf_wrapper, SEXP X_train_mu, SEXP X_train_tau, SEXP Z_train, SEXP y_train, SEXP weights_train, SEXP treatment_binary) { + BEGIN_CPP11 + bcf_add_train_with_weights_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>>(X_train_mu), cpp11::as_cpp>>(X_train_tau), cpp11::as_cpp>>(Z_train), cpp11::as_cpp>(y_train), cpp11::as_cpp>(weights_train), cpp11::as_cpp>(treatment_binary)); + return R_NilValue; + END_CPP11 +} +// bcf.cpp +void bcf_add_train_no_weights_cpp(cpp11::external_pointer bcf_wrapper, cpp11::doubles_matrix<> X_train_mu, cpp11::doubles_matrix<> X_train_tau, cpp11::doubles_matrix<> Z_train, cpp11::doubles y_train, bool treatment_binary); +extern "C" SEXP _stochtree_bcf_add_train_no_weights_cpp(SEXP bcf_wrapper, SEXP X_train_mu, SEXP X_train_tau, SEXP Z_train, SEXP y_train, SEXP treatment_binary) { + BEGIN_CPP11 + bcf_add_train_no_weights_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>>(X_train_mu), cpp11::as_cpp>>(X_train_tau), cpp11::as_cpp>>(Z_train), cpp11::as_cpp>(y_train), cpp11::as_cpp>(treatment_binary)); + return R_NilValue; + END_CPP11 +} +// bcf.cpp +void bcf_add_test_cpp(cpp11::external_pointer bcf_wrapper, cpp11::doubles_matrix<> X_test_mu, cpp11::doubles_matrix<> X_test_tau, cpp11::doubles_matrix<> Z_test); +extern "C" SEXP _stochtree_bcf_add_test_cpp(SEXP bcf_wrapper, SEXP X_test_mu, SEXP X_test_tau, SEXP Z_test) { + BEGIN_CPP11 + bcf_add_test_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>>(X_test_mu), cpp11::as_cpp>>(X_test_tau), cpp11::as_cpp>>(Z_test)); + return R_NilValue; + END_CPP11 +} +// bcf.cpp +void bcf_reset_global_var_samples_cpp(cpp11::external_pointer bcf_wrapper, cpp11::doubles data_vector); +extern "C" SEXP _stochtree_bcf_reset_global_var_samples_cpp(SEXP bcf_wrapper, SEXP data_vector) { + BEGIN_CPP11 + bcf_reset_global_var_samples_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>(data_vector)); + return R_NilValue; + END_CPP11 +} +// bcf.cpp +void bcf_reset_prognostic_leaf_var_samples_cpp(cpp11::external_pointer bcf_wrapper, cpp11::doubles data_vector); +extern "C" SEXP _stochtree_bcf_reset_prognostic_leaf_var_samples_cpp(SEXP bcf_wrapper, SEXP data_vector) { + BEGIN_CPP11 + bcf_reset_prognostic_leaf_var_samples_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>(data_vector)); + return R_NilValue; + END_CPP11 +} +// bcf.cpp +void bcf_reset_treatment_leaf_var_samples_cpp(cpp11::external_pointer bcf_wrapper, cpp11::doubles data_vector); +extern "C" SEXP _stochtree_bcf_reset_treatment_leaf_var_samples_cpp(SEXP bcf_wrapper, SEXP data_vector) { + BEGIN_CPP11 + bcf_reset_treatment_leaf_var_samples_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>(data_vector)); + return R_NilValue; + END_CPP11 +} +// bcf.cpp +void bcf_reset_treatment_coding_samples_cpp(cpp11::external_pointer bcf_wrapper, cpp11::doubles data_vector); +extern "C" SEXP _stochtree_bcf_reset_treatment_coding_samples_cpp(SEXP bcf_wrapper, SEXP data_vector) { + BEGIN_CPP11 + bcf_reset_treatment_coding_samples_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>(data_vector)); + return R_NilValue; + END_CPP11 +} +// bcf.cpp +void bcf_reset_control_coding_samples_cpp(cpp11::external_pointer bcf_wrapper, cpp11::doubles data_vector); +extern "C" SEXP _stochtree_bcf_reset_control_coding_samples_cpp(SEXP bcf_wrapper, SEXP data_vector) { + BEGIN_CPP11 + bcf_reset_control_coding_samples_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>(data_vector)); + return R_NilValue; + END_CPP11 +} +// bcf.cpp +void bcf_reset_train_prediction_samples_cpp(cpp11::external_pointer bcf_wrapper, cpp11::doubles_matrix<> muhat, cpp11::doubles tauhat, cpp11::doubles_matrix<> yhat, int num_obs, int num_samples, int treatment_dim); +extern "C" SEXP _stochtree_bcf_reset_train_prediction_samples_cpp(SEXP bcf_wrapper, SEXP muhat, SEXP tauhat, SEXP yhat, SEXP num_obs, SEXP num_samples, SEXP treatment_dim) { + BEGIN_CPP11 + bcf_reset_train_prediction_samples_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>>(muhat), cpp11::as_cpp>(tauhat), cpp11::as_cpp>>(yhat), cpp11::as_cpp>(num_obs), cpp11::as_cpp>(num_samples), cpp11::as_cpp>(treatment_dim)); + return R_NilValue; + END_CPP11 +} +// bcf.cpp +void bcf_reset_test_prediction_samples_cpp(cpp11::external_pointer bcf_wrapper, cpp11::doubles_matrix<> muhat, cpp11::doubles tauhat, cpp11::doubles_matrix<> yhat, int num_obs, int num_samples, int treatment_dim); +extern "C" SEXP _stochtree_bcf_reset_test_prediction_samples_cpp(SEXP bcf_wrapper, SEXP muhat, SEXP tauhat, SEXP yhat, SEXP num_obs, SEXP num_samples, SEXP treatment_dim) { + BEGIN_CPP11 + bcf_reset_test_prediction_samples_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>>(muhat), cpp11::as_cpp>(tauhat), cpp11::as_cpp>>(yhat), cpp11::as_cpp>(num_obs), cpp11::as_cpp>(num_samples), cpp11::as_cpp>(treatment_dim)); + return R_NilValue; + END_CPP11 +} +// bcf.cpp +void sample_bcf_univariate_cpp(cpp11::external_pointer bcf_wrapper, cpp11::external_pointer forest_samples_mu, cpp11::external_pointer forest_samples_tau, cpp11::external_pointer rng, int cutpoint_grid_size, double sigma_leaf_mu, double sigma_leaf_tau, double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, cpp11::integers feature_types_mu_int, cpp11::integers feature_types_tau_int, int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau); +extern "C" SEXP _stochtree_sample_bcf_univariate_cpp(SEXP bcf_wrapper, SEXP forest_samples_mu, SEXP forest_samples_tau, SEXP rng, SEXP cutpoint_grid_size, SEXP sigma_leaf_mu, SEXP sigma_leaf_tau, SEXP alpha_mu, SEXP alpha_tau, SEXP beta_mu, SEXP beta_tau, SEXP min_samples_leaf_mu, SEXP min_samples_leaf_tau, SEXP nu, SEXP lamb, SEXP a_leaf_mu, SEXP a_leaf_tau, SEXP b_leaf_mu, SEXP b_leaf_tau, SEXP sigma2, SEXP num_trees_mu, SEXP num_trees_tau, SEXP b1, SEXP b0, SEXP feature_types_mu_int, SEXP feature_types_tau_int, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP leaf_init_mu, SEXP leaf_init_tau) { + BEGIN_CPP11 + sample_bcf_univariate_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>>(forest_samples_mu), cpp11::as_cpp>>(forest_samples_tau), cpp11::as_cpp>>(rng), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>(sigma_leaf_mu), cpp11::as_cpp>(sigma_leaf_tau), cpp11::as_cpp>(alpha_mu), cpp11::as_cpp>(alpha_tau), cpp11::as_cpp>(beta_mu), cpp11::as_cpp>(beta_tau), cpp11::as_cpp>(min_samples_leaf_mu), cpp11::as_cpp>(min_samples_leaf_tau), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(a_leaf_mu), cpp11::as_cpp>(a_leaf_tau), cpp11::as_cpp>(b_leaf_mu), cpp11::as_cpp>(b_leaf_tau), cpp11::as_cpp>(sigma2), cpp11::as_cpp>(num_trees_mu), cpp11::as_cpp>(num_trees_tau), cpp11::as_cpp>(b1), cpp11::as_cpp>(b0), cpp11::as_cpp>(feature_types_mu_int), cpp11::as_cpp>(feature_types_tau_int), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(leaf_init_mu), cpp11::as_cpp>(leaf_init_tau)); + return R_NilValue; + END_CPP11 +} +// bcf.cpp +void sample_bcf_multivariate_cpp(cpp11::external_pointer bcf_wrapper, cpp11::external_pointer forest_samples_mu, cpp11::external_pointer forest_samples_tau, cpp11::external_pointer rng, int cutpoint_grid_size, double sigma_leaf_mu, cpp11::doubles_matrix<> sigma_leaf_tau_r, double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, cpp11::integers feature_types_mu_int, cpp11::integers feature_types_tau_int, int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau); +extern "C" SEXP _stochtree_sample_bcf_multivariate_cpp(SEXP bcf_wrapper, SEXP forest_samples_mu, SEXP forest_samples_tau, SEXP rng, SEXP cutpoint_grid_size, SEXP sigma_leaf_mu, SEXP sigma_leaf_tau_r, SEXP alpha_mu, SEXP alpha_tau, SEXP beta_mu, SEXP beta_tau, SEXP min_samples_leaf_mu, SEXP min_samples_leaf_tau, SEXP nu, SEXP lamb, SEXP a_leaf_mu, SEXP a_leaf_tau, SEXP b_leaf_mu, SEXP b_leaf_tau, SEXP sigma2, SEXP num_trees_mu, SEXP num_trees_tau, SEXP b1, SEXP b0, SEXP feature_types_mu_int, SEXP feature_types_tau_int, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP leaf_init_mu, SEXP leaf_init_tau) { + BEGIN_CPP11 + sample_bcf_multivariate_cpp(cpp11::as_cpp>>(bcf_wrapper), cpp11::as_cpp>>(forest_samples_mu), cpp11::as_cpp>>(forest_samples_tau), cpp11::as_cpp>>(rng), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>(sigma_leaf_mu), cpp11::as_cpp>>(sigma_leaf_tau_r), cpp11::as_cpp>(alpha_mu), cpp11::as_cpp>(alpha_tau), cpp11::as_cpp>(beta_mu), cpp11::as_cpp>(beta_tau), cpp11::as_cpp>(min_samples_leaf_mu), cpp11::as_cpp>(min_samples_leaf_tau), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(a_leaf_mu), cpp11::as_cpp>(a_leaf_tau), cpp11::as_cpp>(b_leaf_mu), cpp11::as_cpp>(b_leaf_tau), cpp11::as_cpp>(sigma2), cpp11::as_cpp>(num_trees_mu), cpp11::as_cpp>(num_trees_tau), cpp11::as_cpp>(b1), cpp11::as_cpp>(b0), cpp11::as_cpp>(feature_types_mu_int), cpp11::as_cpp>(feature_types_tau_int), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(leaf_init_mu), cpp11::as_cpp>(leaf_init_tau)); + return R_NilValue; + END_CPP11 +} // data.cpp cpp11::external_pointer create_forest_dataset_cpp(); extern "C" SEXP _stochtree_create_forest_dataset_cpp() { @@ -244,38 +347,51 @@ extern "C" SEXP _stochtree_forest_tracker_cpp(SEXP data, SEXP feature_types, SEX extern "C" { static const R_CallMethodDef CallEntries[] = { - {"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1}, - {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, - {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, - {"_stochtree_create_forest_dataset_cpp", (DL_FUNC) &_stochtree_create_forest_dataset_cpp, 0}, - {"_stochtree_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_dataset_has_basis_cpp, 1}, - {"_stochtree_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_dataset_has_variance_weights_cpp, 1}, - {"_stochtree_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_dataset_num_basis_cpp, 1}, - {"_stochtree_dataset_num_covariates_cpp", (DL_FUNC) &_stochtree_dataset_num_covariates_cpp, 1}, - {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, - {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 3}, - {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, - {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, - {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, - {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, - {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, - {"_stochtree_is_leaf_constant_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_forest_container_cpp, 1}, - {"_stochtree_json_load_forest_container_cpp", (DL_FUNC) &_stochtree_json_load_forest_container_cpp, 2}, - {"_stochtree_json_save_forest_container_cpp", (DL_FUNC) &_stochtree_json_save_forest_container_cpp, 2}, - {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, - {"_stochtree_output_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_output_dimension_forest_container_cpp, 1}, - {"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2}, - {"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2}, - {"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3}, - {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, - {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 13}, - {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 13}, - {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 4}, - {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 5}, - {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, - {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, - {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 3}, - {"_stochtree_update_residual_forest_container_cpp", (DL_FUNC) &_stochtree_update_residual_forest_container_cpp, 7}, + {"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1}, + {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, + {"_stochtree_bcf_add_test_cpp", (DL_FUNC) &_stochtree_bcf_add_test_cpp, 4}, + {"_stochtree_bcf_add_train_no_weights_cpp", (DL_FUNC) &_stochtree_bcf_add_train_no_weights_cpp, 6}, + {"_stochtree_bcf_add_train_with_weights_cpp", (DL_FUNC) &_stochtree_bcf_add_train_with_weights_cpp, 7}, + {"_stochtree_bcf_init_cpp", (DL_FUNC) &_stochtree_bcf_init_cpp, 1}, + {"_stochtree_bcf_reset_control_coding_samples_cpp", (DL_FUNC) &_stochtree_bcf_reset_control_coding_samples_cpp, 2}, + {"_stochtree_bcf_reset_global_var_samples_cpp", (DL_FUNC) &_stochtree_bcf_reset_global_var_samples_cpp, 2}, + {"_stochtree_bcf_reset_prognostic_leaf_var_samples_cpp", (DL_FUNC) &_stochtree_bcf_reset_prognostic_leaf_var_samples_cpp, 2}, + {"_stochtree_bcf_reset_test_prediction_samples_cpp", (DL_FUNC) &_stochtree_bcf_reset_test_prediction_samples_cpp, 7}, + {"_stochtree_bcf_reset_train_prediction_samples_cpp", (DL_FUNC) &_stochtree_bcf_reset_train_prediction_samples_cpp, 7}, + {"_stochtree_bcf_reset_treatment_coding_samples_cpp", (DL_FUNC) &_stochtree_bcf_reset_treatment_coding_samples_cpp, 2}, + {"_stochtree_bcf_reset_treatment_leaf_var_samples_cpp", (DL_FUNC) &_stochtree_bcf_reset_treatment_leaf_var_samples_cpp, 2}, + {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, + {"_stochtree_create_forest_dataset_cpp", (DL_FUNC) &_stochtree_create_forest_dataset_cpp, 0}, + {"_stochtree_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_dataset_has_basis_cpp, 1}, + {"_stochtree_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_dataset_has_variance_weights_cpp, 1}, + {"_stochtree_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_dataset_num_basis_cpp, 1}, + {"_stochtree_dataset_num_covariates_cpp", (DL_FUNC) &_stochtree_dataset_num_covariates_cpp, 1}, + {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, + {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 3}, + {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, + {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, + {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, + {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, + {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, + {"_stochtree_is_leaf_constant_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_forest_container_cpp, 1}, + {"_stochtree_json_load_forest_container_cpp", (DL_FUNC) &_stochtree_json_load_forest_container_cpp, 2}, + {"_stochtree_json_save_forest_container_cpp", (DL_FUNC) &_stochtree_json_save_forest_container_cpp, 2}, + {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, + {"_stochtree_output_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_output_dimension_forest_container_cpp, 1}, + {"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2}, + {"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2}, + {"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3}, + {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, + {"_stochtree_sample_bcf_multivariate_cpp", (DL_FUNC) &_stochtree_sample_bcf_multivariate_cpp, 31}, + {"_stochtree_sample_bcf_univariate_cpp", (DL_FUNC) &_stochtree_sample_bcf_univariate_cpp, 31}, + {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 13}, + {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 13}, + {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 4}, + {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 5}, + {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, + {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, + {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 3}, + {"_stochtree_update_residual_forest_container_cpp", (DL_FUNC) &_stochtree_update_residual_forest_container_cpp, 7}, {NULL, NULL, 0} }; } diff --git a/src/sampler.cpp b/src/sampler.cpp index cf34fa6..e17fac6 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -30,18 +30,18 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer sampler = StochTree::GFRForestSampler(cutpoint_grid_size); sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { + } else if (leaf_model_enum == StochTree::ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); StochTree::GFRForestSampler sampler = StochTree::GFRForestSampler(cutpoint_grid_size); sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { + } else if (leaf_model_enum == StochTree::ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); StochTree::GFRForestSampler sampler = StochTree::GFRForestSampler(cutpoint_grid_size); sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, feature_types_, pre_initialized); @@ -94,18 +94,18 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer sampler = StochTree::MCMCForestSampler(); sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { + } else if (leaf_model_enum == StochTree::ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); StochTree::MCMCForestSampler sampler = StochTree::MCMCForestSampler(); sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); - } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { + } else if (leaf_model_enum == StochTree::ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); StochTree::MCMCForestSampler sampler = StochTree::MCMCForestSampler(); sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); diff --git a/src/stochtree-cpp b/src/stochtree-cpp index 95421bc..798d837 160000 --- a/src/stochtree-cpp +++ b/src/stochtree-cpp @@ -1 +1 @@ -Subproject commit 95421bccb5b38922b6435b16d892a7c905cdb686 +Subproject commit 798d83754c5a5d2099bd7552b128963886df1a59 diff --git a/src/stochtree_types.h b/src/stochtree_types.h index 50ef48a..8f53901 100644 --- a/src/stochtree_types.h +++ b/src/stochtree_types.h @@ -1,12 +1,8 @@ #include +#include #include #include #include #include #include - -enum ForestLeafModel { - kConstant, - kUnivariateRegression, - kMultivariateRegression -}; +#include diff --git a/tools/debug/bcf_demo.R b/tools/debug/bcf_demo.R new file mode 100644 index 0000000..0fe058b --- /dev/null +++ b/tools/debug/bcf_demo.R @@ -0,0 +1,44 @@ +library(stochtree) +n <- 500 +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- as.numeric(rbinom(n,1,0.5)) +x5 <- as.numeric(sample(1:3,n,replace=T)) +X <- cbind(x1,x2,x3,x4,x5) +p <- ncol(X) +g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,4))} +mu1 <- function(x) {1+g(x)+x[,1]*x[,3]} +mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)} +tau1 <- function(x) {rep(3,nrow(x))} +tau2 <- function(x) {1+2*x[,2]*x[,4]} +mu_x <- mu1(X) +tau_x <- tau2(X) +pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 +Z <- rbinom(n,1,pi_x) +E_XZ <- mu_x + Z*tau_x +snr <- 4 +y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = F)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, + X_test = X_test, Z_test = Z_test, pi_test = pi_test) +plot(rowMeans(bcf_model$mu_hat_test), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") +abline(0,1,col="red",lty=3,lwd=3) +plot(rowMeans(bcf_model$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") +abline(0,1,col="red",lty=3,lwd=3)