Skip to content

Commit

Permalink
Merge pull request #13 from StochasticTree/deserialization
Browse files Browse the repository at this point in the history
Complete JSON serialization / deserialization for draws from stochastic tree ensemble models
  • Loading branch information
andrewherren authored May 8, 2024
2 parents bedef67 + d3e1318 commit 20c3b4f
Show file tree
Hide file tree
Showing 20 changed files with 1,110 additions and 288 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ export(createRandomEffectsDataset)
export(createRandomEffectsModel)
export(createRandomEffectsTracker)
export(getRandomEffectSamples)
export(loadForestContainerJson)
export(loadRandomEffectSamplesJson)
export(sample_sigma2_one_iteration)
export(sample_tau_one_iteration)
useDynLib(stochtree, .registration = TRUE)
152 changes: 110 additions & 42 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,70 @@ rfx_dataset_add_weights_cpp <- function(dataset_ptr, weights) {
invisible(.Call(`_stochtree_rfx_dataset_add_weights_cpp`, dataset_ptr, weights))
}

forest_container_cpp <- function(num_trees, output_dimension, is_leaf_constant) {
.Call(`_stochtree_forest_container_cpp`, num_trees, output_dimension, is_leaf_constant)
}

forest_container_from_json_cpp <- function(json_ptr, forest_label) {
.Call(`_stochtree_forest_container_from_json_cpp`, json_ptr, forest_label)
}

num_samples_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_num_samples_forest_container_cpp`, forest_samples)
}

num_trees_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_num_trees_forest_container_cpp`, forest_samples)
}

json_save_forest_container_cpp <- function(forest_samples, json_filename) {
invisible(.Call(`_stochtree_json_save_forest_container_cpp`, forest_samples, json_filename))
}

json_load_forest_container_cpp <- function(forest_samples, json_filename) {
invisible(.Call(`_stochtree_json_load_forest_container_cpp`, forest_samples, json_filename))
}

output_dimension_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_output_dimension_forest_container_cpp`, forest_samples)
}

is_leaf_constant_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_is_leaf_constant_forest_container_cpp`, forest_samples)
}

all_roots_forest_container_cpp <- function(forest_samples, forest_num) {
.Call(`_stochtree_all_roots_forest_container_cpp`, forest_samples, forest_num)
}

add_sample_forest_container_cpp <- function(forest_samples) {
invisible(.Call(`_stochtree_add_sample_forest_container_cpp`, forest_samples))
}

set_leaf_value_forest_container_cpp <- function(forest_samples, leaf_value) {
invisible(.Call(`_stochtree_set_leaf_value_forest_container_cpp`, forest_samples, leaf_value))
}

set_leaf_vector_forest_container_cpp <- function(forest_samples, leaf_vector) {
invisible(.Call(`_stochtree_set_leaf_vector_forest_container_cpp`, forest_samples, leaf_vector))
}

update_residual_forest_container_cpp <- function(data, residual, forest_samples, tracker, requires_basis, forest_num, add) {
invisible(.Call(`_stochtree_update_residual_forest_container_cpp`, data, residual, forest_samples, tracker, requires_basis, forest_num, add))
}

predict_forest_cpp <- function(forest_samples, dataset) {
.Call(`_stochtree_predict_forest_cpp`, forest_samples, dataset)
}

predict_forest_raw_cpp <- function(forest_samples, dataset) {
.Call(`_stochtree_predict_forest_raw_cpp`, forest_samples, dataset)
}

predict_forest_raw_single_forest_cpp <- function(forest_samples, dataset, forest_num) {
.Call(`_stochtree_predict_forest_raw_single_forest_cpp`, forest_samples, dataset, forest_num)
}

forest_kernel_cpp <- function() {
.Call(`_stochtree_forest_kernel_cpp`)
}
Expand Down Expand Up @@ -104,20 +168,20 @@ forest_kernel_compute_kernel_train_test_cpp <- function(forest_kernel, covariate
.Call(`_stochtree_forest_kernel_compute_kernel_train_test_cpp`, forest_kernel, covariates_train, covariates_test, forest_container, forest_num)
}

