Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate core BCF computation to C++ #9

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
490 changes: 200 additions & 290 deletions R/bcf.R

Large diffs are not rendered by default.

52 changes: 52 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,57 @@
# Generated by cpp11: do not edit by hand

bcf_init_cpp <- function(univariate_treatment) {
.Call(`_stochtree_bcf_init_cpp`, univariate_treatment)
}

bcf_add_train_with_weights_cpp <- function(bcf_wrapper, X_train_mu, X_train_tau, Z_train, y_train, weights_train, treatment_binary) {
invisible(.Call(`_stochtree_bcf_add_train_with_weights_cpp`, bcf_wrapper, X_train_mu, X_train_tau, Z_train, y_train, weights_train, treatment_binary))
}

bcf_add_train_no_weights_cpp <- function(bcf_wrapper, X_train_mu, X_train_tau, Z_train, y_train, treatment_binary) {
invisible(.Call(`_stochtree_bcf_add_train_no_weights_cpp`, bcf_wrapper, X_train_mu, X_train_tau, Z_train, y_train, treatment_binary))
}

bcf_add_test_cpp <- function(bcf_wrapper, X_test_mu, X_test_tau, Z_test) {
invisible(.Call(`_stochtree_bcf_add_test_cpp`, bcf_wrapper, X_test_mu, X_test_tau, Z_test))
}

bcf_reset_global_var_samples_cpp <- function(bcf_wrapper, data_vector) {
invisible(.Call(`_stochtree_bcf_reset_global_var_samples_cpp`, bcf_wrapper, data_vector))
}

bcf_reset_prognostic_leaf_var_samples_cpp <- function(bcf_wrapper, data_vector) {
invisible(.Call(`_stochtree_bcf_reset_prognostic_leaf_var_samples_cpp`, bcf_wrapper, data_vector))
}

bcf_reset_treatment_leaf_var_samples_cpp <- function(bcf_wrapper, data_vector) {
invisible(.Call(`_stochtree_bcf_reset_treatment_leaf_var_samples_cpp`, bcf_wrapper, data_vector))
}

bcf_reset_treatment_coding_samples_cpp <- function(bcf_wrapper, data_vector) {
invisible(.Call(`_stochtree_bcf_reset_treatment_coding_samples_cpp`, bcf_wrapper, data_vector))
}

bcf_reset_control_coding_samples_cpp <- function(bcf_wrapper, data_vector) {
invisible(.Call(`_stochtree_bcf_reset_control_coding_samples_cpp`, bcf_wrapper, data_vector))
}

bcf_reset_train_prediction_samples_cpp <- function(bcf_wrapper, muhat, tauhat, yhat, num_obs, num_samples, treatment_dim) {
invisible(.Call(`_stochtree_bcf_reset_train_prediction_samples_cpp`, bcf_wrapper, muhat, tauhat, yhat, num_obs, num_samples, treatment_dim))
}

bcf_reset_test_prediction_samples_cpp <- function(bcf_wrapper, muhat, tauhat, yhat, num_obs, num_samples, treatment_dim) {
invisible(.Call(`_stochtree_bcf_reset_test_prediction_samples_cpp`, bcf_wrapper, muhat, tauhat, yhat, num_obs, num_samples, treatment_dim))
}

sample_bcf_univariate_cpp <- function(bcf_wrapper, forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu_int, feature_types_tau_int, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau) {
invisible(.Call(`_stochtree_sample_bcf_univariate_cpp`, bcf_wrapper, forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu_int, feature_types_tau_int, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau))
}

sample_bcf_multivariate_cpp <- function(bcf_wrapper, forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau_r, alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu_int, feature_types_tau_int, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau) {
invisible(.Call(`_stochtree_sample_bcf_multivariate_cpp`, bcf_wrapper, forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau_r, alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu_int, feature_types_tau_int, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau))
}

