Skip to content

Commit

Permalink
Added progress bar / verbose option and demo for continuous BCF
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewherren committed Jun 12, 2024
1 parent 5afa2ea commit a694940
Show file tree
Hide file tree
Showing 6 changed files with 327 additions and 30 deletions.
25 changes: 24 additions & 1 deletion R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.
#' @param verbose Whether or not to print progress during the sampling loops. Default: FALSE.
#'
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
#' @export
Expand Down Expand Up @@ -82,7 +83,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL,
q = 0.9, sigma2_init = NULL, num_trees = 200, num_gfr = 5,
num_burnin = 0, num_mcmc = 100, sample_sigma = T,
sample_tau = T, random_seed = -1, keep_burnin = F, keep_gfr = F){
sample_tau = T, random_seed = -1, keep_burnin = F,
keep_gfr = F, verbose = F){
# Preprocess covariates
if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
stop("X_train must be a matrix or dataframe")
Expand Down Expand Up @@ -301,6 +303,13 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
if (num_gfr > 0){
gfr_indices = 1:num_gfr
for (i in 1:num_gfr) {
# Print progress
if (verbose) {
if ((i %% 10 == 0) || (i == num_gfr)) {
cat("Sampling", i, "out of", num_gfr, "XBART (grow-from-root) draws\n")
}
}

forest_model$sample_one_iteration(
forest_dataset_train, outcome_train, forest_samples, rng, feature_types,
leaf_model, current_leaf_scale, variable_weights,
Expand Down Expand Up @@ -329,6 +338,20 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
mcmc_indices = (num_gfr+num_burnin+1):(num_gfr+num_burnin+num_mcmc)
}
for (i in (num_gfr+1):num_samples) {
# Print progress
if (verbose) {
if (num_burnin > 0) {
if (((i - num_gfr) %% 100 == 0) || ((i - num_gfr) == num_burnin)) {
cat("Sampling", i - num_gfr, "out of", num_gfr, "BART burn-in draws\n")
}
}
if (num_mcmc > 0) {
if (((i - num_gfr - num_burnin) %% 100 == 0) || (i == num_samples)) {
cat("Sampling", i - num_burnin - num_gfr, "out of", num_mcmc, "BART MCMC draws\n")
}
}
}

forest_model$sample_one_iteration(
forest_dataset_train, outcome_train, forest_samples, rng, feature_types,
leaf_model, current_leaf_scale, variable_weights,
Expand Down
41 changes: 38 additions & 3 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
#' @param verbose Whether or not to print progress during the sampling loops. Default: FALSE.
#'
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
#' @export
Expand Down Expand Up @@ -113,7 +114,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
q = 0.9, sigma2 = NULL, num_trees_mu = 250, num_trees_tau = 50, num_gfr = 5,
num_burnin = 0, num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T,
sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T,
b_0 = -0.5, b_1 = 0.5, random_seed = -1, keep_burnin = F, keep_gfr = F) {
b_0 = -0.5, b_1 = 0.5, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F) {
# Preprocess covariates
if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
stop("X_train must be a matrix or dataframe")
Expand Down Expand Up @@ -168,6 +169,13 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
}
}

# Check that outcome and treatment are numeric
if (!is.numeric(y_train)) stop("y_train must be numeric")
if (!is.numeric(Z_train)) stop("Z_train must be numeric")
if (!is.null(Z_test)) {
if (!is.numeric(Z_test)) stop("Z_test must be numeric")
}

# Data consistency checks
if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) {
stop("X_train and X_test must have the same number of columns")
Expand Down Expand Up @@ -223,6 +231,11 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
rfx_basis_test <- matrix(rep(1,nrow(X_test)), nrow = nrow(X_test), ncol = 1)
}
}

# Check that number of samples are all nonnegative
stopifnot(num_gfr >= 0)
stopifnot(num_burnin >= 0)
stopifnot(num_mcmc >= 0)

# Determine whether a test set is provided
has_test = !is.null(X_test)
Expand All @@ -232,8 +245,9 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
y_train <- as.matrix(y_train)
}

