diff --git a/R/RcppExports.R b/R/RcppExports.R
index c229af13..9d325433 100644
--- a/R/RcppExports.R
+++ b/R/RcppExports.R
@@ -37,8 +37,8 @@ compute_pred_prob_exported <- function(y, w) {
.Call(`_aorsf_compute_pred_prob_exported`, y, w)
}
-expand_y_clsf <- function(y, n_class) {
- .Call(`_aorsf_expand_y_clsf`, y, n_class)
+compute_var_reduction_exported <- function(y_node, w_node, g_node) {
+ .Call(`_aorsf_compute_var_reduction_exported`, y_node, w_node, g_node)
}
is_col_splittable_exported <- function(x, y, r, j) {
@@ -73,11 +73,15 @@ cph_scale <- function(x, w) {
.Call(`_aorsf_cph_scale`, x, w)
}
-orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, sample_with_replacement, sample_fraction, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity) {
- .Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, sample_with_replacement, sample_fraction, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity)
+expand_y_clsf <- function(y, n_class) {
+ .Call(`_aorsf_expand_y_clsf`, y, n_class)
+}
+
+compute_mse_exported <- function(y, w, p) {
+ .Call(`_aorsf_compute_mse_exported`, y, w, p)
}
-compute_var_reduction <- function(y_node, w_node, g_node) {
- .Call(`_aorsf_compute_var_reduction`, y_node, w_node, g_node)
+orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, sample_with_replacement, sample_fraction, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity) {
+ .Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, sample_with_replacement, sample_fraction, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity)
}
diff --git a/R/orsf.R b/R/orsf.R
index 0d5dcbab..9fc61f25 100644
--- a/R/orsf.R
+++ b/R/orsf.R
@@ -371,6 +371,7 @@ orsf <- function(data,
pred_type = oobag_pred_type,
oobag_pred_horizon = oobag_pred_horizon,
oobag_eval_every = oobag_eval_every,
+ oobag_fun = oobag_fun,
importance_type = importance,
importance_max_pvalue = importance_max_pvalue,
importance_group_factors = group_factors,
@@ -391,7 +392,7 @@ orsf <- function(data,
tree_type,
'survival' = do.call(ObliqueForestSurvival$new, args = args),
'classification' = do.call(ObliqueForestClassification$new, args = args),
- 'regression' = stop("not ready yet")
+ 'regression' = do.call(ObliqueForestRegression$new, args = args)
)
if(no_fit) return(object)
diff --git a/R/orsf_R6.R b/R/orsf_R6.R
index c0d472bf..7be88194 100644
--- a/R/orsf_R6.R
+++ b/R/orsf_R6.R
@@ -181,6 +181,9 @@ ObliqueForest <- R6::R6Class(
if(self$tree_type == 'survival')
paste0(' N events: ', private$n_events ),
+ if(self$tree_type == 'classification')
+ paste0(' N classes: ', self$n_class ),
+
paste0(' N trees: ', self$n_tree ),
paste0(' N predictors total: ', n_predictors ),
paste0(' N predictors per node: ', self$mtry ),
@@ -272,17 +275,16 @@ ObliqueForest <- R6::R6Class(
private$check_na_action(new = TRUE, na_action = na_action)
private$check_var_missing(new = TRUE, data = new_data, na_action)
private$check_units(data = new_data)
- private$check_pred_type(oobag = FALSE, pred_type = pred_type)
+ private$check_boundary_checks(boundary_checks)
private$check_n_thread(n_thread)
private$check_verbose_progress(verbose_progress)
private$check_pred_aggregate(pred_aggregate)
- if(self$tree_type == 'survival')
- private$check_pred_horizon(pred_horizon, boundary_checks)
-
+ # check and/or set self$pred_horizon and self$pred_type
+ # with defaults depending on tree type
+ private$init_pred(pred_horizon, pred_type, boundary_checks)
+ # set the rest
self$data <- new_data
- self$pred_horizon <- pred_horizon
- self$pred_type <- pred_type
self$na_action <- na_action
self$n_thread <- n_thread
self$verbose_progress <- verbose_progress
@@ -1971,30 +1973,7 @@ ObliqueForest <- R6::R6Class(
call. = FALSE
)
- test_time <- seq(from = 1, to = 5, length.out = 100)
- test_status <- rep(c(0,1), each = 50)
-
- .y_mat <- cbind(time = test_time, status = test_status)
- .w_vec <- rep(1, times = 100)
- .s_vec <- seq(0.9, 0.1, length.out = 100)
-
- test_output <- try(input(y_mat = .y_mat, w_vec = .w_vec, s_vec = .s_vec),
- silent = FALSE)
-
- if(is_error(test_output)){
-
- stop("oobag_fun encountered an error when it was tested. ",
- "Please make sure your oobag_fun works for this case:\n\n",
- "test_time <- seq(from = 1, to = 5, length.out = 100)\n",
- "test_status <- rep(c(0,1), each = 50)\n\n",
- "y_mat <- cbind(time = test_time, status = test_status)\n",
- "w_vec <- rep(1, times = 100)\n",
- "s_vec <- seq(0.9, 0.1, length.out = 100)\n\n",
- "test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n",
- "test_output should be a numeric value of length 1",
- call. = FALSE)
-
- }
+ test_output <- private$check_oobag_eval_function_internal(input)
if(!is.numeric(test_output)) stop(
"oobag_fun should return a numeric output but instead returns ",
@@ -2375,7 +2354,8 @@ ObliqueForest <- R6::R6Class(
split_rule_R = switch(self$split_rule,
"logrank" = 1,
"cstat" = 2,
- "gini" = 3),
+ "gini" = 3,
+ "variance" = 4),
split_min_events = .dots$split_min_events %||% self$split_min_events %||% 1,
split_min_obs = .dots$split_min_obs %||% self$split_min_obs,
split_min_stat = .dots$split_min_stat %||% self$split_min_stat,
@@ -2401,6 +2381,7 @@ ObliqueForest <- R6::R6Class(
"surv" = 2,
"chf" = 3,
"mort" = 4,
+ "mean" = 5,
"prob" = 6,
"class" = 7,
"leaf" = 8),
@@ -2414,7 +2395,9 @@ ObliqueForest <- R6::R6Class(
"none" = 0,
"harrell's c-index" = 1,
"auc-roc" = 1,
- "user-specified function" = 2
+ "user-specified function" = 2,
+ "mse" = 3,
+ "rsq" = 4
),
oobag_eval_every = .dots$oobag_eval_every %||% self$oobag_eval_every,
# switch(pd_type, "none" = 0L, "smry" = 1L, "ice" = 2L)
@@ -2698,6 +2681,39 @@ ObliqueForestSurvival <- R6::R6Class(
},
+ check_oobag_eval_function_internal = function(oobag_fun){
+
+ test_time <- seq(from = 1, to = 5, length.out = 100)
+ test_status <- rep(c(0,1), each = 50)
+
+ .y_mat <- cbind(time = test_time, status = test_status)
+ .w_vec <- rep(1, times = 100)
+ .s_vec <- seq(0.9, 0.1, length.out = 100)
+
+ test_output <- try(oobag_fun(y_mat = .y_mat,
+ w_vec = .w_vec,
+ s_vec = .s_vec),
+ silent = FALSE)
+
+ if(is_error(test_output)){
+
+ stop("oobag_fun encountered an error when it was tested. ",
+ "Please make sure your oobag_fun works for this case:\n\n",
+ "test_time <- seq(from = 1, to = 5, length.out = 100)\n",
+ "test_status <- rep(c(0,1), each = 50)\n\n",
+ "y_mat <- cbind(time = test_time, status = test_status)\n",
+ "w_vec <- rep(1, times = 100)\n",
+ "s_vec <- seq(0.9, 0.1, length.out = 100)\n\n",
+ "test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n",
+ "test_output should be a numeric value of length 1",
+ call. = FALSE)
+
+ }
+
+ test_output
+
+ },
+
sort_inputs = function(sort_x = TRUE,
sort_y = TRUE,
sort_w = TRUE){
@@ -2790,6 +2806,50 @@ ObliqueForestSurvival <- R6::R6Class(
}
},
+
+ init_pred = function(pred_horizon = NULL, pred_type = NULL,
+ boundary_checks = TRUE){
+
+ pred_type_supplied <- !is.null(pred_type)
+ pred_horizon_supplied <- !is.null(pred_horizon)
+
+ if(pred_type_supplied){
+ private$check_pred_type(oobag = FALSE, pred_type = pred_type)
+ } else {
+ pred_type <- self$pred_type %||% "risk"
+ }
+
+ if(pred_horizon_supplied){
+ private$check_pred_horizon(pred_horizon, boundary_checks, pred_type)
+ } else {
+ pred_horizon <- self$pred_horizon
+ if(is.null(pred_horizon)){
+ stop("pred_horizon was not specified and could not be found in object.",
+ call. = FALSE)
+ }
+ }
+
+ if(pred_type_supplied &&
+ pred_horizon_supplied &&
+ pred_type %in% c('leaf', 'mort')){
+
+ extra_text <- if(length(pred_horizon)>1){
+ " Predictions at each value of pred_horizon will be identical."
+ } else {
+ ""
+ }
+
+ warning("pred_horizon does not impact predictions",
+ " when pred_type is '", pred_type, "'.",
+ extra_text, call. = FALSE)
+
+ }
+
+ self$pred_horizon <- pred_horizon
+ self$pred_type <- pred_type
+
+ },
+
prep_y_internal = function(placeholder = FALSE){
@@ -2894,6 +2954,7 @@ ObliqueForestSurvival <- R6::R6Class(
private$y <- as_matrix(y)
},
+
clean_pred_oobag_internal = function(){
@@ -2915,7 +2976,12 @@ ObliqueForestSurvival <- R6::R6Class(
},
clean_pred_new_internal = function(preds){
- # output in the same order as user's pred_horizon vector
+ # don't let multiple pred horizon values through for mort
+ if(self$pred_type == 'mort'){
+ return(preds[, 1, drop = FALSE])
+ }
+
+ # output in the same order as user's pred_horizon vector
preds <- preds[, order(private$pred_horizon_order), drop = FALSE]
preds
@@ -2953,6 +3019,10 @@ ObliqueForestSurvival <- R6::R6Class(
}
+ # all components are the same if pred type is mort
+ # (user also gets a warning if they ask for this)
+ if(self$pred_type %in% c('mort', 'leaf')) return(results[[1]])
+
return(simplify2array(results))
}
@@ -3010,6 +3080,37 @@ ObliqueForestClassification <- R6::R6Class(
},
+ check_oobag_eval_function_internal = function(oobag_fun){
+
+ test_y <- rep(c(0,1), each = 50)
+
+ .y_mat <- matrix(test_y, ncol = 1)
+ .w_vec <- rep(1, times = 100)
+ .s_vec <- seq(0.9, 0.1, length.out = 100)
+
+ test_output <- try(oobag_fun(y_mat = .y_mat,
+ w_vec = .w_vec,
+ s_vec = .s_vec),
+ silent = FALSE)
+
+ if(is_error(test_output)){
+
+ stop("oobag_fun encountered an error when it was tested. ",
+ "Please make sure your oobag_fun works for this case:\n\n",
+ "test_y <- rep(c(0,1), each = 50)\n",
+ "y_mat <- matrix(test_y, ncol = 1)\n",
+ "w_vec <- rep(1, times = 100)\n",
+ "s_vec <- seq(0.9, 0.1, length.out = 100)\n\n",
+ "test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n",
+ "test_output should be a numeric value of length 1",
+ call. = FALSE)
+
+ }
+
+ test_output
+
+ },
+
init_control = function(){
self$control <- orsf_control_classification(method = 'glm',
@@ -3049,22 +3150,49 @@ ObliqueForestClassification <- R6::R6Class(
+ },
+
+ init_pred = function(pred_horizon = NULL, pred_type = NULL,
+ boundary_checks = TRUE){
+
+ if(!is.null(pred_horizon)){
+ warning("pred_horizon does not impact predictions",
+ " for classification forests", call. = FALSE)
+ }
+
+ if(!is.null(pred_type)){
+ private$check_pred_type(oobag = FALSE, pred_type = pred_type)
+ } else {
+ pred_type <- self$pred_type %||% "prob"
+ }
+
+ self$pred_type <- pred_type
+
},
prep_y_internal = function(placeholder = FALSE){
if(placeholder){
- private$y <- matrix(0, ncol = self$n_class-1, nrow = 1)
+ private$y <- matrix(0, ncol = self$n_class, nrow = 1)
return()
}
# y is always 1 column for classification (right?)
y <- private$y[[1]]
- if(!is.factor(y)) y <- as.factor(y)
+ input_was_numeric <- !is.factor(y)
+
+ if(input_was_numeric) y <- as.factor(y)
n_class <- length(levels(y))
+ if(n_class > 5 && input_was_numeric){
+ stop("The outcome is numeric and has > 5 unique values.",
+ " Did you mean to use orsf_control_regression()? If not,",
+ " please convert ", private$data_names$y, " to a factor and re-run",
+ call. = FALSE)
+ }
+
y <- as.numeric(y) - 1
private$y <- expand_y_clsf(as_matrix(y), n_class)
@@ -3074,7 +3202,176 @@ ObliqueForestClassification <- R6::R6Class(
predict_internal = function(){
# resize y to have the right number of columns
- private$y <- matrix(0, ncol = self$n_class-1)
+ private$y <- matrix(0, ncol = self$n_class)
+
+ cpp_args = private$prep_cpp_args(x = private$x,
+ y = private$y,
+ w = private$w,
+ importance_type = 'none',
+ pred_type = self$pred_type,
+ pred_aggregate = self$pred_aggregate,
+ oobag_pred = FALSE,
+ pred_mode = TRUE,
+ write_forest = FALSE,
+ run_forest = TRUE)
+
+ # no further cleaning needed
+ do.call(orsf_cpp, args = cpp_args)$pred_new
+
+ }
+
+ )
+)
+
+
+# ObliqueForestRegression class ----
+
+ObliqueForestRegression <- R6::R6Class(
+ "ObliqueForestRegression",
+ inherit = ObliqueForest,
+ cloneable = FALSE,
+ public = list(
+
+ n_class = NULL,
+
+ class_levels = NULL
+
+ ),
+ private = list(
+
+ check_split_rule_internal = function(){
+
+ check_arg_is_valid(arg_value = self$split_rule,
+ arg_name = 'split_rule',
+ valid_options = c("variance"))
+
+ },
+
+ check_pred_type_internal = function(oobag, pred_type = NULL){
+
+ input <- pred_type %||% self$pred_type
+
+ arg_name <- if(oobag) 'oobag_pred_type' else 'pred_type'
+
+ check_arg_is_valid(arg_value = input,
+ arg_name = arg_name,
+ valid_options = c("none", "mean", "leaf"))
+
+ },
+
+ check_pred_horizon = function(pred_horizon = NULL,
+ boundary_checks = TRUE,
+ pred_type = NULL){
+
+ # nothing to check
+ NULL
+
+ },
+
+ check_oobag_eval_function_internal = function(oobag_fun){
+
+
+ test_y <- seq(0, 1, length.out = 100)
+
+ .y_mat <- matrix(test_y, ncol = 1)
+ .w_vec <- rep(1, times = 100)
+ .s_vec <- seq(0.9, 0.1, length.out = 100)
+
+ test_output <- try(oobag_fun(y_mat = .y_mat,
+ w_vec = .w_vec,
+ s_vec = .s_vec),
+ silent = FALSE)
+
+ if(is_error(test_output)){
+
+ stop("oobag_fun encountered an error when it was tested. ",
+ "Please make sure your oobag_fun works for this case:\n\n",
+ "test_y <- seq(0, 1, length.out = 100)\n",
+ "y_mat <- matrix(test_y, ncol = 1)\n",
+ "w_vec <- rep(1, times = 100)\n",
+ "s_vec <- seq(0.9, 0.1, length.out = 100)\n\n",
+ "test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n",
+ "test_output should be a numeric value of length 1",
+ call. = FALSE)
+
+ }
+
+ test_output
+
+ },
+
+ init_control = function(){
+
+ self$control <- orsf_control_regression(method = 'glm',
+ scale_x = FALSE,
+ max_iter = 1)
+
+ },
+
+ init_internal = function(){
+
+ self$tree_type <- "regression"
+
+ if(is.factor(self$data[[private$data_names$y]])){
+ stop("Cannot fit regression trees to outcome ",
+ private$data_names$y, " because it is a factor.",
+ " Did you mean to use orsf_control_classification()?",
+ call. = FALSE)
+ }
+
+ if(!is.function(self$control$lincomb_R_function) &&
+ self$control$lincomb_type == 'net'){
+ self$control$lincomb_R_function <- penalized_linreg
+ }
+
+ self$split_rule <- self$split_rule %||% 'variance'
+ self$pred_type <- self$pred_type %||% 'mean'
+ self$split_min_stat <- self$split_min_stat %||%
+ switch(self$split_rule, 'variance' = 0)
+
+ # use default if eval type was not specified by user
+ if(self$oobag_pred_mode && is.null(self$oobag_eval_type)){
+ self$oobag_eval_type <- "RSQ"
+ }
+
+ },
+
+ init_pred = function(pred_horizon = NULL, pred_type = NULL,
+ boundary_checks = TRUE){
+
+ if(!is.null(pred_horizon)){
+ warning("pred_horizon does not impact predictions",
+ " for regression forests", call. = FALSE)
+ }
+
+ if(!is.null(pred_type)){
+ private$check_pred_type(oobag = FALSE, pred_type = pred_type)
+ } else {
+ pred_type <- self$pred_type %||% "mean"
+ }
+
+ self$pred_type <- pred_type
+
+ },
+
+ prep_y_internal = function(placeholder = FALSE){
+
+ if(placeholder){
+ private$y <- matrix(0, ncol = 1, nrow = 1)
+ return()
+ }
+
+ # y is always 1 column for regression (for now)
+ y <- private$y[[1]]
+
+ private$y <- as_matrix(y)
+
+ },
+
+ predict_internal = function(){
+
+ # resize y to have the right number of columns
+ private$y <- matrix(0, ncol = 1)
cpp_args = private$prep_cpp_args(x = private$x,
y = private$y,
diff --git a/R/orsf_predict.R b/R/orsf_predict.R
index 48727c51..c18cab6c 100644
--- a/R/orsf_predict.R
+++ b/R/orsf_predict.R
@@ -81,7 +81,7 @@
predict.ObliqueForest <- function(object,
new_data,
pred_horizon = NULL,
- pred_type = 'risk',
+ pred_type = NULL,
na_action = 'fail',
boundary_checks = TRUE,
n_thread = 1,
@@ -93,22 +93,6 @@ predict.ObliqueForest <- function(object,
# these arguments are mistaken input names since ... isn't used.
check_dots(list(...), .f = predict.ObliqueForest)
- if(pred_type %in% c('leaf', 'mort') && !is.null(pred_horizon)){
-
- extra_text <- if(length(pred_horizon)>1){
- " Predictions at each value of pred_horizon will be identical."
- } else {
- ""
- }
-
- warning("pred_horizon does not impact predictions",
- " when pred_type is '", pred_type, "'.",
- extra_text, call. = FALSE)
-
- }
-
- pred_horizon <- infer_pred_horizon(object, pred_type, pred_horizon)
-
out <- object$predict(new_data = new_data,
pred_horizon = pred_horizon,
pred_type = pred_type,
@@ -122,141 +106,3 @@ predict.ObliqueForest <- function(object,
}
-
-
-
-
-
-
-
-
-
-
-
-# old code in case the infer function fails me:
-# args_tmp <- list(x = x_new,
-# y = matrix(1, ncol=2),
-# w = rep(1, nrow(x_new)),
-# tree_type_R = get_tree_type(object),
-# tree_seeds = get_tree_seeds(object),
-# loaded_forest = object$forest,
-# n_tree = get_n_tree(object),
-# mtry = get_mtry(object),
-# sample_with_replacement = get_sample_with_replacement(object),
-# sample_fraction = get_sample_fraction(object),
-# vi_type_R = 0,
-# vi_max_pvalue = get_importance_max_pvalue(object),
-# oobag_R_function = get_f_oobag_eval(object),
-# leaf_min_events = get_leaf_min_events(object),
-# leaf_min_obs = get_leaf_min_obs(object),
-# split_rule_R = switch(get_split_rule(object),
-# "logrank" = 1,
-# "cstat" = 2),
-# split_min_events = get_split_min_events(object),
-# split_min_obs = get_split_min_obs(object),
-# split_min_stat = get_split_min_stat(object),
-# split_max_cuts = get_n_split(object),
-# split_max_retry = get_n_retry(object),
-# lincomb_R_function = control$lincomb_R_function,
-# lincomb_type_R = switch(control$lincomb_type,
-# 'glm' = 1,
-# 'random' = 2,
-# 'net' = 3,
-# 'custom' = 4),
-# lincomb_eps = control$lincomb_eps,
-# lincomb_iter_max = control$lincomb_iter_max,
-# lincomb_scale = control$lincomb_scale,
-# lincomb_alpha = control$lincomb_alpha,
-# lincomb_df_target = control$lincomb_df_target,
-# lincomb_ties_method = switch(tolower(control$lincomb_ties_method),
-# 'breslow' = 0,
-# 'efron' = 1),
-# pred_type_R = switch(pred_type,
-# "risk" = 1,
-# "surv" = 2,
-# "chf" = 3,
-# "mort" = 4,
-# "leaf" = 8),
-# pred_mode = TRUE,
-# pred_aggregate = pred_aggregate,
-# pred_horizon = pred_horizon_ordered,
-# oobag = FALSE,
-# oobag_eval_type_R = 0,
-# oobag_eval_every = get_n_tree(object),
-# pd_type_R = 0,
-# pd_x_vals = list(matrix(0, ncol=1, nrow=1)),
-# pd_x_cols = list(matrix(1L, ncol=1, nrow=1)),
-# pd_probs = c(0),
-# n_thread = n_thread,
-# write_forest = FALSE,
-# run_forest = TRUE,
-# verbosity = as.integer(verbose_progress))
-#
-# checkout <- c()
-#
-# for(i in names(args)){
-# print(i)
-# if(!is.list(args[[i]]) && !is.function(args[[i]])){
-# if(!all(args[[i]] == args_tmp[[i]]))
-# checkout <- c(checkout, i)
-# }
-# }
-# browser()
-# orsf_out <- orsf_cpp(x = x_new,
-# y = matrix(1, ncol=2),
-# w = rep(1, nrow(x_new)),
-# tree_type_R = get_tree_type(object),
-# tree_seeds = get_tree_seeds(object),
-# loaded_forest = object$forest,
-# n_tree = get_n_tree(object),
-# mtry = get_mtry(object),
-# sample_with_replacement = get_sample_with_replacement(object),
-# sample_fraction = get_sample_fraction(object),
-# vi_type_R = 0,
-# vi_max_pvalue = get_importance_max_pvalue(object),
-# oobag_R_function = get_f_oobag_eval(object),
-# leaf_min_events = get_leaf_min_events(object),
-# leaf_min_obs = get_leaf_min_obs(object),
-# split_rule_R = switch(get_split_rule(object),
-# "logrank" = 1,
-# "cstat" = 2),
-# split_min_events = get_split_min_events(object),
-# split_min_obs = get_split_min_obs(object),
-# split_min_stat = get_split_min_stat(object),
-# split_max_cuts = get_n_split(object),
-# split_max_retry = get_n_retry(object),
-# lincomb_R_function = control$lincomb_R_function,
-# lincomb_type_R = switch(control$lincomb_type,
-# 'glm' = 1,
-# 'random' = 2,
-# 'net' = 3,
-# 'custom' = 4),
-# lincomb_eps = control$lincomb_eps,
-# lincomb_iter_max = control$lincomb_iter_max,
-# lincomb_scale = control$lincomb_scale,
-# lincomb_alpha = control$lincomb_alpha,
-# lincomb_df_target = control$lincomb_df_target,
-# lincomb_ties_method = switch(tolower(control$lincomb_ties_method),
-# 'breslow' = 0,
-# 'efron' = 1),
-# pred_type_R = switch(pred_type,
-# "risk" = 1,
-# "surv" = 2,
-# "chf" = 3,
-# "mort" = 4,
-# "leaf" = 8),
-# pred_mode = TRUE,
-# pred_aggregate = pred_aggregate,
-# pred_horizon = pred_horizon_ordered,
-# oobag = FALSE,
-# oobag_eval_type_R = 0,
-# oobag_eval_every = get_n_tree(object),
-# pd_type_R = 0,
-# pd_x_vals = list(matrix(0, ncol=1, nrow=1)),
-# pd_x_cols = list(matrix(1L, ncol=1, nrow=1)),
-# pd_probs = c(0),
-# n_thread = n_thread,
-# write_forest = FALSE,
-# run_forest = TRUE,
-# verbosity = as.integer(verbose_progress))
-
diff --git a/R/penalized_cph.R b/R/penalized_cph.R
index 33663b0a..210de66d 100644
--- a/R/penalized_cph.R
+++ b/R/penalized_cph.R
@@ -59,7 +59,8 @@ penalized_logreg <- function(x_node,
alpha,
df_target){
- y_node <- as.factor(y_node)
+ col <- sample(ncol(y_node), 1)
+ y_node <- as.factor(y_node[, col])
w_node <- as.numeric(w_node)
penalized_fitter(x_node = x_node,
@@ -71,6 +72,23 @@ penalized_logreg <- function(x_node,
}
+penalized_linreg <- function(x_node,
+ y_node,
+ w_node,
+ alpha,
+ df_target){
+
+ w_node <- as.numeric(w_node)
+
+ penalized_fitter(x_node = x_node,
+ y_node = y_node,
+ w_node = w_node,
+ alpha = alpha,
+ df_target = df_target,
+ family = "gaussian")
+
+}
+
penalized_fitter <- function(x_node,
y_node,
w_node,
diff --git a/man/orsf.Rd b/man/orsf.Rd
index ece67cfc..f6a74d23 100644
--- a/man/orsf.Rd
+++ b/man/orsf.Rd
@@ -8,7 +8,7 @@
orsf(
data,
formula,
- control = orsf_control_fast(),
+ control = NULL,
weights = NULL,
n_tree = 500,
n_split = 5,
diff --git a/man/orsf_control_custom.Rd b/man/orsf_control_custom.Rd
index 921ae356..64cdc2a5 100644
--- a/man/orsf_control_custom.Rd
+++ b/man/orsf_control_custom.Rd
@@ -46,26 +46,17 @@ are shown here.
\}
}\if{html}{\out{}}
-We can plug \code{f_rando} into \code{orsf_control_custom()}, and then pass the
+We can plug \code{f_rando} into \code{orsf_control_survival()}, and then pass the
result into \code{orsf()}:
\if{html}{\out{
}}\preformatted{library(aorsf)
fit_rando <- orsf(pbc_orsf,
Surv(time, status) ~ . - id,
- control = orsf_control_custom(beta_fun = f_rando),
+ control = orsf_control_survival(method = f_rando),
n_tree = 500)
-}\if{html}{\out{
}}
-
-\if{html}{\out{}}\preformatted{## Warning: `orsf_control_custom()` was deprecated in aorsf 0.1.2.
-## i Please use the appropriate survival, classification, or regression control
-## function instead. E.g., `orsf_control_survival(method = your_function)`
-## This warning is displayed once every 8 hours.
-## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
-## generated.
-}\if{html}{\out{
}}
-\if{html}{\out{}}\preformatted{fit_rando
+fit_rando
}\if{html}{\out{
}}
\if{html}{\out{}}\preformatted{## ---------- Oblique random survival forest
@@ -101,12 +92,12 @@ Follow the same steps as above, starting with the custom function:
\}
}\if{html}{\out{
}}
-Then plug the function into \code{orsf_control_custom()} and pass the result
-into \code{orsf()}:
+Then plug the function into \code{orsf_control_survival()} and pass the
+result into \code{orsf()}:
\if{html}{\out{}}\preformatted{fit_pca <- orsf(pbc_orsf,
Surv(time, status) ~ . - id,
- control = orsf_control_custom(beta_fun = f_pca),
+ control = orsf_control_survival(method = f_pca),
n_tree = 500)
}\if{html}{\out{
}}
}
diff --git a/man/orsf_ice_oob.Rd b/man/orsf_ice_oob.Rd
index e7788b77..ee28a776 100644
--- a/man/orsf_ice_oob.Rd
+++ b/man/orsf_ice_oob.Rd
@@ -112,3 +112,60 @@ You can compute individual conditional expectations three ways using a random fo
See examples for more details
}
+\section{Examples}{
+Begin by fitting an ORSF ensemble
+
+\if{html}{\out{}}\preformatted{library(aorsf)
+
+set.seed(329)
+
+fit <- orsf(data = pbc_orsf, formula = Surv(time, status) ~ . - id)
+
+fit
+}\if{html}{\out{
}}
+
+\if{html}{\out{}}\preformatted{## ---------- Oblique random survival forest
+##
+## Linear combinations: Accelerated Cox regression
+## N observations: 276
+## N events: 111
+## N trees: 500
+## N predictors total: 17
+## N predictors per node: 5
+## Average leaves per tree: 21.026
+## Min observations in leaf: 5
+## Min events in leaf: 1
+## OOB stat value: 0.84
+## OOB stat type: Harrell's C-index
+## Variable importance: anova
+##
+## -----------------------------------------
+}\if{html}{\out{
}}
+
+Use the ensemble to compute ICE values using out-of-bag predictions:
+
+\if{html}{\out{}}\preformatted{pred_spec <- list(bili = seq(1, 10, length.out = 25))
+
+ice_oob <- orsf_ice_oob(fit, pred_spec, boundary_checks = FALSE)
+
+ice_oob
+}\if{html}{\out{
}}
+
+\if{html}{\out{}}\preformatted{## id_variable id_row pred_horizon bili pred
+## 1: 1 1 1788 1 0.1264442
+## 2: 1 2 1788 1 0.1739727
+## 3: 1 3 1788 1 0.3904517
+## 4: 1 4 1788 1 0.2874752
+## 5: 1 5 1788 1 0.4398522
+## ---
+## 6896: 25 272 1788 10 0.3076971
+## 6897: 25 273 1788 10 0.4942110
+## 6898: 25 274 1788 10 0.6407498
+## 6899: 25 275 1788 10 0.3871298
+## 6900: 25 276 1788 10 0.6479179
+}\if{html}{\out{
}}
+
+Much more detailed examples are given in the
+\href{https://docs.ropensci.org/aorsf/articles/pd.html#individual-conditional-expectations-ice}{vignette}
+}
+
diff --git a/man/orsf_pd_oob.Rd b/man/orsf_pd_oob.Rd
index f5a4672c..b046293c 100644
--- a/man/orsf_pd_oob.Rd
+++ b/man/orsf_pd_oob.Rd
@@ -133,6 +133,85 @@ See examples for more details
\details{
Partial dependence has a number of \href{https://christophm.github.io/interpretable-ml-book/pdp.html#disadvantages-5}{known limitations and assumptions} that users should be aware of (see Hooker, 2021). In particular, partial dependence is less intuitive when >2 predictors are examined jointly, and it is assumed that the feature(s) for which the partial dependence is computed are not correlated with other features (this is likely not true in many cases). Accumulated local effect plots can be used (see \href{https://christophm.github.io/interpretable-ml-book/ale.html}{here}) in the case where feature independence is not a valid assumption.
}
+\section{Examples}{
+Begin by fitting an ORSF ensemble:
+
+\if{html}{\out{}}\preformatted{library(aorsf)
+
+set.seed(329730)
+
+index_train <- sample(nrow(pbc_orsf), 150)
+
+pbc_orsf_train <- pbc_orsf[index_train, ]
+pbc_orsf_test <- pbc_orsf[-index_train, ]
+
+fit <- orsf(data = pbc_orsf_train,
+ formula = Surv(time, status) ~ . - id,
+ oobag_pred_horizon = 365.25 * 5)
+}\if{html}{\out{
}}
+\subsection{Three ways to compute PD and ICE}{
+
+You can compute partial dependence and ICE three ways with \code{aorsf}:
+\itemize{
+\item using in-bag predictions for the training data
+
+\if{html}{\out{}}\preformatted{pd_train <- orsf_pd_inb(fit, pred_spec = list(bili = 1:5))
+
+pd_train
+}\if{html}{\out{
}}
+
+\if{html}{\out{}}\preformatted{## pred_horizon bili mean lwr medn upr
+## 1: 1826.25 1 0.7932390 0.2177461 0.9060625 0.9816153
+## 2: 1826.25 2 0.7642403 0.1988035 0.8717127 0.9710504
+## 3: 1826.25 3 0.7240284 0.1770122 0.8303501 0.9480047
+## 4: 1826.25 4 0.6744615 0.1615326 0.7599508 0.9088882
+## 5: 1826.25 5 0.6313355 0.1553589 0.7152580 0.8658139
+}\if{html}{\out{
}}
+\item using out-of-bag predictions for the training data
+
+\if{html}{\out{}}\preformatted{pd_train <- orsf_pd_oob(fit, pred_spec = list(bili = 1:5))
+
+pd_train
+}\if{html}{\out{
}}
+
+\if{html}{\out{}}\preformatted{## pred_horizon bili mean lwr medn upr
+## 1: 1826.25 1 0.7840481 0.2727537 0.8694252 0.9809905
+## 2: 1826.25 2 0.7549406 0.2525478 0.8333524 0.9693362
+## 3: 1826.25 3 0.7158234 0.2364582 0.7890158 0.9461864
+## 4: 1826.25 4 0.6656823 0.2260407 0.7158336 0.9151153
+## 5: 1826.25 5 0.6225353 0.2071656 0.6734005 0.8681677
+}\if{html}{\out{
}}
+\item using predictions for a new set of data
+
+\if{html}{\out{}}\preformatted{pd_test <- orsf_pd_new(fit,
+ new_data = pbc_orsf_test,
+ pred_spec = list(bili = 1:5))
+
+pd_test
+}\if{html}{\out{
}}
+
+\if{html}{\out{}}\preformatted{## pred_horizon bili mean lwr medn upr
+## 1: 1826.25 1 0.7524101 0.1868769 0.8121185 0.9803382
+## 2: 1826.25 2 0.7234050 0.1759562 0.7754099 0.9653244
+## 3: 1826.25 3 0.6816975 0.1581292 0.7224945 0.9403449
+## 4: 1826.25 4 0.6339907 0.1467816 0.6598026 0.9000773
+## 5: 1826.25 5 0.5911775 0.1387876 0.6186801 0.8504577
+}\if{html}{\out{
}}
+\item in-bag partial dependence indicates relationships that the model has
+learned during training. This is helpful if your goal is to interpret
+the model.
+\item out-of-bag partial dependence indicates relationships that the model
+has learned during training but using the out-of-bag data simulates
+application of the model to new data. if you want to test your model’s
+reliability or fairness in new data but you don’t have access to a
+large testing set.
+\item new data partial dependence shows how the model predicts outcomes for
+observations it has not seen. This is helpful if you want to test your
+model’s reliability or fairness.
+}
+}
+}
+
\references{
Giles Hooker, Lucas Mentch, Siyu Zhou. Unrestricted Permutation forces Extrapolation: Variable Importance Requires at least One More Model, or There Is No Free Variable Importance. \emph{arXiv e-prints} 2021 Oct; arXiv-1905. URL: https://doi.org/10.48550/arXiv.1905.03151
}
diff --git a/man/orsf_vi.Rd b/man/orsf_vi.Rd
index 30a8db43..f0bfaa8e 100644
--- a/man/orsf_vi.Rd
+++ b/man/orsf_vi.Rd
@@ -216,24 +216,24 @@ orsf_vi_negate(fit_no_vi)
}\if{html}{\out{}}
\if{html}{\out{}}\preformatted{## bili copper sex protime albumin age
-## 0.122403975 0.047873793 0.035948277 0.023733810 0.021851376 0.021517486
+## 0.122344895 0.047850279 0.035986359 0.023711711 0.021831451 0.021503160
## stage ascites chol ast spiders hepato
-## 0.019864687 0.012572127 0.011124625 0.009886485 0.007602310 0.007068509
+## 0.019718835 0.012550534 0.011115307 0.009845811 0.007601474 0.007055077
## edema trt alk.phos trig platelet
-## 0.006428910 0.003703479 0.002386061 0.001258447 -0.001165639
+## 0.006411580 0.003666224 0.002388178 0.001156845 -0.001214167
}\if{html}{\out{
}}
\if{html}{\out{}}\preformatted{orsf_vi_permute(fit_no_vi)
}\if{html}{\out{
}}
\if{html}{\out{}}\preformatted{## bili copper protime age ascites
-## 5.515572e-02 2.183554e-02 1.248239e-02 1.196435e-02 1.176055e-02
+## 5.513908e-02 2.181846e-02 1.246900e-02 1.192659e-02 1.176139e-02
## albumin stage chol spiders edema
-## 1.174498e-02 9.481247e-03 6.227106e-03 5.822340e-03 4.968156e-03
+## 1.175554e-02 9.479348e-03 6.215674e-03 5.752179e-03 4.960035e-03
## ast hepato sex trig alk.phos
-## 4.674158e-03 3.607108e-03 2.501568e-03 1.272238e-03 6.658467e-05
+## 4.647971e-03 3.594325e-03 2.477936e-03 1.162558e-03 6.778008e-05
## platelet trt
-## -1.123911e-03 -1.350953e-03
+## -1.132546e-03 -1.376816e-03
}\if{html}{\out{
}}
}
@@ -250,13 +250,13 @@ orsf_vi_permute(fit_permute_vi)
}\if{html}{\out{}}
\if{html}{\out{}}\preformatted{## bili copper protime albumin age
-## 0.0581641417 0.0256260470 0.0130351708 0.0129893317 0.0121071920
+## 0.0582460338 0.0255992039 0.0130100780 0.0129532316 0.0121027391
## ascites stage chol ast edema
-## 0.0119305736 0.0084039200 0.0070760386 0.0053563176 0.0051454936
+## 0.0119289124 0.0084185175 0.0071302967 0.0053592731 0.0051471990
## spiders hepato sex trig alk.phos
-## 0.0047310213 0.0037519912 0.0026191668 0.0023658449 0.0013233351
+## 0.0046418826 0.0036776097 0.0026334550 0.0024978806 0.0013078222
## platelet trt
-## 0.0003419879 -0.0013496757
+## 0.0003504423 -0.0013892173
}\if{html}{\out{
}}
You can still get negation VI from this fit, but it needs to be computed
@@ -265,11 +265,11 @@ You can still get negation VI from this fit, but it needs to be computed
}\if{html}{\out{}}
\if{html}{\out{}}\preformatted{## bili copper sex protime age albumin
-## 0.1256871293 0.0508556782 0.0362476130 0.0235826879 0.0232958521 0.0225289676
-## stage ascites chol ast spiders edema
-## 0.0211386811 0.0142001708 0.0140190400 0.0108956519 0.0073909899 0.0070418539
-## alk.phos hepato trig trt platelet
-## 0.0049111530 0.0048882405 0.0043590290 0.0038793435 0.0007487968
+## 0.1259391254 0.0507141085 0.0363834330 0.0235136073 0.0233592840 0.0225371677
+## stage chol ascites ast spiders edema
+## 0.0211978251 0.0141956334 0.0141890702 0.0108977272 0.0073762768 0.0070333453
+## hepato alk.phos trig trt platelet
+## 0.0050661672 0.0048879157 0.0044980321 0.0039418881 0.0007189274
}\if{html}{\out{
}}
}
}
diff --git a/man/predict.ObliqueForest.Rd b/man/predict.ObliqueForest.Rd
index 27bdd0da..fad24e11 100644
--- a/man/predict.ObliqueForest.Rd
+++ b/man/predict.ObliqueForest.Rd
@@ -8,7 +8,7 @@
object,
new_data,
pred_horizon = NULL,
- pred_type = "risk",
+ pred_type = NULL,
na_action = "fail",
boundary_checks = TRUE,
n_thread = 1,
@@ -123,7 +123,7 @@ predict(fit,
## [1,] 0.45965512 0.73309199 0.89715078
## [2,] 0.03235764 0.09091330 0.18045864
## [3,] 0.12091603 0.25919883 0.39403239
-## [4,] 0.01568893 0.03825896 0.07691412
+## [4,] 0.01488893 0.03745896 0.07571412
## [5,] 0.01279842 0.02623832 0.06015808
}\if{html}{\out{}}
@@ -138,7 +138,7 @@ predict(fit,
## [1,] 0.5403449 0.2669080 0.1028492
## [2,] 0.9676424 0.9090867 0.8195414
## [3,] 0.8790840 0.7408012 0.6059676
-## [4,] 0.9843111 0.9617410 0.9230859
+## [4,] 0.9851111 0.9625410 0.9242859
## [5,] 0.9872016 0.9737617 0.9398419
}\if{html}{\out{}}
@@ -154,7 +154,7 @@ predict(fit,
## [1,] 0.65381651 1.28606246 1.75476570
## [2,] 0.03531788 0.10967272 0.24697387
## [3,] 0.15371784 0.36989220 0.65462524
-## [4,] 0.01629537 0.04309610 0.09499160
+## [4,] 0.01549537 0.04229610 0.09352493
## [5,] 0.01290261 0.02687956 0.06916273
}\if{html}{\out{}}
@@ -172,7 +172,7 @@ prediction horizon
## [1,] 79.795533
## [2,] 22.393743
## [3,] 38.749709
-## [4,] 13.607722
+## [4,] 13.552788
## [5,] 9.984989
}\if{html}{\out{}}
}
diff --git a/src/Forest.cpp b/src/Forest.cpp
index 0f6f9ec6..1bb81e1c 100644
--- a/src/Forest.cpp
+++ b/src/Forest.cpp
@@ -493,23 +493,6 @@ void Forest::compute_prediction_accuracy(arma::mat& y,
arma::mat& predictions,
arma::uword row_fill){
- if(oobag_eval_type == EVAL_R_FUNCTION){
-
- // initialize function from tree object
- // (Functions can't be stored in C++ classes, but Robjects can)
- Rcpp::Function f_oobag_eval = Rcpp::as(oobag_R_function);
- Rcpp::NumericMatrix y_ = Rcpp::wrap(y);
- Rcpp::NumericVector w_ = Rcpp::wrap(w);
-
- for(uword i = 0; i < oobag_eval.n_cols; ++i){
- vec p = predictions.unsafe_col(i);
- Rcpp::NumericVector p_ = Rcpp::wrap(p);
- Rcpp::NumericVector R_result = f_oobag_eval(y_, w_, p_);
- oobag_eval(row_fill, i) = R_result[0];
- }
- return;
- }
-
compute_prediction_accuracy_internal(y, w, predictions, row_fill);
}
@@ -769,6 +752,7 @@ void Forest::predict_single_thread(Data* prediction_data,
uword eval_row = (progress / oobag_eval_every) - 1;
// mat preds = result.each_col() / oobag_denom;
+
compute_prediction_accuracy(prediction_data, result, eval_row);
}
diff --git a/src/ForestClassification.cpp b/src/ForestClassification.cpp
index 3611ea87..5ece7421 100644
--- a/src/ForestClassification.cpp
+++ b/src/ForestClassification.cpp
@@ -88,18 +88,51 @@ void ForestClassification::compute_prediction_accuracy_internal(
arma::uword row_fill
) {
- double cstat_sum = 0;
+ double result = 0;
- vec y_0 = 1 - sum(y, 1);
- mat y_augment = join_horiz(y_0, y);
+ if(oobag_eval_type == EVAL_R_FUNCTION){
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but Robjects can)
+ Rcpp::Function f_oobag_eval = Rcpp::as(oobag_R_function);
+
+
+ // go through all columns if multi-class y,
+ // but only go through one column if y is binary
+ // uword start = 0;
+ // if(n_class == 2) start = 1;
+
+ Rcpp::NumericVector w_ = Rcpp::wrap(w);
+
+ for(uword i = 0; i < predictions.n_cols; ++i){
+
+ vec y_i = y.unsafe_col(i);
+ vec p_i = predictions.unsafe_col(i);
+
+ Rcpp::NumericVector y_ = Rcpp::wrap(y_i);
+ Rcpp::NumericVector p_ = Rcpp::wrap(p_i);
+
+ Rcpp::NumericVector R_result = f_oobag_eval(y_, w_, p_);
+
+ double result_addon = R_result[0];
+
+ result += result_addon;
+
+ }
+
+ oobag_eval(row_fill, 0) = result / predictions.n_cols;
+
+ return;
+
+ }
for(uword i = 0; i < predictions.n_cols; i++){
- vec y_i = y_augment.unsafe_col(i);
+ vec y_i = y.unsafe_col(i);
vec p_i = predictions.unsafe_col(i);
- cstat_sum += compute_cstat_clsf(y_i, w, p_i);
+ result += compute_cstat_clsf(y_i, w, p_i);
}
- oobag_eval(row_fill, 0) = cstat_sum / predictions.n_cols;
+ oobag_eval(row_fill, 0) = result / predictions.n_cols;
}
diff --git a/src/ForestRegression.cpp b/src/ForestRegression.cpp
new file mode 100644
index 00000000..8878f679
--- /dev/null
+++ b/src/ForestRegression.cpp
@@ -0,0 +1,151 @@
+// Forest.cpp
+
+#include
+#include "ForestRegression.h"
+#include "TreeRegression.h"
+
+#include
+
+using namespace arma;
+using namespace Rcpp;
+
+namespace aorsf {
+
+ForestRegression::ForestRegression() { }
+
+void ForestRegression::resize_pred_mat_internal(arma::mat& p){
+
+ p.zeros(data->n_rows, 1);
+
+ if(verbosity > 3){
+ Rcout << " -- pred mat size: " << p.n_rows << " rows by ";
+ Rcout << p.n_cols << " columns." << std::endl << std::endl;
+ }
+
+}
+
+void ForestRegression::load(
+ arma::uword n_tree,
+ arma::uword n_obs,
+ std::vector& forest_rows_oobag,
+ std::vector>& forest_cutpoint,
+ std::vector>& forest_child_left,
+ std::vector>& forest_coef_values,
+ std::vector>& forest_coef_indices,
+ std::vector>& forest_leaf_pred_prob,
+ std::vector>& forest_leaf_summary,
+ PartialDepType pd_type,
+ std::vector& pd_x_vals,
+ std::vector& pd_x_cols,
+ arma::vec& pd_probs
+) {
+
+ this->n_tree = n_tree;
+ this->pd_type = pd_type;
+ this->pd_x_vals = pd_x_vals;
+ this->pd_x_cols = pd_x_cols;
+ this->pd_probs = pd_probs;
+
+ if(verbosity > 2){
+ Rcout << "---- loading forest from input list ----";
+ Rcout << std::endl << std::endl;
+ }
+
+ // Create trees
+ trees.reserve(n_tree);
+
+ for (uword i = 0; i < n_tree; ++i) {
+ trees.push_back(
+ std::make_unique(n_obs,
+ forest_rows_oobag[i],
+ forest_cutpoint[i],
+ forest_child_left[i],
+ forest_coef_values[i],
+ forest_coef_indices[i],
+ forest_leaf_pred_prob[i],
+ forest_leaf_summary[i])
+ );
+ }
+
+ if(n_thread > 1){
+ // Create thread ranges
+ equalSplit(thread_ranges, 0, n_tree - 1, n_thread);
+ }
+
+
+}
+
+void ForestRegression::compute_prediction_accuracy_internal(
+ arma::mat& y,
+ arma::vec& w,
+ arma::mat& predictions,
+ arma::uword row_fill
+) {
+
+ if(oobag_eval_type == EVAL_R_FUNCTION){
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but Robjects can)
+ Rcpp::Function f_oobag_eval = Rcpp::as(oobag_R_function);
+ Rcpp::NumericMatrix y_ = Rcpp::wrap(y);
+ Rcpp::NumericVector w_ = Rcpp::wrap(w);
+
+ for(uword i = 0; i < oobag_eval.n_cols; ++i){
+ vec p = predictions.unsafe_col(i);
+ Rcpp::NumericVector p_ = Rcpp::wrap(p);
+ Rcpp::NumericVector R_result = f_oobag_eval(y_, w_, p_);
+ oobag_eval(row_fill, i) = R_result[0];
+ }
+
+ return;
+
+ }
+
+ double result = 0;
+
+ for(uword i = 0; i < predictions.n_cols; i++){
+
+ vec y_i = y.unsafe_col(i);
+ vec p_i = predictions.unsafe_col(i);
+
+ if(oobag_eval_type == EVAL_MSE){
+ result += compute_mse(y_i, w, p_i);
+ } else if (oobag_eval_type == EVAL_RSQ){
+ result += compute_rsq(y_i, w, p_i);
+ }
+
+ }
+
+ oobag_eval(row_fill, 0) = result / predictions.n_cols;
+
+}
+
+// growInternal() in ranger
+void ForestRegression::plant() {
+
+ trees.reserve(n_tree);
+
+ for (arma::uword i = 0; i < n_tree; ++i) {
+ trees.push_back(std::make_unique());
+ }
+
+}
+
+std::vector> ForestRegression::get_leaf_pred_prob() {
+
+ std::vector> result;
+
+ result.reserve(n_tree);
+
+ for (auto& tree : trees) {
+ auto& temp = dynamic_cast(*tree);
+ result.push_back(temp.get_leaf_pred_prob());
+ }
+
+ return result;
+
+}
+
+}
+
+
diff --git a/src/ForestRegression.h b/src/ForestRegression.h
new file mode 100644
index 00000000..b5707a3e
--- /dev/null
+++ b/src/ForestRegression.h
@@ -0,0 +1,62 @@
+
+// ForestRegression.h
+
+#ifndef FORESTREGRESSION_H
+#define FORESTREGRESSION_H
+
+#include "Data.h"
+#include "globals.h"
+#include "Forest.h"
+
+namespace aorsf {
+
+class ForestRegression: public Forest {
+
+public:
+
+ ForestRegression();
+
+ virtual ~ForestRegression() override = default;
+
+ ForestRegression(const ForestRegression&) = delete;
+ ForestRegression& operator=(const ForestRegression&) = delete;
+
+ void load(
+ arma::uword n_tree,
+ arma::uword n_obs,
+ std::vector& forest_rows_oobag,
+ std::vector>& forest_cutpoint,
+ std::vector>& forest_child_left,
+ std::vector>& forest_coef_values,
+ std::vector>& forest_coef_indices,
+ std::vector>& forest_leaf_pred_prob,
+ std::vector>& forest_leaf_summary,
+ PartialDepType pd_type,
+ std::vector& pd_x_vals,
+ std::vector& pd_x_cols,
+ arma::vec& pd_probs
+ );
+
+ void resize_pred_mat_internal(arma::mat& p) override;
+
+ void compute_prediction_accuracy_internal(
+ arma::mat& y,
+ arma::vec& w,
+ arma::mat& predictions,
+ arma::uword row_fill
+ ) override;
+
+ // growInternal() in ranger
+ void plant() override;
+
+ std::vector> get_leaf_pred_prob();
+
+ uword n_class;
+
+};
+
+}
+
+
+
+#endif /* ForestRegression_H */
diff --git a/src/ForestSurvival.cpp b/src/ForestSurvival.cpp
index 7bf2a543..10a8ae10 100644
--- a/src/ForestSurvival.cpp
+++ b/src/ForestSurvival.cpp
@@ -163,6 +163,23 @@ void ForestSurvival::compute_prediction_accuracy_internal(
arma::uword row_fill
) {
+ if(oobag_eval_type == EVAL_R_FUNCTION){
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but Robjects can)
+ Rcpp::Function f_oobag_eval = Rcpp::as(oobag_R_function);
+ Rcpp::NumericMatrix y_ = Rcpp::wrap(y);
+ Rcpp::NumericVector w_ = Rcpp::wrap(w);
+
+ for(uword i = 0; i < oobag_eval.n_cols; ++i){
+ vec p = predictions.unsafe_col(i);
+ Rcpp::NumericVector p_ = Rcpp::wrap(p);
+ Rcpp::NumericVector R_result = f_oobag_eval(y_, w_, p_);
+ oobag_eval(row_fill, i) = R_result[0];
+ }
+ return;
+ }
+
bool pred_is_risklike = true;
if(pred_type == PRED_SURVIVAL) pred_is_risklike = false;
diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp
index 876aae8f..77bc6e1a 100644
--- a/src/RcppExports.cpp
+++ b/src/RcppExports.cpp
@@ -138,15 +138,16 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
-// expand_y_clsf
-arma::mat expand_y_clsf(arma::vec& y, arma::uword n_class);
-RcppExport SEXP _aorsf_expand_y_clsf(SEXP ySEXP, SEXP n_classSEXP) {
+// compute_var_reduction_exported
+double compute_var_reduction_exported(arma::vec& y_node, arma::vec& w_node, arma::uvec& g_node);
+RcppExport SEXP _aorsf_compute_var_reduction_exported(SEXP y_nodeSEXP, SEXP w_nodeSEXP, SEXP g_nodeSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< arma::vec& >::type y(ySEXP);
- Rcpp::traits::input_parameter< arma::uword >::type n_class(n_classSEXP);
- rcpp_result_gen = Rcpp::wrap(expand_y_clsf(y, n_class));
+ Rcpp::traits::input_parameter< arma::vec& >::type y_node(y_nodeSEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type w_node(w_nodeSEXP);
+ Rcpp::traits::input_parameter< arma::uvec& >::type g_node(g_nodeSEXP);
+ rcpp_result_gen = Rcpp::wrap(compute_var_reduction_exported(y_node, w_node, g_node));
return rcpp_result_gen;
END_RCPP
}
@@ -262,6 +263,31 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
+// expand_y_clsf
+arma::mat expand_y_clsf(arma::vec& y, arma::uword n_class);
+RcppExport SEXP _aorsf_expand_y_clsf(SEXP ySEXP, SEXP n_classSEXP) {
+BEGIN_RCPP
+ Rcpp::RObject rcpp_result_gen;
+ Rcpp::RNGScope rcpp_rngScope_gen;
+ Rcpp::traits::input_parameter< arma::vec& >::type y(ySEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type n_class(n_classSEXP);
+ rcpp_result_gen = Rcpp::wrap(expand_y_clsf(y, n_class));
+ return rcpp_result_gen;
+END_RCPP
+}
+// compute_mse_exported
+double compute_mse_exported(arma::vec& y, arma::vec& w, arma::vec& p);
+RcppExport SEXP _aorsf_compute_mse_exported(SEXP ySEXP, SEXP wSEXP, SEXP pSEXP) {
+BEGIN_RCPP
+ Rcpp::RObject rcpp_result_gen;
+ Rcpp::RNGScope rcpp_rngScope_gen;
+ Rcpp::traits::input_parameter< arma::vec& >::type y(ySEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type p(pSEXP);
+ rcpp_result_gen = Rcpp::wrap(compute_mse_exported(y, w, p));
+ return rcpp_result_gen;
+END_RCPP
+}
// orsf_cpp
List orsf_cpp(arma::mat& x, arma::mat& y, arma::vec& w, arma::uword tree_type_R, Rcpp::IntegerVector& tree_seeds, Rcpp::List& loaded_forest, Rcpp::RObject lincomb_R_function, Rcpp::RObject oobag_R_function, arma::uword n_tree, arma::uword mtry, bool sample_with_replacement, double sample_fraction, arma::uword vi_type_R, double vi_max_pvalue, double leaf_min_events, double leaf_min_obs, arma::uword split_rule_R, double split_min_events, double split_min_obs, double split_min_stat, arma::uword split_max_cuts, arma::uword split_max_retry, arma::uword lincomb_type_R, double lincomb_eps, arma::uword lincomb_iter_max, bool lincomb_scale, double lincomb_alpha, arma::uword lincomb_df_target, arma::uword lincomb_ties_method, bool pred_mode, arma::uword pred_type_R, arma::vec pred_horizon, bool pred_aggregate, bool oobag, arma::uword oobag_eval_type_R, arma::uword oobag_eval_every, int pd_type_R, std::vector& pd_x_vals, std::vector& pd_x_cols, arma::vec& pd_probs, unsigned int n_thread, bool write_forest, bool run_forest, int verbosity);
RcppExport SEXP _aorsf_orsf_cpp(SEXP xSEXP, SEXP ySEXP, SEXP wSEXP, SEXP tree_type_RSEXP, SEXP tree_seedsSEXP, SEXP loaded_forestSEXP, SEXP lincomb_R_functionSEXP, SEXP oobag_R_functionSEXP, SEXP n_treeSEXP, SEXP mtrySEXP, SEXP sample_with_replacementSEXP, SEXP sample_fractionSEXP, SEXP vi_type_RSEXP, SEXP vi_max_pvalueSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP, SEXP split_rule_RSEXP, SEXP split_min_eventsSEXP, SEXP split_min_obsSEXP, SEXP split_min_statSEXP, SEXP split_max_cutsSEXP, SEXP split_max_retrySEXP, SEXP lincomb_type_RSEXP, SEXP lincomb_epsSEXP, SEXP lincomb_iter_maxSEXP, SEXP lincomb_scaleSEXP, SEXP lincomb_alphaSEXP, SEXP lincomb_df_targetSEXP, SEXP lincomb_ties_methodSEXP, SEXP pred_modeSEXP, SEXP pred_type_RSEXP, SEXP pred_horizonSEXP, SEXP pred_aggregateSEXP, SEXP oobagSEXP, SEXP oobag_eval_type_RSEXP, SEXP oobag_eval_everySEXP, SEXP pd_type_RSEXP, SEXP pd_x_valsSEXP, SEXP pd_x_colsSEXP, SEXP pd_probsSEXP, SEXP n_threadSEXP, SEXP write_forestSEXP, SEXP run_forestSEXP, SEXP verbositySEXP) {
@@ -316,19 +342,6 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
-// compute_var_reduction
-double compute_var_reduction(arma::vec& y_node, arma::vec& w_node, arma::uvec& g_node);
-RcppExport SEXP _aorsf_compute_var_reduction(SEXP y_nodeSEXP, SEXP w_nodeSEXP, SEXP g_nodeSEXP) {
-BEGIN_RCPP
- Rcpp::RObject rcpp_result_gen;
- Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< arma::vec& >::type y_node(y_nodeSEXP);
- Rcpp::traits::input_parameter< arma::vec& >::type w_node(w_nodeSEXP);
- Rcpp::traits::input_parameter< arma::uvec& >::type g_node(g_nodeSEXP);
- rcpp_result_gen = Rcpp::wrap(compute_var_reduction(y_node, w_node, g_node));
- return rcpp_result_gen;
-END_RCPP
-}
static const R_CallMethodDef CallEntries[] = {
{"_aorsf_coxph_fit_exported", (DL_FUNC) &_aorsf_coxph_fit_exported, 6},
@@ -340,7 +353,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_aorsf_compute_logrank_exported", (DL_FUNC) &_aorsf_compute_logrank_exported, 3},
{"_aorsf_compute_gini_exported", (DL_FUNC) &_aorsf_compute_gini_exported, 3},
{"_aorsf_compute_pred_prob_exported", (DL_FUNC) &_aorsf_compute_pred_prob_exported, 2},
- {"_aorsf_expand_y_clsf", (DL_FUNC) &_aorsf_expand_y_clsf, 2},
+ {"_aorsf_compute_var_reduction_exported", (DL_FUNC) &_aorsf_compute_var_reduction_exported, 3},
{"_aorsf_is_col_splittable_exported", (DL_FUNC) &_aorsf_is_col_splittable_exported, 4},
{"_aorsf_find_cuts_survival_exported", (DL_FUNC) &_aorsf_find_cuts_survival_exported, 6},
{"_aorsf_sprout_node_survival_exported", (DL_FUNC) &_aorsf_sprout_node_survival_exported, 2},
@@ -349,8 +362,9 @@ static const R_CallMethodDef CallEntries[] = {
{"_aorsf_x_submat_mult_beta_pd_exported", (DL_FUNC) &_aorsf_x_submat_mult_beta_pd_exported, 8},
{"_aorsf_scale_x_exported", (DL_FUNC) &_aorsf_scale_x_exported, 2},
{"_aorsf_cph_scale", (DL_FUNC) &_aorsf_cph_scale, 2},
+ {"_aorsf_expand_y_clsf", (DL_FUNC) &_aorsf_expand_y_clsf, 2},
+ {"_aorsf_compute_mse_exported", (DL_FUNC) &_aorsf_compute_mse_exported, 3},
{"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 44},
- {"_aorsf_compute_var_reduction", (DL_FUNC) &_aorsf_compute_var_reduction, 3},
{NULL, NULL, 0}
};
diff --git a/src/Tree.cpp b/src/Tree.cpp
index 48671c76..9039320b 100644
--- a/src/Tree.cpp
+++ b/src/Tree.cpp
@@ -756,6 +756,18 @@
mat out;
return(out);
}
+
+ arma::mat Tree::glmnet_fit(){
+ Rcpp::stop("default glmnet fit function called");
+ mat out;
+ return(out);
+ }
+
+ arma::mat Tree::user_fit(){
+ Rcpp::stop("default user fit function called");
+ mat out;
+ return(out);
+ }
// # nocov end
void Tree::grow(arma::vec* vi_numer,
@@ -876,11 +888,8 @@
switch (lincomb_type) {
case LC_GLM: {
-
beta = glm_fit();
-
break;
-
}
case LC_RANDOM_COEFS: {
@@ -898,41 +907,13 @@
}
case LC_GLMNET: {
-
- NumericMatrix xx = wrap(x_node);
- NumericMatrix yy = wrap(y_node);
- NumericVector ww = wrap(w_node);
-
- // initialize function from tree object
- // (Functions can't be stored in C++ classes, but RObjects can)
- Function f_beta = as(lincomb_R_function);
-
- NumericMatrix beta_R = f_beta(xx, yy, ww,
- lincomb_alpha,
- lincomb_df_target);
-
- beta = mat(beta_R.begin(), beta_R.nrow(), beta_R.ncol(), false);
-
+ beta = glmnet_fit();
break;
-
}
case LC_R_FUNCTION: {
-
- NumericMatrix xx = wrap(x_node);
- NumericMatrix yy = wrap(y_node);
- NumericVector ww = wrap(w_node);
-
- // initialize function from tree object
- // (Functions can't be stored in C++ classes, but RObjects can)
- Function f_beta = as(lincomb_R_function);
-
- NumericMatrix beta_R = f_beta(xx, yy, ww);
-
- beta = mat(beta_R.begin(), beta_R.nrow(), beta_R.ncol(), false);
-
+ beta = user_fit();
break;
-
}
} // end switch lincomb_type
diff --git a/src/Tree.h b/src/Tree.h
index b3a5af78..f04598c1 100644
--- a/src/Tree.h
+++ b/src/Tree.h
@@ -219,6 +219,8 @@
void find_rows_inbag(arma::uword n_obs);
virtual arma::mat glm_fit();
+ virtual arma::mat glmnet_fit();
+ virtual arma::mat user_fit();
virtual uword get_n_col_vi()=0;
diff --git a/src/TreeClassification.cpp b/src/TreeClassification.cpp
index 351bcdf9..96e71310 100644
--- a/src/TreeClassification.cpp
+++ b/src/TreeClassification.cpp
@@ -52,11 +52,9 @@
case SPLIT_GINI: {
- for(uword i = 0; i < y_node.n_cols; i++){
- vec y_i = y_node.unsafe_col(i);
- result += compute_gini(y_i, w_node, g_node);
- }
- result /= y_node.n_cols;
+ vec y_i = y_node.unsafe_col(y_col_split);
+ result = compute_gini(y_i, w_node, g_node);
+
// gini index: lower is better, so
// transform to make consistent with other stats
result = (result-1) * -1;
@@ -66,13 +64,10 @@
case SPLIT_CONCORD: {
- for(uword i = 0; i < y_node.n_cols; i++){
- vec y_i = y_node.unsafe_col(i);
- result += compute_cstat_clsf(y_i, w_node, g_node);
- }
- result /= y_node.n_cols;
-
+ vec y_i = y_node.unsafe_col(y_col_split);
+ result = compute_cstat_clsf(y_i, w_node, g_node);
break;
+
}
default:
@@ -140,17 +135,7 @@
arma::mat TreeClassification::glm_fit(){
- vec y_col;
-
- if(splittable_y_cols.size() > 1){
- std::uniform_int_distribution udist_ycol(0, splittable_y_cols.size() - 1);
- uword j = udist_ycol(random_number_generator);
- uword k = splittable_y_cols[j];
- y_col = y_node.unsafe_col(k);
- } else {
- y_col = y_node.unsafe_col(0);
- }
-
+ vec y_col = y_node.unsafe_col(y_col_split);
mat out = logreg_fit(x_node,
y_col,
@@ -163,18 +148,57 @@
}
+ arma::mat TreeClassification::glmnet_fit(){
+
+ arma::vec y_col = y_node.unsafe_col(y_col_split);
+
+ NumericMatrix xx = wrap(x_node);
+ NumericMatrix yy = wrap(y_col);
+ NumericVector ww = wrap(w_node);
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but RObjects can)
+ Function f_beta = as(lincomb_R_function);
+
+ NumericMatrix beta_R = f_beta(xx, yy, ww,
+ lincomb_alpha,
+ lincomb_df_target);
+
+ mat beta = mat(beta_R.begin(), beta_R.nrow(), beta_R.ncol(), false);
+
+ return(beta);
+
+ }
+
+ arma::mat TreeClassification::user_fit(){
+
+ vec y_col = y_node.unsafe_col(y_col_split);
+
+ NumericMatrix xx = wrap(x_node);
+ NumericMatrix yy = wrap(y_col);
+ NumericVector ww = wrap(w_node);
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but RObjects can)
+ Function f_beta = as(lincomb_R_function);
+
+ NumericMatrix beta_R = f_beta(xx, yy, ww);
+
+ mat beta = mat(beta_R.begin(), beta_R.nrow(), beta_R.ncol(), false);
+
+ return(beta);
+
+ }
+
double TreeClassification::compute_prediction_accuracy_internal(
arma::mat& preds
){
double cstat_sum = 0;
- // note: preds includes a column for the non-case, but y does not.
- // That is why the preds column is ahead by 1 here.
-
for(uword i = 0; i < y_oobag.n_cols; i++){
vec y_i = y_oobag.unsafe_col(i);
- vec p_i = preds.unsafe_col(i+1);
+ vec p_i = preds.unsafe_col(i);
cstat_sum += compute_cstat_clsf(y_i, w_oobag, p_i);
}
@@ -186,71 +210,84 @@
double safer_mtry = mtry;
- if(lincomb_type == LC_GLM ||
- lincomb_type == LC_GLMNET){
+ // conditions to split a column:
+ // >= 3 events per predictor
+ // >= 3 non-events per predictor
- // conditions to split a column:
- // >= 3 events per predictor
- // >= 3 non-events per predictor
+ double n = y_node.n_rows;
+ vec y_sum_cases = sum(y_node, 0).t();
+ vec y_sum_ctrls = n - y_sum_cases;
- double n = y_node.n_rows;
- vec y_sum_cases = sum(y_node, 0).t();
- vec y_sum_ctrls = n - y_sum_cases;
+ if(verbosity > 3){
- if(verbosity > 3){
-
- for(uword i = 0; i < y_sum_cases.size(); ++i){
- Rcout << " -- For column " << i << ": ";
- Rcout << y_sum_cases[i] << " cases, ";
- Rcout << y_sum_ctrls[i] << " controls (unweighted)" << std::endl;
- }
+ for(uword i = 0; i < y_sum_cases.size(); ++i){
+ Rcout << " -- For column " << i << ": ";
+ Rcout << y_sum_cases[i] << " cases, ";
+ Rcout << y_sum_ctrls[i] << " controls (unweighted)" << std::endl;
}
+ }
- splittable_y_cols.zeros(y_node.n_cols);
- uword counter = 0;
+ splittable_y_cols.zeros(y_node.n_cols);
- for(uword i = 0; i < y_node.n_cols; ++i){
+ uword counter = 0;
- if(y_sum_cases[i] >= 3 && y_sum_ctrls[i] >= 3){
- splittable_y_cols[counter] = i;
- counter++;
- }
+ for(uword i = 0; i < y_node.n_cols; ++i){
+ if(y_sum_cases[i] >= 3 && y_sum_ctrls[i] >= 3){
+ splittable_y_cols[counter] = i;
+ counter++;
}
- splittable_y_cols.resize(counter);
+ }
- if(counter == 0){
+ splittable_y_cols.resize(counter);
- if(verbosity > 3){
- Rcout << " -- No y columns are splittable" << std::endl << std::endl;
- }
+ if(counter == 0){
- return counter;
+ if(verbosity > 3){
+ Rcout << " -- No y columns are splittable" << std::endl << std::endl;
}
- if(verbosity > 3){
- for(auto &i : splittable_y_cols){
- Rcout << " -- Y column " << i << " is splittable" << std::endl;
- }
+ return counter;
+
+ }
+
+ if(verbosity > 3){
+ for(auto &i : splittable_y_cols){
+ Rcout << " -- Y column " << i << " is splittable" << std::endl;
}
+ }
- // glmnet can handle higher dimension x,
- // but regular glm cannot.
- if(lincomb_type == LC_GLM){
+ uword best_count = 0;
- for (auto& i : splittable_y_cols){
+ for(auto& ycol : splittable_y_cols){
- while (y_sum_cases[i] / safer_mtry < 3 ||
- y_sum_ctrls[i] / safer_mtry < 3){
- --safer_mtry;
- }
+ uword min_count;
- }
+ if(y_sum_cases[ycol] <= y_sum_ctrls[ycol]){
+ min_count = y_sum_cases[ycol];
+ } else {
+ min_count = y_sum_ctrls[ycol];
+ }
+ if(min_count > best_count){
+ y_col_split = ycol;
+ best_count = min_count;
}
+ }
+
+ if(verbosity > 3){
+ Rcout << " -- Most splittable Y column: " << y_col_split << std::endl;
+ }
+ // glmnet can handle higher dimension x,
+ // but other methods probably cannot.
+ if(lincomb_type != LC_GLM){
+
+ while (best_count / safer_mtry < 3){
+ --safer_mtry;
+ }
}
diff --git a/src/TreeClassification.h b/src/TreeClassification.h
index 903792b1..f2f2535e 100644
--- a/src/TreeClassification.h
+++ b/src/TreeClassification.h
@@ -53,6 +53,8 @@
double compute_prediction_accuracy_internal(arma::mat& preds) override;
arma::mat glm_fit() override;
+ arma::mat glmnet_fit() override;
+ arma::mat user_fit() override;
uword get_n_col_vi() override;
@@ -65,6 +67,7 @@
arma::uword n_class;
arma::uvec splittable_y_cols;
+ arma::uword y_col_split;
// prob holds the predicted prob for each class
std::vector leaf_pred_prob;
diff --git a/src/TreeRegression.cpp b/src/TreeRegression.cpp
new file mode 100644
index 00000000..17fe3f11
--- /dev/null
+++ b/src/TreeRegression.cpp
@@ -0,0 +1,241 @@
+/*-----------------------------------------------------------------------------
+ This file is part of aorsf.
+ Author: Byron C Jaeger
+ aorsf may be modified and distributed under the terms of the MIT license.
+#----------------------------------------------------------------------------*/
+
+#include
+#include "TreeRegression.h"
+#include "Coxph.h"
+#include "utility.h"
+// #include "NodeSplitStats.h"
+
+ using namespace arma;
+ using namespace Rcpp;
+
+ namespace aorsf {
+
+ TreeRegression::TreeRegression() { }
+
+ TreeRegression::TreeRegression(arma::uword n_obs,
+ arma::uvec& rows_oobag,
+ std::vector& cutpoint,
+ std::vector& child_left,
+ std::vector& coef_values,
+ std::vector& coef_indices,
+ std::vector& leaf_pred_prob,
+ std::vector& leaf_summary) :
+ Tree(rows_oobag, cutpoint, child_left, coef_values, coef_indices, leaf_summary),
+ leaf_pred_prob(leaf_pred_prob){
+
+ find_rows_inbag(n_obs);
+
+ }
+
+ void TreeRegression::resize_leaves(arma::uword new_size){
+
+ leaf_pred_prob.resize(new_size);
+ leaf_summary.resize(new_size);
+
+ }
+
+
+ double TreeRegression::compute_split_score(){
+
+ double result=0;
+
+ switch (split_rule) {
+
+ case SPLIT_VARIANCE: {
+
+ for(uword i = 0; i < y_node.n_cols; i++){
+ vec y_i = y_node.unsafe_col(i);
+ result += compute_var_reduction(y_i, w_node, g_node);
+ }
+
+ result /= y_node.n_cols;
+ break;
+
+ }
+
+ default:
+ Rcpp::stop("invalid split rule");
+ break;
+
+ }
+
+ return(result);
+
+ }
+
+ void TreeRegression::sprout_leaf_internal(uword node_id){
+
+ double pred_mean = compute_pred_mean(y_node, w_node);
+
+ leaf_summary[node_id] = pred_mean;
+
+ vec quant_probs = {0.25, 0.50, 0.75};
+
+ // TODO: make weighted version
+ vec quant_vals = quantile(y_node, quant_probs);
+
+ leaf_pred_prob[node_id] = quant_vals;
+
+ }
+
+ arma::uword TreeRegression::predict_value_internal(
+ arma::uvec& pred_leaf_sort,
+ arma::mat& pred_output,
+ arma::vec& pred_denom,
+ PredType pred_type,
+ bool oobag
+ ){
+
+ uword n_preds_made = 0;
+
+ if(pred_type == PRED_PROBABILITY){
+
+ for(auto& it : pred_leaf_sort){
+
+ uword leaf_id = pred_leaf[it];
+ if(leaf_id == max_nodes) break;
+ pred_output.row(it) += leaf_pred_prob[leaf_id].t();
+
+ n_preds_made++;
+ if(oobag) pred_denom[it]++;
+
+ }
+
+ } else if(pred_type == PRED_MEAN){
+
+ for(auto& it : pred_leaf_sort){
+
+ uword leaf_id = pred_leaf[it];
+ if(leaf_id == max_nodes) break;
+
+ pred_output.at(it, 0) += leaf_summary[leaf_id];
+
+ n_preds_made++;
+ if(oobag) pred_denom[it]++;
+
+ }
+
+ }
+
+ return(n_preds_made);
+
+ }
+
+ arma::mat TreeRegression::glm_fit(){
+
+ vec y_col = y_node.unsafe_col(0);
+
+ mat out = linreg_fit(x_node,
+ y_col,
+ w_node,
+ lincomb_scale,
+ lincomb_eps,
+ lincomb_iter_max);
+
+ return(out);
+
+ }
+
+ arma::mat TreeRegression::glmnet_fit(){
+
+ NumericMatrix xx = wrap(x_node);
+ NumericMatrix yy = wrap(y_node);
+ NumericVector ww = wrap(w_node);
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but RObjects can)
+ Function f_beta = as(lincomb_R_function);
+
+ NumericMatrix beta_R = f_beta(xx, yy, ww,
+ lincomb_alpha,
+ lincomb_df_target);
+
+ mat beta = mat(beta_R.begin(), beta_R.nrow(), beta_R.ncol(), false);
+
+ return(beta);
+
+ }
+
+ arma::mat TreeRegression::user_fit(){
+
+ NumericMatrix xx = wrap(x_node);
+ NumericMatrix yy = wrap(y_node);
+ NumericVector ww = wrap(w_node);
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but RObjects can)
+ Function f_beta = as(lincomb_R_function);
+
+ NumericMatrix beta_R = f_beta(xx, yy, ww);
+
+ mat beta = mat(beta_R.begin(), beta_R.nrow(), beta_R.ncol(), false);
+
+ return(beta);
+
+ }
+
+
+ double TreeRegression::compute_prediction_accuracy_internal(
+ arma::mat& preds
+ ){
+
+ double mse_sum = 0;
+
+ for(uword i = 0; i < y_oobag.n_cols; i++){
+ vec y_i = y_oobag.unsafe_col(i);
+ vec p_i = preds.unsafe_col(i);
+ mse_sum += compute_mse(y_i, w_oobag, p_i);
+ }
+
+ return mse_sum / preds.n_cols;
+
+ }
+
+ arma::uword TreeRegression::find_safe_mtry(){
+
+ double safer_mtry = mtry;
+
+ if(lincomb_type == LC_GLM ||
+ lincomb_type == LC_GLMNET){
+
+ // conditions to split a column:
+ // >= 3 non-events per predictor
+
+ double n = y_node.n_rows;
+
+ if(verbosity > 3){
+ Rcout << " -- N obs (unweighted): " << n << std::endl;
+ }
+
+ while (n / safer_mtry < 3){
+ --safer_mtry;
+ if(safer_mtry == 0) break;
+ }
+
+ }
+
+ uword out = safer_mtry;
+
+ return(out);
+
+ }
+
+ uword TreeRegression::get_n_col_vi(){
+ return(1);
+ }
+
+ void TreeRegression::fill_pred_values_vi(mat& pred_values){
+
+ for(uword i = 0; i < pred_values.n_rows; ++i){
+ pred_values.at(i, 0) = leaf_summary[pred_leaf[i]];
+ }
+
+ }
+
+ } // namespace aorsf
+
diff --git a/src/TreeRegression.h b/src/TreeRegression.h
new file mode 100644
index 00000000..68ced389
--- /dev/null
+++ b/src/TreeRegression.h
@@ -0,0 +1,77 @@
+/*-----------------------------------------------------------------------------
+ This file is part of aorsf.
+ Author: Byron C Jaeger
+ aorsf may be modified and distributed under the terms of the MIT license.
+#----------------------------------------------------------------------------*/
+
+#ifndef TREEREGRESSION_H_
+#define TREEREGRESSION_H_
+
+
+#include "Data.h"
+#include "globals.h"
+#include "Tree.h"
+
+ namespace aorsf {
+
+ class TreeRegression: public Tree {
+
+ public:
+
+ TreeRegression();
+
+ TreeRegression(const TreeRegression&) = delete;
+ TreeRegression& operator=(const TreeRegression&) = delete;
+
+ virtual ~TreeRegression() override = default;
+
+ TreeRegression(arma::uword n_obs,
+ arma::uvec& rows_oobag,
+ std::vector& cutpoint,
+ std::vector& child_left,
+ std::vector& coef_values,
+ std::vector& coef_indices,
+ std::vector& leaf_pred_prob,
+ std::vector& leaf_summary);
+
+ void resize_leaves(arma::uword new_size) override;
+
+ double compute_split_score() override;
+
+ void sprout_leaf_internal(arma::uword node_id) override;
+
+ arma::uword predict_value_internal(arma::uvec& pred_leaf_sort,
+ arma::mat& pred_output,
+ arma::vec& pred_denom,
+ PredType pred_type,
+ bool oobag) override;
+
+ arma::uword find_safe_mtry() override;
+
+ double compute_prediction_accuracy_internal(arma::mat& preds) override;
+
+ arma::mat glm_fit() override;
+ arma::mat glmnet_fit() override;
+ arma::mat user_fit() override;
+
+ uword get_n_col_vi() override;
+
+ void fill_pred_values_vi(arma::mat& pred_values) override;
+
+ std::vector& get_leaf_pred_prob(){
+ return(leaf_pred_prob);
+ }
+
+ arma::uword n_class;
+
+ arma::uvec splittable_y_cols;
+
+ // prob holds the predicted prob for each class
+ std::vector leaf_pred_prob;
+ // summary (see Tree.h) holds class vote
+
+ };
+
+ } // namespace aorsf
+
+#endif /* TREERegression_H_ */
diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp
index 2068b47f..c46a78b5 100644
--- a/src/TreeSurvival.cpp
+++ b/src/TreeSurvival.cpp
@@ -693,6 +693,44 @@
}
+ arma::mat TreeSurvival::glmnet_fit(){
+
+ NumericMatrix xx = wrap(x_node);
+ NumericMatrix yy = wrap(y_node);
+ NumericVector ww = wrap(w_node);
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but RObjects can)
+ Function f_beta = as(lincomb_R_function);
+
+ NumericMatrix beta_R = f_beta(xx, yy, ww,
+ lincomb_alpha,
+ lincomb_df_target);
+
+ mat beta = mat(beta_R.begin(), beta_R.nrow(), beta_R.ncol(), false);
+
+ return(beta);
+
+ }
+
+ arma::mat TreeSurvival::user_fit(){
+
+ NumericMatrix xx = wrap(x_node);
+ NumericMatrix yy = wrap(y_node);
+ NumericVector ww = wrap(w_node);
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but RObjects can)
+ Function f_beta = as(lincomb_R_function);
+
+ NumericMatrix beta_R = f_beta(xx, yy, ww);
+
+ mat beta = mat(beta_R.begin(), beta_R.nrow(), beta_R.ncol(), false);
+
+ return(beta);
+
+ }
+
uword TreeSurvival::get_n_col_vi(){
return(1);
}
diff --git a/src/TreeSurvival.h b/src/TreeSurvival.h
index 3292f51c..bd8e0e49 100644
--- a/src/TreeSurvival.h
+++ b/src/TreeSurvival.h
@@ -93,6 +93,8 @@
double compute_prediction_accuracy_internal(arma::mat& preds) override;
arma::mat glm_fit() override;
+ arma::mat glmnet_fit() override;
+ arma::mat user_fit() override;
// indx holds the times
std::vector leaf_pred_indx;
diff --git a/src/globals.h b/src/globals.h
index e69b4109..6e8f4ff8 100644
--- a/src/globals.h
+++ b/src/globals.h
@@ -31,13 +31,16 @@
enum SplitRule {
SPLIT_LOGRANK = 1,
SPLIT_CONCORD = 2,
- SPLIT_GINI = 3
+ SPLIT_GINI = 3,
+ SPLIT_VARIANCE = 4
};
enum EvalType {
EVAL_NONE = 0,
EVAL_CONCORD = 1,
- EVAL_R_FUNCTION = 2
+ EVAL_R_FUNCTION = 2,
+ EVAL_MSE = 3,
+ EVAL_RSQ = 4,
};
enum PartialDepType {
diff --git a/src/orsf_oop.cpp b/src/orsf_oop.cpp
index 7a607fd9..8fd508da 100644
--- a/src/orsf_oop.cpp
+++ b/src/orsf_oop.cpp
@@ -19,6 +19,7 @@
#include "Forest.h"
#include "ForestSurvival.h"
#include "ForestClassification.h"
+#include "ForestRegression.h"
#include "Coxph.h"
#include "utility.h"
@@ -122,23 +123,15 @@
){ return compute_pred_prob(y, w); }
// [[Rcpp::export]]
- arma::mat expand_y_clsf(arma::vec& y,
- arma::uword n_class){
-
- arma::mat out(y.n_rows, n_class - 1, arma::fill::zeros);
-
- for(arma::uword i = 0; i < y.n_rows; ++i){
-
- double yval = y[i];
-
- if(yval > 0){ out.at(i, yval-1) = 1; }
-
- }
+ double compute_var_reduction_exported(arma::vec& y_node,
+ arma::vec& w_node,
+ arma::uvec& g_node){
- return(out);
+ return(compute_var_reduction(y_node, w_node, g_node));
}
+
// [[Rcpp::export]]
bool is_col_splittable_exported(arma::mat& x,
arma::mat& y,
@@ -324,6 +317,30 @@
}
+ // [[Rcpp::export]]
+ arma::mat expand_y_clsf(arma::vec& y,
+ arma::uword n_class){
+
+ arma::mat out(y.n_rows, n_class, arma::fill::zeros);
+
+ for(arma::uword i = 0; i < y.n_rows; ++i){
+ out.at(i, y[i]) = 1;
+ }
+
+ return(out);
+
+ }
+
+// [[Rcpp::export]]
+double compute_mse_exported(arma::vec& y,
+ arma::vec& w,
+ arma::vec& p){
+
+ return(compute_mse(y, w, p));
+
+}
+
+
// [[Rcpp::export]]
List orsf_cpp(arma::mat& x,
arma::mat& y,
@@ -389,14 +406,6 @@
uword n_obs = data->get_n_rows();
-
- if(tree_type == TREE_REGRESSION){
-
- stop("that tree type is not ready yet");
-
- }
-
-
if(n_thread == 0){
n_thread = std::thread::hardware_concurrency();
}
@@ -436,11 +445,22 @@
case TREE_CLASSIFICATION:
- forest = std::make_unique(data->n_cols_y + 1);
+ forest = std::make_unique(data->n_cols_y);
if(verbosity > 3){
Rcout << "initializing classification forest" << std::endl;
- Rcout << " -- n_class: " << data->n_cols_y + 1 << std::endl;
+ Rcout << " -- n_class: " << data->n_cols_y << std::endl;
+ Rcout << std::endl << std::endl;
+ }
+
+ break;
+
+ case TREE_REGRESSION:
+
+ forest = std::make_unique();
+
+ if(verbosity > 3){
+ Rcout << "initializing regression forest" << std::endl;
Rcout << std::endl << std::endl;
}
@@ -448,12 +468,12 @@
default:
- Rcpp::stop("only survival and classification trees are currently implemented");
+ Rcpp::stop("unrecognized tree type");
break;
}
- // does the forest need to be grown or is it already grown?
+ // does the forest need to be grown?
bool grow_mode = loaded_forest.size() == 0;
forest->init(std::move(data),
@@ -529,6 +549,17 @@
pd_type, pd_x_vals, pd_x_cols, pd_probs);
+ } else if (tree_type == TREE_REGRESSION){
+
+ std::vector> leaf_pred_prob = loaded_forest["leaf_pred_prob"];
+
+ auto& temp = dynamic_cast(*forest);
+
+ temp.load(n_tree, n_obs, rows_oobag, cutpoint, child_left,
+ coef_values, coef_indices, leaf_pred_prob, leaf_summary,
+ pd_type, pd_x_vals, pd_x_cols, pd_probs);
+
+
}
}
@@ -562,13 +593,22 @@
forest_out.push_back(forest->get_leaf_summary(), "leaf_summary");
if(tree_type == TREE_SURVIVAL){
+
auto& temp = dynamic_cast(*forest);
forest_out.push_back(temp.get_leaf_pred_indx(), "leaf_pred_indx");
forest_out.push_back(temp.get_leaf_pred_prob(), "leaf_pred_prob");
forest_out.push_back(temp.get_leaf_pred_chaz(), "leaf_pred_chaz");
+
} else if (tree_type == TREE_CLASSIFICATION){
+
auto& temp = dynamic_cast(*forest);
forest_out.push_back(temp.get_leaf_pred_prob(), "leaf_pred_prob");
+
+ } else if (tree_type == TREE_REGRESSION){
+
+ auto& temp = dynamic_cast(*forest);
+ forest_out.push_back(temp.get_leaf_pred_prob(), "leaf_pred_prob");
+
}
result.push_back(forest_out, "forest");
@@ -596,51 +636,4 @@
return(result);
}
-
-
- // [[Rcpp::export]]
- double compute_var_reduction(arma::vec& y_node,
- arma::vec& w_node,
- arma::uvec& g_node){
-
- double root_mean = 0, left_mean = 0, right_mean = 0;
- double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;
-
- for(arma::uword i = 0; i < y_node.n_rows; ++i){
-
- double w_i = w_node[i];
- double y_i = y_node[i] * w_i;
-
- root_w_sum += w_i;
- root_mean += y_i;
-
- if(g_node[i] == 1){
- right_w_sum += w_i;
- right_mean += y_i;
- } else {
- left_w_sum += w_i;
- left_mean += y_i;
- }
-
- }
-
- root_mean /= root_w_sum;
- left_mean /= left_w_sum;
- right_mean /= right_w_sum;
-
- double ans = 0;
-
- for(arma::uword i = 0; i < y_node.n_rows; ++i){
-
- double w_i = w_node[i];
- double y_i = y_node[i];
- double g_i = g_node[i];
- double obs_mean = g_i*right_mean + (1 - g_i)*left_mean;
-
- ans += w_i * pow(y_i - root_mean, 2) - w_i * pow(y_i - obs_mean, 2);
-
- }
- ans /= root_w_sum;
- return(ans);
- }
-
+
diff --git a/src/utility.cpp b/src/utility.cpp
index d46d53be..b4a84d48 100644
--- a/src/utility.cpp
+++ b/src/utility.cpp
@@ -474,25 +474,128 @@
}
+ double compute_var_reduction(arma::vec& y,
+ arma::vec& w,
+ arma::uvec& g){
+
+ double root_mean = 0, left_mean = 0, right_mean = 0;
+ double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;
+
+ for(arma::uword i = 0; i < y.n_rows; ++i){
+
+ double w_i = w[i];
+ double y_i = y[i] * w_i;
+
+ root_w_sum += w_i;
+ root_mean += y_i;
+
+ if(g[i] == 1){
+ right_w_sum += w_i;
+ right_mean += y_i;
+ } else {
+ left_w_sum += w_i;
+ left_mean += y_i;
+ }
+
+ }
+
+ root_mean /= root_w_sum;
+ left_mean /= left_w_sum;
+ right_mean /= right_w_sum;
+
+ double ans = 0;
+
+ for(arma::uword i = 0; i < y.n_rows; ++i){
+
+ double w_i = w[i];
+ double y_i = y[i];
+ double g_i = g[i];
+ double obs_mean = g_i*right_mean + (1 - g_i)*left_mean;
+
+ ans += w_i * pow(y_i - root_mean, 2) - w_i * pow(y_i - obs_mean, 2);
+
+ }
+ ans /= root_w_sum;
+ return(ans);
+ }
+
+ double compute_mse(arma::vec& y,
+ arma::vec& w,
+ arma::vec& p){
+
+ double numer = 0;
+ double denom = 0;
+
+ for(uword i = 0; i < p.size(); ++i){
+
+ numer += (y[i] - p[i]) * (y[i] - p[i]) * w[i];
+ denom += w[i];
+
+ }
+
+ return(numer/denom);
+
+ }
+
+ double compute_rsq(arma::vec& y,
+ arma::vec& w,
+ arma::vec& p){
+
+ double truth_mean = compute_pred_mean(y, w);
+
+ double SS_residuals = 0, SS_total = 0;
+
+ for(uword i = 0; i < p.size(); ++i){
+
+ SS_residuals += (y[i] - p[i]) * (y[i] - p[i]) * w[i];
+ SS_total += (y[i] - truth_mean) * (y[i] - truth_mean) * w[i];
+
+ }
+
+ return(1 - (SS_residuals/SS_total));
+
+ }
+
vec compute_pred_prob(mat& y, vec& w){
double n_wtd = 0;
+
vec pred_prob(y.n_cols, fill::zeros);
for(uword i = 0; i < y.n_rows; ++i){
n_wtd += w[i];
+
for(uword j = 0; j < y.n_cols; ++j){
pred_prob[j] += (y.at(i, j) * w[i]);
}
+
}
pred_prob /= n_wtd;
- vec pred_0 = vec {1 - sum(pred_prob)};
- pred_prob = join_vert(pred_0, pred_prob);
+
+ // vec pred_0 = vec {1 - sum(pred_prob)};
+ // pred_prob = join_vert(pred_0, pred_prob);
+
return(pred_prob);
}
+ double compute_pred_mean(mat& y, vec& w){
+
+ double numer = 0;
+ double denom = 0;
+
+ for(uword i = 0; i < y.size(); ++i){
+
+ numer += (y.at(i, 0) * w[i]);
+ denom += w[i];
+
+ }
+
+ return(numer / denom);
+
+ }
+
arma::mat linreg_fit(arma::mat& x_node,
arma::mat& y_node,
arma::vec& w_node,
@@ -500,6 +603,10 @@
double epsilon,
arma::uword iter_max){
+ mat x_transforms;
+
+ if(do_scale) x_transforms = scale_x(x_node, w_node);
+
// Add an intercept column to the design matrix
vec intercept(x_node.n_rows, fill::ones);
mat X = join_horiz(intercept, x_node);
@@ -510,17 +617,38 @@
uword resid_df = X.n_rows - X.n_cols;
- vec beta = solve(X.t() * diagmat(w_node) * X, X.t() * (w_node % y_node));
+ vec beta;
+
+ bool nonsingular = solve(beta,
+ X.t() * diagmat(w_node) * X,
+ X.t() * (w_node % y_node),
+ solve_opts::no_approx);
+
+ if(!nonsingular){
+ mat result(beta.size(), 2, fill::zeros);
+ return(result);
+ }
vec resid = y_node - X * beta;
double s2 = as_scalar(trans(resid) * (w_node % resid) / (resid_df));
- mat beta_cov = s2 * inv(X.t() * diagmat(w_node) * X);
+ mat xtx_inverse;
- vec se = sqrt(diagvec(beta_cov));
+ bool invertible = inv(xtx_inverse, X.t() * diagmat(w_node) * X);
- vec tscores = beta / se;
+ if(!invertible) {
+ mat result(beta.size(), 2, fill::zeros);
+ return(result);
+ }
+
+ mat beta_cov = s2 * xtx_inverse;
+
+ vec beta_var = diagvec(beta_cov);
+
+ if(do_scale) unscale_outputs(x_node, beta, beta_var, x_transforms);
+
+ vec tscores = beta / sqrt(beta_var);
// Calculate two-tailed p-values
vec pvalues(X.n_cols);
@@ -533,7 +661,6 @@
}
-
mat result = join_horiz(beta, pvalues);
return(result.rows(1, result.n_rows-1));
diff --git a/src/utility.h b/src/utility.h
index c0fe330b..bfdfc09a 100644
--- a/src/utility.h
+++ b/src/utility.h
@@ -85,13 +85,28 @@ aorsf may be modified and distributed under the terms of the MIT license.
arma::vec& w,
arma::uvec& g);
+ double compute_mse(arma::vec& y,
+ arma::vec& w,
+ arma::vec& p);
+
+ double compute_rsq(arma::vec& y,
+ arma::vec& w,
+ arma::vec& p);
+
double compute_gini(arma::mat& y,
arma::vec& w,
arma::uvec& g);
+ double compute_var_reduction(arma::vec& y,
+ arma::vec& w,
+ arma::uvec& g);
+
arma::vec compute_pred_prob(arma::mat& y,
arma::vec& w);
+ double compute_pred_mean(arma::mat& y,
+ arma::vec& w);
+
arma::mat linreg_fit(arma::mat& x_node,
arma::mat& y_node,
arma::vec& w_node,
diff --git a/tests/testthat/helper-orsf.R b/tests/testthat/helper-orsf.R
index c3515781..efa0021e 100644
--- a/tests/testthat/helper-orsf.R
+++ b/tests/testthat/helper-orsf.R
@@ -27,6 +27,67 @@ change_scale <- function(x, mult_by = 1/2){
x * mult_by
}
+# R version written using matrixStats
+
+weighted_variance <- function (x, w = NULL, idxs = NULL,
+ na.rm = FALSE, center = NULL,
+ ...) {
+ n <- length(x)
+ if (is.null(w)) {
+ w <- rep(1, times = n)
+ }
+ else if (length(w) != n) {
+ stop(sprintf("The number of elements in arguments '%s' and '%s' does not match: %.0f != %.0f",
+ "w", "x", length(w), n))
+ }
+ else if (!is.null(idxs)) {
+ w <- w[idxs]
+ }
+ if (!is.null(idxs)) {
+ x <- x[idxs]
+ n <- length(x)
+ }
+ na_value <- NA
+ storage.mode(na_value) <- storage.mode(x)
+ tmp <- (is.na(w) | w > 0)
+ if (!all(tmp)) {
+ x <- .subset(x, tmp)
+ w <- .subset(w, tmp)
+ n <- length(x)
+ }
+ tmp <- NULL
+ if (na.rm) {
+ keep <- which(!is.na(x))
+ x <- .subset(x, keep)
+ w <- .subset(w, keep)
+ n <- length(x)
+ keep <- NULL
+ }
+
+ tmp <- is.infinite(w)
+ if (any(tmp)) {
+ keep <- tmp
+ x <- .subset(x, keep)
+ n <- length(x)
+ w <- rep(1, times = n)
+ keep <- NULL
+ }
+ tmp <- NULL
+ if (n <= 1L)
+ return(na_value)
+ wsum <- sum(w)
+ if (is.null(center)) {
+ center <- sum(w * x)/wsum
+ }
+ x <- x - center
+ x <- x^2
+ lambda <- 1/(wsum - 1)
+ sigma2 <- lambda * sum(w * x)
+ x <- w <- NULL
+
+ sigma2
+}
+
#' Find cut-point boundaries (R version)
#'
#' Used to test the cpp version for finding cutpoints
@@ -182,7 +243,10 @@ f_pca <- function(x_node, y_node, w_node) {
pca <- stats::prcomp(x_node, rank. = 2)
# use a random principal component to split the node
- pca$rotation[, 2, drop = FALSE]
+
+ col <- sample(ncol(pca$rotation), 1)
+
+ pca$rotation[, col, drop = FALSE]
}
@@ -194,10 +258,10 @@ expect_equal_leaf_summary <- function(x, y){
tolerance = 1e-9)
}
-expect_equal_oobag_eval <- function(x, y){
+expect_equal_oobag_eval <- function(x, y, tolerance = 1e-9){
expect_equal(x$eval_oobag$stat_values,
y$eval_oobag$stat_values,
- tolerance = 1e-9)
+ tolerance = tolerance)
}
expect_no_missing <- function(x){
@@ -275,9 +339,12 @@ prep_test_matrices <- function(data, outcomes = c("time", "status")){
if(length(outcomes) > 1){
y <- prep_y_surv(data, names_y_data)
sorted <- collapse::radixorder(y[, 1], -y[, 2])
- } else {
+ } else if(is.factor(data[[names_y_data]])) {
y <- prep_y_clsf(data, names_y_data)
sorted <- collapse::seq_row(data)
+ } else {
+ y <- matrix(data[[names_y_data]], ncol = 1)
+ sorted <- collapse::seq_row(data)
}
x <- prep_x(data, fi, names_x_data, means, standard_deviations)
@@ -286,8 +353,8 @@ prep_test_matrices <- function(data, outcomes = c("time", "status")){
return(
list(
- x = x[sorted, ],
- y = y[sorted, ],
+ x = x[sorted, , drop=FALSE],
+ y = y[sorted, , drop=FALSE],
w = w[sorted]
)
)
diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R
index a250deb6..6c96e93b 100644
--- a/tests/testthat/setup.R
+++ b/tests/testthat/setup.R
@@ -50,13 +50,6 @@ for(i in vars){
# make sorted x and y matrices for testing internal cpp functions
pbc_mats <- prep_test_matrices(pbc, outcomes = c("time", "status"))
-# data lists ----
-
-data_list_pbc <- list(pbc_standard = pbc,
- pbc_status_12 = pbc_status_12,
- pbc_scaled = pbc_scale,
- pbc_noised = pbc_noise)
-
# penguins ----
penguins <- penguins_orsf
@@ -64,10 +57,8 @@ penguins <- penguins_orsf
penguins_scale <- penguins_noise <- penguins
-vars <- c("bill_length_mm",
- "bill_depth_mm",
- "flipper_length_mm",
- "body_mass_g")
+vars <- c("bill_length_mm", "bill_depth_mm",
+ "flipper_length_mm", "body_mass_g")
for(i in vars){
penguins_noise[[i]] <- add_noise(penguins_noise[[i]])
@@ -77,13 +68,40 @@ for(i in vars){
# make sorted x and y matrices for testing internal cpp functions
penguins_mats <- prep_test_matrices(penguins, outcomes = c("species"))
+# mtcars ----
+
+mtcars_scale <- mtcars_noise <- mtcars
+
+vars <- c("drat", "wt", "qsec", "disp")
+
+for(i in vars){
+ mtcars_noise[[i]] <- add_noise(mtcars_noise[[i]])
+ mtcars_scale[[i]] <- change_scale(mtcars_scale[[i]])
+}
+
+# make sorted x and y matrices for testing internal cpp functions
+mtcars_mats <- prep_test_matrices(mtcars, outcomes = c("mpg"))
+
# data lists ----
+data_list_pbc <- list(pbc_standard = pbc,
+ pbc_status_12 = pbc_status_12,
+ pbc_scaled = pbc_scale,
+ pbc_noised = pbc_noise)
+
data_list_penguins <- list(penguins_standard = penguins,
penguins_scaled = penguins_scale,
penguins_noised = penguins_noise)
-# matric lists ----
+data_list_mtcars <- list(mtcars_standard = mtcars,
+ mtcars_scaled = mtcars_scale,
+ mtcars_noised = mtcars_noise)
+
+
+
+
+
+# matrix lists ----
mat_list_surv <- list(pbc = pbc_mats,
flc = flc_mats,
@@ -97,9 +115,7 @@ seeds_standard <- 329
n_tree_test <- 5
controls_surv <- list(
- fast = orsf_control_survival(method = 'glm',
- scale_x = FALSE,
- max_iter = 1),
+ fast = orsf_control_survival(method = 'glm', scale_x = FALSE, max_iter = 1),
net = orsf_control_survival(method = 'net'),
custom = orsf_control_survival(method = f_pca)
)
@@ -115,6 +131,40 @@ fit_standard_pbc <- lapply(
}
)
+controls_clsf <- list(
+ fast = orsf_control_classification(method = 'glm', scale_x = FALSE, max_iter = 1),
+ net = orsf_control_classification(method = 'net'),
+ custom = orsf_control_classification(method = f_pca)
+)
+
+fit_standard_penguins <- lapply(
+ controls_clsf,
+ function(cntrl){
+ orsf(penguins,
+ formula = species ~ .,
+ n_tree = n_tree_test,
+ control = cntrl,
+ tree_seed = seeds_standard)
+ }
+)
+
+controls_regr <- list(
+ fast = orsf_control_regression(method = 'glm', scale_x = FALSE, max_iter = 1),
+ net = orsf_control_regression(method = 'net'),
+ custom = orsf_control_regression(method = f_pca)
+)
+
+fit_standard_mtcars <- lapply(
+ controls_regr,
+ function(cntrl){
+ orsf(mtcars,
+ formula = mpg ~ .,
+ n_tree = n_tree_test,
+ control = cntrl,
+ tree_seed = seeds_standard)
+ }
+)
+
# training and testing data ----
pred_types_surv <- c(risk = 'risk',
@@ -123,9 +173,22 @@ pred_types_surv <- c(risk = 'risk',
mort = 'mort',
leaf = 'leaf')
-pbc_train_rows <- sample(nrow(pbc_orsf), size = 170)
+pred_types_clsf <- c(prob = 'prob',
+ class = 'class',
+ leaf = 'leaf')
+
+pred_types_regr <- c(mean = 'mean',
+ leaf = 'leaf')
+pbc_train_rows <- sample(nrow(pbc_orsf), size = 170)
pbc_train <- pbc[pbc_train_rows, ]
pbc_test <- pbc[-pbc_train_rows, ]
+penguins_train_rows <- sample(nrow(penguins_orsf), size = 180)
+penguins_train <- penguins[penguins_train_rows, ]
+penguins_test <- penguins[-penguins_train_rows, ]
+
+mtcars_train_rows <- sample(nrow(mtcars), size = 16)
+mtcars_train <- mtcars[mtcars_train_rows, ]
+mtcars_test <- mtcars[-mtcars_train_rows, ]
diff --git a/tests/testthat/test-compute_pred_prob.R b/tests/testthat/test-compute_pred_prob.R
index b22fe28e..ecd0dd55 100644
--- a/tests/testthat/test-compute_pred_prob.R
+++ b/tests/testthat/test-compute_pred_prob.R
@@ -8,7 +8,6 @@ test_that(
w <- sample(1:5, length(y), replace = TRUE)
y_probs_wtd <- compute_pred_prob_exported(y_expand, w)
target_probs_wtd <- apply(y_expand, 2, weighted.mean, w)
- target_probs_wtd <- c(1-sum(target_probs_wtd), target_probs_wtd)
expect_equal(y_probs_wtd, matrix(target_probs_wtd, ncol = 1))
}
)
diff --git a/tests/testthat/test-compute_var_reduction.R b/tests/testthat/test-compute_var_reduction.R
index 44dfd07f..2b974f17 100644
--- a/tests/testthat/test-compute_var_reduction.R
+++ b/tests/testthat/test-compute_var_reduction.R
@@ -1,72 +1,14 @@
-# R version written using matrixStats
-weightedVar <- function (x, w = NULL, idxs = NULL, na.rm = FALSE, center = NULL,
- ...) {
- n <- length(x)
- if (is.null(w)) {
- w <- rep(1, times = n)
- }
- else if (length(w) != n) {
- stop(sprintf("The number of elements in arguments '%s' and '%s' does not match: %.0f != %.0f",
- "w", "x", length(w), n))
- }
- else if (!is.null(idxs)) {
- w <- w[idxs]
- }
- if (!is.null(idxs)) {
- x <- x[idxs]
- n <- length(x)
- }
- na_value <- NA
- storage.mode(na_value) <- storage.mode(x)
- tmp <- (is.na(w) | w > 0)
- if (!all(tmp)) {
- x <- .subset(x, tmp)
- w <- .subset(w, tmp)
- n <- length(x)
- }
- tmp <- NULL
- if (na.rm) {
- keep <- which(!is.na(x))
- x <- .subset(x, keep)
- w <- .subset(w, keep)
- n <- length(x)
- keep <- NULL
- }
-
- tmp <- is.infinite(w)
- if (any(tmp)) {
- keep <- tmp
- x <- .subset(x, keep)
- n <- length(x)
- w <- rep(1, times = n)
- keep <- NULL
- }
- tmp <- NULL
- if (n <= 1L)
- return(na_value)
- wsum <- sum(w)
- if (is.null(center)) {
- center <- sum(w * x)/wsum
- }
- x <- x - center
- x <- x^2
- lambda <- 1/(wsum - 1)
- sigma2 <- lambda * sum(w * x)
- x <- w <- NULL
-
- sigma2
-}
var_reduction_R <- function(y, w, g){
- (sum(w) - 1)/sum(w) * weightedVar(y, w = w) -
- (sum(w*g) - 1)/(sum(w))*weightedVar(y, w = w, idxs = which(g == 1)) -
- (sum(w*(1-g)) - 1)/(sum(w))*weightedVar(y, w = w, idxs = which(g == 0))
+ (sum(w) - 1)/sum(w) * weighted_variance(y, w = w) -
+ (sum(w*g) - 1)/(sum(w))*weighted_variance(y, w = w, idxs = which(g == 1)) -
+ (sum(w*(1-g)) - 1)/(sum(w))*weighted_variance(y, w = w, idxs = which(g == 0))
}
test_that(
- desc = 'computed variance reduction close to matrixStats::weightedVar',
+ desc = 'computed variance reduction close to matrixStats::weighted_variance',
code = {
n_runs <- 100
@@ -78,20 +20,18 @@ test_that(
y <- rnorm(100)
w <- runif(100, 0, 2)
g <- rbinom(100, 1, 0.5)
- diffs_vec[i] <- abs(compute_var_reduction(y, w, g) -
+ diffs_vec[i] <- abs(compute_var_reduction_exported(y, w, g) -
var_reduction_R(y, w, g))
}
- # unweighted is basically identical to cstat from survival
- expect_lt(mean(diffs_vec), 1e-6)
+ # basically identical to R version
+ expect_equal(diffs_vec, rep(0, length(diffs_vec)), tolerance = 1e-6)
}
)
-# # The cpp implementation is 80+ times faster than the implementation using
-# # matrixStats::weightedVar
# microbenchmark::microbenchmark(
-# cpp = compute_var_reduction(y, w, g),
+# cpp = compute_var_reduction_exported(y, w, g),
# r = var_reduction_R(y, w, g),
# times = 10000
# )
diff --git a/tests/testthat/test-expand_y_clsf.R b/tests/testthat/test-expand_y_clsf.R
index f2c4f584..ae02cebd 100644
--- a/tests/testthat/test-expand_y_clsf.R
+++ b/tests/testthat/test-expand_y_clsf.R
@@ -9,13 +9,20 @@ test_that(
ones <- which(y == 1)
twos <- which(y == 2)
- expect_true(all(y_expand[zeros, ] == 0))
+ # zeros should be 1 in column 1, 0 o.w.
+ expect_true(all(y_expand[zeros, 1] == 1))
+ expect_true(all(y_expand[zeros, 2] == 0))
+ expect_true(all(y_expand[zeros, 3] == 0))
+
# ones should be 1 in column 1, 0 o.w.
- expect_true(all(y_expand[ones, 1] == 1))
- expect_true(all(y_expand[ones, 2] == 0))
+ expect_true(all(y_expand[ones, 1] == 0))
+ expect_true(all(y_expand[ones, 2] == 1))
+ expect_true(all(y_expand[ones, 3] == 0))
+
# twos should be 1 in column 2, 0 o.w.
expect_true(all(y_expand[twos, 1] == 0))
- expect_true(all(y_expand[twos, 2] == 1))
+ expect_true(all(y_expand[twos, 2] == 0))
+ expect_true(all(y_expand[twos, 3] == 1))
}
)
diff --git a/tests/testthat/test-lincomb_linreg.R b/tests/testthat/test-lincomb_linreg.R
index 419637e1..3f9b31cb 100644
--- a/tests/testthat/test-lincomb_linreg.R
+++ b/tests/testthat/test-lincomb_linreg.R
@@ -1,41 +1,49 @@
-# test_that(
-# desc = "linreg_fit with weights approximately equal to glm()",
-# code = {
-#
-# nrows <- 1000
-# ncols <- 20
-#
-# X <- matrix(data = rnorm(nrows*ncols), nrow = nrows, ncol = ncols)
-#
-# # X <- cbind(1, X)
-#
-# colnames(X) <- c(
-# # "intercept",
-# paste0("x", seq(ncols))
-# )
-#
-# Y <- matrix(rnorm(nrows), ncol = 1)
-#
-# glm_data <- as.data.frame(cbind(y=as.numeric(Y), X))
-#
-# # Fit logistic regression using the custom function
-#
-# W <- sample(1:3, nrow(X), replace=TRUE)
-#
-# cpp = linreg_fit_exported(X, Y, W, do_scale = TRUE,
-# epsilon = 1e-9, iter_max = 20)
-#
-# R = lm(y ~ ., weights = as.integer(W), data = glm_data)
-#
-# R_summary <- summary(R)
-#
-# R_beta_est <- as.numeric(R_summary$coefficients[-1, 'Estimate'])
-# R_beta_pvalues <- as.numeric(R_summary$coefficients[-1, 'Pr(>|t|)'])
-#
-# expect_equal(cpp[,1], R_beta_est, tolerance = 1e-9)
-# expect_equal(cpp[,2], R_beta_pvalues, tolerance = 1e-9)
-#
-#
-# }
-# )
+test_that(
+ desc = "linreg_fit with weights approximately equal to glm()",
+ code = {
+
+ nrows <- 1000
+ ncols <- 20
+
+ X <- matrix(data = rnorm(nrows*ncols), nrow = nrows, ncol = ncols)
+
+ # X <- cbind(1, X)
+
+ colnames(X) <- c(
+ # "intercept",
+ paste0("x", seq(ncols))
+ )
+
+ Y <- matrix(rnorm(nrows), ncol = 1)
+
+ glm_data <- as.data.frame(cbind(y=as.numeric(Y), X))
+
+ # Fit logistic regression using the custom function
+
+ W <- sample(1:3, nrow(X), replace=TRUE)
+
+ cpp = linreg_fit_exported(X, Y, W, do_scale = FALSE,
+ epsilon = 1e-9, iter_max = 20)
+
+ cpp_scale = linreg_fit_exported(X, Y, W, do_scale = TRUE,
+ epsilon = 1e-9, iter_max = 20)
+
+ R = lm(y ~ ., weights = as.integer(W), data = glm_data)
+
+ R_summary <- summary(R)
+
+ R_beta_est <- as.numeric(R_summary$coefficients[-1, 'Estimate'])
+ R_beta_pvalues <- as.numeric(R_summary$coefficients[-1, 'Pr(>|t|)'])
+
+ expect_equal(cpp[,1], R_beta_est, tolerance = 1e-9)
+ # scaling changes estimates by a little (that's expected) but not too much
+ expect_equal(cpp_scale[,1], R_beta_est, tolerance = .1)
+
+ expect_equal(cpp[,2], R_beta_pvalues, tolerance = 1e-9)
+ # scaling does not change p-values at all
+ expect_equal(cpp_scale[,2], R_beta_pvalues, tolerance = 1e-9)
+
+
+ }
+)
diff --git a/tests/testthat/test-orsf.R b/tests/testthat/test-orsf.R
index 98e04cb1..d93121ab 100644
--- a/tests/testthat/test-orsf.R
+++ b/tests/testthat/test-orsf.R
@@ -1,11 +1,11 @@
-
-f <- time + status ~ .
-
test_that(
desc = 'non-formula inputs are vetted',
code = {
+ # correct formula
+ f <- time + status ~ .
+
expect_error(orsf(pbc, f, n_tree = 0), "should be >= 1")
expect_error(orsf(pbc, f, n_split = "3"), "should have type")
expect_error(orsf(pbc, f, mtry = 5000), 'should be <=')
@@ -13,6 +13,7 @@ test_that(
expect_error(orsf(pbc, f, leaf_min_obs = 5000), 'should be <=')
expect_error(orsf(pbc, f, attachData = TRUE), 'attach_data?')
expect_error(orsf(pbc, f, Control = 0), 'control?')
+ expect_error(orsf(pbc, f, tree_seeds = c(1,2,3)), 'number of trees')
expect_error(orsf(pbc, f, sample_fraction = 1, oobag_pred_type = 'risk'),
'no samples are out-of-bag')
expect_error(orsf(pbc, f, split_rule = 'cstat', split_min_stat = 1),
@@ -25,12 +26,46 @@ test_that(
}
)
+test_that(
+ desc = 'outcome type can be guessed',
+ code = {
+
+
+ fit_regr <- orsf(mtcars, mpg ~ ., no_fit = TRUE)
+ fit_clsf <- orsf(penguins, species ~ ., no_fit = TRUE)
+ fit_surv <- orsf(pbc, time + status ~ ., no_fit = TRUE)
+
+ expect_s3_class(fit_regr, "ObliqueForestRegression")
+ expect_s3_class(fit_clsf, "ObliqueForestClassification")
+ expect_s3_class(fit_surv, "ObliqueForestSurvival")
+
+ }
+)
+
+test_that(
+ desc = 'potential user-errors with outcome types are caught',
+ code = {
+
+ expect_error(
+ orsf(penguins, species ~., control = orsf_control_regression()),
+ "it is a factor"
+ )
+
+ expect_error(
+ orsf(mtcars, mpg ~., control = orsf_control_classification()),
+ "please convert mpg to a factor"
+ )
+
+ }
+)
+
+
test_that(
desc = 'target_df too high is caught',
code = {
cntrl <- orsf_control_survival(method = 'net', target_df = 10)
- expect_error(orsf(pbc_orsf, formula = f, control = cntrl), 'should be <=')
+ expect_error(orsf(pbc, time + status ~ ., control = cntrl), 'should be <=')
}
)
@@ -47,9 +82,61 @@ test_that(
expect_equal_leaf_summary(fit_dt, fit_standard_pbc$fast)
+ fit_dt <- orsf(as.data.table(penguins),
+ formula = species ~ .,
+ n_tree = n_tree_test,
+ control = controls_clsf$fast,
+ tree_seed = seeds_standard)
+
+ expect_equal_leaf_summary(fit_dt, fit_standard_penguins$fast)
+
+ fit_dt <- orsf(as.data.table(mtcars),
+ formula = mpg ~ .,
+ n_tree = n_tree_test,
+ control = controls_regr$fast,
+ tree_seed = seeds_standard)
+
+ expect_equal_leaf_summary(fit_dt, fit_standard_mtcars$fast)
+
}
)
+test_that(
+ desc = "orsf runs with lists and recipes",
+ code = {
+
+ pbc_list <- as.list(pbc_orsf)
+ pbc_list_bad <- pbc_list
+ pbc_list_bad$trt <- pbc_list_bad$trt[1:3]
+ pbc_list_bad$age <- pbc_list_bad$age[1:5]
+
+ skip_on_cran() # I don't want to list recipes in suggests
+
+ recipe <- recipes::recipe(pbc_orsf, formula = time + status ~ .) %>%
+ recipes::step_rm(id)
+
+ recipe_prepped <- recipes::prep(recipe)
+
+ fit_recipe <- orsf(recipe_prepped, Surv(time, status) ~ .,
+ n_tree = n_tree_test,
+ tree_seeds = seeds_standard)
+
+ expect_equal_leaf_summary(fit_recipe, fit_standard_pbc$fast)
+
+ fit_list <- orsf(pbc_list,
+ Surv(time, status) ~ . - id,
+ n_tree = n_tree_test,
+ tree_seeds = seeds_standard)
+
+ expect_equal_leaf_summary(fit_list, fit_standard_pbc$fast)
+
+ expect_error(
+ orsf(pbc_list_bad, Surv(time, status) ~ .),
+ regexp = 'unable to cast data'
+ )
+
+ }
+)
test_that(
desc = "blank and non-standard names trigger an error",
@@ -81,15 +168,70 @@ test_that(
}
)
+test_that(
+ desc = 'if oobag time is unspecified, pred horizon = median(time)',
+ code = {
+
+ fit_1 <- orsf(data = pbc_orsf,
+ formula = time + status ~ . - id,
+ n_tree = 1)
+
+ fit_2 <- orsf(data = pbc_orsf,
+ formula = time + status ~ . - id,
+ n_tree = 1,
+ oobag_pred_type = 'none')
+
+ expect_equal(fit_1$pred_horizon, median(pbc_orsf$time))
+ expect_equal(fit_1$pred_horizon, fit_2$pred_horizon)
+
+ }
+)
+
+
+test_that(
+ desc = 'list columns are not allowed',
+ code = {
+
+ pbc_temp <- pbc_orsf
+ pbc_temp$list_col <- list(list(a=1))
+
+ expect_error(
+ orsf(pbc_temp, time + status ~ . - id),
+ regexp = ''
+ )
+ }
+)
+
test_that(
desc = "algorithm grows more accurate with higher number of iterations",
code = {
+ eval_every <- max(round(n_tree_test/5), 1)
+
fit <- orsf(pbc,
formula = Surv(time, status) ~ .,
- n_tree = 50,
- oobag_eval_every = 5)
+ n_tree = n_tree_test,
+ tree_seeds = seeds_standard,
+ oobag_eval_every = eval_every)
+
+ expect_lt(fit$eval_oobag$stat_values[1],
+ last_value(fit$eval_oobag$stat_values))
+
+ fit <- orsf(penguins,
+ formula = species ~ .,
+ n_tree = n_tree_test,
+ tree_seeds = seeds_standard,
+ oobag_eval_every = eval_every)
+
+ expect_lt(fit$eval_oobag$stat_values[1],
+ last_value(fit$eval_oobag$stat_values))
+
+ fit <- orsf(mtcars,
+ formula = mpg ~ .,
+ n_tree = n_tree_test*3, # just needs a bit extra
+ tree_seeds = seeds_standard,
+ oobag_eval_every = eval_every)
expect_lt(fit$eval_oobag$stat_values[1],
last_value(fit$eval_oobag$stat_values))
@@ -99,7 +241,7 @@ test_that(
test_that(
- desc = 'Boundary case: empty training data throw an error',
+ desc = 'Empty training data throw an error',
code = {
expect_error(
@@ -115,12 +257,13 @@ test_that(
}
)
-pbc_temp <- pbc_orsf
-pbc_temp[, 'bili'] <- NA_real_
-
test_that(
desc = "Data with all-`NA` fields or columns are rejected",
code = {
+
+ pbc_temp <- pbc
+ pbc_temp[, 'bili'] <- NA_real_
+
expect_error(orsf(pbc_temp, time + status ~ . - id,
na_action = 'omit'),
'complete data')
@@ -132,31 +275,29 @@ test_that(
}
)
-pbc_temp$bili[1:10] <- 12
-
test_that(
desc = "data with missing values are rejected when na_action is fail",
code = {
+ pbc_temp <- pbc
+ pbc_temp[1, 'bili'] <- NA_real_
+
expect_error(orsf(pbc_temp, time + status ~ . - id),
'missing values')
-
}
)
-pbc_temp <- copy(pbc_orsf)
-pbc_temp[1:10, 'bili'] <- NA_real_
-pbc_temp_orig <- copy(pbc_temp)
-
test_that(
desc = 'missing data are dropped when na_action is omit',
code = {
+ pbc_temp <- pbc
+ pbc_temp[1, 'bili'] <- NA_real_
+
fit_omit <- orsf(pbc_temp, time + status ~ .-id, na_action = 'omit')
- expect_equal(fit_omit$n_obs,
- nrow(stats::na.omit(pbc_temp)))
+ expect_equal(fit_omit$n_obs, nrow(stats::na.omit(pbc_temp)))
}
)
@@ -165,169 +306,128 @@ test_that(
desc = 'missing data are imputed when na_action is impute_meanmode',
code = {
- fit_impute <- orsf(pbc_temp,
- time + status ~ .,
+ mtcars_temp <- mtcars
+ mtcars_temp$disp[1] <- NA
+
+ fit_impute <- orsf(mtcars_temp, mpg ~ .,
na_action = 'impute_meanmode')
- expect_equal(fit_impute$n_obs, nrow(pbc_temp))
+ expect_equal(fit_impute$n_obs, nrow(mtcars_temp))
+
+ # users data are not modified by imputation
+ expect_true(is.na(mtcars_temp$disp[1]))
+ expect_identical(mtcars_temp, fit_impute$data)
}
)
test_that(
- "data are not unintentionally modified by reference when imputed",
+ desc = 'robust to threading, outcome formats, scaling, and noising',
code = {
- expect_identical(pbc_temp, pbc_temp_orig)
+
+ fits_surv <- lapply(data_list_pbc[-1], function(data){
+ orsf(data,
+ formula = time + status ~ .,
+ n_thread = 2,
+ n_tree = n_tree_test,
+ tree_seeds = seeds_standard)
+ })
+
+ expect_equal_leaf_summary(fits_surv$pbc_status_12,
+ fit_standard_pbc$fast)
+
+ expect_equal_oobag_eval(fits_surv$pbc_scaled, fit_standard_pbc$fast)
+ expect_equal_oobag_eval(fits_surv$pbc_noised, fit_standard_pbc$fast)
+
+ fits_clsf <- lapply(data_list_penguins[-1], function(data){
+ orsf(data,
+ formula = species ~ .,
+ n_thread = 2,
+ n_tree = n_tree_test,
+ tree_seeds = seeds_standard)
+ })
+
+ expect_equal_oobag_eval(fits_clsf$penguins_scaled, fit_standard_penguins$fast)
+ expect_equal_oobag_eval(fits_clsf$penguins_noised, fit_standard_penguins$fast)
+
+ fits_regr <- lapply(data_list_mtcars[-1], function(data){
+ orsf(data,
+ formula = mpg ~ .,
+ n_thread = 2,
+ n_tree = n_tree_test,
+ tree_seeds = seeds_standard)
+ })
+
+ expect_equal_oobag_eval(fits_regr$mtcars_scaled, fit_standard_mtcars$fast)
+ expect_equal_oobag_eval(fits_regr$mtcars_noised, fit_standard_mtcars$fast)
+
}
)
-pbc_noise <- data_list_pbc$pbc_noised
-pbc_scale <- data_list_pbc$pbc_scaled
-
-n_tree_robust <- 500
-
-fit_orsf <-
- orsf(pbc,
- Surv(time, status) ~ .,
- n_thread = 1,
- n_tree = n_tree_robust,
- tree_seeds = seeds_standard)
-
-fit_orsf_2 <-
- orsf(pbc,
- Surv(time, status) ~ .,
- n_thread = 5,
- n_tree = n_tree_robust,
- tree_seeds = seeds_standard)
-
-fit_orsf_noise <-
- orsf(pbc_noise,
- Surv(time, status) ~ .,
- n_tree = n_tree_robust,
- tree_seeds = seeds_standard)
-
-fit_orsf_scale <-
- orsf(pbc_scale,
- Surv(time, status) ~ .,
- n_tree = n_tree_robust,
- tree_seeds = seeds_standard)
-
-#' @srrstats {ML7.1} *Demonstrate effect of numeric scaling of input data.*
+
test_that(
- desc = 'outputs are robust to multi-threading, scaling, and noising',
+ desc = 'oob error correct for user-specified function',
code = {
- expect_lt(
- abs(
- fit_orsf$eval_oobag$stat_values -
- fit_orsf_scale$eval_oobag$stat_values
- ),
- 0.01
- )
+ fit <- orsf(data = pbc,
+ formula = time + status ~ . -id,
+ n_tree = n_tree_test,
+ oobag_fun = oobag_c_survival,
+ tree_seeds = seeds_standard)
- expect_lt(
- abs(
- fit_orsf$eval_oobag$stat_values -
- fit_orsf_2$eval_oobag$stat_values
- ),
- 0.01
- )
+ expect_equal_oobag_eval(fit, fit_standard_pbc$fast)
- expect_lt(
- abs(
- fit_orsf$eval_oobag$stat_values -
- fit_orsf_noise$eval_oobag$stat_values
+ # can also reproduce it from the oobag predictions
+ expect_equal(
+ oobag_c_survival(
+ y_mat = as.matrix(pbc_orsf[,c("time", "status")]),
+ w_vec = rep(1, nrow(pbc_orsf)),
+ s_vec = fit$pred_oobag
),
- 0.01
+ as.numeric(fit$eval_oobag$stat_values)
)
+ skip_on_cran() # don't want to suggest yardstick or Hmisc
- expect_lt(
- max(abs(fit_orsf$pred_oobag - fit_orsf_scale$pred_oobag)),
- 0.1
- )
+ oobag_rsq_eval <- function(y_mat, w_vec, s_vec){
- expect_lt(
- max(abs(fit_orsf$pred_oobag - fit_orsf_2$pred_oobag)),
- 0.1
- )
+ yardstick::rsq_trad_vec(truth = as.numeric(y_mat),
+ estimate = as.numeric(s_vec),
+ case_weights = as.numeric(w_vec))
+ }
- expect_lt(
- max(abs(fit_orsf$pred_oobag - fit_orsf_noise$pred_oobag)),
- 0.1
- )
+ fit <- orsf(data = mtcars,
+ formula = mpg ~ .,
+ n_tree = n_tree_test,
+ oobag_fun = oobag_rsq_eval,
+ tree_seeds = seeds_standard)
- expect_lt(
- mean(abs(fit_orsf$importance - fit_orsf_noise$importance)),
- 0.1
+ expect_equal(
+ fit$eval_oobag$stat_values[1,1],
+ yardstick::rsq_trad_vec(truth = as.numeric(mtcars$mpg),
+ estimate = as.numeric(fit$pred_oobag),
+ case_weights = rep(1, nrow(mtcars)))
)
- expect_equal(fit_orsf$forest,
- fit_orsf_2$forest)
-
- expect_equal(fit_orsf$importance,
- fit_orsf_2$importance)
-
- expect_equal(fit_orsf$forest$rows_oobag,
- fit_orsf_noise$forest$rows_oobag)
-
- expect_equal(fit_orsf$forest$rows_oobag,
- fit_orsf_scale$forest$rows_oobag)
-
- expect_equal(fit_orsf$forest$leaf_summary,
- fit_orsf_scale$forest$leaf_summary)
-
- }
-)
-
-
-test_that(
- desc = 'oob rows identical with same tree seeds, oob error correct for user-specified function',
- code = {
-
- tree_seeds = sample.int(n = 50000, size = 100)
- bad_tree_seeds <- c(1,2,3)
+ oobag_cstat_clsf <- function(y_mat, w_vec, s_vec){
- expect_error(
- orsf(data = pbc_orsf,
- formula = time+status~.-id,
- n_tree = 100,
- mtry = 2,
- tree_seeds = bad_tree_seeds),
- regexp = 'the number of trees'
- )
-
- fit_1 <- orsf(data = pbc_orsf,
- formula = time+status~.-id,
- n_tree = 100,
- mtry = 2,
- tree_seeds = tree_seeds)
+ y_vec = as.numeric(y_mat)
+ cstat <- Hmisc::somers2(x = s_vec,
+ y = y_vec,
+ weights = w_vec)['C']
+ cstat
- fit_2 <- orsf(data = pbc_orsf,
- formula = time+status~.-id,
- n_tree = 100,
- mtry = 6,
- tree_seeds = tree_seeds)
+ }
- expect_equal(fit_1$forest$rows_oobag,
- fit_2$forest$rows_oobag)
+ fit <- orsf(data = penguins,
+ formula = species ~ .,
+ n_tree = n_tree_test,
+ oobag_fun = oobag_cstat_clsf,
+ tree_seeds = seeds_standard)
- fit_3 <- orsf(data = pbc_orsf,
- formula = time+status~.-id,
- n_tree = 100,
- mtry = 6,
- oobag_fun = oobag_c_survival,
- tree_seeds = tree_seeds)
+ expect_equal_oobag_eval(fit, fit_standard_penguins$fast)
- expect_equal(
- oobag_c_survival(
- y_mat = as.matrix(pbc_orsf[,c("time", "status")]),
- w_vec = rep(1, nrow(pbc_orsf)),
- s_vec = fit_3$pred_oobag
- ),
- as.numeric(fit_3$eval_oobag$stat_values)
- )
}
)
@@ -363,68 +463,35 @@ test_that(
)
test_that(
- desc = "results are similar after adding trivial noise",
+ desc = 'orsf_fit objects can be saved and loaded with saveRDS and readRDS',
code = {
- expect_true(
- abs(fit_orsf$eval_oobag$stat_values - fit_orsf_noise$eval_oobag$stat_values) < 0.01
- )
-
- expect_true(
- mean(abs(fit_orsf$pred_oobag-fit_orsf_noise$pred_oobag)) < 0.1
- )
+ skip_on_cran()
- }
+ fil <- tempfile("fit_orsf", fileext = ".rds")
-)
+ ## save a single object to file
+ saveRDS(fit_standard_pbc$fast, fil)
+ ## restore it under a different name
+ fit <- readRDS(fil)
-# test_that(
-# desc = 'orsf_fit objects can be saved and loaded with saveRDS and readRDS',
-# code = {
-#
-# fil <- tempfile("fit_orsf", fileext = ".rds")
-#
-# ## save a single object to file
-# saveRDS(fit_orsf, fil)
-# ## restore it under a different name
-# fit_orsf_read_in <- readRDS(fil)
-#
-# # NULL these attributes because they are functions
-# # the env of functions in fit_orsf_read_in will not be identical to the env
-# # of functions in fit_orsf. Everything else should be identical.
-#
-# attr(fit_orsf, 'f_beta') <- NULL
-# attr(fit_orsf_read_in, 'f_beta') <- NULL
-#
-# attr(fit_orsf, 'f_oobag_eval') <- NULL
-# attr(fit_orsf_read_in, 'f_oobag_eval') <- NULL
-#
-# expect_equal(fit_orsf, fit_orsf_read_in)
-#
-# p1=predict(fit_orsf,
-# new_data = fit_orsf$data,
-# pred_horizon = 1000)
-#
-# p2=predict(fit_orsf_read_in,
-# new_data = fit_orsf_read_in$data,
-# pred_horizon = 1000)
-#
-# expect_equal(p1, p2)
-#
-# }
-# )
+ p1 <- predict(fit_standard_pbc$fast, new_data = pbc_test)
+ p2 <- predict(fit, new_data = pbc_test)
+ expect_equal(p1, p2)
+ }
+)
test_that(
- desc = 'orsf() runs as intended for valid inputs',
+ desc = 'oblique survival forests run as intended for valid inputs',
code = {
# just takes forever.
skip_on_cran()
inputs <- expand.grid(
- data_format = c('plain', 'tibble', 'data.table'),
+ data_format = c('plain'),
n_tree = 1,
n_split = 1,
n_retry = 0,
@@ -435,7 +502,7 @@ test_that(
split_rule = c("logrank", "cstat"),
split_min_events = 5,
split_min_obs = 15,
- oobag_pred_type = c('none', 'risk', 'surv', 'chf', 'mort'),
+ oobag_pred_type = c('none', 'risk', 'mort'),
oobag_pred_horizon = c(1,2,3),
orsf_control = c('cph', 'net', 'custom'),
stringsAsFactors = FALSE
@@ -483,7 +550,7 @@ test_that(
oobag_pred_type = inputs$oobag_pred_type[i],
oobag_pred_horizon = pred_horizon)
- expect_s3_class(fit, class = 'ObliqueForest')
+ expect_s3_class(fit, class = 'ObliqueForestSurvival')
# data are not unintentionally modified by reference,
expect_identical(data_fun(pbc_orsf), fit$data)
@@ -551,103 +618,247 @@ test_that(
)
test_that(
- desc = 'if oobag time is unspecified, pred horizon = median(time)',
+ desc = 'oblique classification forests run as intended for valid inputs',
code = {
- fit_1 <- orsf(data = pbc_orsf,
- formula = time + status ~ . - id,
- n_tree = 1)
+ # just takes forever.
+ skip_on_cran()
- fit_2 <- orsf(data = pbc_orsf,
- formula = time + status ~ . - id,
- n_tree = 1,
- oobag_pred_type = 'none')
+ inputs <- expand.grid(
+ data_format = c('plain'),
+ n_tree = 1,
+ n_split = 1,
+ n_retry = 0,
+ mtry = 3,
+ sample_with_replacement = c(TRUE, FALSE),
+ leaf_min_obs = 10,
+ split_rule = c("gini", "cstat"),
+ split_min_obs = 15,
+ oobag_pred_type = c('none', 'prob'),
+ orsf_control = c('glm', 'net', 'custom'),
+ stringsAsFactors = FALSE
+ )
- expect_equal(fit_1$pred_horizon, fit_2$pred_horizon)
+ for(i in seq(nrow(inputs))){
- }
-)
+ data_fun <- switch(
+ as.character(inputs$data_format[i]),
+ 'plain' = function(x) x,
+ 'tibble' = tibble::as_tibble,
+ 'data.table' = as.data.table
+ )
-pbc_temp <- pbc_orsf
-pbc_temp$list_col <- list(list(a=1))
+ control <- switch(inputs$orsf_control[i],
+ 'glm' = orsf_control_classification(method = 'glm'),
+ 'net' = orsf_control_classification(method = 'net'),
+ 'custom' = orsf_control_classification(method = f_pca))
-#' @srrstats {G2.12} *pre-processing identifies list columns and throws informative error*
+ if(inputs$sample_with_replacement[i]){
+ sample_fraction <- 0.632
+ } else {
+ sample_fraction <- runif(n = 1, min = .25, max = .75)
+ }
-test_that(
- desc = 'list columns are not allowed',
- code = {
- expect_error(
- orsf(pbc_temp, time + status ~ . - id),
- regexp = ''
- )
- }
-)
+ fit <- orsf(data = data_fun(penguins_orsf),
+ formula = species ~ .,
+ control = control,
+ sample_with_replacement = inputs$sample_with_replacement[i],
+ sample_fraction = sample_fraction,
+ n_tree = inputs$n_tree[i],
+ n_split = inputs$n_split[i],
+ n_retry = inputs$n_retry[i],
+ mtry = inputs$mtry[i],
+ leaf_min_events = inputs$leaf_min_events[i],
+ leaf_min_obs = inputs$leaf_min_obs[i],
+ split_rule = inputs$split_rule[i],
+ split_min_events = inputs$split_min_events[i],
+ split_min_obs = inputs$split_min_obs[i],
+ oobag_pred_type = inputs$oobag_pred_type[i])
-fit_unwtd <- orsf(pbc_orsf, Surv(time, status) ~ . - id)
+ expect_s3_class(fit, class = 'ObliqueForestClassification')
-fit_wtd <- orsf(pbc_orsf, Surv(time, status) ~ . - id,
- weights = rep(2, nrow(pbc_orsf)))
+ # data are not unintentionally modified by reference,
+ expect_identical(data_fun(penguins_orsf), fit$data)
-test_that(
- desc = 'weights work as intended',
- code = {
- # using weights should make the trees much deeper:
- expect_gt(fit_wtd$get_mean_leaves_per_tree(),
- fit_unwtd$get_mean_leaves_per_tree())
+ expect_no_missing(fit$forest)
+ expect_no_missing(fit$importance)
+
+ expect_length(fit$forest$rows_oobag, n = fit$n_tree)
+ expect_length(fit$forest$cutpoint, n = fit$n_tree)
+ expect_length(fit$forest$child_left, n = fit$n_tree)
+ expect_length(fit$forest$coef_indices, n = fit$n_tree)
+ expect_length(fit$forest$coef_values, n = fit$n_tree)
+ expect_length(fit$forest$leaf_summary, n = fit$n_tree)
+
+ if(!inputs$sample_with_replacement[i]){
+ expect_equal(
+ 1 - length(fit$forest$rows_oobag[[1]]) / fit$n_obs,
+ sample_fraction,
+ tolerance = 0.025
+ )
+ }
+
+ if(inputs$oobag_pred_type[i] != 'none'){
+
+ expect_length(fit$eval_oobag$stat_values, 1)
+
+ expect_equal(nrow(fit$pred_oobag), fit$n_obs)
+
+ # these lengths should match for n_tree=1
+ # b/c only the oobag rows of the first tree
+ # will get a prediction value. Note that the
+ # vectors themselves aren't equal b/c rows_oobag
+ # corresponds to the sorted version of the data.
+ expect_equal(
+ length(which(complete.cases(fit$pred_oobag))),
+ length(fit$forest$rows_oobag[[1]])
+ )
+
+ oobag_preds <- na.omit(fit$pred_oobag)
+
+ expect_true(all(apply(oobag_preds, 1, sum) == 1))
+ expect_true(all(oobag_preds >= 0))
+ expect_true(all(oobag_preds <= 1))
+
+ } else {
+
+ expect_equal(dim(fit$eval_oobag$stat_values), c(0, 0))
+
+ }
+
+ }
}
)
test_that(
- desc = "lists can be plugged into orsf",
+ desc = 'oblique regression forests run as intended for valid inputs',
code = {
- pbc_list <- as.list(pbc_orsf)
- pbc_list_bad <- pbc_list
- pbc_list_bad$trt <- pbc_list_bad$trt[1:3]
- pbc_list_bad$age <- pbc_list_bad$age[1:5]
+ # just takes forever.
+ skip_on_cran()
- # # only run locally - I don't want to list recipes in suggests
- # recipe <- recipes::recipe(pbc_orsf, formula = time + status ~ .) %>%
- # recipes::step_rm(id) %>%
- # recipes::step_scale(recipes::all_numeric_predictors())
- #
- # recipe_prepped <- recipes::prep(recipe)
- #
- # fit_recipe <- orsf(recipe_prepped, Surv(time, status) ~ .)
- #
- # expect_s3_class(fit_recipe, 'ObliqueForest')
+ inputs <- expand.grid(
+ data_format = c('plain'),
+ n_tree = 1,
+ n_split = 1,
+ n_retry = 0,
+ mtry = 3,
+ sample_with_replacement = c(TRUE, FALSE),
+ leaf_min_obs = 3,
+ split_rule = c("variance"),
+ split_min_obs = 6,
+ oobag_pred_type = c('none', 'mean'),
+ orsf_control = c('glm', 'net', 'custom'),
+ stringsAsFactors = FALSE
+ )
- fit_list <- orsf(pbc_list, Surv(time, status) ~ .)
+ for(i in seq(nrow(inputs))){
- expect_s3_class(fit_list, 'ObliqueForest')
+ data_fun <- switch(
+ as.character(inputs$data_format[i]),
+ 'plain' = function(x) x,
+ 'tibble' = tibble::as_tibble,
+ 'data.table' = as.data.table
+ )
- expect_error(
- orsf(pbc_list_bad, Surv(time, status) ~ .),
- regexp = 'unable to cast data'
- )
+ control <- switch(inputs$orsf_control[i],
+ 'glm' = orsf_control_regression(method = 'glm'),
+ 'net' = orsf_control_regression(method = 'net'),
+ 'custom' = orsf_control_regression(method = f_pca))
+
+ if(inputs$sample_with_replacement[i]){
+ sample_fraction <- 0.632
+ } else {
+ sample_fraction <- runif(n = 1, min = .25, max = .75)
+ }
+
+ fit <- orsf(data = data_fun(mtcars),
+ formula = mpg ~ .,
+ control = control,
+ sample_with_replacement = inputs$sample_with_replacement[i],
+ sample_fraction = sample_fraction,
+ n_tree = inputs$n_tree[i],
+ n_split = inputs$n_split[i],
+ n_retry = inputs$n_retry[i],
+ mtry = inputs$mtry[i],
+ leaf_min_events = inputs$leaf_min_events[i],
+ leaf_min_obs = inputs$leaf_min_obs[i],
+ split_rule = inputs$split_rule[i],
+ split_min_events = inputs$split_min_events[i],
+ split_min_obs = inputs$split_min_obs[i],
+ oobag_pred_type = inputs$oobag_pred_type[i])
+
+ expect_s3_class(fit, class = 'ObliqueForestRegression')
+
+ # data are not unintentionally modified by reference,
+ expect_identical(data_fun(mtcars), fit$data)
+
+
+ expect_no_missing(fit$forest)
+ expect_no_missing(fit$importance)
+
+ expect_length(fit$forest$rows_oobag, n = fit$n_tree)
+ expect_length(fit$forest$cutpoint, n = fit$n_tree)
+ expect_length(fit$forest$child_left, n = fit$n_tree)
+ expect_length(fit$forest$coef_indices, n = fit$n_tree)
+ expect_length(fit$forest$coef_values, n = fit$n_tree)
+ expect_length(fit$forest$leaf_summary, n = fit$n_tree)
+
+ if(!inputs$sample_with_replacement[i]){
+ expect_equal(
+ 1 - length(fit$forest$rows_oobag[[1]]) / fit$n_obs,
+ sample_fraction,
+ # bigger tolerance b/c sample size is small
+ tolerance = 0.075
+ )
+ }
+
+ if(inputs$oobag_pred_type[i] != 'none'){
+
+ expect_length(fit$eval_oobag$stat_values, 1)
+
+ expect_equal(nrow(fit$pred_oobag), fit$n_obs)
+
+ # these lengths should match for n_tree=1
+ # b/c only the oobag rows of the first tree
+ # will get a prediction value. Note that the
+ # vectors themselves aren't equal b/c rows_oobag
+ # corresponds to the sorted version of the data.
+ expect_equal(
+ length(which(complete.cases(fit$pred_oobag))),
+ length(fit$forest$rows_oobag[[1]])
+ )
+
+
+ } else {
+
+ expect_equal(dim(fit$eval_oobag$stat_values), c(0, 0))
+
+ }
+
+ }
}
)
test_that(
- desc = 'oobag error works w/oobag_eval_every & custom oobag fun works',
+ desc = 'weights work as intended',
code = {
- fit_custom_oobag <- orsf(pbc,
- formula = Surv(time, status) ~ .,
- n_tree = n_tree_test,
- oobag_eval_every = 1,
- oobag_fun = oobag_c_survival,
- tree_seeds = seeds_standard)
+ fit_unwtd <- orsf(pbc_orsf,
+ Surv(time, status) ~ . - id,
+ n_tree = n_tree_test)
- expect_equal_leaf_summary(fit_custom_oobag, fit_standard_pbc$fast)
+ fit_wtd <- orsf(pbc_orsf,
+ Surv(time, status) ~ . - id,
+ weights = rep(2, nrow(pbc_orsf)),
+ n_tree = n_tree_test)
- expect_equal(
- get_last_oob_stat_value(fit_standard_pbc$fast),
- get_last_oob_stat_value(fit_custom_oobag)
- )
+ # using weights should make the trees much deeper:
+ expect_gt(fit_wtd$get_mean_leaves_per_tree(),
+ fit_unwtd$get_mean_leaves_per_tree())
}
)
@@ -655,23 +866,8 @@ test_that(
-# Similar to obliqueRSF?
-# suppressPackageStartupMessages({
-# library(obliqueRSF)
-# })
-#
-# set.seed(50)
-#
-# fit_aorsf <- orsf(pbc_orsf,
-# formula = Surv(time, status) ~ . - id,
-# n_tree = 100)
-# fit_obliqueRSF <- ORSF(pbc_orsf, ntree = 100, verbose = FALSE)
-#
-#
-# risk_aorsf <- predict(fit_aorsf, new_data = pbc_orsf, pred_horizon = 3500)
-# risk_obliqueRSF <- 1-predict(fit_obliqueRSF, newdata = pbc_orsf, times = 3500)
-#
-# cor(risk_obliqueRSF, risk_aorsf)
-# plot(risk_obliqueRSF, risk_aorsf)
+
+
+
diff --git a/vignettes/aorsf.Rmd b/vignettes/aorsf.Rmd
index d831c6b1..d6a3d36b 100644
--- a/vignettes/aorsf.Rmd
+++ b/vignettes/aorsf.Rmd
@@ -23,7 +23,7 @@ This article covers core features of the `aorsf` package.
## Background
-The oblique random forest (RF) is an extension of the axis-based RF. Instead of using a single variable to split data and grow new branches, trees in the oblique RF use a weighted combination of multiple variables.
+The oblique random forest (RF) is an extension of the traditional (axis-based) RF. Instead of using a single variable to split data and grow new branches, trees in the oblique RF use a weighted combination of multiple variables.
## Oblique RFs for survival, classification, and regression
@@ -49,6 +49,14 @@ penguin_fit <- orsf(data = penguins_orsf,
penguin_fit
+# An oblique regression RF
+
+cars_fit <- orsf(data = mtcars,
+ n_tree = 5,
+ formula = mpg ~ .)
+
+cars_fit
+
```
you may notice that the first input of `aorsf` is `data`. This is a design choice that makes it easier to use `orsf` with pipes (i.e., `%>%` or `|>`). For instance,
@@ -70,9 +78,9 @@ pbc_fit <- pbc_orsf |>
### Variable importance
-`aorsf` provides multiple ways to compute variable importance.
+There are multiple methods to compute variable importance.
-- To compute negation importance, ORSF multiplies each coefficient of that variable by -1 and then re-computes the out-of-sample (sometimes referred to as out-of-bag) accuracy of the ORSF model.
+- To compute *negation* importance, ORSF multiplies each coefficient of that variable by -1 and then re-computes the out-of-sample (sometimes referred to as out-of-bag) accuracy of the ORSF model.
```{r}
@@ -80,7 +88,7 @@ pbc_fit <- pbc_orsf |>
```
-- You can also compute variable importance using permutation, a more classical approach.
+- You can also compute variable importance using *permutation*, a more classical approach that noises up a predictor and then assigned the resulting degradation in prediction accuracy to be the importance of that predictor.
```{r}