create_forest_dataset_cpp <- function() {
.Call(`_stochtree_create_forest_dataset_cpp`)
}
Expand Down
1 change: 1 addition & 0 deletions src/Makevars
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ CPP_PKGROOT=stochtree-cpp
PKG_CPPFLAGS= -I$(CPP_PKGROOT)/include -I$(CPP_PKGROOT)/dependencies/boost_math/include -I$(CPP_PKGROOT)/dependencies/eigen

OBJECTS = \
bcf.o \
data.o \
predictor.o \
sampler.o \
Expand Down
303 changes: 303 additions & 0 deletions src/bcf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
#include <cpp11.hpp>
#include "stochtree_types.h"
#include <stochtree/cpp_api.h>
#include <stochtree/leaf_model.h>
#include <functional>
#include <memory>
#include <variant>
#include <vector>

[[cpp11::register]]
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_init_cpp(bool univariate_treatment = true) {
std::unique_ptr<StochTree::BCFModelWrapper> bcf_ptr = std::make_unique<StochTree::BCFModelWrapper>(univariate_treatment);
return cpp11::external_pointer<StochTree::BCFModelWrapper>(bcf_ptr.release());
}

[[cpp11::register]]
void bcf_add_train_with_weights_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper, cpp11::doubles_matrix<> X_train_mu,
cpp11::doubles_matrix<> X_train_tau, cpp11::doubles_matrix<> Z_train,
cpp11::doubles y_train, cpp11::doubles weights_train, bool treatment_binary
) {
// Data dimensions
int n = X_train_mu.nrow();
int X_train_mu_cols = X_train_mu.ncol();
int X_train_tau_cols = X_train_tau.ncol();
int Z_train_cols = Z_train.ncol();

// Pointers to R data
double* X_train_mu_data_ptr = REAL(PROTECT(X_train_mu));
double* X_train_tau_data_ptr = REAL(PROTECT(X_train_tau));
double* Z_train_data_ptr = REAL(PROTECT(Z_train));
double* y_train_data_ptr = REAL(PROTECT(y_train));
double* weights_train_data_ptr = REAL(PROTECT(weights_train));

// Load training data into BCF model
bcf_wrapper->LoadTrain(
y_train_data_ptr, n, X_train_mu_data_ptr, X_train_mu_cols,
X_train_tau_data_ptr, X_train_tau_cols, Z_train_data_ptr,
Z_train_cols, treatment_binary, weights_train_data_ptr
);

// UNPROTECT the SEXPs created to point to the R data
UNPROTECT(5);
}

[[cpp11::register]]
void bcf_add_train_no_weights_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper, cpp11::doubles_matrix<> X_train_mu,
cpp11::doubles_matrix<> X_train_tau, cpp11::doubles_matrix<> Z_train,
cpp11::doubles y_train, bool treatment_binary
) {
// Data dimensions
int n = X_train_mu.nrow();
int X_train_mu_cols = X_train_mu.ncol();
int X_train_tau_cols = X_train_tau.ncol();
int Z_train_cols = Z_train.ncol();

// Pointers to R data
double* X_train_mu_data_ptr = REAL(PROTECT(X_train_mu));
double* X_train_tau_data_ptr = REAL(PROTECT(X_train_tau));
double* Z_train_data_ptr = REAL(PROTECT(Z_train));
double* y_train_data_ptr = REAL(PROTECT(y_train));

// Load training data into BCF model
bcf_wrapper->LoadTrain(
y_train_data_ptr, n, X_train_mu_data_ptr, X_train_mu_cols,
X_train_tau_data_ptr, X_train_tau_cols, Z_train_data_ptr,
Z_train_cols, treatment_binary
);

// UNPROTECT the SEXPs created to point to the R data
UNPROTECT(4);
}

