diff --git a/DESCRIPTION b/DESCRIPTION index 6697717..f5647fd 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -14,7 +14,8 @@ License: MIT + file LICENSE Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.1 -LinkingTo: +LinkingTo: + stochvol, cpp11 Suggests: knitr, @@ -25,6 +26,7 @@ Suggests: mvtnorm, ggplot2, latex2exp, + invgamma, testthat (>= 3.0.0) VignetteBuilder: knitr SystemRequirements: C++17 diff --git a/NAMESPACE b/NAMESPACE index ce0feea..6567e78 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -34,7 +34,10 @@ export(oneHotEncode) export(oneHotInitializeAndEncode) export(orderedCatInitializeAndPreprocess) export(orderedCatPreprocess) +export(preprocessPredictionDataFrame) +export(preprocessTrainDataFrame) export(sample_sigma2_one_iteration) export(sample_tau_one_iteration) export(saveBCFModelToJsonFile) +export(varbart) useDynLib(stochtree, .registration = TRUE) diff --git a/R/bart.R b/R/bart.R index 22b6d1f..b5f6d65 100644 --- a/R/bart.R +++ b/R/bart.R @@ -40,7 +40,8 @@ #' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. #' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0. #' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0. -#' +#' @param Sparse Whether you want to turn on the dirichilet prior. +#' @param Theta_Update Whether or not update the theta of the dirichilet prior. #' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). #' @export #' @@ -76,7 +77,9 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, q = 0.9, sigma2_init = NULL, num_trees = 200, num_gfr = 5, num_burnin = 0, num_mcmc = 100, sample_sigma = T, - sample_tau = T, random_seed = -1, keep_burnin = F, keep_gfr = F){ + sample_tau = T, random_seed = -1, keep_burnin = F, keep_gfr = F, + Sparse = F, + Theta_Update = F){ # Preprocess covariates if (!is.data.frame(X_train)) { stop("X_train must be a dataframe") @@ -291,13 +294,17 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, # Variable selection weights variable_weights <- rep(1/ncol(X_train), ncol(X_train)) + #Variable Selection Splits + variable_count_splits <- as.integer(rep(0, ncol(X_train))) + var_count_matrix = matrix(NA, nrow = num_samples, ncol = ncol(X_train)) + # Run GFR (warm start) if specified if (num_gfr > 0){ gfr_indices = 1:num_gfr for (i in 1:num_gfr) { forest_model$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples, rng, feature_types, - leaf_model, current_leaf_scale, variable_weights, + leaf_model, current_leaf_scale, variable_weights, variable_count_splits, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = F ) if (sample_sigma) { @@ -314,6 +321,9 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, } } + #Dirichlet Initialization + theta = 1 + # Run MCMC if (num_burnin + num_mcmc > 0) { if (num_burnin > 0) { @@ -323,11 +333,20 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, mcmc_indices = (num_gfr+num_burnin+1):(num_gfr+num_burnin+num_mcmc) } for (i in (num_gfr+1):num_samples) { - forest_model$sample_one_iteration( + variable_count_splits = forest_model$sample_one_iteration( forest_dataset_train, outcome_train, forest_samples, rng, feature_types, - leaf_model, current_leaf_scale, variable_weights, + leaf_model, current_leaf_scale, variable_weights, variable_count_splits, current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = F ) + if(Sparse == TRUE){ + lpv = draw_s(variable_count_splits, theta) + variable_weights = exp(lpv) + if(Theta_Update == TRUE){ + theta = draw_theta0(theta, lpv, 0.5, 1, rho = length(lpv)) + } + + } + var_count_matrix[i,] = variable_count_splits if (sample_sigma) { global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda) current_sigma2 <- global_var_samples[i] @@ -339,6 +358,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, if (has_rfx) { rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, current_sigma2, rng) } + } } @@ -420,7 +440,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, "has_rfx_basis" = has_basis_rfx, "num_rfx_basis" = num_basis_rfx, "sample_sigma" = sample_sigma, - "sample_tau" = sample_tau + "sample_tau" = sample_tau, + "variable_count_splits" =var_count_matrix ) result <- list( "forests" = forest_samples, @@ -644,3 +665,87 @@ getRandomEffectSamples.bartmodel <- function(object, ...){ return(result) } + + + + + +log_sum_exp = function(v){ + n = length(v) + mx = max(v) + sm = 0 + for(i in 1:n){ + sm = sm + exp(v[i] - mx) + } + return(mx + log(sm)) +} + +log_gamma = function(shape){ + y = log(rgamma(1, shape+ 1)) + z = log(runif(1))/shape + return(y+z) +} + +log_dirichilet = function(alpha){ + k = length(alpha) + draw = rep(0,k) + for(j in 1:k){ + draw[j] = log_gamma(alpha[j]) + } + lse = log_sum_exp(draw) + for(j in 1:k){ + draw[j] = draw[j] - lse + } + return(draw) +} + + +draw_s = function(nv,theta = 1){ + n = length(nv) + theta_ = rep(0, n) + for(i in 1:n){ + theta_[i] = theta/n + nv[i] + } + lpv = log_dirichilet(theta_) + return(lpv) +} + + + +discrete = function(wts) { + p <- length(wts) + x <- 0 + vOut <- rep(0, p) + vOut <- rmultinom(1, size = 1, prob = wts) + if (vOut[1] == 0) { + for (j in 2:p) { + x <- x + j * vOut[j] + } + } + return(x) +} + +draw_theta0 = function(theta, lpv, a , b, rho) { + p = length(lpv) + sumlpv = sum(lpv) + lambda_g <- seq(1 / 1001, 1000 / 1001, length.out = 1000) + theta_g <- lambda_g * rho / (1 - lambda_g) + lwt_g = rep(0, 1000) + + for (k in 1:1000) { + theta_log_lik = lgamma(theta_g[k]) - p * lgamma(theta_g[k] / p) + (theta_g[k] / p) * sumlpv + beta_log_prior = (a - 1) * log(lambda_g[k]) + (b - 1) * log(1 - lambda_g[k]) + lwt_g[k] = theta_log_lik + beta_log_prior + } + + lse <- log_sum_exp(lwt_g) + lwt_g <- exp(lwt_g - lse) + weights <- lwt_g / sum(lwt_g) + theta <- theta_g[discrete(weights)] + + return(theta) +} + + + + diff --git a/R/bcf.R b/R/bcf.R index efea77d..f3c1414 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -290,6 +290,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU } # Set variable weights for the prognostic and treatment effect forests + variable_selection_splits = NULL variable_weights_mu = rep(1/ncol(X_train_mu), ncol(X_train_mu)) variable_weights_tau = rep(1/ncol(X_train_tau), ncol(X_train_tau)) @@ -404,7 +405,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset_mu_train, outcome_train, forest_samples_mu, rng, feature_types_mu, - 0, current_leaf_scale_mu, variable_weights_mu, + 0, current_leaf_scale_mu, variable_weights_mu, variable_selection_splits, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T ) @@ -421,7 +422,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Sample the treatment forest forest_model_tau$sample_one_iteration( forest_dataset_tau_train, outcome_train, forest_samples_tau, rng, feature_types_tau, - 1, current_leaf_scale_tau, variable_weights_tau, + 1, current_leaf_scale_tau, variable_weights_tau, variable_selection_splits, current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = T ) @@ -488,7 +489,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset_mu_train, outcome_train, forest_samples_mu, rng, feature_types_mu, - 0, current_leaf_scale_mu, variable_weights_mu, + 0, current_leaf_scale_mu, variable_weights_mu, variable_selection_splits, current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T ) @@ -505,7 +506,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Sample the treatment forest forest_model_tau$sample_one_iteration( forest_dataset_tau_train, outcome_train, forest_samples_tau, rng, feature_types_tau, - 1, current_leaf_scale_tau, variable_weights_tau, + 1, current_leaf_scale_tau, variable_weights_tau, variable_selection_splits, current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = T ) diff --git a/R/cpp11.R b/R/cpp11.R index 6f04624..10cbc15 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -272,8 +272,8 @@ sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, tracker invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized)) } -sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized) { - invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized)) +sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, variable_count_splits, global_variance, leaf_model_int, pre_initialized) { + .Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, variable_count_splits, global_variance, leaf_model_int, pre_initialized) } sample_sigma2_one_iteration_cpp <- function(residual, rng, nu, lambda) { diff --git a/R/model.R b/R/model.R index b39ae12..b172621 100644 --- a/R/model.R +++ b/R/model.R @@ -67,12 +67,13 @@ ForestModel <- R6::R6Class( #' @param leaf_model_int Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression) #' @param leaf_model_scale Scale parameter used in the leaf node model (should be a q x q matrix where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`) #' @param variable_weights Vector specifying sampling probability for all p covariates in `forest_dataset` + #' @param variable_count_splits #### #' @param global_scale Global variance parameter #' @param cutpoint_grid_size (Optional) Number of unique cutpoints to consider (default: 500, currently only used when `GFR = TRUE`) #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm #' @param pre_initialized (Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: F. sample_one_iteration = function(forest_dataset, residual, forest_samples, rng, feature_types, - leaf_model_int, leaf_model_scale, variable_weights, + leaf_model_int, leaf_model_scale, variable_weights, variable_count_splits, global_scale, cutpoint_grid_size = 500, gfr = T, pre_initialized = F) { if (gfr) { @@ -87,7 +88,7 @@ ForestModel <- R6::R6Class( forest_dataset$data_ptr, residual$data_ptr, forest_samples$forest_container_ptr, self$tracker_ptr, self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, - variable_weights, global_scale, leaf_model_int, pre_initialized + variable_weights, variable_count_splits, global_scale, leaf_model_int, pre_initialized ) } } diff --git a/R/varbart.R b/R/varbart.R new file mode 100644 index 0000000..0bf8153 --- /dev/null +++ b/R/varbart.R @@ -0,0 +1,293 @@ +#' Title +#' +#' @param y_train +#' @param X_train +#' @param sv +#' @param W_train +#' @param group_ids_train +#' @param rfx_basis_train +#' @param X_test +#' @param W_test +#' @param group_ids_test +#' @param rfx_basis_test +#' @param ordered_cat_vars +#' @param unordered_cat_vars +#' @param cutpoint_grid_size +#' @param tau_init +#' @param alpha +#' @param beta +#' @param min_samples_leaf +#' @param output_dimension +#' @param is_leaf_constant +#' @param leaf_regression +#' @param nu +#' @param lambda +#' @param a_leaf +#' @param b_leaf +#' @param q +#' @param sigma2_init +#' @param num_trees +#' @param num_gfr +#' @param num_burnin +#' @param num_mcmc +#' @param random_seed +#' @param keep_burnin +#' +#' @return +#' @export +#' +#' @examples +varbart = function(y_train, X_train, sv = FALSE , W_train = NULL, group_ids_train = NULL, + rfx_basis_train = NULL, X_test = NULL, W_test = NULL, + group_ids_test = NULL, rfx_basis_test = NULL, + ordered_cat_vars = NULL, unordered_cat_vars = NULL, + cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, + beta = 2.0, min_samples_leaf = 5, output_dimension = 1, + is_leaf_constant = T,leaf_regression = F, nu = 3, + lambda = NULL, a_leaf = 3, b_leaf = NULL, q = 0.9, + sigma2_init = NULL, num_trees = 250, num_gfr = 0, + num_burnin = 0, num_mcmc = 100, random_seed = -1, + keep_burnin = F){ + + num_samples = num_burnin + num_mcmc + + M = ncol(y_train) + TT = nrow(y_train) + K = ncol(X_train) +# colnames(X) = paste(rep(colnames(Yraw), p), +# sort(rep(paste(".t-", sprintf("%02d", 1:p), sep = ""), each = 2)), sep = "" ) + # Calibrate priors for sigma^2 and tau + + beta_ols <- solve(crossprod(X)) %*% crossprod(X,Y) + sigma2hat = crossprod(Y-X%*%beta_ols)/TT + quantile_cutoff <- 0.9 + if (is.null(lambda)) { + lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu + } + if (is.null(sigma2_init)) sigma2_init <- sigma2hat + if (is.null(b_leaf)) b_leaf <- var(Y)/(2*num_trees) + if (is.null(tau_init)) tau_init <- var(Y)/(num_trees) + current_leaf_scale <- as.matrix(tau_init) + current_sigma2 <- sigma2_init + + # Data + forest_dataset_train_ls = c() + + # Random number generator (std::mt19937) + if (is.null(random_seed)) random_seed = sample(1:10000,1,F) + rng <- createRNG(random_seed) + + # Sampling data structures + feature_types <- as.integer(feature_types) + #outcome_train = createOutcome(Y[,1]) + + # Container of forest samples + forest_samples_ls = c() + bart_sampler_ls = c() + + + for(mm in 1:M){ + # outcome_train_ls = c(outcome_train_ls ,createOutcome(Y[,mm])) + forest_dataset_train_ls = c(forest_dataset_train_ls, stochtree::createForestDataset(X, basis = NULL, variance_weights = NULL)) + forest_samples_ls = c(forest_samples_ls, stochtree::createForestContainer(num_trees, output_dimension,is_leaf_constant)) + bart_sampler_ls = c(bart_sampler_ls,createForestModel(forest_dataset_train_ls[[mm]], + feature_types, num_trees,nrow(X), alpha, beta, + min_samples_leaf)) + } + + #Initialization of Covariance Objects: + th.A0 = matrix(1,M,M) + eta = matrix(NA, TT, M) + H = matrix(-3, TT, M) + d.A0 = diag(M) + sigma_mat = matrix(NA, M, 1) + Y.fit.BART = Y*0 + + + # Init. for the Horseshoe: + lambda.A0 = 1 + nu.A0 = 1 + tau.A0 =1 + zeta.A0 =1 + prior.cov = rep(1, M*(M-1)/2) + + # Stochastic Volatility: + sv_priors = list() + sv_draw = list() + h_latent = list() + + sv_params_mcmc = array(NA, dim = c(num_samples, 4, M)) + sv_params_mat <- matrix(NA, M, 4) + colnames(sv_params_mat) = c("mu", "phi", "sigma","h") + + if(sv){ + for(mm in 1:M){ + sv_draw[[mm]] = list(mu = 0, phi = 0.99, sigma = 0.01, nu = Inf, rho = 0, beta = NA, latent0 = 0) + h_latent[[mm]] = rep(0,TT) + sv_priors[[mm]] = stochvol::specify_priors( + mu = sv_normal(mean =0, sd = 10), + phi = sv_beta(shape1 =5 , shape2 = 1.5), + sigma2 = sv_gamma(shape = 0.5, rate = 10), + nu = sv_infinity(), + rho = sv_constant(0)) + } + }else{ + for(mm in 1:M){ + sv_draw[[mm]] = list(mu = 0, phi = 0.99, sigma = 0.01, nu = Inf, rho = 0, beta = NA, latent0 = 0) + h_latent[[mm]] = rep(0,TT) + sv_priors[[mm]] <- stochvol::specify_priors( + mu = sv_constant(0), + phi = sv_constant(1-1e-12), + sigma2 = sv_constant(1e-12), + nu = sv_infinity(), + rho = sv_constant(0) + ) + } + + } + + # Container of variance parameter samples + global_var_samples <- matrix(0, num_samples, M) + colnames(global_var_samples) = variables + leaf_scale_samples <- matrix(0, num_samples, M) + Y.store = array(NA, dim=c(num_samples, TT, M)) + H.store = array(NA, dim=c(num_samples, TT, M)) + # Run MCMC + + # Run MCMC + pb = txtProgressBar(min = 0, max = num_samples, style = 3) + start = Sys.time() + + + for (i in (num_gfr+1):num_samples) { + for(mm in 1:M){ + + if(mm >1){ + eta_mm = eta[,1:(mm -1), drop = FALSE] + A0_mm = d.A0[mm,1:(mm-1)] + outcome_train = createOutcome(Y[,mm] -Y.fit.BART[,mm] - eta_mm%*%A0_mm) + }else{ + outcome_train = createOutcome(Y[,1] - Y.fit.BART[,1]) + } + variable_tree_splits = as.integer(c(0,0)) + forest_samples = forest_samples_ls[[mm]] + forest_model = bart_sampler_ls[[mm]] + forest_dataset_train = forest_dataset_train_ls[[mm]] + + forest_model$sample_one_iteration( + forest_dataset_train, outcome_train, forest_samples, rng, + feature_types, leaf_model, current_leaf_scale[mm,mm], variable_weights, variable_tree_splits, + current_sigma2[mm,mm], cutpoint_grid_size, gfr = FALSE, pre_initialized = FALSE) + + leaf_scale_samples[i,mm] <- sample_tau_one_iteration(forest_samples, rng, a_leaf, b_leaf[mm,mm], i-1) + current_leaf_scale[mm,mm] <- as.matrix(leaf_scale_samples[i,mm]) + global_var_samples[i,mm] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda[mm,mm]) + + current_sigma2[mm,mm] <- global_var_samples[i,mm] + sigma_mat[mm] = sqrt(global_var_samples[i,mm]) + + Y.fit.BART[,mm] = forest_samples$predict_raw_single_forest(forest_dataset_train, i-1) + eta[,mm] = Y[,mm] - Y.fit.BART[,mm] + + if(mm >1){ + norm.mm = as.numeric(exp(-.5*h_latent[[mm]]) * 1/sigma_mat[mm,]) + u_mm = eta[,1:(mm-1), drop =F]*norm.mm + eta_mm = eta[,mm]*norm.mm + if (mm == 2){ + v0.inv = 1/th.A0[mm,1] + }else{ + v0.inv = diag(1/th.A0[mm,1:(mm-1)]) + } + V.cov = solve(crossprod(u_mm) + v0.inv) + mu.cov = V.cov %*% crossprod(u_mm, eta_mm) + mu.cov.draw = mu.cov + t(chol(V.cov)) %*% rnorm(ncol(V.cov)) + d.A0[mm,1:(mm-1)] = mu.cov.draw + + } + + } + + + shocks = eta %*%t(solve(d.A0)) + for (mm in 1:M){ + if(sv == TRUE){ + svdraw_mm = stochvol::svsample_general_cpp(shocks[,mm]/sigma_mat[mm], + startpara = sv_draw[[mm]], startlatent = h_latent[[mm]], + priorspec = sv_priors[[mm]]) + sv_draw[[mm]][c("mu", "phi", "sigma")] = as.list(svdraw_mm$para[, c("mu", "phi", "sigma")]) + h_latent[[mm]] = svdraw_mm$latent + sv_params_mat[mm, ] = c(svdraw_mm$para[, c("mu", "phi", "sigma")], svdraw_mm$latent[TT]) + weights = as.numeric(exp(svdraw_mm$latent)) + forest_dataset_train_ls[[mm]] = stochtree::createForestDataset(X, basis = NULL, variance_weights = weights) + H[,mm] = log(sigma_mat[mm]^2) + svdraw_mm$latent + }else{ + H[,mm] = log(sigma_mat[mm]^2) + } + } + + # Updating the Shrinkage priors + + hs_draw = get.hs(bdraw = d.A0[lower.tri(d.A0)],lambda.hs = lambda.A0, nu.hs = nu.A0, + tau.hs = tau.A0, zeta.hs = zeta.A0) + lambda.A0 = hs_draw$lambda + nu.A0 = hs_draw$nu + tau.A0 = hs_draw$tau + zeta.A0 = hs_draw$zeta + prior.cov = hs_draw$psi + + th.A0[lower.tri(th.A0)] = prior.cov + th.A0[th.A0>10] = 10 + th.A0[th.A0<1e-8] = 1e-8 + + + H.store[i,,] = exp(H) + Y.store[i,,] = (Y.fit.BART*t(matrix(Ysd,M,TT)))+t(matrix(Ymu,M,TT)) + dimnames(Y.store) = list(paste0("mcmc",1:num_samples),rownames(Y),colnames(Y)) + iter.update = 250 + setTxtProgressBar(pb, i) + if (i %% iter.update==0){ + end = Sys.time() + message(paste0("\n Average time for single draw over last ",iter.update," draws ", + round(as.numeric(end-start)/iter.update, digits=4), " seconds, currently at draw ", i)) + start = Sys.time() + } + + } + dimnames(Y.store) = list(paste0("mcmc",1:num_samples),rownames(Y),colnames(Y)) + dimnames(H.store) = list(paste0("mcmc",1:num_samples),rownames(Y),colnames(Y)) + return_obj = list("Y.fitted" = Y.store, "H.fitted" = H.store, "Ysd" = Ysd, "Ymu" = Ymu, + "sv.mcmc" = sv_params_mcmc, "A0" = d.A0) + + return(return_obj) +} + + +## Sample HS for a regression. + +get.hs <- function(bdraw,lambda.hs,nu.hs,tau.hs,zeta.hs){ + k <- length(bdraw) + if (is.na(tau.hs)){ + tau.hs <- 1 + }else{ + tau.hs <- invgamma::rinvgamma(1,shape=(k+1)/2,rate=1/zeta.hs+sum(bdraw^2/lambda.hs)/2) + } + + lambda.hs <- invgamma::rinvgamma(k,shape=1,rate=1/nu.hs+bdraw^2/(2*tau.hs)) + + nu.hs <- invgamma::rinvgamma(k,shape=1,rate=1+1/lambda.hs) + zeta.hs <- invgamma::rinvgamma(1,shape=1,rate=1+1/tau.hs) + + ret <- list("psi"=(lambda.hs*tau.hs),"lambda"=lambda.hs,"tau"=tau.hs,"nu"=nu.hs,"zeta"=zeta.hs) + return(ret) +} + + + + + + + + + + + diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index d30159f..263e553 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -79,6 +79,7 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) leaf_model_int, leaf_model_scale, variable_weights, + variable_count_splits, global_scale, cutpoint_grid_size = 500, gfr = T, @@ -105,6 +106,9 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) \item{\code{variable_weights}}{Vector specifying sampling probability for all p covariates in \code{forest_dataset}} +\item{\code{variable_count_splits}}{\subsection{}{ +}} + \item{\code{global_scale}}{Global variance parameter} \item{\code{cutpoint_grid_size}}{(Optional) Number of unique cutpoints to consider (default: 500, currently only used when \code{GFR = TRUE})} diff --git a/man/BART.Rd b/man/bart.Rd similarity index 96% rename from man/BART.Rd rename to man/bart.Rd index fba18fe..77f81e2 100644 --- a/man/BART.Rd +++ b/man/bart.Rd @@ -14,8 +14,6 @@ bart( W_test = NULL, group_ids_test = NULL, rfx_basis_test = NULL, - ordered_cat_vars = NULL, - unordered_cat_vars = NULL, cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, @@ -36,7 +34,9 @@ bart( sample_tau = T, random_seed = -1, keep_burnin = F, - keep_gfr = F + keep_gfr = F, + Sparse = F, + Theta_Update = F ) } \arguments{ @@ -67,10 +67,6 @@ that were not in the training set.} \item{rfx_basis_test}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} -\item{ordered_cat_vars}{Vector of names of ordered categorical variables.} - -\item{unordered_cat_vars}{Vector of names of unordered categorical variables.} - \item{cutpoint_grid_size}{Maximum size of the "grid" of potential cutpoints to consider. Default: 100.} \item{tau_init}{Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here.} @@ -111,6 +107,10 @@ that were not in the training set.} \item{keep_gfr}{Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.} +\item{Sparse}{Whether you want to turn on the dirichilet prior.} + +\item{Theta_Update}{Whether or not update the theta of the dirichilet prior.} + \item{variable_weights}{Vector of length \code{ncol(X_train)} indicating a "weight" placed on each variable for sampling purposes. Default: \code{rep(1/ncol(X_train),ncol(X_train))}.} } diff --git a/man/bcf.Rd b/man/bcf.Rd index 604d872..9cddf3c 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -16,8 +16,6 @@ bcf( pi_test = NULL, group_ids_test = NULL, rfx_basis_test = NULL, - ordered_cat_vars = NULL, - unordered_cat_vars = NULL, cutpoint_grid_size = 100, sigma_leaf_mu = NULL, sigma_leaf_tau = NULL, @@ -53,7 +51,7 @@ bcf( ) } \arguments{ -\item{X_train}{Covariates used to split trees in the ensemble. Can be passed as either a matrix or dataframe.} +\item{X_train}{Covariates used to split trees in the ensemble. Must be passed as a dataframe.} \item{Z_train}{Vector of (continuous or binary) treatment assignments.} @@ -79,10 +77,6 @@ that were not in the training set.} \item{rfx_basis_test}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} -\item{ordered_cat_vars}{Vector of names of ordered categorical variables.} - -\item{unordered_cat_vars}{Vector of names of unordered categorical variables.} - \item{cutpoint_grid_size}{Maximum size of the "grid" of potential cutpoints to consider. Default: 100.} \item{sigma_leaf_mu}{Starting value of leaf node scale parameter for the prognostic forest. Calibrated internally as \code{2/num_trees_mu} if not set here.} @@ -133,7 +127,7 @@ that were not in the training set.} \item{sample_sigma_leaf_tau}{Whether or not to update the \code{sigma_leaf_tau} leaf scale variance parameter in the treatment effect forest based on \code{IG(a_leaf_tau, b_leaf_tau)}. Default: T.} -\item{propensity_covariate}{Whether to include the propensity score as a covariate in either or both of the forests. Enter "none" for neither, "mu" for the prognostic forest, "tau" for the treatment forest, and "both" for both forests. If this is not "none" and a propensity score is not provided, it will be estimated from (\code{X_train}, \code{Z_train}) using \code{xgboost}. Default: "mu".} +\item{propensity_covariate}{Whether to include the propensity score as a covariate in either or both of the forests. Enter "none" for neither, "mu" for the prognostic forest, "tau" for the treatment forest, and "both" for both forests. If this is not "none" and a propensity score is not provided, it will be estimated from (\code{X_train}, \code{Z_train}) using \code{stochtree::bart()}. Default: "mu".} \item{adaptive_coding}{Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters \code{b_0} and \code{b_1} that attach to the outcome model \verb{[b_0 (1-Z) + b_1 Z] tau(X)}. This is ignored when Z is not binary. Default: T.} @@ -174,6 +168,9 @@ Z <- rbinom(n,1,pi_x) E_XZ <- mu_x + Z*tau_x snr <- 4 y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = T) +X$x5 <- factor(X$x5, ordered = T) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test @@ -192,7 +189,7 @@ mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, - X_test = X_test, Z_test = Z_test, pi_test = pi_test, ordered_cat_vars = c(4,5)) + X_test = X_test, Z_test = Z_test, pi_test = pi_test) # plot(rowMeans(bcf_model$mu_hat_test), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") # abline(0,1,col="red",lty=3,lwd=3) # plot(rowMeans(bcf_model$tau_hat_test), tau_test, xlab = "predicted", ylab = "actual", main = "Treatment effect") diff --git a/man/computeForestKernels.Rd b/man/computeForestKernels.Rd index 7421019..ccabc40 100644 --- a/man/computeForestKernels.Rd +++ b/man/computeForestKernels.Rd @@ -11,10 +11,10 @@ computeForestKernels(bart_model, X_train, X_test = NULL, forest_num = NULL) \arguments{ \item{bart_model}{Object of type \code{bartmodel} corresponding to a BART model with at least one sample} -\item{X_train}{Matrix of "training" data. In a traditional Gaussian process kriging context, this +\item{X_train}{"Training" dataframe. In a traditional Gaussian process kriging context, this corresponds to the observations for which outcomes are observed.} -\item{X_test}{(Optional) Matrix of "test" data. In a traditional Gaussian process kriging context, this +\item{X_test}{(Optional) "Test" dataframe. In a traditional Gaussian process kriging context, this corresponds to the observations for which outcomes are unobserved and must be estimated based on the kernels k(X_test,X_test), k(X_test,X_train), and k(X_train,X_train). If not provided, this function will only compute k(X_train, X_train).} diff --git a/man/convertBCFModelToJson.Rd b/man/convertBCFModelToJson.Rd index 775898a..8841b6e 100644 --- a/man/convertBCFModelToJson.Rd +++ b/man/convertBCFModelToJson.Rd @@ -40,6 +40,9 @@ rfx_coefs <- matrix(c(-1, -1, 1, 1),nrow=2,byrow=T) rfx_basis <- cbind(1, runif(n, -1, 1)) rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = T) +X$x5 <- factor(X$x5, ordered = T) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test @@ -67,7 +70,7 @@ bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, - rfx_basis_test = rfx_basis_test, ordered_cat_vars = c(4,5), + rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F) # bcf_json <- convertBCFModelToJson(bcf_model) diff --git a/man/getRandomEffectSamples.bcf.Rd b/man/getRandomEffectSamples.bcf.Rd index 0842ed5..cccf6e5 100644 --- a/man/getRandomEffectSamples.bcf.Rd +++ b/man/getRandomEffectSamples.bcf.Rd @@ -42,6 +42,9 @@ rfx_coefs <- matrix(c(-1, -1, 1, 1),nrow=2,byrow=T) rfx_basis <- cbind(1, runif(n, -1, 1)) rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = T) +X$x5 <- factor(X$x5, ordered = T) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test @@ -69,7 +72,7 @@ bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, - rfx_basis_test = rfx_basis_test, ordered_cat_vars = c(4,5), + rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F) rfx_samples <- getRandomEffectSamples(bcf_model) diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index 3b3f634..1261394 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -16,7 +16,7 @@ \arguments{ \item{bart}{Object of type \code{bart} containing draws of a regression forest and associated sampling outputs.} -\item{X_test}{Covariates used to determine tree leaf predictions for each observation.} +\item{X_test}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a dataframe.} \item{W_test}{(Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: \code{NULL}.} diff --git a/man/predict.bcf.Rd b/man/predict.bcf.Rd index 74e78a9..00aaa57 100644 --- a/man/predict.bcf.Rd +++ b/man/predict.bcf.Rd @@ -17,7 +17,7 @@ \arguments{ \item{bcf}{Object of type \code{bcf} containing draws of a Bayesian causal forest model and associated sampling outputs.} -\item{X_test}{Covariates used to determine tree leaf predictions for each observation.} +\item{X_test}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a dataframe.} \item{Z_test}{Treatments used for prediction.} @@ -58,6 +58,9 @@ Z <- rbinom(n,1,pi_x) E_XZ <- mu_x + Z*tau_x snr <- 4 y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = T) +X$x5 <- factor(X$x5, ordered = T) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test @@ -75,7 +78,7 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, ordered_cat_vars = c(4,5)) +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train) preds <- predict(bcf_model, X_test, Z_test, pi_test) # plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") # abline(0,1,col="red",lty=3,lwd=3) diff --git a/man/preprocessPredictionDataFrame.Rd b/man/preprocessPredictionDataFrame.Rd new file mode 100644 index 0000000..e4c77ab --- /dev/null +++ b/man/preprocessPredictionDataFrame.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{preprocessPredictionDataFrame} +\alias{preprocessPredictionDataFrame} +\title{Preprocess a dataframe of covariate values, converting categorical variables +to integers and one-hot encoding if need be. Returns a list including a +matrix of preprocessed covariate values and associated tracking.} +\usage{ +preprocessPredictionDataFrame(input_df, metadata) +} +\arguments{ +\item{input_df}{Dataframe of covariates. Users must pre-process any +categorical variables as factors (ordered for ordered categorical).} + +\item{metadata}{List containing information on variables, including train set +categories for categorical variables} +} +\value{ +Preprocessed data with categorical variables appropriately preprocessed +} +\description{ +Preprocess a dataframe of covariate values, converting categorical variables +to integers and one-hot encoding if need be. Returns a list including a +matrix of preprocessed covariate values and associated tracking. +} +\examples{ +cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) +metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, num_numeric_vars = 3) +X_preprocessed <- preprocessPredictionDataFrame(cov_df, metadata) +} diff --git a/man/preprocessTrainDataFrame.Rd b/man/preprocessTrainDataFrame.Rd new file mode 100644 index 0000000..66cbe41 --- /dev/null +++ b/man/preprocessTrainDataFrame.Rd @@ -0,0 +1,29 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{preprocessTrainDataFrame} +\alias{preprocessTrainDataFrame} +\title{Preprocess a dataframe of covariate values, converting categorical variables +to integers and one-hot encoding if need be. Returns a list including a +matrix of preprocessed covariate values and associated tracking.} +\usage{ +preprocessTrainDataFrame(input_df) +} +\arguments{ +\item{input_df}{Dataframe of covariates. Users must pre-process any +categorical variables as factors (ordered for ordered categorical).} +} +\value{ +List with preprocessed data and details on the number of each type +of variable, unique categories associated with categorical variables, and the +vector of feature types needed for calls to BART and BCF. +} +\description{ +Preprocess a dataframe of covariate values, converting categorical variables +to integers and one-hot encoding if need be. Returns a list including a +matrix of preprocessed covariate values and associated tracking. +} +\examples{ +cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) +preprocess_list <- preprocessTrainDataFrame(cov_df) +X <- preprocess_list$X +} diff --git a/man/varbart.Rd b/man/varbart.Rd new file mode 100644 index 0000000..04768ce --- /dev/null +++ b/man/varbart.Rd @@ -0,0 +1,109 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/varbart.R +\name{varbart} +\alias{varbart} +\title{Title} +\usage{ +varbart( + y_train, + X_train, + sv = FALSE, + W_train = NULL, + group_ids_train = NULL, + rfx_basis_train = NULL, + X_test = NULL, + W_test = NULL, + group_ids_test = NULL, + rfx_basis_test = NULL, + ordered_cat_vars = NULL, + unordered_cat_vars = NULL, + cutpoint_grid_size = 100, + tau_init = NULL, + alpha = 0.95, + beta = 2, + min_samples_leaf = 5, + output_dimension = 1, + is_leaf_constant = T, + leaf_regression = F, + nu = 3, + lambda = NULL, + a_leaf = 3, + b_leaf = NULL, + q = 0.9, + sigma2_init = NULL, + num_trees = 250, + num_gfr = 0, + num_burnin = 0, + num_mcmc = 100, + random_seed = -1, + keep_burnin = F +) +} +\arguments{ +\item{y_train}{} + +\item{X_train}{} + +\item{sv}{} + +\item{W_train}{} + +\item{group_ids_train}{} + +\item{rfx_basis_train}{} + +\item{X_test}{} + +\item{W_test}{} + +\item{group_ids_test}{} + +\item{rfx_basis_test}{} + +\item{ordered_cat_vars}{} + +\item{unordered_cat_vars}{} + +\item{cutpoint_grid_size}{} + +\item{tau_init}{} + +\item{alpha}{} + +\item{beta}{} + +\item{min_samples_leaf}{} + +\item{output_dimension}{} + +\item{is_leaf_constant}{} + +\item{leaf_regression}{} + +\item{nu}{} + +\item{lambda}{} + +\item{a_leaf}{} + +\item{b_leaf}{} + +\item{q}{} + +\item{sigma2_init}{} + +\item{num_trees}{} + +\item{num_gfr}{} + +\item{num_burnin}{} + +\item{num_mcmc}{} + +\item{random_seed}{} + +\item{keep_burnin}{} +} +\description{ +Title +} diff --git a/src/cpp11.cpp b/src/cpp11.cpp index dcfdad2..37d6db3 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -505,11 +505,10 @@ extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual END_CPP11 } // sampler.cpp -void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double global_variance, int leaf_model_int, bool pre_initialized); -extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP global_variance, SEXP leaf_model_int, SEXP pre_initialized) { +cpp11::writable::integers sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, cpp11::writable::integers variable_count_splits, double global_variance, int leaf_model_int, bool pre_initialized); +extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP variable_count_splits, SEXP global_variance, SEXP leaf_model_int, SEXP pre_initialized) { BEGIN_CPP11 - sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(pre_initialized)); - return R_NilValue; + return cpp11::as_sexp(sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(variable_count_splits), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(pre_initialized))); END_CPP11 } // sampler.cpp @@ -870,7 +869,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 13}, - {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 13}, + {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 14}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 4}, {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 5}, {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, diff --git a/src/sampler.cpp b/src/sampler.cpp index 3347e91..e53dee4 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -75,7 +75,7 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, +cpp11::writable::integers sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer tracker, @@ -83,7 +83,8 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, - cpp11::doubles variable_weights, + cpp11::doubles variable_weights, + cpp11::writable::integers variable_count_splits, double global_variance, int leaf_model_int, bool pre_initialized = false ) { @@ -122,20 +123,32 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer variable_count_splits_cpp(variable_count_splits.begin(), variable_count_splits.end()); + // This should be better dealwith internally, like: + //std::vector variable_count_splits(variable_weights.size()); + + // Run one iteration of the sampler if (leaf_model_enum == ForestLeafModel::kConstant) { StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(leaf_scale); StochTree::MCMCForestSampler sampler = StochTree::MCMCForestSampler(); - sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, variable_count_splits_cpp, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kUnivariateRegression) { StochTree::GaussianUnivariateRegressionLeafModel leaf_model = StochTree::GaussianUnivariateRegressionLeafModel(leaf_scale); StochTree::MCMCForestSampler sampler = StochTree::MCMCForestSampler(); - sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, variable_count_splits_cpp, global_variance, pre_initialized); } else if (leaf_model_enum == ForestLeafModel::kMultivariateRegression) { StochTree::GaussianMultivariateRegressionLeafModel leaf_model = StochTree::GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); StochTree::MCMCForestSampler sampler = StochTree::MCMCForestSampler(); - sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, global_variance, pre_initialized); + sampler.SampleOneIter(*tracker, *forest_samples, leaf_model, *data, *residual, *split_prior, *rng, var_weights_vector, variable_count_splits_cpp, global_variance, pre_initialized); } + + for(int i =0; i