diff --git a/R/bart.R b/R/bart.R index 91c4ef4..c2e3935 100644 --- a/R/bart.R +++ b/R/bart.R @@ -359,8 +359,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, "num_samples" = num_samples, "has_basis" = !is.null(W_train), "has_rfx" = has_rfx, - "has_basis_rfx" = has_basis_rfx, - "num_basis_rfx" = num_basis_rfx + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx ) result <- list( "forests" = forest_samples, diff --git a/R/bcf.R b/R/bcf.R index e3ebfc0..c864268 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -104,7 +104,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU nu = 3, lambda = NULL, a_leaf_mu = 3, a_leaf_tau = 3, b_leaf_mu = NULL, b_leaf_tau = NULL, q = 0.9, sigma2 = NULL, num_trees_mu = 250, num_trees_tau = 50, num_gfr = 5, num_burnin = 0, num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T, - sample_sigma_leaf_tau = T, propensity_covariate = "mu", adaptive_coding = T, + sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5, random_seed = -1) { # Convert all input data to matrices if not already converted if ((is.null(dim(X_train))) && (!is.null(X_train))) { @@ -414,6 +414,10 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU mu_x_raw_train <- forest_samples_mu$predict_raw_single_forest(forest_dataset_mu_train, i-1) tau_x_raw_train <- forest_samples_tau$predict_raw_single_forest(forest_dataset_tau_train, i-1) partial_resid_mu_train <- resid_train - mu_x_raw_train + if (has_rfx) { + rfx_preds_train <- rfx_model$predict(rfx_dataset_train, rfx_tracker_train) + partial_resid_mu_train <- partial_resid_mu_train - rfx_preds_train + } # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0)) @@ -484,15 +488,26 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU # Sample coding parameters (if requested) if (adaptive_coding) { + # Estimate mu(X) and tau(X) and compute y - mu(X) mu_x_raw_train <- forest_samples_mu$predict_raw_single_forest(forest_dataset_mu_train, i-1) tau_x_raw_train <- forest_samples_tau$predict_raw_single_forest(forest_dataset_tau_train, i-1) + partial_resid_mu_train <- resid_train - mu_x_raw_train + if (has_rfx) { + rfx_preds_train <- rfx_model$predict(rfx_dataset_train, rfx_tracker_train) + partial_resid_mu_train <- partial_resid_mu_train - rfx_preds_train + } + + # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0)) s_tt1 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==1)) - partial_resid_mu_train <- resid_train - mu_x_raw_train s_ty0 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==0)) s_ty1 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==1)) + + # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) current_b_0 <- rnorm(1, (s_ty0/(s_tt0 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt0 + 2*current_sigma2))) current_b_1 <- rnorm(1, (s_ty1/(s_tt1 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt1 + 2*current_sigma2))) + + # Update basis for the leaf regression tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 forest_dataset_tau_train$update_basis(tau_basis_train) b_0_samples[i] <- current_b_0 @@ -501,6 +516,9 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 forest_dataset_tau_test$update_basis(tau_basis_test) } + + # TODO Update leaf predictions and residual + } # Sample variance parameters (if requested) @@ -583,8 +601,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU "adaptive_coding" = adaptive_coding, "num_samples" = num_samples, "has_rfx" = has_rfx, - "has_basis_rfx" = has_basis_rfx, - "num_basis_rfx" = num_basis_rfx + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx ) result <- list( "forests_mu" = forest_samples_mu, diff --git a/R/cpp11.R b/R/cpp11.R index f8f93f6..8d51619 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -200,6 +200,10 @@ rfx_model_sample_random_effects_cpp <- function(rfx_model, rfx_dataset, residual invisible(.Call(`_stochtree_rfx_model_sample_random_effects_cpp`, rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, global_variance, rng)) } +rfx_model_predict_cpp <- function(rfx_model, rfx_dataset, rfx_tracker) { + .Call(`_stochtree_rfx_model_predict_cpp`, rfx_model, rfx_dataset, rfx_tracker) +} + rfx_container_predict_cpp <- function(rfx_container, rfx_dataset, label_mapper) { .Call(`_stochtree_rfx_container_predict_cpp`, rfx_container, rfx_dataset, label_mapper) } diff --git a/R/random_effects.R b/R/random_effects.R index b9c2f88..938bca4 100644 --- a/R/random_effects.R +++ b/R/random_effects.R @@ -198,6 +198,16 @@ RandomEffectsModel <- R6::R6Class( rfx_samples$rfx_container_ptr, global_variance, rng$rng_ptr) }, + #' @description + #' Predict from (a single sample of a) random effects model. + #' @param rfx_dataset Object of type `RandomEffectsDataset` + #' @param rfx_tracker Object of type `RandomEffectsTracker` + #' @return Vector of predictions with size matching number of observations in rfx_dataset + predict = function(rfx_dataset, rfx_tracker) { + pred <- rfx_model_predict_cpp(self$rfx_model_ptr, rfx_dataset$data_ptr, rfx_tracker$rfx_tracker_ptr) + return(pred) + }, + #' @description #' Set value for the "working parameter." This is typically #' used for initialization, but could also be used to interrupt diff --git a/man/RandomEffectsModel.Rd b/man/RandomEffectsModel.Rd index 9f6c1fa..41c8a03 100644 --- a/man/RandomEffectsModel.Rd +++ b/man/RandomEffectsModel.Rd @@ -23,6 +23,7 @@ sampling from the conditional posterior of each parameter. \itemize{ \item \href{#method-RandomEffectsModel-new}{\code{RandomEffectsModel$new()}} \item \href{#method-RandomEffectsModel-sample_random_effect}{\code{RandomEffectsModel$sample_random_effect()}} +\item \href{#method-RandomEffectsModel-predict}{\code{RandomEffectsModel$predict()}} \item \href{#method-RandomEffectsModel-set_working_parameter}{\code{RandomEffectsModel$set_working_parameter()}} \item \href{#method-RandomEffectsModel-set_group_parameters}{\code{RandomEffectsModel$set_group_parameters()}} \item \href{#method-RandomEffectsModel-set_working_parameter_cov}{\code{RandomEffectsModel$set_working_parameter_cov()}} @@ -91,6 +92,28 @@ None } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RandomEffectsModel-predict}{}}} +\subsection{Method \code{predict()}}{ +Predict from (a single sample of a) random effects model. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RandomEffectsModel$predict(rfx_dataset, rfx_tracker)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{rfx_dataset}}{Object of type \code{RandomEffectsDataset}} + +\item{\code{rfx_tracker}}{Object of type \code{RandomEffectsTracker}} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Vector of predictions with size matching number of observations in rfx_dataset +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-RandomEffectsModel-set_working_parameter}{}}} \subsection{Method \code{set_working_parameter()}}{ diff --git a/src/cpp11.cpp b/src/cpp11.cpp index a875c03..6cba548 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -372,6 +372,13 @@ extern "C" SEXP _stochtree_rfx_model_sample_random_effects_cpp(SEXP rfx_model, S END_CPP11 } // random_effects.cpp +cpp11::writable::doubles rfx_model_predict_cpp(cpp11::external_pointer rfx_model, cpp11::external_pointer rfx_dataset, cpp11::external_pointer rfx_tracker); +extern "C" SEXP _stochtree_rfx_model_predict_cpp(SEXP rfx_model, SEXP rfx_dataset, SEXP rfx_tracker) { + BEGIN_CPP11 + return cpp11::as_sexp(rfx_model_predict_cpp(cpp11::as_cpp>>(rfx_model), cpp11::as_cpp>>(rfx_dataset), cpp11::as_cpp>>(rfx_tracker))); + END_CPP11 +} +// random_effects.cpp cpp11::writable::doubles rfx_container_predict_cpp(cpp11::external_pointer rfx_container, cpp11::external_pointer rfx_dataset, cpp11::external_pointer label_mapper); extern "C" SEXP _stochtree_rfx_container_predict_cpp(SEXP rfx_container, SEXP rfx_dataset, SEXP label_mapper) { BEGIN_CPP11 @@ -740,6 +747,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_label_mapper_from_json_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_cpp, 2}, {"_stochtree_rfx_label_mapper_to_list_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_to_list_cpp, 1}, {"_stochtree_rfx_model_cpp", (DL_FUNC) &_stochtree_rfx_model_cpp, 2}, + {"_stochtree_rfx_model_predict_cpp", (DL_FUNC) &_stochtree_rfx_model_predict_cpp, 3}, {"_stochtree_rfx_model_sample_random_effects_cpp", (DL_FUNC) &_stochtree_rfx_model_sample_random_effects_cpp, 7}, {"_stochtree_rfx_model_set_group_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameter_covariance_cpp, 2}, {"_stochtree_rfx_model_set_group_parameters_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameters_cpp, 2}, diff --git a/src/random_effects.cpp b/src/random_effects.cpp index 3f73cee..463b8e9 100644 --- a/src/random_effects.cpp +++ b/src/random_effects.cpp @@ -107,6 +107,14 @@ void rfx_model_sample_random_effects_cpp(cpp11::external_pointerAddSample(*rfx_model); } +[[cpp11::register]] +cpp11::writable::doubles rfx_model_predict_cpp(cpp11::external_pointer rfx_model, + cpp11::external_pointer rfx_dataset, + cpp11::external_pointer rfx_tracker) { + std::vector output = rfx_model->Predict(*rfx_dataset, *rfx_tracker); + return output; +} + [[cpp11::register]] cpp11::writable::doubles rfx_container_predict_cpp(cpp11::external_pointer rfx_container, cpp11::external_pointer rfx_dataset,