Skip to content

Commit

Permalink
Updated interface
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewherren committed May 9, 2024
1 parent 20c3b4f commit 0556458
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 6 deletions.
4 changes: 2 additions & 2 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
"num_samples" = num_samples,
"has_basis" = !is.null(W_train),
"has_rfx" = has_rfx,
"has_basis_rfx" = has_basis_rfx,
"num_basis_rfx" = num_basis_rfx
"has_rfx_basis" = has_basis_rfx,
"num_rfx_basis" = num_basis_rfx
)
result <- list(
"forests" = forest_samples,
Expand Down
26 changes: 22 additions & 4 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
nu = 3, lambda = NULL, a_leaf_mu = 3, a_leaf_tau = 3, b_leaf_mu = NULL, b_leaf_tau = NULL,
q = 0.9, sigma2 = NULL, num_trees_mu = 250, num_trees_tau = 50, num_gfr = 5,
num_burnin = 0, num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T,
sample_sigma_leaf_tau = T, propensity_covariate = "mu", adaptive_coding = T,
sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T,
b_0 = -0.5, b_1 = 0.5, random_seed = -1) {
# Convert all input data to matrices if not already converted
if ((is.null(dim(X_train))) && (!is.null(X_train))) {
Expand Down Expand Up @@ -414,6 +414,10 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
mu_x_raw_train <- forest_samples_mu$predict_raw_single_forest(forest_dataset_mu_train, i-1)
tau_x_raw_train <- forest_samples_tau$predict_raw_single_forest(forest_dataset_tau_train, i-1)
partial_resid_mu_train <- resid_train - mu_x_raw_train
if (has_rfx) {
rfx_preds_train <- rfx_model$predict(rfx_dataset_train, rfx_tracker_train)
partial_resid_mu_train <- partial_resid_mu_train - rfx_preds_train
}

# Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z]
s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0))
Expand Down Expand Up @@ -484,15 +488,26 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU

# Sample coding parameters (if requested)
if (adaptive_coding) {
# Estimate mu(X) and tau(X) and compute y - mu(X)
mu_x_raw_train <- forest_samples_mu$predict_raw_single_forest(forest_dataset_mu_train, i-1)
tau_x_raw_train <- forest_samples_tau$predict_raw_single_forest(forest_dataset_tau_train, i-1)
partial_resid_mu_train <- resid_train - mu_x_raw_train
if (has_rfx) {
rfx_preds_train <- rfx_model$predict(rfx_dataset_train, rfx_tracker_train)
partial_resid_mu_train <- partial_resid_mu_train - rfx_preds_train
}

# Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z]
s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0))
s_tt1 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==1))
partial_resid_mu_train <- resid_train - mu_x_raw_train
s_ty0 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==0))
s_ty1 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==1))

# Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z)
current_b_0 <- rnorm(1, (s_ty0/(s_tt0 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt0 + 2*current_sigma2)))
current_b_1 <- rnorm(1, (s_ty1/(s_tt1 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt1 + 2*current_sigma2)))

# Update basis for the leaf regression
tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1
forest_dataset_tau_train$update_basis(tau_basis_train)
b_0_samples[i] <- current_b_0
Expand All @@ -501,6 +516,9 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1
forest_dataset_tau_test$update_basis(tau_basis_test)
}

# TODO Update leaf predictions and residual

}

# Sample variance parameters (if requested)
Expand Down Expand Up @@ -583,8 +601,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
"adaptive_coding" = adaptive_coding,
"num_samples" = num_samples,
"has_rfx" = has_rfx,
"has_basis_rfx" = has_basis_rfx,
"num_basis_rfx" = num_basis_rfx
"has_rfx_basis" = has_basis_rfx,
"num_rfx_basis" = num_basis_rfx
)
result <- list(
"forests_mu" = forest_samples_mu,
Expand Down
4 changes: 4 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ rfx_model_sample_random_effects_cpp <- function(rfx_model, rfx_dataset, residual
invisible(.Call(`_stochtree_rfx_model_sample_random_effects_cpp`, rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, global_variance, rng))
}

rfx_model_predict_cpp <- function(rfx_model, rfx_dataset, rfx_tracker) {
.Call(`_stochtree_rfx_model_predict_cpp`, rfx_model, rfx_dataset, rfx_tracker)
}

rfx_container_predict_cpp <- function(rfx_container, rfx_dataset, label_mapper) {
.Call(`_stochtree_rfx_container_predict_cpp`, rfx_container, rfx_dataset, label_mapper)
}
Expand Down
10 changes: 10 additions & 0 deletions R/random_effects.R
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,16 @@ RandomEffectsModel <- R6::R6Class(
rfx_samples$rfx_container_ptr, global_variance, rng$rng_ptr)
},

