Skip to content

Commit

Permalink
Merge pull request #14 from StochasticTree/interface_updates
Browse files Browse the repository at this point in the history
Updated interface
  • Loading branch information
andrewherren committed May 11, 2024
2 parents 20c3b4f + c3e6016 commit 4b8ac00
Show file tree
Hide file tree
Showing 49 changed files with 3,994 additions and 169 deletions.
16 changes: 9 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ LinkingTo:
cpp11
Suggests:
knitr,
rmarkdown,
Matrix,
tgp,
MASS,
mvtnorm,
ggplot2,
latex2exp
rmarkdown,
Matrix,
tgp,
MASS,
mvtnorm,
ggplot2,
latex2exp,
testthat (>= 3.0.0)
VignetteBuilder: knitr
SystemRequirements: C++17
Imports:
R6
URL: https://stochastictree.github.io/stochtree-r/
Config/testthat/edition: 3
15 changes: 15 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
# Generated by roxygen2: do not edit by hand

S3method(convertToJson,bcf)
S3method(getRandomEffectSamples,bartmodel)
S3method(getRandomEffectSamples,bcf)
S3method(predict,bartmodel)
S3method(predict,bcf)
S3method(saveToJsonFile,bcf)
export(bart)
export(bcf)
export(computeForestKernels)
export(computeForestLeafIndices)
export(convertToJson)
export(createBCFModelFromJson)
export(createBCFModelFromJsonFile)
export(createCppJson)
export(createCppJsonFile)
export(createForestContainer)
export(createForestCovariates)
export(createForestCovariatesFromMetadata)
export(createForestDataset)
export(createForestKernel)
export(createForestModel)
Expand All @@ -22,6 +30,13 @@ export(createRandomEffectsTracker)
export(getRandomEffectSamples)
export(loadForestContainerJson)
export(loadRandomEffectSamplesJson)
export(loadScalarJson)
export(loadVectorJson)
export(oneHotEncode)
export(oneHotInitializeAndEncode)
export(orderedCatInitializeAndPreprocess)
export(orderedCatPreprocess)
export(sample_sigma2_one_iteration)
export(sample_tau_one_iteration)
export(saveToJsonFile)
useDynLib(stochtree, .registration = TRUE)
4 changes: 2 additions & 2 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
"num_samples" = num_samples,
"has_basis" = !is.null(W_train),
"has_rfx" = has_rfx,
"has_basis_rfx" = has_basis_rfx,
"num_basis_rfx" = num_basis_rfx
"has_rfx_basis" = has_basis_rfx,
"num_rfx_basis" = num_basis_rfx
)
result <- list(
"forests" = forest_samples,
Expand Down
567 changes: 541 additions & 26 deletions R/bcf.R

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ rfx_model_sample_random_effects_cpp <- function(rfx_model, rfx_dataset, residual
invisible(.Call(`_stochtree_rfx_model_sample_random_effects_cpp`, rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, global_variance, rng))
}

rfx_model_predict_cpp <- function(rfx_model, rfx_dataset, rfx_tracker) {
.Call(`_stochtree_rfx_model_predict_cpp`, rfx_model, rfx_dataset, rfx_tracker)
}

rfx_container_predict_cpp <- function(rfx_container, rfx_dataset, label_mapper) {
.Call(`_stochtree_rfx_container_predict_cpp`, rfx_container, rfx_dataset, label_mapper)
}
Expand Down Expand Up @@ -304,6 +308,14 @@ json_add_double_cpp <- function(json_ptr, field_name, field_value) {
invisible(.Call(`_stochtree_json_add_double_cpp`, json_ptr, field_name, field_value))
}

json_add_bool_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_value) {
invisible(.Call(`_stochtree_json_add_bool_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_value))
}

json_add_bool_cpp <- function(json_ptr, field_name, field_value) {
invisible(.Call(`_stochtree_json_add_bool_cpp`, json_ptr, field_name, field_value))
}

json_add_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_vector) {
invisible(.Call(`_stochtree_json_add_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_vector))
}
Expand All @@ -312,6 +324,22 @@ json_add_vector_cpp <- function(json_ptr, field_name, field_vector) {
invisible(.Call(`_stochtree_json_add_vector_cpp`, json_ptr, field_name, field_vector))
}

