diff --git a/R/bart.R b/R/bart.R index 91c4ef4..c2e3935 100644 --- a/R/bart.R +++ b/R/bart.R @@ -359,8 +359,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, "num_samples" = num_samples, "has_basis" = !is.null(W_train), "has_rfx" = has_rfx, - "has_basis_rfx" = has_basis_rfx, - "num_basis_rfx" = num_basis_rfx + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx ) result <- list( "forests" = forest_samples, diff --git a/R/bcf.R b/R/bcf.R index e3ebfc0..c864268 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -104,7 +104,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU nu = 3, lambda = NULL, a_leaf_mu = 3, a_leaf_tau = 3, b_leaf_mu = NULL, b_leaf_tau = NULL, q = 0.9, sigma2 = NULL, num_trees_mu = 250, num_trees_tau = 50, num_gfr = 5, num_burnin = 0, num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T, - sample_sigma_leaf_tau = T, propensity_covariate = "mu", adaptive_coding = T, + sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5, random_seed = -1) { # Convert all input data to matrices if not already converted if ((is.null(dim(X_train))) && (!is.null(X_train))) { @@ -414,6 +414,10 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU mu_x_raw_train <- forest_samples_mu$predict_raw_single_forest(forest_dataset_mu_train, i-1) tau_x_raw_train <- forest_samples_tau$predict_raw_single_forest(forest_dataset_tau_train, i-1) partial_resid_mu_train <- resid_train - mu_x_raw_train + if (has_rfx) { + rfx_preds_train <- rfx_model$predict(rfx_dataset_train, rfx_tracker_train) + partial_resid_mu_train <- partial_resid_mu_train - rfx_preds_train + } # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0)) @@ -484,15 +488,26 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Sample coding parameters (if requested) if (adaptive_coding) { + # Estimate mu(X) and tau(X) and compute y - mu(X) mu_x_raw_train <- forest_samples_mu$predict_raw_single_forest(forest_dataset_mu_train, i-1) tau_x_raw_train <- forest_samples_tau$predict_raw_single_forest(forest_dataset_tau_train, i-1) + partial_resid_mu_train <- resid_train - mu_x_raw_train + if (has_rfx) { + rfx_preds_train <- rfx_model$predict(rfx_dataset_train, rfx_tracker_train) + partial_resid_mu_train <- partial_resid_mu_train - rfx_preds_train + } + + # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0)) s_tt1 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==1)) - partial_resid_mu_train <- resid_train - mu_x_raw_train s_ty0 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==0)) s_ty1 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==1)) + + # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) current_b_0 <- rnorm(1, (s_ty0/(s_tt0 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt0 + 2*current_sigma2))) current_b_1 <- rnorm(1, (s_ty1/(s_tt1 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt1 + 2*current_sigma2))) + + # Update basis for the leaf regression tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 forest_dataset_tau_train$update_basis(tau_basis_train) b_0_samples[i] <- current_b_0 @@ -501,6 +516,9 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 forest_dataset_tau_test$update_basis(tau_basis_test) } + + # TODO Update leaf predictions and residual + } # Sample variance parameters (if requested) @@ -583,8 +601,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU "adaptive_coding" = adaptive_coding, "num_samples" = num_samples, "has_rfx" = has_rfx, - "has_basis_rfx" = has_basis_rfx, - "num_basis_rfx" = num_basis_rfx + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx ) result <- list( "forests_mu" = forest_samples_mu, diff --git a/R/cpp11.R b/R/cpp11.R index f8f93f6..8d51619 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -200,6 +200,10 @@ rfx_model_sample_random_effects_cpp <- function(rfx_model, rfx_dataset, residual invisible(.Call(`_stochtree_rfx_model_sample_random_effects_cpp`, rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, global_variance, rng)) } +rfx_model_predict_cpp <- function(rfx_model, rfx_dataset, rfx_tracker) { + .Call(`_stochtree_rfx_model_predict_cpp`, rfx_model, rfx_dataset, rfx_tracker) +} + rfx_container_predict_cpp <- function(rfx_container, rfx_dataset, label_mapper) { .Call(`_stochtree_rfx_container_predict_cpp`, rfx_container, rfx_dataset, label_mapper) } diff --git a/R/random_effects.R b/R/random_effects.R index b9c2f88..938bca4 100644 --- a/R/random_effects.R +++ b/R/random_effects.R @@ -198,6 +198,16 @@ RandomEffectsModel <- R6::R6Class( rfx_samples$rfx_container_ptr, global_variance, rng$rng_ptr) }, + #' @description + #' Predict from (a single sample of a) random effects model. + #' @param rfx_dataset Object of type `RandomEffectsDataset` + #' @param rfx_tracker Object of type `RandomEffectsTracker` + #' @return Vector of predictions with size matching number of observations in rfx_dataset + predict = function(rfx_dataset, rfx_tracker) { + pred <- rfx_model_predict_cpp(self$rfx_model_ptr, rfx_dataset$data_ptr, rfx_tracker$rfx_tracker_ptr) + return(pred) + }, + #' @description #' Set value for the "working parameter." This is typically #' used for initialization, but could also be used to interrupt diff --git a/man/RandomEffectsModel.Rd b/man/RandomEffectsModel.Rd index 9f6c1fa..41c8a03 100644 --- a/man/RandomEffectsModel.Rd +++ b/man/RandomEffectsModel.Rd @@ -23,6 +23,7 @@ sampling from the conditional posterior of each parameter. \itemize{ \item \href{#method-RandomEffectsModel-new}{\code{RandomEffectsModel$new()}} \item \href{#method-RandomEffectsModel-sample_random_effect}{\code{RandomEffectsModel$sample_random_effect()}} +\item \href{#method-RandomEffectsModel-predict}{\code{RandomEffectsModel$predict()}} \item \href{#method-RandomEffectsModel-set_working_parameter}{\code{RandomEffectsModel$set_working_parameter()}} \item \href{#method-RandomEffectsModel-set_group_parameters}{\code{RandomEffectsModel$set_group_parameters()}} \item \href{#method-RandomEffectsModel-set_working_parameter_cov}{\code{RandomEffectsModel$set_working_parameter_cov()}} @@ -91,6 +92,28 @@ None } } \if{html}{\out{