predict_forest_cpp <- function(forest_samples, dataset) {
.Call(`_stochtree_predict_forest_cpp`, forest_samples, dataset)
rfx_container_cpp <- function(num_components, num_groups) {
.Call(`_stochtree_rfx_container_cpp`, num_components, num_groups)
}

predict_forest_raw_cpp <- function(forest_samples, dataset) {
.Call(`_stochtree_predict_forest_raw_cpp`, forest_samples, dataset)
rfx_container_from_json_cpp <- function(json_ptr, rfx_label) {
.Call(`_stochtree_rfx_container_from_json_cpp`, json_ptr, rfx_label)
}

predict_forest_raw_single_forest_cpp <- function(forest_samples, dataset, forest_num) {
.Call(`_stochtree_predict_forest_raw_single_forest_cpp`, forest_samples, dataset, forest_num)
rfx_label_mapper_from_json_cpp <- function(json_ptr, rfx_label) {
.Call(`_stochtree_rfx_label_mapper_from_json_cpp`, json_ptr, rfx_label)
}

rfx_container_cpp <- function(num_components, num_groups) {
.Call(`_stochtree_rfx_container_cpp`, num_components, num_groups)
rfx_group_ids_from_json_cpp <- function(json_ptr, rfx_label) {
.Call(`_stochtree_rfx_group_ids_from_json_cpp`, json_ptr, rfx_label)
}

rfx_model_cpp <- function(num_components, num_groups) {
Expand Down Expand Up @@ -220,72 +284,76 @@ rng_cpp <- function(random_seed) {
.Call(`_stochtree_rng_cpp`, random_seed)
}

forest_container_cpp <- function(num_trees, output_dimension, is_leaf_constant) {
.Call(`_stochtree_forest_container_cpp`, num_trees, output_dimension, is_leaf_constant)
tree_prior_cpp <- function(alpha, beta, min_samples_leaf) {
.Call(`_stochtree_tree_prior_cpp`, alpha, beta, min_samples_leaf)
}

num_samples_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_num_samples_forest_container_cpp`, forest_samples)
forest_tracker_cpp <- function(data, feature_types, num_trees, n) {
.Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n)
}

num_trees_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_num_trees_forest_container_cpp`, forest_samples)
init_json_cpp <- function() {
.Call(`_stochtree_init_json_cpp`)
}

json_save_forest_container_cpp <- function(forest_samples, json_filename) {
invisible(.Call(`_stochtree_json_save_forest_container_cpp`, forest_samples, json_filename))
json_add_double_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_value) {
invisible(.Call(`_stochtree_json_add_double_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_value))
}

json_load_forest_container_cpp <- function(forest_samples, json_filename) {
invisible(.Call(`_stochtree_json_load_forest_container_cpp`, forest_samples, json_filename))
json_add_double_cpp <- function(json_ptr, field_name, field_value) {
invisible(.Call(`_stochtree_json_add_double_cpp`, json_ptr, field_name, field_value))
}

output_dimension_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_output_dimension_forest_container_cpp`, forest_samples)
json_add_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_vector) {
invisible(.Call(`_stochtree_json_add_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_vector))
}

is_leaf_constant_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_is_leaf_constant_forest_container_cpp`, forest_samples)
json_add_vector_cpp <- function(json_ptr, field_name, field_vector) {
invisible(.Call(`_stochtree_json_add_vector_cpp`, json_ptr, field_name, field_vector))
}

all_roots_forest_container_cpp <- function(forest_samples, forest_num) {
.Call(`_stochtree_all_roots_forest_container_cpp`, forest_samples, forest_num)
json_contains_field_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_contains_field_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}

add_sample_forest_container_cpp <- function(forest_samples) {
invisible(.Call(`_stochtree_add_sample_forest_container_cpp`, forest_samples))
json_contains_field_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_contains_field_cpp`, json_ptr, field_name)
}

set_leaf_value_forest_container_cpp <- function(forest_samples, leaf_value) {
invisible(.Call(`_stochtree_set_leaf_value_forest_container_cpp`, forest_samples, leaf_value))
json_extract_double_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_extract_double_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}

set_leaf_vector_forest_container_cpp <- function(forest_samples, leaf_vector) {
invisible(.Call(`_stochtree_set_leaf_vector_forest_container_cpp`, forest_samples, leaf_vector))
json_extract_double_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_extract_double_cpp`, json_ptr, field_name)
}

update_residual_forest_container_cpp <- function(data, residual, forest_samples, tracker, requires_basis, forest_num, add) {
invisible(.Call(`_stochtree_update_residual_forest_container_cpp`, data, residual, forest_samples, tracker, requires_basis, forest_num, add))
json_extract_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_extract_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}

tree_prior_cpp <- function(alpha, beta, min_samples_leaf) {
.Call(`_stochtree_tree_prior_cpp`, alpha, beta, min_samples_leaf)
json_extract_vector_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_extract_vector_cpp`, json_ptr, field_name)
}

forest_tracker_cpp <- function(data, feature_types, num_trees, n) {
.Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n)
json_add_forest_cpp <- function(json_ptr, forest_samples) {
.Call(`_stochtree_json_add_forest_cpp`, json_ptr, forest_samples)
}

init_json_cpp <- function() {
.Call(`_stochtree_init_json_cpp`)
json_increment_rfx_count_cpp <- function(json_ptr) {
invisible(.Call(`_stochtree_json_increment_rfx_count_cpp`, json_ptr))
}

json_add_forest_cpp <- function(json_ptr, forest_samples) {
invisible(.Call(`_stochtree_json_add_forest_cpp`, json_ptr, forest_samples))
json_add_rfx_container_cpp <- function(json_ptr, rfx_samples) {
.Call(`_stochtree_json_add_rfx_container_cpp`, json_ptr, rfx_samples)
}

json_add_rfx_label_mapper_cpp <- function(json_ptr, label_mapper) {
.Call(`_stochtree_json_add_rfx_label_mapper_cpp`, json_ptr, label_mapper)
}

json_add_rfx_cpp <- function(json_ptr, rfx_samples) {
invisible(.Call(`_stochtree_json_add_rfx_cpp`, json_ptr, rfx_samples))
json_add_rfx_groupids_cpp <- function(json_ptr, groupids) {
.Call(`_stochtree_json_add_rfx_groupids_cpp`, json_ptr, groupids)
}

json_save_cpp <- function(json_ptr, filename) {
Expand Down
22 changes: 22 additions & 0 deletions R/forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ ForestSamples <- R6::R6Class(
self$forest_container_ptr <- forest_container_cpp(num_trees, output_dimension, is_leaf_constant)
},

#' @description
#' Create a new ForestContainer object from a json object
#' @param json_object Object of class `CppJson`
#' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy
#' @return A new `ForestContainer` object.
load_from_json = function(json_object, json_forest_label) {
self$forest_container_ptr <- forest_container_from_json_cpp(json_object$json_ptr, json_forest_label)
},

#' @description
#' Predict every tree ensemble on every sample in `forest_dataset`
#' @param forest_dataset `ForestDataset` R class
Expand Down Expand Up @@ -170,3 +179,16 @@ createForestContainer <- function(num_trees, output_dimension=1, is_leaf_constan
ForestSamples$new(num_trees, output_dimension, is_leaf_constant)
)))
}

#' Load a container of forest samples from json
#'
#' @param json_object Object of class `CppJson`
#' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy
#'
#' @return `ForestSamples` object
#' @export
loadForestContainerJson <- function(json_object, json_forest_label) {
invisible(output <- ForestSamples$new(0,1,T))
output$load_from_json(json_object, json_forest_label)
return(output)
}
57 changes: 47 additions & 10 deletions R/random_effects.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,54 @@ RandomEffectSamples <- R6::R6Class(

#' @description
#' Create a new RandomEffectSamples object.
#' @return A new `RandomEffectSamples` object.
initialize = function() {},

#' @description
#' Construct RandomEffectSamples object from other "in-session" R objects
#' @param num_components Number of "components" or bases defining the random effects regression
#' @param num_groups Number of random effects groups
#' @param random_effects_tracker Object of type `RandomEffectsTracker`
#' @return A new `RandomEffectSamples` object.
initialize = function(num_components, num_groups, random_effects_tracker) {
#' @return NULL
load_in_session = function(num_components, num_groups, random_effects_tracker) {
# Initialize
self$rfx_container_ptr <- rfx_container_cpp(num_components, num_groups)
self$label_mapper_ptr <- rfx_label_mapper_cpp(random_effects_tracker$rfx_tracker_ptr)
self$training_group_ids <- rfx_tracker_get_unique_group_ids_cpp(random_effects_tracker$rfx_tracker_ptr)
},

#' @description
#' Construct RandomEffectSamples object from a json object
#' @param json_object Object of class `CppJson`
#' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy
#' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy
#' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy
#' @return A new `RandomEffectSamples` object.
load_from_json = function(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) {
self$rfx_container_ptr <- rfx_container_from_json_cpp(json_object$json_ptr, json_rfx_container_label)
self$label_mapper_ptr <- rfx_label_mapper_from_json_cpp(json_object$json_ptr, json_rfx_mapper_label)
self$training_group_ids <- rfx_group_ids_from_json_cpp(json_object$json_ptr, json_rfx_groupids_label)
},

#' @description
#' Predict random effects for each observation implied by `rfx_group_ids` and `rfx_basis`.
#' If a random effects model is "intercept-only" the `rfx_basis` will be a vector of ones of size `length(rfx_group_ids)`.
#' @param rfx_group_ids Indices of random effects groups in a prediction set
#' @param rfx_basis Basis used for random effects prediction
#' @param rfx_basis (Optional ) Basis used for random effects prediction
#' @return Matrix with as many rows as observations provided and as many columns as samples drawn of the model.
predict = function(rfx_group_ids, rfx_basis) {
num_observations = length(rfx_group_ids)
predict = function(rfx_group_ids, rfx_basis = NULL) {
num_obs = length(rfx_group_ids)
if (is.null(rfx_basis)) rfx_basis <- matrix(rep(1,num_obs), ncol = 1)
num_samples = rfx_container_num_samples_cpp(self$rfx_container_ptr)
num_components = rfx_container_num_components_cpp(self$rfx_container_ptr)
num_groups = rfx_container_num_groups_cpp(self$rfx_container_ptr)
rfx_group_ids_int <- as.integer(rfx_group_ids)
stopifnot(sum(abs(rfx_group_ids_int-rfx_group_ids)) < 1e-6)
stopifnot(sum(!(rfx_group_ids %in% self$training_group_ids)) == 0)
stopifnot(ncol(rfx_basis) == num_components)
rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis)
rfx_dataset <- createRandomEffectsDataset(rfx_group_ids_int, rfx_basis)
output <- rfx_container_predict_cpp(self$rfx_container_ptr, rfx_dataset$data_ptr, self$label_mapper_ptr)
dim(output) <- c(num_observations, num_samples)
dim(output) <- c(num_obs, num_samples)
return(output)
},

Expand Down Expand Up @@ -264,9 +285,9 @@ RandomEffectsModel <- R6::R6Class(
#' @return `RandomEffectSamples` object
#' @export
createRandomEffectSamples <- function(num_components, num_groups, random_effects_tracker) {
return(invisible((
RandomEffectSamples$new(num_components, num_groups, random_effects_tracker)
)))
invisible(output <- RandomEffectSamples$new())
output$load_in_session(num_components, num_groups, random_effects_tracker)
return(output)
}

#' Create a `RandomEffectsTracker` object
Expand All @@ -291,3 +312,19 @@ createRandomEffectsModel <- function(num_components, num_groups) {
RandomEffectsModel$new(num_components, num_groups)
)))
}

#' Load a container of forest samples from json
#'
#' @param json_object Object of class `CppJson`
#' @param json_rfx_num Integer index indicating the position of the random effects term to be unpacked
#'
#' @return `RandomEffectSamples` object
#' @export
loadRandomEffectSamplesJson <- function(json_object, json_rfx_num) {
json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num)
json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num)
json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num)
invisible(output <- RandomEffectSamples$new())
output$load_from_json(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label)
return(output)
}
Loading

0 comments on commit 20c3b4f

Please sign in to comment.