json_add_string_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_vector) {
invisible(.Call(`_stochtree_json_add_string_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_vector))
}

json_add_string_vector_cpp <- function(json_ptr, field_name, field_vector) {
invisible(.Call(`_stochtree_json_add_string_vector_cpp`, json_ptr, field_name, field_vector))
}

json_add_string_subfolder_cpp <- function(json_ptr, subfolder_name, field_name, field_value) {
invisible(.Call(`_stochtree_json_add_string_subfolder_cpp`, json_ptr, subfolder_name, field_name, field_value))
}

json_add_string_cpp <- function(json_ptr, field_name, field_value) {
invisible(.Call(`_stochtree_json_add_string_cpp`, json_ptr, field_name, field_value))
}

json_contains_field_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_contains_field_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}
Expand All @@ -328,6 +356,22 @@ json_extract_double_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_extract_double_cpp`, json_ptr, field_name)
}

json_extract_bool_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_extract_bool_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}

json_extract_bool_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_extract_bool_cpp`, json_ptr, field_name)
}

json_extract_string_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_extract_string_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}

json_extract_string_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_extract_string_cpp`, json_ptr, field_name)
}

json_extract_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_extract_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}
Expand All @@ -336,6 +380,14 @@ json_extract_vector_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_extract_vector_cpp`, json_ptr, field_name)
}

json_extract_string_vector_subfolder_cpp <- function(json_ptr, subfolder_name, field_name) {
.Call(`_stochtree_json_extract_string_vector_subfolder_cpp`, json_ptr, subfolder_name, field_name)
}

json_extract_string_vector_cpp <- function(json_ptr, field_name) {
.Call(`_stochtree_json_extract_string_vector_cpp`, json_ptr, field_name)
}

json_add_forest_cpp <- function(json_ptr, forest_samples) {
.Call(`_stochtree_json_add_forest_cpp`, json_ptr, forest_samples)
}
Expand All @@ -359,3 +411,7 @@ json_add_rfx_groupids_cpp <- function(json_ptr, groupids) {
json_save_cpp <- function(json_ptr, filename) {
invisible(.Call(`_stochtree_json_save_cpp`, json_ptr, filename))
}

json_load_cpp <- function(json_ptr, filename) {
invisible(.Call(`_stochtree_json_load_cpp`, json_ptr, filename))
}
13 changes: 0 additions & 13 deletions R/forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,3 @@ createForestContainer <- function(num_trees, output_dimension=1, is_leaf_constan
ForestSamples$new(num_trees, output_dimension, is_leaf_constant)
)))
}

#' Load a container of forest samples from json
#'
#' @param json_object Object of class `CppJson`
#' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy
#'
#' @return `ForestSamples` object
#' @export
loadForestContainerJson <- function(json_object, json_forest_label) {
invisible(output <- ForestSamples$new(0,1,T))
output$load_from_json(json_object, json_forest_label)
return(output)
}
14 changes: 14 additions & 0 deletions R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,17 @@
#' @return List of random effect samples
#' @export
getRandomEffectSamples <- function(object, ...) UseMethod("getRandomEffectSamples")

#' Convert a model object (BCF, BART, etc...) to JSON
#'
#' @return Object of type `CppJson` which can be saved to disk with the `save_file(filename)`
#' method
#' @export
convertToJson <- function(object, ...) UseMethod("convertToJson")

#' Convert a model object (BCF, BART, etc...) to JSON and save it to a file
#' with a json suffix named `filename`
#'
#' @return NULL
#' @export
saveToJsonFile <- function(object, filename, ...) UseMethod("saveToJsonFile")
26 changes: 10 additions & 16 deletions R/random_effects.R
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,16 @@ RandomEffectsModel <- R6::R6Class(
rfx_samples$rfx_container_ptr, global_variance, rng$rng_ptr)
},

#' @description
#' Predict from (a single sample of a) random effects model.
#' @param rfx_dataset Object of type `RandomEffectsDataset`
#' @param rfx_tracker Object of type `RandomEffectsTracker`
#' @return Vector of predictions with size matching number of observations in rfx_dataset
predict = function(rfx_dataset, rfx_tracker) {
pred <- rfx_model_predict_cpp(self$rfx_model_ptr, rfx_dataset$data_ptr, rfx_tracker$rfx_tracker_ptr)
return(pred)
},

#' @description
#' Set value for the "working parameter." This is typically
#' used for initialization, but could also be used to interrupt
Expand Down Expand Up @@ -312,19 +322,3 @@ createRandomEffectsModel <- function(num_components, num_groups) {
RandomEffectsModel$new(num_components, num_groups)
)))
}

#' Load a container of forest samples from json
#'
#' @param json_object Object of class `CppJson`
#' @param json_rfx_num Integer index indicating the position of the random effects term to be unpacked
#'
#' @return `RandomEffectSamples` object
#' @export
loadRandomEffectSamplesJson <- function(json_object, json_rfx_num) {
json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num)
json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num)
json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num)
invisible(output <- RandomEffectSamples$new())
output$load_from_json(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label)
return(output)
}
Loading

0 comments on commit 4b8ac00

Please sign in to comment.