Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first update variable split #18

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ License: MIT + file LICENSE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
LinkingTo:
LinkingTo:
stochvol,
cpp11
Suggests:
knitr,
Expand All @@ -25,6 +26,7 @@ Suggests:
mvtnorm,
ggplot2,
latex2exp,
invgamma,
testthat (>= 3.0.0)
VignetteBuilder: knitr
SystemRequirements: C++17
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ export(oneHotEncode)
export(oneHotInitializeAndEncode)
export(orderedCatInitializeAndPreprocess)
export(orderedCatPreprocess)
export(preprocessPredictionDataFrame)
export(preprocessTrainDataFrame)
export(sample_sigma2_one_iteration)
export(sample_tau_one_iteration)
export(saveBCFModelToJsonFile)
export(varbart)
useDynLib(stochtree, .registration = TRUE)
117 changes: 111 additions & 6 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.
#'
#' @param Sparse Whether you want to turn on the dirichilet prior.
#' @param Theta_Update Whether or not update the theta of the dirichilet prior.
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
#' @export
#'
Expand Down Expand Up @@ -76,7 +77,9 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL,
q = 0.9, sigma2_init = NULL, num_trees = 200, num_gfr = 5,
num_burnin = 0, num_mcmc = 100, sample_sigma = T,
sample_tau = T, random_seed = -1, keep_burnin = F, keep_gfr = F){
sample_tau = T, random_seed = -1, keep_burnin = F, keep_gfr = F,
Sparse = F,
Theta_Update = F){
# Preprocess covariates
if (!is.data.frame(X_train)) {
stop("X_train must be a dataframe")
Expand Down Expand Up @@ -291,13 +294,17 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
# Variable selection weights
variable_weights <- rep(1/ncol(X_train), ncol(X_train))

#Variable Selection Splits
variable_count_splits <- as.integer(rep(0, ncol(X_train)))
var_count_matrix = matrix(NA, nrow = num_samples, ncol = ncol(X_train))

# Run GFR (warm start) if specified
if (num_gfr > 0){
gfr_indices = 1:num_gfr
for (i in 1:num_gfr) {
forest_model$sample_one_iteration(
forest_dataset_train, outcome_train, forest_samples, rng, feature_types,
leaf_model, current_leaf_scale, variable_weights,
leaf_model, current_leaf_scale, variable_weights, variable_count_splits,
current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = F
)
if (sample_sigma) {
Expand All @@ -314,6 +321,9 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
}
}

#Dirichlet Initialization
theta = 1

# Run MCMC
if (num_burnin + num_mcmc > 0) {
if (num_burnin > 0) {
Expand All @@ -323,11 +333,20 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
mcmc_indices = (num_gfr+num_burnin+1):(num_gfr+num_burnin+num_mcmc)
}
for (i in (num_gfr+1):num_samples) {
forest_model$sample_one_iteration(
variable_count_splits = forest_model$sample_one_iteration(
forest_dataset_train, outcome_train, forest_samples, rng, feature_types,
leaf_model, current_leaf_scale, variable_weights,
leaf_model, current_leaf_scale, variable_weights, variable_count_splits,
current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = F
)
if(Sparse == TRUE){
lpv = draw_s(variable_count_splits, theta)
variable_weights = exp(lpv)
if(Theta_Update == TRUE){
theta = draw_theta0(theta, lpv, 0.5, 1, rho = length(lpv))
}

}
var_count_matrix[i,] = variable_count_splits
if (sample_sigma) {
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda)
current_sigma2 <- global_var_samples[i]
Expand All @@ -339,6 +358,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
if (has_rfx) {
rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, current_sigma2, rng)
}

}
}

Expand Down Expand Up @@ -420,7 +440,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
"has_rfx_basis" = has_basis_rfx,
"num_rfx_basis" = num_basis_rfx,
"sample_sigma" = sample_sigma,
"sample_tau" = sample_tau
"sample_tau" = sample_tau,
"variable_count_splits" =var_count_matrix
)
result <- list(
"forests" = forest_samples,
Expand Down Expand Up @@ -644,3 +665,87 @@ getRandomEffectSamples.bartmodel <- function(object, ...){

return(result)
}





log_sum_exp = function(v){
n = length(v)
mx = max(v)
sm = 0
for(i in 1:n){
sm = sm + exp(v[i] - mx)
}
return(mx + log(sm))
}

log_gamma = function(shape){
y = log(rgamma(1, shape+ 1))
z = log(runif(1))/shape
return(y+z)
}

log_dirichilet = function(alpha){
k = length(alpha)
draw = rep(0,k)
for(j in 1:k){
draw[j] = log_gamma(alpha[j])
}
lse = log_sum_exp(draw)
for(j in 1:k){
draw[j] = draw[j] - lse
}
return(draw)
}


draw_s = function(nv,theta = 1){
n = length(nv)
theta_ = rep(0, n)
for(i in 1:n){
theta_[i] = theta/n + nv[i]
}
lpv = log_dirichilet(theta_)
return(lpv)
}



discrete = function(wts) {
p <- length(wts)
x <- 0
vOut <- rep(0, p)
vOut <- rmultinom(1, size = 1, prob = wts)
if (vOut[1] == 0) {
for (j in 2:p) {
x <- x + j * vOut[j]
}
}
return(x)
}

draw_theta0 = function(theta, lpv, a , b, rho) {
p = length(lpv)
sumlpv = sum(lpv)
lambda_g <- seq(1 / 1001, 1000 / 1001, length.out = 1000)
theta_g <- lambda_g * rho / (1 - lambda_g)
lwt_g = rep(0, 1000)

for (k in 1:1000) {
theta_log_lik = lgamma(theta_g[k]) - p * lgamma(theta_g[k] / p) + (theta_g[k] / p) * sumlpv
beta_log_prior = (a - 1) * log(lambda_g[k]) + (b - 1) * log(1 - lambda_g[k])
lwt_g[k] = theta_log_lik + beta_log_prior
}

lse <- log_sum_exp(lwt_g)
lwt_g <- exp(lwt_g - lse)
weights <- lwt_g / sum(lwt_g)
theta <- theta_g[discrete(weights)]

return(theta)
}




9 changes: 5 additions & 4 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
}

# Set variable weights for the prognostic and treatment effect forests
variable_selection_splits = NULL
variable_weights_mu = rep(1/ncol(X_train_mu), ncol(X_train_mu))
variable_weights_tau = rep(1/ncol(X_train_tau), ncol(X_train_tau))

Expand Down Expand Up @@ -404,7 +405,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
# Sample the prognostic forest
forest_model_mu$sample_one_iteration(
forest_dataset_mu_train, outcome_train, forest_samples_mu, rng, feature_types_mu,
0, current_leaf_scale_mu, variable_weights_mu,
0, current_leaf_scale_mu, variable_weights_mu, variable_selection_splits,
current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T
)

Expand All @@ -421,7 +422,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
# Sample the treatment forest
forest_model_tau$sample_one_iteration(
forest_dataset_tau_train, outcome_train, forest_samples_tau, rng, feature_types_tau,
1, current_leaf_scale_tau, variable_weights_tau,
1, current_leaf_scale_tau, variable_weights_tau, variable_selection_splits,
current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T
)

Expand Down Expand Up @@ -488,7 +489,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
# Sample the prognostic forest
forest_model_mu$sample_one_iteration(
forest_dataset_mu_train, outcome_train, forest_samples_mu, rng, feature_types_mu,
0, current_leaf_scale_mu, variable_weights_mu,
0, current_leaf_scale_mu, variable_weights_mu, variable_selection_splits,
current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T
)

Expand All @@ -505,7 +506,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
# Sample the treatment forest
forest_model_tau$sample_one_iteration(
forest_dataset_tau_train, outcome_train, forest_samples_tau, rng, feature_types_tau,
1, current_leaf_scale_tau, variable_weights_tau,
1, current_leaf_scale_tau, variable_weights_tau, variable_selection_splits,
current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T
)

Expand Down
4 changes: 2 additions & 2 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, tracker
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized))
}

sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized) {
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized))
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, variable_count_splits, global_variance, leaf_model_int, pre_initialized) {
.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, variable_count_splits, global_variance, leaf_model_int, pre_initialized)
}

sample_sigma2_one_iteration_cpp <- function(residual, rng, nu, lambda) {
Expand Down
5 changes: 3 additions & 2 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ ForestModel <- R6::R6Class(
#' @param leaf_model_int Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression)
#' @param leaf_model_scale Scale parameter used in the leaf node model (should be a q x q matrix where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`)
#' @param variable_weights Vector specifying sampling probability for all p covariates in `forest_dataset`
#' @param variable_count_splits ####
#' @param global_scale Global variance parameter
#' @param cutpoint_grid_size (Optional) Number of unique cutpoints to consider (default: 500, currently only used when `GFR = TRUE`)
#' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm
#' @param pre_initialized (Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: F.
sample_one_iteration = function(forest_dataset, residual, forest_samples, rng, feature_types,
leaf_model_int, leaf_model_scale, variable_weights,
leaf_model_int, leaf_model_scale, variable_weights, variable_count_splits,
global_scale, cutpoint_grid_size = 500, gfr = T,
pre_initialized = F) {
if (gfr) {
Expand All @@ -87,7 +88,7 @@ ForestModel <- R6::R6Class(
forest_dataset$data_ptr, residual$data_ptr,
forest_samples$forest_container_ptr, self$tracker_ptr, self$tree_prior_ptr,
rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale,
variable_weights, global_scale, leaf_model_int, pre_initialized
variable_weights, variable_count_splits, global_scale, leaf_model_int, pre_initialized
)
}
}
Expand Down
Loading