diff --git a/R/orsf.R b/R/orsf.R index f6b1f6ea..133e7435 100644 --- a/R/orsf.R +++ b/R/orsf.R @@ -419,7 +419,7 @@ orsf <- function(data, ) if(importance %in% c("permute", "negate") && !oobag_pred){ - oobag_pred <- TRUE # Should I add a warning? + # oobag_pred <- TRUE # Should I add a warning? oobag_pred_type <- 'surv' } @@ -690,11 +690,12 @@ orsf <- function(data, tree_seeds <- sample(x = n_tree*2, size = n_tree, replace = FALSE) vi_max_pvalue = 0.01 + tree_type_R = 3 orsf_out <- orsf_cpp(x = x_sort, y = y_sort, w = w_sort, - tree_type_R = 3, + tree_type_R = tree_type_R, tree_seeds = as.integer(tree_seeds), loaded_forest = list(), n_tree = n_tree, @@ -773,10 +774,11 @@ orsf <- function(data, "1" = "Harrell's C-statistic", "2" = "User-specified function") - #' @srrstats {G2.10} *drop = FALSE for type consistency* orsf_out$pred_oobag <- orsf_out$pred_oobag[unsorted, , drop = FALSE] + orsf_out$pred_oobag[is.nan(orsf_out$pred_oobag)] <- NA_real_ + } orsf_out$pred_horizon <- oobag_pred_horizon @@ -833,6 +835,7 @@ orsf <- function(data, attr(orsf_out, 'vi_max_pvalue') <- vi_max_pvalue attr(orsf_out, 'split_rule') <- split_rule attr(orsf_out, 'n_thread') <- n_thread + attr(orsf_out, 'tree_type') <- tree_type_R attr(orsf_out, 'tree_seeds') <- tree_seeds diff --git a/R/orsf_attr.R b/R/orsf_attr.R index d42fd1c4..6a3e9459 100644 --- a/R/orsf_attr.R +++ b/R/orsf_attr.R @@ -60,6 +60,7 @@ get_verbose_progress <- function(object) attr(object, 'verbose_progress') get_vi_max_pvalue <- function(object) attr(object, 'vi_max_pvalue') get_split_rule <- function(object) attr(object, 'split_rule') get_n_thread <- function(object) attr(object, 'n_thread') +get_tree_type <- function(object) attr(object, 'tree_type') #' ORSF status diff --git a/R/orsf_predict.R b/R/orsf_predict.R index ab784490..218248d4 100644 --- a/R/orsf_predict.R +++ b/R/orsf_predict.R @@ -79,6 +79,7 @@ predict.orsf_fit <- function(object, pred_type = 'risk', na_action = 'fail', boundary_checks = TRUE, + n_thread = 1, ...){ # catch any arguments that didn't match and got relegated to ... @@ -129,22 +130,64 @@ predict.orsf_fit <- function(object, # names_x_data = names_x_data) # ) - pred_type_cpp <- switch( + pred_type_R <- switch( pred_type, - "risk" = "R", - "surv" = "S", - "chf" = "H", - "mort" = "M" + "risk" = 1, + "surv" = 2, + "chf" = 3, + "mort" = 4 ) - out_values <- - if(pred_type_cpp == "M"){ - orsf_pred_mort(object, x_new) - } else if (length(pred_horizon) == 1L) { - orsf_pred_uni(object$forest, x_new, pred_horizon_ordered, pred_type_cpp) - } else { - orsf_pred_multi(object$forest, x_new, pred_horizon_ordered, pred_type_cpp) - } + orsf_out <- orsf_cpp(x = x_new, + y = matrix(1, ncol=2), + w = rep(1, nrow(x_new)), + tree_type_R = get_tree_type(object), + tree_seeds = get_tree_seeds(object), + loaded_forest = object$forest, + n_tree = get_n_tree(object), + mtry = get_mtry(object), + vi_type_R = 0, + vi_max_pvalue = get_vi_max_pvalue(object), + lincomb_R_function = get_f_beta(object), + oobag_R_function = get_f_oobag_eval(object), + leaf_min_events = get_leaf_min_events(object), + leaf_min_obs = get_leaf_min_obs(object), + split_rule_R = switch(get_split_rule(object), + "logrank" = 1, + "cstat" = 2), + split_min_events = get_split_min_events(object), + split_min_obs = get_split_min_obs(object), + split_min_stat = get_split_min_stat(object), + split_max_cuts = get_n_split(object), + split_max_retry = get_n_retry(object), + lincomb_type_R = switch(get_orsf_type(object), + 'fast' = 1, + 'cph' = 1, + 'random' = 2, + 'net' = 3, + 'custom' = 4), + lincomb_eps = get_cph_eps(object), + lincomb_iter_max = get_cph_iter_max(object), + lincomb_scale = get_cph_do_scale(object), + lincomb_alpha = get_net_alpha(object), + lincomb_df_target = get_net_df_target(object), + lincomb_ties_method = switch( + tolower(get_cph_method(object)), + 'breslow' = 0, + 'efron' = 1 + ), + pred_type_R = pred_type_R, + pred_mode = TRUE, + pred_horizon = pred_horizon_ordered, + oobag = FALSE, + oobag_eval_type_R = 0, + oobag_eval_every = get_n_tree(object), + n_thread = n_thread, + write_forest = FALSE, + run_forest = TRUE, + verbosity = 4) + + out_values <- orsf_out$pred_new if(na_action == "pass"){ @@ -164,15 +207,6 @@ predict.orsf_fit <- function(object, } -orsf_pred_mort <- function(object, x_new){ - pred_mat <- orsf_pred_multi(object$forest, - x_new = x_new, - time_vec = get_event_times(object), - pred_type = 'H') - - matrix(apply(pred_mat, MARGIN = 1, FUN = sum), ncol = 1) - -} diff --git a/R/orsf_vi.R b/R/orsf_vi.R index a3590954..93370e5b 100644 --- a/R/orsf_vi.R +++ b/R/orsf_vi.R @@ -211,13 +211,14 @@ orsf_vi_ <- function(object, group_factors, type_vi, oobag_fun = NULL){ #' orsf_vi_oobag_ <- function(object, type_vi, oobag_fun){ - if(!contains_oobag(object)){ - stop("cannot compute ", - switch(type_vi, 'negate' = 'negation', 'permute' = 'permutation'), - " importance if the orsf_fit object does not have out-of-bag error", - " (see oobag_pred in ?orsf).", - call. = FALSE) - } + # can remove this b/c prediction accuracy is now computed at tree level + # if(!contains_oobag(object)){ + # stop("cannot compute ", + # switch(type_vi, 'negate' = 'negation', 'permute' = 'permutation'), + # " importance if the orsf_fit object does not have out-of-bag error", + # " (see oobag_pred in ?orsf).", + # call. = FALSE) + # } if(contains_vi(object) && is.null(oobag_fun) && diff --git a/scratch.R b/scratch.R index ca08e548..a024824d 100644 --- a/scratch.R +++ b/scratch.R @@ -2,19 +2,20 @@ library(tidyverse) library(riskRegression) library(survival) -sink("orsf-output.txt") fit <- orsf(pbc_orsf, Surv(time, status) ~ . - id, - n_tree = 2, + n_tree = 3, + tree_seeds = 1:3, n_thread = 1, mtry = 2, - oobag_pred_type = 'mort', - split_rule = 'logrank', - importance = 'negate', - split_min_stat = 3, - verbose_progress = 4) -sink() -orsf_vi(fit) + oobag_pred_type = 'surv', + split_rule = 'cstat', + importance = 'none', + split_min_stat = 0.4, + verbose_progress = 1) +sink("orsf-output.txt") +prd <- predict(fit, new_data = pbc_orsf, pred_horizon = 1000, pred_type = 'risk') +sink() library(randomForestSRC) diff --git a/src/Forest.cpp b/src/Forest.cpp index f96e8394..9d4a228f 100644 --- a/src/Forest.cpp +++ b/src/Forest.cpp @@ -15,6 +15,7 @@ void Forest::init(std::unique_ptr input_data, Rcpp::IntegerVector& tree_seeds, arma::uword n_tree, arma::uword mtry, + bool grow_mode, VariableImportance vi_type, double vi_max_pvalue, // leaves @@ -48,6 +49,7 @@ void Forest::init(std::unique_ptr input_data, this->tree_seeds = tree_seeds; this->n_tree = n_tree; this->mtry = mtry; + this->grow_mode = grow_mode; this->vi_type = vi_type; this->vi_max_pvalue = vi_max_pvalue; this->leaf_min_obs = leaf_min_obs; @@ -95,13 +97,15 @@ void Forest::init(std::unique_ptr input_data, } -void Forest::run(bool verbose, bool oobag){ +void Forest::run(bool oobag){ if(pred_mode){ + init_trees(); + this->pred_values = predict(oobag); - } else { + } else if (grow_mode) { // initialize the trees plant(); @@ -113,12 +117,11 @@ void Forest::run(bool verbose, bool oobag){ this->pred_values = predict(oobag); } - if(vi_type == VI_PERMUTE || vi_type == VI_NEGATE){ - compute_oobag_vi(); - } - } + if(vi_type == VI_PERMUTE || vi_type == VI_NEGATE){ + compute_oobag_vi(); + } } @@ -154,35 +157,35 @@ void Forest::init_trees(){ void Forest::grow() { + // initialize trees before doing anything else init_trees(); // Create thread ranges equalSplit(thread_ranges, 0, n_tree - 1, n_thread); + // reset progress to 0 + progress = 0; + if(n_thread == 1){ // ensure safe usage of R functions and glmnet - // by growing trees in a single thread. There does - // not need to be a corresponding predict_single_thread - // function b/c the R functions are only called during - // the grow phase of the forest. - vec* vi_numer_ptr = &vi_numer; - uvec* vi_denom_ptr = &vi_denom; - grow_single_thread(vi_numer_ptr, vi_denom_ptr); + // by growing trees in a single thread. + grow_single_thread(&vi_numer, &vi_denom); return; } // catch interrupts from threads aborted = false; aborted_threads = 0; - // show progress from threads - progress = 0; + // containers std::vector threads; std::vector vi_numer_threads(n_thread); std::vector vi_denom_threads(n_thread); + // reserve memory threads.reserve(n_thread); + // begin multi-thread grow for (uint i = 0; i < n_thread; ++i) { vi_numer_threads[i].zeros(data->n_cols); @@ -193,8 +196,11 @@ void Forest::grow() { &(vi_denom_threads[i])); } - show_progress("Growing trees...", n_tree); + if(verbosity == 1){ + show_progress("Growing trees", n_tree); + } + // end multi-thread grow for (auto &thread : threads) { thread.join(); } @@ -203,13 +209,11 @@ void Forest::grow() { throw std::runtime_error("User interrupt."); } - if(vi_type != VI_NONE){ + if(vi_type == VI_ANOVA){ for(uint i = 0; i < n_thread; ++i){ - vi_numer += vi_numer_threads[i]; - if(vi_type == VI_ANOVA) vi_denom += vi_denom_threads[i]; - + vi_denom += vi_denom_threads[i]; } } @@ -219,33 +223,60 @@ void Forest::grow() { void Forest::grow_single_thread(vec* vi_numer_ptr, uvec* vi_denom_ptr){ - for (uint i = 0; i < n_tree; ++i) { - if(verbosity > 1){ - Rcout << "------------ Growing tree " << i << " --------------"; - Rcout << std::endl; - Rcout << std::endl; - } + using std::chrono::steady_clock; + using std::chrono::duration_cast; + using std::chrono::seconds; + + steady_clock::time_point start_time = steady_clock::now(); + steady_clock::time_point last_time = steady_clock::now(); + size_t max_progress = n_tree; + + for (uint i = 0; i < n_tree; ++i) { + + if(verbosity > 1){ + Rcout << "------------ Growing tree " << i << " --------------"; + Rcout << std::endl; + Rcout << std::endl; + } - trees[i]->grow(vi_numer_ptr, vi_denom_ptr); + trees[i]->grow(vi_numer_ptr, vi_denom_ptr); - if(vi_type == VI_PERMUTE || vi_type == VI_NEGATE){ + ++progress; - if(verbosity > 1){ - Rcout << "------------ Computing VI for tree " << i << " -----"; - Rcout << std::endl; - Rcout << std::endl; - } + if(verbosity == 1){ - trees[i]->compute_oobag_vi(vi_numer_ptr, vi_type); + seconds elapsed_time = duration_cast(steady_clock::now() - last_time); + + if ((progress > 0 && elapsed_time.count() > STATUS_INTERVAL) || + (progress == max_progress)) { + + double relative_progress = (double) progress / (double) max_progress; + seconds time_from_start = duration_cast(steady_clock::now() - start_time); + uint remaining_time = (1 / relative_progress - 1) * time_from_start.count(); + + Rcout << "Growing trees: "; + Rcout << round(100 * relative_progress) << "%. "; + + if(progress < max_progress){ + Rcout << "~ time remaining: "; + Rcout << beautifyTime(remaining_time) << "."; } - Rcpp::checkUserInterrupt(); + Rcout << std::endl; + + last_time = steady_clock::now(); } } + Rcpp::checkUserInterrupt(); + + } + +} + void Forest::grow_multi_thread(uint thread_idx, vec* vi_numer_ptr, @@ -306,7 +337,9 @@ void Forest::compute_oobag_vi() { this, i, &(vi_numer_threads[i])); } - show_progress("Computing variable importance...", n_tree); + if(verbosity == 1){ + show_progress("Computing importance", n_tree); + } for (auto &thread : threads) { thread.join(); @@ -324,10 +357,47 @@ void Forest::compute_oobag_vi() { void Forest::compute_oobag_vi_single_thread(vec* vi_numer_ptr) { + using std::chrono::steady_clock; + using std::chrono::duration_cast; + using std::chrono::seconds; + + steady_clock::time_point start_time = steady_clock::now(); + steady_clock::time_point last_time = steady_clock::now(); + size_t max_progress = n_tree; + for(uint i = 0; i < n_tree; ++i){ trees[i]->compute_oobag_vi(vi_numer_ptr, vi_type); + ++progress; + + if(verbosity == 1){ + + seconds elapsed_time = duration_cast(steady_clock::now() - last_time); + + if ((progress > 0 && elapsed_time.count() > STATUS_INTERVAL) || + (progress == max_progress)) { + + double relative_progress = (double) progress / (double) max_progress; + seconds time_from_start = duration_cast(steady_clock::now() - start_time); + uint remaining_time = (1 / relative_progress - 1) * time_from_start.count(); + + Rcout << "Computing importance: "; + Rcout << round(100 * relative_progress) << "%. "; + + if(progress < max_progress){ + Rcout << "~ time remaining: "; + Rcout << beautifyTime(remaining_time) << "."; + } + + Rcout << std::endl; + + last_time = steady_clock::now(); + + } + + } + Rcpp::checkUserInterrupt(); } @@ -365,12 +435,20 @@ void Forest::compute_prediction_accuracy(Data* prediction_data, arma::mat& prediction_values, arma::uword row_fill){ + // avoid dividing by zero uvec valid_observations = find(oobag_denom > 0); + // subset each data input mat y_valid = prediction_data->y_rows(valid_observations); vec w_valid = prediction_data->w_subvec(valid_observations); mat p_valid = prediction_values.rows(valid_observations); + // scale predictions based on how many trees contributed + // (important to note it's different for each oobag obs) + vec valid_denom = oobag_denom(valid_observations); + p_valid.each_col() /= valid_denom; + + // pass along to forest-specific version compute_prediction_accuracy(y_valid, w_valid, p_valid, row_fill); } @@ -413,7 +491,9 @@ mat Forest::predict(bool oobag) { &(oobag_denom_threads[i])); } - show_progress("Predicting...", n_tree); + if(verbosity == 1){ + show_progress("Computing predictions", n_tree); + } // wait for all threads to finish before proceeding for (auto &thread : threads) { @@ -430,33 +510,27 @@ mat Forest::predict(bool oobag) { // evaluate oobag error after joining each thread // (only safe to do this when the condition below holds) - if(n_tree/oobag_eval_every == n_thread && n_thread>1 && ipredict_leaf(prediction_data, oobag); - + Rcout << "made it here" << std::endl; trees[i]->predict_value(&result, &oobag_denom, pred_type, oobag); - + Rcout << "made it here 2" << std::endl; progress++; - // if user wants to track oobag error over time: - if(oobag && (progress % oobag_eval_every == 0) ){ + if(verbosity == 1){ - uword eval_row = (progress / oobag_eval_every) - 1; + seconds elapsed_time = duration_cast(steady_clock::now() - last_time); + + if ((progress > 0 && elapsed_time.count() > STATUS_INTERVAL) || + (progress == max_progress)) { + + double relative_progress = (double) progress / (double) max_progress; + seconds time_from_start = duration_cast(steady_clock::now() - start_time); + uint remaining_time = (1 / relative_progress - 1) * time_from_start.count(); + + Rcout << "Computing predictions: "; + Rcout << round(100 * relative_progress) << "%. "; + + if(progress < max_progress){ + Rcout << "~ time remaining: "; + Rcout << beautifyTime(remaining_time) << "."; + } + + Rcout << std::endl; + + last_time = steady_clock::now(); + + } + + } + // if tracking oobag error over time: + if(oobag && (progress % oobag_eval_every == 0) ){ - mat preds = result.each_col() / oobag_denom; - compute_prediction_accuracy(prediction_data, preds, eval_row); + uword eval_row = (progress / oobag_eval_every) - 1; + // mat preds = result.each_col() / oobag_denom; + compute_prediction_accuracy(prediction_data, result, eval_row); } @@ -520,18 +627,6 @@ void Forest::predict_multi_thread(uint thread_idx, std::unique_lock lock(mutex); ++progress; - // if user wants to track oobag error over time: - if( n_thread==1 && oobag && (progress%oobag_eval_every==0) ){ - - uword eval_row = (progress/oobag_eval_every) - 1; - - mat preds = (*result_ptr); - preds.each_col() /= (*denom_ptr); - - compute_prediction_accuracy(prediction_data, preds, eval_row); - - } - condition_variable.notify_one(); } @@ -574,7 +669,9 @@ void Forest::show_progress(std::string operation, size_t max_progress) { // Wait for message from threads and show output if enough time elapsed while (progress < max_progress) { + condition_variable.wait(lock); + seconds elapsed_time = duration_cast(steady_clock::now() - last_time); // Check for user interrupt @@ -585,16 +682,21 @@ void Forest::show_progress(std::string operation, size_t max_progress) { return; } - if (progress > 0 && elapsed_time.count() > STATUS_INTERVAL) { + if ((progress > 0 && elapsed_time.count() > STATUS_INTERVAL) || + (progress == max_progress)) { double relative_progress = (double) progress / (double) max_progress; seconds time_from_start = duration_cast(steady_clock::now() - start_time); uint remaining_time = (1 / relative_progress - 1) * time_from_start.count(); - Rcout << operation << "Progress: "; + Rcout << operation << ": "; Rcout << round(100 * relative_progress) << "%. "; - Rcout << "Estimated remaining time: "; - Rcout << beautifyTime(remaining_time) << "."; + + if(progress < max_progress){ + Rcout << "~ time remaining: "; + Rcout << beautifyTime(remaining_time) << "."; + } + Rcout << std::endl; last_time = steady_clock::now(); diff --git a/src/Forest.h b/src/Forest.h index 3bfd0b2c..67017d58 100644 --- a/src/Forest.h +++ b/src/Forest.h @@ -37,6 +37,7 @@ class Forest { Rcpp::IntegerVector& tree_seeds, arma::uword n_tree, arma::uword mtry, + bool grow_mode, VariableImportance vi_type, double vi_max_pvalue, // leaves @@ -188,7 +189,7 @@ class Forest { return(pred_values); } - void run(bool verbose, bool oobag); + void run(bool oobag); virtual void plant() = 0; @@ -276,6 +277,9 @@ class Forest { bool pred_mode; PredType pred_type; + // is forest already grown? + bool grow_mode; + arma::mat pred_values; // out-of-bag diff --git a/src/Tree.cpp b/src/Tree.cpp index 9928377e..f11bb8ee 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -497,6 +497,10 @@ double stat, stat_best = 0; + if(verbosity > 3){ + Rcout << " -- cutpoint (score)" << std::endl; + } + for(it = cuts_sampled.begin(); it != cuts_sampled.end(); ++it){ // flip node assignments from left to right, up to the next cutpoint @@ -510,7 +514,7 @@ it_start = *it; if(verbosity > 3){ - Rcout << " ---- cutpoint (score): "; + Rcout << " --- "; Rcout << lincomb.at(lincomb_sort(*it)); Rcout << " (" << stat << "), "; Rcout << "N = " << sum(g_node % w_node) << " moving right"; @@ -521,7 +525,7 @@ if(verbosity > 3){ Rcout << std::endl; - Rcout << " ---- best stat: " << stat_best; + Rcout << " -- best stat: " << stat_best; Rcout << ", min to split: " << split_min_stat; Rcout << std::endl; Rcout << std::endl; @@ -548,8 +552,9 @@ void Tree::sprout_leaf(uword node_id){ - if(verbosity > 3){ + if(verbosity > 2){ Rcout << "-- sprouting node " << node_id << " into a leaf"; + Rcout << " (N = " << sum(w_node) << ")"; Rcout << std::endl; Rcout << std::endl; } @@ -769,10 +774,9 @@ // find all valid cutpoints for lincomb cuts_all = find_cutpoints(); - if(verbosity > 3){ + if(verbosity > 3 && cuts_all.is_empty()){ - Rcout << " ---- no. of cutpoints identified: " << cuts_all.size(); - Rcout << std::endl; + Rcout << " -- no cutpoints identified"; Rcout << std::endl; } @@ -790,6 +794,10 @@ // 1. a split of the node is guaranteed // 2. the method used for lincombs allows it + if(verbosity > 3){ + Rcout << " -- p-values:" << std::endl; + } + vec beta_var = beta.unsafe_col(1); double pvalue; @@ -802,12 +810,26 @@ pvalue = R::pchisq(pow(beta_est[i],2)/beta_var[i], 1, false, false); + if(verbosity > 3){ + + Rcout << " --- column " << cols_node[i] << ": "; + Rcout << pvalue; + if(pvalue < 0.05) Rcout << "*"; + if(pvalue < 0.01) Rcout << "*"; + if(pvalue < 0.001) Rcout << "*"; + if(pvalue < vi_max_pvalue) Rcout << " [+1 to VI numerator]"; + Rcout << std::endl; + + } + if(pvalue < vi_max_pvalue){ (*vi_numer)[cols_node[i]]++; } } } + if(verbosity > 3){ Rcout << std::endl; } + } // make new nodes if a valid cutpoint was found @@ -875,6 +897,7 @@ } uvec obs_in_node; + // it iterates over the observations in a node uvec::iterator it; @@ -883,8 +906,7 @@ for(i = 0; i < coef_values.size(); i++){ - if(VERBOSITY > 0) - Rcout << "moving obs in node " << i << std::endl; + Rcout << "moving obs in node " << i << std::endl; // if child_left == 0, it's a leaf (no need to find next child) if(child_left[i] != 0){ @@ -901,8 +923,6 @@ lincomb = prediction_data->x_submat(obs_in_node, coef_indices[i]) * coef_values[i]; - if(lincomb.size() == 0) stop("sommin wrong"); - it = obs_in_node.begin(); for(j = 0; j < lincomb.size(); ++j, ++it){ @@ -919,8 +939,6 @@ } - if(VERBOSITY > 0){ - uvec in_left = find(pred_leaf == child_left[i]); uvec in_right = find(pred_leaf == child_left[i]+1); @@ -929,17 +947,15 @@ Rcout << "No. to node " << child_left[i]+1 << ": "; Rcout << in_right.size() << std::endl << std::endl; - } - } } } - if(VERBOSITY > 0){ - Rcout << "---- done with leaf predictions ----" << std::endl; - } + if(oobag) pred_leaf.elem(rows_inbag).fill(max_nodes); + + Rcout << "---- done with leaf predictions ----" << std::endl; } @@ -950,11 +966,13 @@ void Tree::negate_coef(arma::uword pred_col){ for(uint j = 0; j < coef_indices.size(); ++j){ + for(uword k = 0; k < coef_indices[j].size(); ++k){ if(coef_indices[j][k] == pred_col){ coef_values[j][k] *= (-1); } } + } } diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp index b1e4d8a7..f23d309d 100644 --- a/src/TreeSurvival.cpp +++ b/src/TreeSurvival.cpp @@ -396,7 +396,7 @@ void TreeSurvival::sprout_leaf(uword node_id){ - if(verbosity > 3){ + if(verbosity > 2){ Rcout << "-- sprouting node " << node_id << " into a leaf"; Rcout << " (N = " << sum(w_node) << ")"; Rcout << std::endl; @@ -493,7 +493,11 @@ } while (i < leaf_data.n_rows); - if(VERBOSITY > 1) print_mat(leaf_data, "leaf_data", 10, 5); + if(verbosity > 3){ + mat tmp_mat = join_horiz(y_node, w_node); + print_mat(tmp_mat, "time & status & weights in this node", 10, 10); + print_mat(leaf_data, "leaf_data (showing up to 5 rows)", 5, 5); + } leaf_pred_indx[node_id] = leaf_data.col(0); leaf_pred_prob[node_id] = leaf_data.col(1); @@ -529,21 +533,7 @@ uvec::iterator it = pred_leaf_sort.begin(); - // oobag leaf prediction has zeros for inbag rows - // TODO: Change this to be max_nodes+1 - // (0 is a valid leaf for empty tree) - if(oobag){ - while(pred_leaf(*it) == 0 && it < pred_leaf_sort.end()){ - ++it; - } - } - - if(it == pred_leaf_sort.end()){ - if(VERBOSITY > 0){ - Rcout << "Tree was empty, no predictions were made" << std::endl; - } - return; - } + uword leaf_id = pred_leaf[*it]; double pred_t0; @@ -565,13 +555,15 @@ do { - uword leaf_id = pred_leaf[*it]; + Rcout << "leaf_id: " << leaf_id << std::endl; // copies of leaf data using same aux memory leaf_times = vec(leaf_pred_indx[leaf_id].begin(), leaf_pred_indx[leaf_id].size(), false); + Rcout << "leaf_times: " << leaf_times.t() << std::endl; + switch (pred_type) { case PRED_RISK: case PRED_SURVIVAL: { @@ -608,6 +600,8 @@ } + Rcout << "leaf_values: " << leaf_values.t() << std::endl; + // don't reset i in the loop b/c leaf_times ascend i = 0; @@ -669,6 +663,22 @@ if(pred_type == PRED_RISK) temp_vec = 1 - temp_vec; + // while (it < pred_leaf_sort.end()-1 && leaf_id == pred_leaf(*it)) { + // + // Rcout << "*it: " << *it << std::endl; + // + // (*pred_output).row(*it) += temp_vec.t(); + // + // Rcout << "it was safe" << std::endl; + // + // if(oobag) (*pred_denom)[*it]++; + // + // Rcout << "it was safe 2" << std::endl; + // + // ++it; + // + // }; + (*pred_output).row(*it) += temp_vec.t(); if(oobag) (*pred_denom)[*it]++; ++it; @@ -686,7 +696,12 @@ } - } while (it < pred_leaf_sort.end()); + leaf_id = pred_leaf(*it); + Rcout << "new pred_leaf: " << leaf_id << std::endl; + + } while (it < pred_leaf_sort.end() && (!oobag || leaf_id < max_nodes)); + + Rcout << "Made it out"; } diff --git a/src/globals.h b/src/globals.h index 3314053a..2174e531 100644 --- a/src/globals.h +++ b/src/globals.h @@ -96,7 +96,7 @@ const int VERBOSITY = 0; // Interval to print progress in seconds - const double STATUS_INTERVAL = 15.0; + const double STATUS_INTERVAL = 1.0; } // namespace aorsf diff --git a/src/orsf_oop.cpp b/src/orsf_oop.cpp index f4e0257e..db5aa4a5 100644 --- a/src/orsf_oop.cpp +++ b/src/orsf_oop.cpp @@ -152,10 +152,14 @@ } + // does the forest need to be grown or is it already grown? + bool grow_mode = loaded_forest.size() == 0; + forest->init(std::move(data), tree_seeds, n_tree, mtry, + grow_mode, vi_type, vi_max_pvalue, leaf_min_obs, @@ -181,93 +185,89 @@ n_thread, verbosity); - // Load forest object if in prediction mode - if(pred_mode){ + // Load forest object if it was already grown + if(!grow_mode){ - std::vector> cutpoint = loaded_forest["cutpoint"]; - std::vector> child_left = loaded_forest["child_left"]; - std::vector> coef_values = loaded_forest["coef_values"]; - std::vector> coef_indices = loaded_forest["coef_indices"]; - std::vector> leaf_summary = loaded_forest["leaf_summary"]; + std::vector> cutpoint = loaded_forest["cutpoint"]; + std::vector> child_left = loaded_forest["child_left"]; + std::vector> coef_values = loaded_forest["coef_values"]; + std::vector> coef_indices = loaded_forest["coef_indices"]; + std::vector> leaf_summary = loaded_forest["leaf_summary"]; + if(tree_type == TREE_SURVIVAL){ - if(tree_type == TREE_SURVIVAL){ + std::vector> leaf_pred_indx = loaded_forest["leaf_pred_indx"]; + std::vector> leaf_pred_prob = loaded_forest["leaf_pred_prob"]; + std::vector> leaf_pred_chaz = loaded_forest["leaf_pred_chaz"]; - std::vector> leaf_pred_indx = loaded_forest["leaf_pred_indx"]; - std::vector> leaf_pred_prob = loaded_forest["leaf_pred_prob"]; - std::vector> leaf_pred_chaz = loaded_forest["leaf_pred_chaz"]; + auto& temp = dynamic_cast(*forest); + temp.load(n_tree, cutpoint, child_left, coef_values, coef_indices, + leaf_pred_indx, leaf_pred_prob, leaf_pred_chaz, leaf_summary); - auto& temp = dynamic_cast(*forest); - temp.load(n_tree, cutpoint, child_left, coef_values, coef_indices, - leaf_pred_indx, leaf_pred_prob, leaf_pred_chaz, leaf_summary); + } } - } + if(run_forest){ forest->run(oobag); } - if(run_forest){ - forest->run(false, oobag); - } + if(pred_mode){ + result.push_back(forest->get_predictions(), "pred_new"); - if(pred_mode){ + } else { - result.push_back(forest->get_predictions(), "pred_new"); + if (oobag) result.push_back(forest->get_predictions(), "pred_oobag"); - } else { + List eval_oobag; + eval_oobag.push_back(forest->get_oobag_eval(), "stat_values"); + eval_oobag.push_back(oobag_eval_type_R, "stat_type"); + result.push_back(eval_oobag, "eval_oobag"); - if (oobag) result.push_back(forest->get_predictions(), "pred_oobag"); + } - List eval_oobag_out; - eval_oobag_out.push_back(forest->get_oobag_eval(), "stat_values"); - eval_oobag_out.push_back(oobag_eval_type_R, "stat_type"); - result.push_back(eval_oobag_out, "eval_oobag"); + if(write_forest){ + + List forest_out; + forest_out.push_back(forest->get_rows_oobag(), "rows_oobag"); + forest_out.push_back(forest->get_cutpoint(), "cutpoint"); + forest_out.push_back(forest->get_child_left(), "child_left"); + forest_out.push_back(forest->get_coef_indices(), "coef_indices"); + forest_out.push_back(forest->get_coef_values(), "coef_values"); + forest_out.push_back(forest->get_leaf_summary(), "leaf_summary"); + + if(tree_type == TREE_SURVIVAL){ + auto& temp = dynamic_cast(*forest); + forest_out.push_back(temp.get_leaf_pred_indx(), "leaf_pred_indx"); + forest_out.push_back(temp.get_leaf_pred_prob(), "leaf_pred_prob"); + forest_out.push_back(temp.get_leaf_pred_chaz(), "leaf_pred_chaz"); + } - } + result.push_back(forest_out, "forest"); + } + if(vi_type != VI_NONE){ - if(write_forest){ + vec vi_output; - List forest_out; - forest_out.push_back(forest->get_rows_oobag(), "rows_oobag"); - forest_out.push_back(forest->get_cutpoint(), "cutpoint"); - forest_out.push_back(forest->get_child_left(), "child_left"); - forest_out.push_back(forest->get_coef_indices(), "coef_indices"); - forest_out.push_back(forest->get_coef_values(), "coef_values"); - forest_out.push_back(forest->get_leaf_summary(), "leaf_summary"); + if(run_forest){ - if(tree_type == TREE_SURVIVAL){ - auto& temp = dynamic_cast(*forest); - forest_out.push_back(temp.get_leaf_pred_indx(), "leaf_pred_indx"); - forest_out.push_back(temp.get_leaf_pred_prob(), "leaf_pred_prob"); - forest_out.push_back(temp.get_leaf_pred_chaz(), "leaf_pred_chaz"); - // consider dropping unique_event_times; is it needed after grow()? - // result.push_back(forest->get_unique_event_times(), "unique_event_times"); - } + if(vi_type == VI_ANOVA){ - result.push_back(forest_out, "forest"); + vi_output = forest->get_vi_numer() / forest->get_vi_denom(); - } + } else { - if(vi_type != VI_NONE){ + vi_output = forest->get_vi_numer() / n_tree; - vec vi_output; + } - if(run_forest){ - if(vi_type == VI_ANOVA){ - vi_output = forest->get_vi_numer() / forest->get_vi_denom(); - } else { - vi_output = forest->get_vi_numer() / n_tree; } - } - - result.push_back(vi_output, "importance"); - - } + result.push_back(vi_output, "importance"); + } - return(result); + return(result); } diff --git a/src/utility.cpp b/src/utility.cpp index 77c72d3b..0b612ad9 100644 --- a/src/utility.cpp +++ b/src/utility.cpp @@ -25,7 +25,7 @@ if(x.n_cols < max_cols) ncol_print = x.n_cols-1; - Rcout << " ---- view of " << label << " ---- " << std::endl << std::endl; + Rcout << " -- " << label << std::endl << std::endl; Rcout << x.submat(0, 0, nrow_print, ncol_print); Rcout << std::endl << std::endl; @@ -38,7 +38,7 @@ uword n_print = max_elem-1; if(x.size() <= n_print) n_print = x.size()-1; - Rcout << " ---- view of " << label << " ---- " << std::endl << std::endl; + Rcout << " -- " << label << std::endl << std::endl; if(x.size() == 0){ Rcout << " empty vector"; @@ -57,7 +57,7 @@ uword n_print = max_elem-1; if(x.size() <= n_print) n_print = x.size()-1; - Rcout << " ---- view of " << label << " ---- " << std::endl << std::endl; + Rcout << " -- " << label << std::endl << std::endl; if(x.size() == 0){ Rcout << " empty vector"; diff --git a/tests/testthat/test-orsf.R b/tests/testthat/test-orsf.R index f189d28a..29824240 100644 --- a/tests/testthat/test-orsf.R +++ b/tests/testthat/test-orsf.R @@ -817,8 +817,9 @@ test_that( n_split = 1, n_retry = 0, mtry = 3, - leaf_min_events = 1, - leaf_min_obs = c(5, 10), + leaf_min_events = 5, + leaf_min_obs = c(10), + split_rule = c("logrank", "cstat"), split_min_events = 5, split_min_obs = 15, oobag_pred_type = c('none', 'risk', 'surv', 'chf', 'mort'), @@ -865,31 +866,50 @@ test_that( mtry = inputs$mtry[i], leaf_min_events = inputs$leaf_min_events[i], leaf_min_obs = inputs$leaf_min_obs[i], + split_rule = inputs$split_rule[i], split_min_events = inputs$split_min_events[i], split_min_obs = inputs$split_min_obs[i], oobag_pred_type = inputs$oobag_pred_type[i], oobag_pred_horizon = pred_horizon) - expect_s3_class(fit_cph, class = 'orsf_fit') - expect_equal(get_n_tree(fit_cph), inputs$n_tree[i]) - expect_equal(get_n_split(fit_cph), inputs$n_split[i]) - expect_equal(get_n_retry(fit_cph), inputs$n_retry[i]) - expect_equal(get_mtry(fit_cph), inputs$mtry[i]) - expect_equal(get_leaf_min_events(fit_cph), inputs$leaf_min_events[i]) - expect_equal(get_leaf_min_obs(fit_cph), inputs$leaf_min_obs[i]) - expect_equal(get_split_min_events(fit_cph), inputs$split_min_events[i]) - expect_equal(get_split_min_obs(fit_cph), inputs$split_min_obs[i]) - expect_equal(fit_cph$pred_horizon, pred_horizon) + expect_s3_class(fit, class = 'orsf_fit') + expect_equal(get_n_tree(fit), inputs$n_tree[i]) + expect_equal(get_n_split(fit), inputs$n_split[i]) + expect_equal(get_n_retry(fit), inputs$n_retry[i]) + expect_equal(get_mtry(fit), inputs$mtry[i]) + expect_equal(get_leaf_min_events(fit), inputs$leaf_min_events[i]) + expect_equal(get_leaf_min_obs(fit), inputs$leaf_min_obs[i]) + expect_equal(get_split_min_events(fit), inputs$split_min_events[i]) + expect_equal(get_split_min_obs(fit), inputs$split_min_obs[i]) + expect_equal(fit$pred_horizon, pred_horizon) - expect_length(fit_cph$forest$rows_oobag, n = get_n_tree(fit_cph)) + expect_length(fit$forest$rows_oobag, n = get_n_tree(fit)) if(inputs$oobag_pred_type[i] != 'none'){ - expect_length(fit_cph$eval_oobag$stat_values, length(pred_horizon)) - expect_equal(nrow(fit_cph$pred_oobag), get_n_obs(fit_cph)) + expect_length(fit$eval_oobag$stat_values, length(pred_horizon)) + expect_equal(nrow(fit$pred_oobag), get_n_obs(fit)) + + # these lengths should match for n_tree=1 + # b/c only the oobag rows of the first tree + # will get a prediction value. Note that the + # vectors themselves aren't equal b/c rows_oobag + # corresponds to the sorted version of the data. + expect_equal( + length(which(complete.cases(fit$pred_oobag))), + length(fit$forest$rows_oobag[[1]]) + ) + + oobag_preds <- na.omit(fit$pred_oobag) + + expect_true(all(oobag_preds >= 0)) + + if(inputs$oobag_pred_type[i] %in% c("risk", "surv")){ + expect_true(all(oobag_preds <= 1)) + } } else { - expect_equal(dim(fit_cph$eval_oobag$stat_values), c(0, 0)) + expect_equal(dim(fit$eval_oobag$stat_values), c(0, 0)) } } diff --git a/vignettes/oobag.Rmd b/vignettes/oobag.Rmd index bb08eafe..95a542b9 100644 --- a/vignettes/oobag.Rmd +++ b/vignettes/oobag.Rmd @@ -121,13 +121,13 @@ Second, you can pass your function into `orsf()`, and it will be used in place o fit <- orsf(data = pbc_orsf, formula = Surv(time, status) ~ . - id, - n_tree = 500, + n_tree = 50, oobag_pred_horizon = 2000, oobag_fun = oobag_fun_brier, - oobag_eval_every = 25) + oobag_eval_every = 1) plot( - x = seq(25, 500, by = 25), + x = seq(1, 50, by = 1), y = fit$eval_oobag$stat_values, main = 'Out-of-bag error computed after each new tree is grown.', sub = 'For the Brier score, lower values indicate more accurate predictions', @@ -165,7 +165,7 @@ plot( y = fit$eval_oobag$stat_values, main = 'Out-of-bag time-dependent AUC\ncomputed after each new tree is grown.', xlab = 'Number of trees grown', - ylab = "AUC at t = 3,500" + ylab = "AUC at t = 2,000" ) ```