[[cpp11::register]]
void bcf_add_test_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper, cpp11::doubles_matrix<> X_test_mu,
cpp11::doubles_matrix<> X_test_tau, cpp11::doubles_matrix<> Z_test
) {
// Data dimensions
int n = X_test_mu.nrow();
int X_test_mu_cols = X_test_mu.ncol();
int X_test_tau_cols = X_test_tau.ncol();
int Z_test_cols = Z_test.ncol();

// Pointers to R data
double* X_test_mu_data_ptr = REAL(PROTECT(X_test_mu));
double* X_test_tau_data_ptr = REAL(PROTECT(X_test_tau));
double* Z_test_data_ptr = REAL(PROTECT(Z_test));

// Load test data into BCF model
bcf_wrapper->LoadTest(
X_test_mu_data_ptr, n, X_test_mu_cols,
X_test_tau_data_ptr, X_test_tau_cols,
Z_test_data_ptr, Z_test_cols
);

// UNPROTECT the SEXPs created to point to the R data
UNPROTECT(3);
}

[[cpp11::register]]
void bcf_reset_global_var_samples_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper,
cpp11::doubles data_vector
) {
// Data dimensions
int n = data_vector.size();

// Pointer to R data
double* data_ptr = REAL(PROTECT(data_vector));

// Map Eigen array to data in the R container
bcf_wrapper->ResetGlobalVarSamples(data_ptr, n);

// UNPROTECT the SEXP created to point to the R data
UNPROTECT(1);
}

[[cpp11::register]]
void bcf_reset_prognostic_leaf_var_samples_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper,
cpp11::doubles data_vector
) {
// Data dimensions
int n = data_vector.size();

// Pointer to R data
double* data_ptr = REAL(PROTECT(data_vector));

// Map Eigen array to data in the R container
bcf_wrapper->ResetPrognosticLeafVarSamples(data_ptr, n);

// UNPROTECT the SEXP created to point to the R data
UNPROTECT(1);
}

[[cpp11::register]]
void bcf_reset_treatment_leaf_var_samples_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper,
cpp11::doubles data_vector
) {
// Data dimensions
int n = data_vector.size();

// Pointer to R data
double* data_ptr = REAL(PROTECT(data_vector));

// Map Eigen array to data in the R container
bcf_wrapper->ResetTreatmentLeafVarSamples(data_ptr, n);

// UNPROTECT the SEXP created to point to the R data
UNPROTECT(1);
}

[[cpp11::register]]
void bcf_reset_treatment_coding_samples_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper,
cpp11::doubles data_vector
) {
// Data dimensions
int n = data_vector.size();

// Pointer to R data
double* data_ptr = REAL(PROTECT(data_vector));

// Map Eigen array to data in the R container
bcf_wrapper->ResetTreatedCodingSamples(data_ptr, n);

// UNPROTECT the SEXP created to point to the R data
UNPROTECT(1);
}

[[cpp11::register]]
void bcf_reset_control_coding_samples_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper,
cpp11::doubles data_vector
) {
// Data dimensions
int n = data_vector.size();

// Pointer to R data
double* data_ptr = REAL(PROTECT(data_vector));

// Map Eigen array to data in the R container
bcf_wrapper->ResetControlCodingSamples(data_ptr, n);

// UNPROTECT the SEXP created to point to the R data
UNPROTECT(1);
}

[[cpp11::register]]
void bcf_reset_train_prediction_samples_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper,
cpp11::doubles_matrix<> muhat, cpp11::doubles tauhat, cpp11::doubles_matrix<> yhat,
int num_obs, int num_samples, int treatment_dim
) {
// Pointers to R data
double* muhat_data_ptr = REAL(PROTECT(muhat));
double* tauhat_data_ptr = REAL(PROTECT(tauhat));
double* yhat_data_ptr = REAL(PROTECT(yhat));

// Map Eigen array to data in the R container
bcf_wrapper->ResetTrainPredictionSamples(muhat_data_ptr, tauhat_data_ptr, yhat_data_ptr, num_obs, num_samples, treatment_dim);

// UNPROTECT the SEXP created to point to the R data
UNPROTECT(3);
}

[[cpp11::register]]
void bcf_reset_test_prediction_samples_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper,
cpp11::doubles_matrix<> muhat, cpp11::doubles tauhat, cpp11::doubles_matrix<> yhat,
int num_obs, int num_samples, int treatment_dim
) {
// Pointers to R data
double* muhat_data_ptr = REAL(PROTECT(muhat));
double* tauhat_data_ptr = REAL(PROTECT(tauhat));
double* yhat_data_ptr = REAL(PROTECT(yhat));

// Map Eigen array to data in the R container
bcf_wrapper->ResetTestPredictionSamples(muhat_data_ptr, tauhat_data_ptr, yhat_data_ptr, num_obs, num_samples, treatment_dim);

// UNPROTECT the SEXP created to point to the R data
UNPROTECT(3);
}