#' @description
#' Predict from (a single sample of a) random effects model.
#' @param rfx_dataset Object of type `RandomEffectsDataset`
#' @param rfx_tracker Object of type `RandomEffectsTracker`
#' @return Vector of predictions with size matching number of observations in rfx_dataset
predict = function(rfx_dataset, rfx_tracker) {
pred <- rfx_model_predict_cpp(self$rfx_model_ptr, rfx_dataset$data_ptr, rfx_tracker$rfx_tracker_ptr)
return(pred)
},

#' @description
#' Set value for the "working parameter." This is typically
#' used for initialization, but could also be used to interrupt
Expand Down
23 changes: 23 additions & 0 deletions man/RandomEffectsModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,13 @@ extern "C" SEXP _stochtree_rfx_model_sample_random_effects_cpp(SEXP rfx_model, S
END_CPP11
}
// random_effects.cpp
cpp11::writable::doubles rfx_model_predict_cpp(cpp11::external_pointer<StochTree::MultivariateRegressionRandomEffectsModel> rfx_model, cpp11::external_pointer<StochTree::RandomEffectsDataset> rfx_dataset, cpp11::external_pointer<StochTree::RandomEffectsTracker> rfx_tracker);
extern "C" SEXP _stochtree_rfx_model_predict_cpp(SEXP rfx_model, SEXP rfx_dataset, SEXP rfx_tracker) {
BEGIN_CPP11
return cpp11::as_sexp(rfx_model_predict_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::MultivariateRegressionRandomEffectsModel>>>(rfx_model), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::RandomEffectsDataset>>>(rfx_dataset), cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::RandomEffectsTracker>>>(rfx_tracker)));
END_CPP11
}
// random_effects.cpp
cpp11::writable::doubles rfx_container_predict_cpp(cpp11::external_pointer<StochTree::RandomEffectsContainer> rfx_container, cpp11::external_pointer<StochTree::RandomEffectsDataset> rfx_dataset, cpp11::external_pointer<StochTree::LabelMapper> label_mapper);
extern "C" SEXP _stochtree_rfx_container_predict_cpp(SEXP rfx_container, SEXP rfx_dataset, SEXP label_mapper) {
BEGIN_CPP11
Expand Down Expand Up @@ -740,6 +747,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_stochtree_rfx_label_mapper_from_json_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_cpp, 2},
{"_stochtree_rfx_label_mapper_to_list_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_to_list_cpp, 1},
{"_stochtree_rfx_model_cpp", (DL_FUNC) &_stochtree_rfx_model_cpp, 2},
{"_stochtree_rfx_model_predict_cpp", (DL_FUNC) &_stochtree_rfx_model_predict_cpp, 3},
{"_stochtree_rfx_model_sample_random_effects_cpp", (DL_FUNC) &_stochtree_rfx_model_sample_random_effects_cpp, 7},
{"_stochtree_rfx_model_set_group_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameter_covariance_cpp, 2},
{"_stochtree_rfx_model_set_group_parameters_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameters_cpp, 2},
Expand Down
8 changes: 8 additions & 0 deletions src/random_effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ void rfx_model_sample_random_effects_cpp(cpp11::external_pointer<StochTree::Mult
rfx_container->AddSample(*rfx_model);
}

[[cpp11::register]]
cpp11::writable::doubles rfx_model_predict_cpp(cpp11::external_pointer<StochTree::MultivariateRegressionRandomEffectsModel> rfx_model,
cpp11::external_pointer<StochTree::RandomEffectsDataset> rfx_dataset,
cpp11::external_pointer<StochTree::RandomEffectsTracker> rfx_tracker) {
std::vector<double> output = rfx_model->Predict(*rfx_dataset, *rfx_tracker);
return output;
}

[[cpp11::register]]
cpp11::writable::doubles rfx_container_predict_cpp(cpp11::external_pointer<StochTree::RandomEffectsContainer> rfx_container,
cpp11::external_pointer<StochTree::RandomEffectsDataset> rfx_dataset,
Expand Down

0 comments on commit 0556458

Please sign in to comment.