# Check whether treatment is binary
# Check whether treatment is binary (specifically 0-1 binary)
binary_treatment <- length(unique(Z_train)) == 2
if (!(all(sort(unique(Z_train)) == c(0,1)))) binary_treatment <- F

# Adaptive coding will be ignored for continuous / ordered categorical treatments
if ((!binary_treatment) && (adaptive_coding)) {
Expand Down Expand Up @@ -402,11 +416,18 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
update_residual_forest_container_cpp(forest_dataset_tau_train$data_ptr, outcome_train$data_ptr,
forest_samples_tau$forest_container_ptr, forest_model_tau$tracker_ptr,
T, 0, F)

# Run GFR (warm start) if specified
if (num_gfr > 0){
gfr_indices = 1:num_gfr
for (i in 1:num_gfr) {
# Print progress
if (verbose) {
if ((i %% 10 == 0) || (i == num_gfr)) {
cat("Sampling", i, "out of", num_gfr, "XBCF (grow-from-root) draws\n")
}
}

# Sample the prognostic forest
forest_model_mu$sample_one_iteration(
forest_dataset_mu_train, outcome_train, forest_samples_mu, rng, feature_types_mu,
Expand Down Expand Up @@ -491,6 +512,20 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
mcmc_indices = (num_gfr+num_burnin+1):(num_gfr+num_burnin+num_mcmc)
}
for (i in (num_gfr+1):num_samples) {
# Print progress
if (verbose) {
if (num_burnin > 0) {
if (((i - num_gfr) %% 100 == 0) || ((i - num_gfr) == num_burnin)) {
cat("Sampling", i - num_gfr, "out of", num_gfr, "BCF burn-in draws\n")
}
}
if (num_mcmc > 0) {
if (((i - num_gfr - num_burnin) %% 100 == 0) || (i == num_samples)) {
cat("Sampling", i - num_burnin - num_gfr, "out of", num_mcmc, "BCF MCMC draws\n")
}
}
}

# Sample the prognostic forest
forest_model_mu$sample_one_iteration(
forest_dataset_mu_train, outcome_train, forest_samples_mu, rng, feature_types_mu,
Expand Down
5 changes: 4 additions & 1 deletion man/BART.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/bcf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 61 additions & 0 deletions tools/debug/continuous_treatment_bcf.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
library(stochtree)

# Generate data with a continuous treatment
n <- 500
snr <- 3
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
x4 <- rnorm(n)
x5 <- rnorm(n)
X <- cbind(x1,x2,x3,x4,x5)
p <- ncol(X)
mu_x <- 1 + 2*x1 - 4*(x2 < 0) + 4*(x2 >= 0) + 3*(abs(x3) - sqrt(2/pi))
tau_x <- 1 + 2*x4
u <- runif(n)
pi_x <- ((mu_x-1)/4) + 4*(u-0.5)
Z <- pi_x + rnorm(n,0,1)
E_XZ <- mu_x + Z*tau_x
y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
X <- as.data.frame(X)

# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = F))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]

# Run continuous treatment BCF
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_warmstart <- bcf(
X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train,
X_test = X_test, Z_test = Z_test, pi_test = pi_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, verbose = T
)

# Inspect results
mu_hat_train <- rowMeans(bcf_model_warmstart$mu_hat_train)
tau_hat_train <- rowMeans(bcf_model_warmstart$tau_hat_train)
mu_hat_test <- rowMeans(bcf_model_warmstart$mu_hat_test)
tau_hat_test <- rowMeans(bcf_model_warmstart$tau_hat_test)
plot(mu_train, mu_hat_train); abline(0,1,lwd=3,lty=3,col="red")
plot(tau_train, tau_hat_train); abline(0,1,lwd=3,lty=3,col="red")
plot(mu_test, mu_hat_test); abline(0,1,lwd=3,lty=3,col="red")
plot(tau_test, tau_hat_test); abline(0,1,lwd=3,lty=3,col="red")
Loading

0 comments on commit a694940

Please sign in to comment.