[[cpp11::register]]
void sample_bcf_univariate_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper,
cpp11::external_pointer<StochTree::ForestContainer> forest_samples_mu,
cpp11::external_pointer<StochTree::ForestContainer> forest_samples_tau,
cpp11::external_pointer<std::mt19937> rng,
int cutpoint_grid_size, double sigma_leaf_mu, double sigma_leaf_tau,
double alpha_mu, double alpha_tau, double beta_mu, double beta_tau,
int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb,
double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau,
double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0,
cpp11::integers feature_types_mu_int, cpp11::integers feature_types_tau_int,
int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau
) {
// Convert feature_types
std::vector<StochTree::FeatureType> feature_types_mu(feature_types_mu_int.size());
for (int i = 0; i < feature_types_mu_int.size(); i++) {
feature_types_mu.at(i) = static_cast<StochTree::FeatureType>(feature_types_mu_int.at(i));
}
std::vector<StochTree::FeatureType> feature_types_tau(feature_types_tau_int.size());
for (int i = 0; i < feature_types_tau_int.size(); i++) {
feature_types_tau.at(i) = static_cast<StochTree::FeatureType>(feature_types_tau_int.at(i));
}

// Run the sampler
bcf_wrapper->SampleBCF(forest_samples_mu.get(), forest_samples_tau.get(), rng.get(),
cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau,
beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb,
a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau,
b1, b0, feature_types_mu, feature_types_tau, num_gfr, num_burnin, num_mcmc,
leaf_init_mu, leaf_init_tau);
}

[[cpp11::register]]
void sample_bcf_multivariate_cpp(
cpp11::external_pointer<StochTree::BCFModelWrapper> bcf_wrapper,
cpp11::external_pointer<StochTree::ForestContainer> forest_samples_mu,
cpp11::external_pointer<StochTree::ForestContainer> forest_samples_tau,
cpp11::external_pointer<std::mt19937> rng,
int cutpoint_grid_size, double sigma_leaf_mu, cpp11::doubles_matrix<> sigma_leaf_tau_r,
double alpha_mu, double alpha_tau, double beta_mu, double beta_tau,
int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb,
double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau,
double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0,
cpp11::integers feature_types_mu_int, cpp11::integers feature_types_tau_int,
int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau
) {
// Convert feature_types
std::vector<StochTree::FeatureType> feature_types_mu(feature_types_mu_int.size());
for (int i = 0; i < feature_types_mu_int.size(); i++) {
feature_types_mu.at(i) = static_cast<StochTree::FeatureType>(feature_types_mu_int.at(i));
}
std::vector<StochTree::FeatureType> feature_types_tau(feature_types_tau_int.size());
for (int i = 0; i < feature_types_tau_int.size(); i++) {
feature_types_tau.at(i) = static_cast<StochTree::FeatureType>(feature_types_tau_int.at(i));
}

// Convert sigma_leaf_tau
Eigen::MatrixXd sigma_leaf_tau;
int num_row = sigma_leaf_tau_r.nrow();
int num_col = sigma_leaf_tau_r.ncol();
sigma_leaf_tau.resize(num_row, num_col);
for (int i = 0; i < num_row; i++) {
for (int j = 0; j < num_col; j++) {
sigma_leaf_tau(i,j) = sigma_leaf_tau_r(i,j);
}
}

// Run the sampler
bcf_wrapper->SampleBCF(forest_samples_mu.get(), forest_samples_tau.get(), rng.get(),
cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau,
beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb,
a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau,
b1, b0, feature_types_mu, feature_types_tau, num_gfr, num_burnin, num_mcmc,
leaf_init_mu, leaf_init_tau);
}
Loading
Loading