From b9453e0b42eaf387dc8cc37df33bbeb79471d3dc Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Fri, 22 Sep 2023 23:43:00 -0400 Subject: [PATCH] oob functions allowed --- R/RcppExports.R | 4 +- R/check.R | 22 +++-- R/oobag_c_harrell.R | 49 ++--------- R/orsf.R | 39 +++++---- R/orsf_attr.R | 2 +- src/Forest.cpp | 175 +++++++++++++++++++++++++------------ src/Forest.h | 40 +++++---- src/ForestSurvival.cpp | 34 ++++--- src/ForestSurvival.h | 10 +-- src/RcppExports.cpp | 9 +- src/Tree.cpp | 6 +- src/TreeSurvival.cpp | 1 + src/orsf_oop.cpp | 18 ++-- src/utility.cpp | 12 +-- tests/testthat/test-orsf.R | 53 ++++++----- vignettes/oobag.Rmd | 37 ++++---- 16 files changed, 282 insertions(+), 229 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 6fcd3178..00761fac 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -13,7 +13,7 @@ compute_cstat_exported_uvec <- function(y, w, g, pred_is_risklike) { .Call(`_aorsf_compute_cstat_exported_uvec`, y, w, g, pred_is_risklike) } -orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, oobag, oobag_eval_type_R, oobag_eval_every, n_thread, write_forest) { - .Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, oobag, oobag_eval_type_R, oobag_eval_every, n_thread, write_forest) +orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, oobag, oobag_eval_type_R, oobag_eval_every, n_thread, write_forest, run_forest) { + .Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, oobag, oobag_eval_type_R, oobag_eval_every, n_thread, write_forest, run_forest) } diff --git a/R/check.R b/R/check.R index c653b329..04ee0216 100644 --- a/R/check.R +++ b/R/check.R @@ -1616,8 +1616,8 @@ check_oobag_fun <- function(oobag_fun){ oobag_fun_args <- names(formals(oobag_fun)) - if(length(oobag_fun_args) != 2) stop( - "oobag_fun should have 2 input arguments but instead has ", + if(length(oobag_fun_args) != 3) stop( + "oobag_fun should have 3 input arguments but instead has ", length(oobag_fun_args), call. = FALSE ) @@ -1628,8 +1628,14 @@ check_oobag_fun <- function(oobag_fun){ call. = FALSE ) - if(oobag_fun_args[2] != 's_vec') stop( - "the second input argument of oobag_fun should be named 's_vec' ", + if(oobag_fun_args[2] != 'w_vec') stop( + "the second input argument of oobag_fun should be named 'w_vec' ", + "but is instead named '", oobag_fun_args[1], "'", + call. = FALSE + ) + + if(oobag_fun_args[3] != 's_vec') stop( + "the third input argument of oobag_fun should be named 's_vec' ", "but is instead named '", oobag_fun_args[2], "'", call. = FALSE ) @@ -1638,9 +1644,12 @@ check_oobag_fun <- function(oobag_fun){ test_status <- rep(c(0,1), each = 50) .y_mat <- cbind(time = test_time, status = test_status) + .w_vec <- rep(1, times = 100) .s_vec <- seq(0.9, 0.1, length.out = 100) - test_output <- try(oobag_fun(y_mat = .y_mat, s_vec = .s_vec), + test_output <- try(oobag_fun(y_mat = .y_mat, + w_vec = .w_vec, + s_vec = .s_vec), silent = FALSE) if(is_error(test_output)){ @@ -1650,8 +1659,9 @@ check_oobag_fun <- function(oobag_fun){ "test_time <- seq(from = 1, to = 5, length.out = 100)\n", "test_status <- rep(c(0,1), each = 50)\n\n", "y_mat <- cbind(time = test_time, status = test_status)\n", + "w_vec <- rep(1, times = 100)\n", "s_vec <- seq(0.9, 0.1, length.out = 100)\n\n", - "test_output <- oobag_fun(y_mat = y_mat, s_vec = s_vec)\n\n", + "test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n", "test_output should be a numeric value of length 1", call. = FALSE) diff --git a/R/oobag_c_harrell.R b/R/oobag_c_harrell.R index 3e22ca23..03734cb8 100644 --- a/R/oobag_c_harrell.R +++ b/R/oobag_c_harrell.R @@ -12,50 +12,13 @@ #' @noRd #' -oobag_c_harrell <- function(y_mat, s_vec){ +oobag_c_survival <- function(y_mat, w_vec, s_vec){ - sorted <- order(y_mat[, 1], -y_mat[, 2]) + survival::concordancefit( + y = survival::Surv(y_mat), + x = s_vec + )$concordance - y_mat <- y_mat[sorted, ] - s_vec <- s_vec[sorted] - - time = y_mat[, 1] - status = y_mat[, 2] - events = which(status == 1) - - k = nrow(y_mat) - - total <- 0 - concordant <- 0 - - for(i in events){ - - if(i+1 <= k){ - - for(j in seq(i+1, k)){ - - if(time[j] > time[i]){ - - total <- total + 1 - - if(s_vec[j] > s_vec[i]){ - - concordant <- concordant + 1 - - } else if (s_vec[j] == s_vec[i]){ - - concordant <- concordant + 0.5 - - } - - } - - } - - } - - } +} - concordant / total -} diff --git a/R/orsf.R b/R/orsf.R index 7c015bab..e5aff569 100644 --- a/R/orsf.R +++ b/R/orsf.R @@ -678,10 +678,11 @@ orsf <- function(data, collapse::radixorder(y[, 1], # order this way for risk sets -y[, 2]) # order this way for oob C-statistic. + if(is.null(weights)) weights <- rep(1, nrow(x)) + x_sort <- x[sorted, , drop = FALSE] y_sort <- y[sorted, , drop = FALSE] - - if(is.null(weights)) weights <- rep(1, nrow(x)) + w_sort <- weights[sorted] if(length(tree_seeds) == 1) set.seed(tree_seeds) @@ -690,13 +691,13 @@ orsf <- function(data, vi_max_pvalue = 0.01 - orsf_out <- orsf_cpp(x = x, - y = y, - w = weights, + orsf_out <- orsf_cpp(x = x_sort, + y = y_sort, + w = w_sort, tree_type_R = 3, tree_seeds = as.integer(tree_seeds), loaded_forest = list(), - n_tree = if(no_fit) 0 else n_tree, + n_tree = n_tree, mtry = mtry, vi_type_R = switch(importance, "none" = 0, @@ -745,21 +746,20 @@ orsf <- function(data, 'user' = 2), oobag_eval_every = oobag_eval_every, n_thread = n_thread, - write_forest = TRUE) - - # browser() + write_forest = TRUE, + run_forest = !no_fit) # if someone says no_fit and also says don't attach the data, # give them a warning but also do the right thing for them. orsf_out$data <- if(attach_data) data else NULL - if(importance != 'none'){ + if(importance != 'none' && !no_fit){ rownames(orsf_out$importance) <- colnames(x) orsf_out$importance <- rev(orsf_out$importance[order(orsf_out$importance), , drop=TRUE]) } - if(oobag_pred){ + if(oobag_pred && !no_fit){ # put the oob predictions into the same order as the training data. unsorted <- collapse::radixorder(sorted) @@ -833,7 +833,7 @@ orsf <- function(data, attr(orsf_out, 'split_rule') <- split_rule attr(orsf_out, 'n_thread') <- n_thread - attr(orsf_out, 'tree_seeds') <- if(is.null(tree_seeds)) c() else tree_seeds + attr(orsf_out, 'tree_seeds') <- tree_seeds #' @srrstats {ML5.0a} *orsf output has its own class* class(orsf_out) <- "orsf_fit" @@ -1037,17 +1037,17 @@ orsf_train_ <- function(object, -y[, 2]) # order this way for oob C-statistic. } + weights <- get_weights_user(object) x_sort <- x[sorted, ] y_sort <- y[sorted, ] + w_sort <- weights[sorted] oobag_eval_every <- min(n_tree, get_oobag_eval_every(object)) - weights <- get_weights_user(object) - - orsf_out <- orsf_cpp(x = x, - y = y, - w = weights, + orsf_out <- orsf_cpp(x = x_sort, + y = y_sort, + w = w_sort, tree_type_R = 3, tree_seeds = get_tree_seeds(object), loaded_forest = list(), @@ -1100,9 +1100,10 @@ orsf_train_ <- function(object, 'none' = 0, 'cstat' = 1, 'user' = 2), - oobag_eval_every = get_oobag_eval_every(object), + oobag_eval_every = oobag_eval_every, n_thread = get_n_thread(object), - write_forest = TRUE) + write_forest = TRUE, + run_forest = TRUE) object$pred_oobag <- orsf_out$pred_oobag diff --git a/R/orsf_attr.R b/R/orsf_attr.R index 64bdc97b..d42fd1c4 100644 --- a/R/orsf_attr.R +++ b/R/orsf_attr.R @@ -82,7 +82,7 @@ is_trained <- function(object) attr(object, 'trained') #' #' @noRd #' -contains_oobag <- function(object) {!is_empty(object$pred_oobag)} +contains_oobag <- function(object) {!is_empty(object$eval_oobag$stat_values)} #' Determine whether object has variable importance estimates #' diff --git a/src/Forest.cpp b/src/Forest.cpp index d7535b09..7dbdc791 100644 --- a/src/Forest.cpp +++ b/src/Forest.cpp @@ -72,14 +72,15 @@ void Forest::init(std::unique_ptr input_data, if(vi_type != VI_NONE){ vi_numer.zeros(data->get_n_cols()); - if(vi_type == VI_ANOVA){ vi_denom.zeros(data->get_n_cols()); } - } - if(VERBOSITY > 0){ + // oobag denominator tracks the number of times an obs is oobag + oobag_denom.zeros(data->get_n_rows()); + + if(VERBOSITY > 0){ Rcout << "------------ input data dimensions ------------" << std::endl; Rcout << "N obs total: " << data->get_n_rows() << std::endl; Rcout << "N columns total: " << data->get_n_cols() << std::endl; @@ -181,12 +182,12 @@ void Forest::grow() { vi_numer_threads[i].zeros(data->n_cols); if(vi_type == VI_ANOVA) vi_denom_threads[i].zeros(data->n_cols); - threads.emplace_back(&Forest::grow_in_threads, this, i, + threads.emplace_back(&Forest::grow_multi_thread, this, i, &(vi_numer_threads[i]), &(vi_denom_threads[i])); } - showProgress("Growing trees...", n_tree); + show_progress("Growing trees...", n_tree); for (auto &thread : threads) { thread.join(); @@ -227,7 +228,7 @@ void Forest::grow_single_thread(vec* vi_numer_ptr, } -void Forest::grow_in_threads(uint thread_idx, +void Forest::grow_multi_thread(uint thread_idx, vec* vi_numer_ptr, uvec* vi_denom_ptr) { @@ -266,6 +267,12 @@ void Forest::compute_oobag_vi() { // show progress from threads progress = 0; + if(n_thread == 1){ + vec* vi_numer_ptr = &vi_numer; + compute_oobag_vi_single_thread(vi_numer_ptr); + return; + } + std::vector threads; std::vector vi_numer_threads(n_thread); // no denominator b/c it is equal to n_tree for all oob vi methods @@ -276,11 +283,11 @@ void Forest::compute_oobag_vi() { vi_numer_threads[i].zeros(data->n_cols); - threads.emplace_back(&Forest::compute_oobag_vi_in_threads, + threads.emplace_back(&Forest::compute_oobag_vi_multi_thread, this, i, &(vi_numer_threads[i])); } - showProgress("Computing variable importance...", n_tree); + show_progress("Computing variable importance...", n_tree); for (auto &thread : threads) { thread.join(); @@ -296,8 +303,19 @@ void Forest::compute_oobag_vi() { } +void Forest::compute_oobag_vi_single_thread(vec* vi_numer_ptr) { + + for(uint i = 0; i < n_tree; ++i){ + + trees[i]->compute_oobag_vi(vi_numer_ptr, vi_type); + + Rcpp::checkUserInterrupt(); + + } + +} -void Forest::compute_oobag_vi_in_threads(uint thread_idx, vec* vi_numer_ptr) { +void Forest::compute_oobag_vi_multi_thread(uint thread_idx, vec* vi_numer_ptr) { if (thread_ranges.size() > thread_idx + 1) { @@ -324,10 +342,23 @@ void Forest::compute_oobag_vi_in_threads(uint thread_idx, vec* vi_numer_ptr) { } +void Forest::compute_prediction_accuracy(Data* prediction_data, + arma::mat& prediction_values, + arma::uword row_fill){ + + uvec valid_observations = find(oobag_denom > 0); + + mat y_valid = prediction_data->y_rows(valid_observations); + vec w_valid = prediction_data->w_subvec(valid_observations); + mat p_valid = prediction_values(valid_observations); + + compute_prediction_accuracy(y_valid, w_valid, p_valid, row_fill); + +} + mat Forest::predict(bool oobag) { mat result; - vec oob_denom; // No. of cols in pred mat depend on the type of forest resize_pred_mat(result); @@ -336,73 +367,76 @@ mat Forest::predict(bool oobag) { // (needs to be resized even if !oobag) resize_oobag_eval(); - // oobag denominator tracks the number of times an obs is oobag - if(oobag){ - oob_denom.zeros(data->n_rows); - } - progress = 0; aborted = false; aborted_threads = 0; - std::vector threads; - std::vector result_threads(n_thread); - std::vector oob_denom_threads(n_thread); + if(n_thread == 1){ + // ensure safe usage of R functions + predict_single_thread(data.get(), oobag, result); - threads.reserve(n_thread); + } else { - for (uint i = 0; i < n_thread; ++i) { + std::vector threads; + std::vector result_threads(n_thread); + std::vector oobag_denom_threads(n_thread); - resize_pred_mat(result_threads[i]); - if(oobag) oob_denom_threads[i].zeros(data->n_rows); + threads.reserve(n_thread); - threads.emplace_back(&Forest::predict_in_threads, - this, i, data.get(), oobag, - &(result_threads[i]), - &(oob_denom_threads[i])); - } + for (uint i = 0; i < n_thread; ++i) { - showProgress("Predicting...", n_tree); + resize_pred_mat(result_threads[i]); + if(oobag) oobag_denom_threads[i].zeros(data->n_rows); - // wait for all threads to finish before proceeding - for (auto &thread : threads) { - thread.join(); - } + threads.emplace_back(&Forest::predict_multi_thread, + this, i, data.get(), oobag, + &(result_threads[i]), + &(oobag_denom_threads[i])); + } - for(uint i = 0; i < n_thread; ++i){ + show_progress("Predicting...", n_tree); - result += result_threads[i]; + // wait for all threads to finish before proceeding + for (auto &thread : threads) { + thread.join(); + } - if(oobag){ + for(uint i = 0; i < n_thread; ++i){ + + result += result_threads[i]; - oob_denom += oob_denom_threads[i]; + if(oobag){ - // evaluate oobag error after joining each thread - // (only safe to do this when the condition below holds) - if(n_tree/oobag_eval_every == n_thread && n_thread>1 && i1 && ipredict_leaf(prediction_data, oobag); + + trees[i]->predict_value(&result, &oobag_denom, pred_type, oobag); + + progress++; + + // if user wants to track oobag error over time: + if(oobag && (progress % oobag_eval_every == 0) ){ + + uword eval_row = (progress / oobag_eval_every) - 1; + + + mat preds = result.each_col() / oobag_denom; + compute_prediction_accuracy(prediction_data, preds, eval_row); + + } + + } + +} + +void Forest::predict_multi_thread(uint thread_idx, + Data* prediction_data, + bool oobag, + mat* result_ptr, + vec* denom_ptr) { if (thread_ranges.size() > thread_idx + 1) { @@ -448,7 +509,7 @@ void Forest::predict_in_threads(uint thread_idx, mat preds = (*result_ptr); preds.each_col() /= (*denom_ptr); - compute_prediction_accuracy(prediction_data, eval_row, preds); + compute_prediction_accuracy(prediction_data, preds, eval_row); } @@ -468,6 +529,8 @@ arma::uword Forest::find_max_eval_steps(){ if(n_evals > n_tree) n_evals = n_tree; + if(n_evals < 1) n_evals = 1; + return(n_evals); } @@ -480,7 +543,7 @@ void Forest::resize_oobag_eval(){ } -void Forest::showProgress(std::string operation, size_t max_progress) { +void Forest::show_progress(std::string operation, size_t max_progress) { using std::chrono::steady_clock; using std::chrono::duration_cast; diff --git a/src/Forest.h b/src/Forest.h index d7fedf91..ac2e6f05 100644 --- a/src/Forest.h +++ b/src/Forest.h @@ -69,16 +69,16 @@ class Forest { // void run(bool verbose, bool oobag); virtual void compute_prediction_accuracy( - Data* prediction_data, - arma::uword row_fill, - arma::mat& predictions - ) = 0; + Data* prediction_data, + arma::mat& prediction_values, + arma::uword row_fill + ); virtual void compute_prediction_accuracy( arma::mat& y, arma::vec& w, - arma::uword row_fill, - arma::mat& predictions + arma::mat& predictions, + arma::uword row_fill ) = 0; std::vector> get_cutpoint() { @@ -202,24 +202,27 @@ class Forest { void grow_single_thread(vec* vi_numer_ptr, uvec* vi_denom_ptr); - void grow_in_threads(uint thread_idx, - vec* vi_numer_ptr, - uvec* vi_denom_ptr); - + void grow_multi_thread(uint thread_idx, + vec* vi_numer_ptr, + uvec* vi_denom_ptr); + void predict_single_thread(Data* prediction_data, + bool oobag, + mat& result); - void predict_in_threads(uint thread_idx, - Data* prediction_data, - bool oobag, - mat* result_ptr, - vec* denom_ptr); + void predict_multi_thread(uint thread_idx, + Data* prediction_data, + bool oobag, + mat* result_ptr, + vec* denom_ptr); void compute_oobag_vi(); - void compute_oobag_vi_in_threads(uint thread_idx, - vec* vi_numer_ptr); + void compute_oobag_vi_single_thread(vec* vi_numer_ptr); + + void compute_oobag_vi_multi_thread(uint thread_idx, vec* vi_numer_ptr); - void showProgress(std::string operation, size_t max_progress); + void show_progress(std::string operation, size_t max_progress); virtual void resize_pred_mat(arma::mat& p) = 0; @@ -276,6 +279,7 @@ class Forest { // out-of-bag bool oobag_pred; + arma::vec oobag_denom; arma::mat oobag_eval; EvalType oobag_eval_type; arma::uword oobag_eval_every; diff --git a/src/ForestSurvival.cpp b/src/ForestSurvival.cpp index 59e2acfc..52739af1 100644 --- a/src/ForestSurvival.cpp +++ b/src/ForestSurvival.cpp @@ -137,26 +137,34 @@ void ForestSurvival::resize_oobag_eval(){ } -void ForestSurvival::compute_prediction_accuracy(Data* prediction_data, - arma::uword row_fill, - arma::mat& predictions){ - - mat y = prediction_data->get_y(); - vec w = prediction_data->get_w(); - - compute_prediction_accuracy(y, w, row_fill, predictions); - -} - void ForestSurvival::compute_prediction_accuracy(arma::mat& y, arma::vec& w, - arma::uword row_fill, - arma::mat& predictions){ + arma::mat& predictions, + arma::uword row_fill){ bool pred_is_risklike = true; if(pred_type == PRED_SURVIVAL) pred_is_risklike = false; + + if(oobag_eval_type == EVAL_R_FUNCTION){ + + // initialize function from tree object + // (Functions can't be stored in C++ classes, but Robjects can) + Function f_oobag_eval = as(oobag_R_function); + NumericMatrix y_ = wrap(y); + NumericVector w_ = wrap(w); + + for(arma::uword i = 0; i < oobag_eval.n_cols; ++i){ + vec p = predictions.col(i); + NumericVector p_ = wrap(p); + NumericVector R_result = f_oobag_eval(y_, w_, p_); + oobag_eval(row_fill, i) = R_result[0]; + } + return; + } + + for(arma::uword i = 0; i < oobag_eval.n_cols; ++i){ vec p = predictions.unsafe_col(i); oobag_eval(row_fill, i) = compute_cstat(y, w, p, pred_is_risklike); diff --git a/src/ForestSurvival.h b/src/ForestSurvival.h index e1dfa69e..2161d0ad 100644 --- a/src/ForestSurvival.h +++ b/src/ForestSurvival.h @@ -41,17 +41,11 @@ class ForestSurvival: public Forest { // growInternal() in ranger void plant() override; - void compute_prediction_accuracy( - Data* prediction_data, - arma::uword row_fill, - arma::mat& predictions - ) override; - void compute_prediction_accuracy( arma::mat& y, arma::vec& w, - arma::uword row_fill, - arma::mat& predictions + arma::mat& predictions, + arma::uword row_fill ) override; protected: diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 578c81bd..e09c8c63 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -56,8 +56,8 @@ BEGIN_RCPP END_RCPP } // orsf_cpp -List orsf_cpp(arma::mat& x, arma::mat& y, arma::vec& w, arma::uword tree_type_R, Rcpp::IntegerVector& tree_seeds, Rcpp::List loaded_forest, Rcpp::RObject lincomb_R_function, Rcpp::RObject oobag_R_function, arma::uword n_tree, arma::uword mtry, arma::uword vi_type_R, double vi_max_pvalue, double leaf_min_events, double leaf_min_obs, arma::uword split_rule_R, double split_min_events, double split_min_obs, double split_min_stat, arma::uword split_max_cuts, arma::uword split_max_retry, arma::uword lincomb_type_R, double lincomb_eps, arma::uword lincomb_iter_max, bool lincomb_scale, double lincomb_alpha, arma::uword lincomb_df_target, arma::uword lincomb_ties_method, bool pred_mode, arma::uword pred_type_R, arma::vec pred_horizon, bool oobag, arma::uword oobag_eval_type_R, arma::uword oobag_eval_every, unsigned int n_thread, bool write_forest); -RcppExport SEXP _aorsf_orsf_cpp(SEXP xSEXP, SEXP ySEXP, SEXP wSEXP, SEXP tree_type_RSEXP, SEXP tree_seedsSEXP, SEXP loaded_forestSEXP, SEXP lincomb_R_functionSEXP, SEXP oobag_R_functionSEXP, SEXP n_treeSEXP, SEXP mtrySEXP, SEXP vi_type_RSEXP, SEXP vi_max_pvalueSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP, SEXP split_rule_RSEXP, SEXP split_min_eventsSEXP, SEXP split_min_obsSEXP, SEXP split_min_statSEXP, SEXP split_max_cutsSEXP, SEXP split_max_retrySEXP, SEXP lincomb_type_RSEXP, SEXP lincomb_epsSEXP, SEXP lincomb_iter_maxSEXP, SEXP lincomb_scaleSEXP, SEXP lincomb_alphaSEXP, SEXP lincomb_df_targetSEXP, SEXP lincomb_ties_methodSEXP, SEXP pred_modeSEXP, SEXP pred_type_RSEXP, SEXP pred_horizonSEXP, SEXP oobagSEXP, SEXP oobag_eval_type_RSEXP, SEXP oobag_eval_everySEXP, SEXP n_threadSEXP, SEXP write_forestSEXP) { +List orsf_cpp(arma::mat& x, arma::mat& y, arma::vec& w, arma::uword tree_type_R, Rcpp::IntegerVector& tree_seeds, Rcpp::List loaded_forest, Rcpp::RObject lincomb_R_function, Rcpp::RObject oobag_R_function, arma::uword n_tree, arma::uword mtry, arma::uword vi_type_R, double vi_max_pvalue, double leaf_min_events, double leaf_min_obs, arma::uword split_rule_R, double split_min_events, double split_min_obs, double split_min_stat, arma::uword split_max_cuts, arma::uword split_max_retry, arma::uword lincomb_type_R, double lincomb_eps, arma::uword lincomb_iter_max, bool lincomb_scale, double lincomb_alpha, arma::uword lincomb_df_target, arma::uword lincomb_ties_method, bool pred_mode, arma::uword pred_type_R, arma::vec pred_horizon, bool oobag, arma::uword oobag_eval_type_R, arma::uword oobag_eval_every, unsigned int n_thread, bool write_forest, bool run_forest); +RcppExport SEXP _aorsf_orsf_cpp(SEXP xSEXP, SEXP ySEXP, SEXP wSEXP, SEXP tree_type_RSEXP, SEXP tree_seedsSEXP, SEXP loaded_forestSEXP, SEXP lincomb_R_functionSEXP, SEXP oobag_R_functionSEXP, SEXP n_treeSEXP, SEXP mtrySEXP, SEXP vi_type_RSEXP, SEXP vi_max_pvalueSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP, SEXP split_rule_RSEXP, SEXP split_min_eventsSEXP, SEXP split_min_obsSEXP, SEXP split_min_statSEXP, SEXP split_max_cutsSEXP, SEXP split_max_retrySEXP, SEXP lincomb_type_RSEXP, SEXP lincomb_epsSEXP, SEXP lincomb_iter_maxSEXP, SEXP lincomb_scaleSEXP, SEXP lincomb_alphaSEXP, SEXP lincomb_df_targetSEXP, SEXP lincomb_ties_methodSEXP, SEXP pred_modeSEXP, SEXP pred_type_RSEXP, SEXP pred_horizonSEXP, SEXP oobagSEXP, SEXP oobag_eval_type_RSEXP, SEXP oobag_eval_everySEXP, SEXP n_threadSEXP, SEXP write_forestSEXP, SEXP run_forestSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -96,7 +96,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< arma::uword >::type oobag_eval_every(oobag_eval_everySEXP); Rcpp::traits::input_parameter< unsigned int >::type n_thread(n_threadSEXP); Rcpp::traits::input_parameter< bool >::type write_forest(write_forestSEXP); - rcpp_result_gen = Rcpp::wrap(orsf_cpp(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, oobag, oobag_eval_type_R, oobag_eval_every, n_thread, write_forest)); + Rcpp::traits::input_parameter< bool >::type run_forest(run_forestSEXP); + rcpp_result_gen = Rcpp::wrap(orsf_cpp(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, oobag, oobag_eval_type_R, oobag_eval_every, n_thread, write_forest, run_forest)); return rcpp_result_gen; END_RCPP } @@ -105,7 +106,7 @@ static const R_CallMethodDef CallEntries[] = { {"_aorsf_coxph_fit_exported", (DL_FUNC) &_aorsf_coxph_fit_exported, 6}, {"_aorsf_compute_cstat_exported_vec", (DL_FUNC) &_aorsf_compute_cstat_exported_vec, 4}, {"_aorsf_compute_cstat_exported_uvec", (DL_FUNC) &_aorsf_compute_cstat_exported_uvec, 4}, - {"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 35}, + {"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 36}, {NULL, NULL, 0} }; diff --git a/src/Tree.cpp b/src/Tree.cpp index c66622d4..bf1b6a7f 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -857,7 +857,7 @@ } - if(n_retry == split_max_retry){ + if(n_retry >= split_max_retry){ sprout_leaf(*node); break; } @@ -955,6 +955,10 @@ } + if(VERBOSITY > 0){ + Rcout << "---- done with leaf predictions ----" << std::endl; + } + } double Tree::compute_prediction_accuracy(arma::vec& preds){ diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp index c65768fe..bec3d1c9 100644 --- a/src/TreeSurvival.cpp +++ b/src/TreeSurvival.cpp @@ -559,6 +559,7 @@ vec leaf_times, leaf_values; vec temp_vec((*pred_horizon).size()); + double temp_dbl; do { diff --git a/src/orsf_oop.cpp b/src/orsf_oop.cpp index 373a57f6..3f78d5f6 100644 --- a/src/orsf_oop.cpp +++ b/src/orsf_oop.cpp @@ -105,7 +105,8 @@ arma::uword oobag_eval_type_R, arma::uword oobag_eval_every, unsigned int n_thread, - bool write_forest){ + bool write_forest, + bool run_forest){ List result; @@ -202,7 +203,10 @@ } - forest->run(false, oobag); + if(run_forest){ + forest->run(false, oobag); + } + if(pred_mode){ @@ -248,10 +252,12 @@ vec vi_output; - if(vi_type == VI_ANOVA){ - vi_output = forest->get_vi_numer() / forest->get_vi_denom(); - } else { - vi_output = forest->get_vi_numer() / n_tree; + if(run_forest){ + if(vi_type == VI_ANOVA){ + vi_output = forest->get_vi_numer() / forest->get_vi_denom(); + } else { + vi_output = forest->get_vi_numer() / n_tree; + } } result.push_back(vi_output, "importance"); diff --git a/src/utility.cpp b/src/utility.cpp index dff45ce2..572a0683 100644 --- a/src/utility.cpp +++ b/src/utility.cpp @@ -161,9 +161,9 @@ } double compute_cstat(arma::mat& y, - arma::vec& w, - arma::vec& p, - bool pred_is_risklike){ + arma::vec& w, + arma::vec& p, + bool pred_is_risklike){ vec y_time = y.unsafe_col(0); vec y_status = y.unsafe_col(1); @@ -209,9 +209,9 @@ double compute_cstat(arma::mat& y, - arma::vec& w, - arma::uvec& g, - bool pred_is_risklike){ + arma::vec& w, + arma::uvec& g, + bool pred_is_risklike){ // note: g must have only values of 0 and 1 to use this. // note: this is a little different in its approach than diff --git a/tests/testthat/test-orsf.R b/tests/testthat/test-orsf.R index 213c64b2..a443555c 100644 --- a/tests/testthat/test-orsf.R +++ b/tests/testthat/test-orsf.R @@ -3,13 +3,6 @@ library(survival) # for Surv # misc functions used for tests ---- -cstat_bcj <- function(y_mat, s_vec){ - - sorted <- order( y_mat[, 1], -y_mat[, 2]) - oobag_c_harrell_testthat(y_mat[sorted, ], s_vec[sorted, ]) - -} - no_miss_list <- function(l){ sapply(l, function(x){ @@ -535,24 +528,24 @@ for(i in vars){ fit_orsf <- orsf(pbc_orsf, Surv(time, status) ~ . - id, n_thread = 1, - n_tree = 10, - tree_seeds = 1:10) + n_tree = 100, + tree_seeds = 1:100) fit_orsf_2 <- orsf(pbc_orsf, Surv(time, status) ~ . - id, n_thread = 5, - n_tree = 10, - tree_seeds = 1:10) + n_tree = 100, + tree_seeds = 1:100) fit_orsf_noise <- orsf(pbc_noise, Surv(time, status) ~ . - id, - n_tree = 10, - tree_seeds = 1:10) + n_tree = 100, + tree_seeds = 1:100) fit_orsf_scale <- orsf(pbc_scale, Surv(time, status) ~ . - id, - n_tree = 10, - tree_seeds = 1:10) + n_tree = 100, + tree_seeds = 1:100) #' @srrstats {ML7.1} *Demonstrate effect of numeric scaling of input data.* test_that( @@ -614,9 +607,6 @@ test_that( expect_equal(fit_orsf$forest$leaf_summary, fit_orsf_scale$forest$leaf_summary) - expect_equal(fit_orsf$forest$leaf_summary, - fit_orsf_noise$forest$leaf_summary) - } ) @@ -625,8 +615,8 @@ test_that( code = { object <- orsf(pbc_orsf, Surv(time, status) ~ . - id, - n_tree = 10, - tree_seeds = 1:10, + n_tree = 100, + tree_seeds = 1:100, no_fit = TRUE) fit_orsf_3 <- orsf_train(object) @@ -654,13 +644,13 @@ test_that( desc = 'oob rows identical with same tree seeds, oob error correct for user-specified function', code = { - tree_seeds = sample.int(n = 50000, size = 10) + tree_seeds = sample.int(n = 50000, size = 100) bad_tree_seeds <- c(1,2,3) expect_error( orsf(data = pbc_orsf, formula = time+status~.-id, - n_tree = 10, + n_tree = 100, mtry = 2, tree_seeds = bad_tree_seeds), regexp = 'the number of trees' @@ -668,26 +658,33 @@ test_that( fit_1 <- orsf(data = pbc_orsf, formula = time+status~.-id, - n_tree = 10, + n_tree = 100, mtry = 2, tree_seeds = tree_seeds) fit_2 <- orsf(data = pbc_orsf, formula = time+status~.-id, - n_tree = 10, + n_tree = 100, mtry = 6, tree_seeds = tree_seeds) + expect_equal(fit_1$forest$rows_oobag, + fit_2$forest$rows_oobag) + fit_3 <- orsf(data = pbc_orsf, formula = time+status~.-id, - n_tree = 10, + n_tree = 100, mtry = 6, - oobag_fun = oobag_c_harrell, + oobag_fun = oobag_c_survival, tree_seeds = tree_seeds) expect_equal( - fit_2$eval_oobag$stat_values, - fit_3$eval_oobag$stat_values + oobag_c_survival( + y_mat = as.matrix(pbc_orsf[,c("time", "status")]), + w_vec = rep(1, nrow(pbc_orsf)), + s_vec = fit_3$pred_oobag + ), + as.numeric(fit_3$eval_oobag$stat_values) ) } diff --git a/vignettes/oobag.Rmd b/vignettes/oobag.Rmd index 72dfb5b4..bb08eafe 100644 --- a/vignettes/oobag.Rmd +++ b/vignettes/oobag.Rmd @@ -18,7 +18,7 @@ knitr::opts_chunk$set( ```{r setup} -library(aorsf) +# library(aorsf) library(survival) library(SurvMetrics) @@ -39,7 +39,7 @@ Let's fit an oblique random survival forest and plot the distribution of the ens fit <- orsf(data = pbc_orsf, formula = Surv(time, status) ~ . - id, oobag_pred_type = 'surv', - oobag_pred_horizon = 3500) + oobag_pred_horizon = 2000) hist(fit$pred_oobag, main = 'Ensemble out-of-bag survival predictions at t=3,500') @@ -69,7 +69,8 @@ As each out-of-bag data set contains about one-third of the training set, the ou fit <- orsf(data = pbc_orsf, formula = Surv(time, status) ~ . - id, n_tree = 50, - oobag_pred_horizon = 3500, + oobag_pred_type = 'surv', + oobag_pred_horizon = 2000, oobag_eval_every = 1) plot( @@ -90,7 +91,7 @@ In some cases, you may want more control over how out-of-bag error is estimated. ```{r} -oobag_fun_brier <- function(y_mat, s_vec){ +oobag_fun_brier <- function(y_mat, w_vec, s_vec){ # output is numeric vector of length 1 as.numeric( @@ -98,7 +99,7 @@ oobag_fun_brier <- function(y_mat, s_vec){ object = Surv(time = y_mat[, 1], event = y_mat[, 2]), pre_sp = s_vec, # t_star in Brier() should match oob_pred_horizon in orsf() - t_star = 3500 + t_star = 2000 ) ) @@ -109,7 +110,7 @@ There are two ways to apply your own function to compute out-of-bag error. First ```{r} -oobag_fun_brier(y_mat = fit$data[, c('time', 'status')], +oobag_fun_brier(y_mat = pbc_orsf[,c('time', 'status')], s_vec = fit$pred_oobag) ``` @@ -120,13 +121,13 @@ Second, you can pass your function into `orsf()`, and it will be used in place o fit <- orsf(data = pbc_orsf, formula = Surv(time, status) ~ . - id, - n_tree = 50, - oobag_pred_horizon = 3500, + n_tree = 500, + oobag_pred_horizon = 2000, oobag_fun = oobag_fun_brier, - oobag_eval_every = 1) + oobag_eval_every = 25) plot( - x = seq(1, 50, by = 1), + x = seq(25, 500, by = 25), y = fit$eval_oobag$stat_values, main = 'Out-of-bag error computed after each new tree is grown.', sub = 'For the Brier score, lower values indicate more accurate predictions', @@ -140,13 +141,13 @@ We can also compute a time-dependent C-statistic instead of Harrell's C-statisti ```{r} -oobag_fun_tdep_cstat <- function(y_mat, s_vec){ +oobag_fun_tdep_cstat <- function(y_mat, w_vec, s_vec){ as.numeric( SurvMetrics::Cindex( object = Surv(time = y_mat[, 1], event = y_mat[, 2]), predicted = s_vec, - t_star = 3500 + t_star = 2000 ) ) @@ -155,7 +156,7 @@ oobag_fun_tdep_cstat <- function(y_mat, s_vec){ fit <- orsf(data = pbc_orsf, formula = Surv(time, status) ~ . - id, n_tree = 50, - oobag_pred_horizon = 3500, + oobag_pred_horizon = 2000, oobag_fun = oobag_fun_tdep_cstat, oobag_eval_every = 1) @@ -210,11 +211,11 @@ Negation importance is based on the out-of-bag error, so of course you may be cu ```{r} fit_tdep_cstat <- orsf(data = pbc_orsf, - formula = Surv(time, status) ~ . - id, - n_tree = 500, - oobag_pred_horizon = 3500, - oobag_fun = oobag_fun_tdep_cstat, - importance = 'negate') + formula = Surv(time, status) ~ . - id, + n_tree = 50, + oobag_pred_horizon = 2000, + oobag_fun = oobag_fun_tdep_cstat, + importance = 'negate') fit_tdep_cstat$importance