Skip to content

Commit

Permalink
Merge pull request #27 from ropensci/issue23
Browse files Browse the repository at this point in the history
Issue23
  • Loading branch information
bcjaeger authored Oct 19, 2023
2 parents e406f2f + 3004869 commit c87fc58
Show file tree
Hide file tree
Showing 16 changed files with 450 additions and 108 deletions.
16 changes: 14 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

coxph_fit_exported <- function(x_node, y_node, w_node, method, cph_eps, cph_iter_max) {
.Call(`_aorsf_coxph_fit_exported`, x_node, y_node, w_node, method, cph_eps, cph_iter_max)
coxph_fit_exported <- function(x_node, y_node, w_node, method, epsilon, iter_max) {
.Call(`_aorsf_coxph_fit_exported`, x_node, y_node, w_node, method, epsilon, iter_max)
}

linreg_fit_exported <- function(x_node, y_node, w_node, do_scale, epsilon, iter_max) {
.Call(`_aorsf_linreg_fit_exported`, x_node, y_node, w_node, do_scale, epsilon, iter_max)
}

logreg_fit_exported <- function(x_node, y_node, w_node, do_scale, epsilon, iter_max) {
.Call(`_aorsf_logreg_fit_exported`, x_node, y_node, w_node, do_scale, epsilon, iter_max)
}

compute_cstat_exported_vec <- function(y, w, p, pred_is_risklike) {
Expand Down Expand Up @@ -37,6 +45,10 @@ x_submat_mult_beta_exported <- function(x, y, w, x_rows, x_cols, beta) {
.Call(`_aorsf_x_submat_mult_beta_exported`, x, y, w, x_rows, x_cols, beta)
}

scale_x_exported <- function(x, w) {
.Call(`_aorsf_scale_x_exported`, x, w)
}

cph_scale <- function(x, w) {
.Call(`_aorsf_cph_scale`, x, w)
}
Expand Down
37 changes: 0 additions & 37 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,6 @@
#'
#' Fit an oblique random survival forest
#'
#' @srrstats {G1.4} *documented with Roxygen*
#' @srrstats {G1.1} *aorsf is an improvement of the ORSF algorithm implemented in obliqueRSF, which was an extension of Hemant Ishwaran's random survival forest.*
#' @srrstats {G1.3} *linear combinations of inputs defined.*
#' @srrstats {G1.5} *orsf() will be used in publications to benchmark performance of the aorsf package in computation speed and prediction accuracy.*
#' @srrstats {G1.6} *orsf() should be used to compare performance claims with other packages.*
#' @srrstats {G2.1} *Inputs have indication of type in parentheticals. This format is used in all exported functions.*
#' @srrstats {G5.2a} *messages produced here (e.g., with `stop()`, `warning()`, `message()`) are unique and make effort to highlight the specific data elements that cause the error*
#' @srrstats {G2.0a} *secondary documentation of arg lengths. When an input has length 1, a parenthetical gives the specific type of value it should be and uses a singular description (e.g., an integer). When inputs have length > 1, a vector description is used (e.g., integer vector)*
#' @srrstats {ML1.0} *Documentation includes a subsection that makes clear conceptual distinction between train and test data*
#' @srrstats {ML3.3} *Properties and behaviours of aorsf models are explicitly compared with objects produced by other ML software in the "Introduction to aorsf" vignette.*
#' @srrstats {ML4.0} *orsf() is a unified single-function interface to model training. orsf_train() is able to receive as input an untrained model specified by orsf() when no_fit = TRUE. Models with categorically different specifications are able to be submitted to the same model training function.*
#' @srrstats {ML5.2, ML5.2a} *The structure and functionality of trained aorsf objects is documented through vignettes. In particular, basic functionality extending from the aorsf class is explicitly described in the "Introduction to aorsf" vignette, and additional functionality is documented in the "Out-of-bag predictions and evaluation" and "Compute partial dependence with ORSF" vignettes. Each vignettes demonstrates functionality clearly with example code.*
#' @srrstats {ML5.3} *Assessment of model performance is implemented through out-of-bag error, which is finalized after a model is trained*
#' @srrstats {ML5.4} *The "Out-of-bag predictions and evaluation" vignette shows how to implement built-in or user-specified functions for this functionality.*
#' @srrstats {ML1.1} *Training data are labelled as "train".*
#' @srrstats {G2.5} *factors used as predictors can be ordered and un-ordered.*
#' @srrstats {ML4.1b} *The value of out-of-bag error can be returned for every oobag_eval_every step.*
#' @srrstats {ML4.2} *The extraction of out-of-bag error is explicitly documented with example code in the "Out-of-bag predictions and evaluation" vignette.*
#' @srrstats {ML3.5b} *Users can specify the kind of loss function to assess distance between model estimates and desired output. This is discussed in detail in the "Out-of-bag predictions and evaluation" vignette.*
#' @srrstats {ML5.4a} *Harrell's C-statistic, an internally utilized metric for model performance, is clearly and distinctly documented and cited.*
#' @srrstats {ML5.4b} *It is possible to submit custom metrics to a model assessment function, and the ability to do so is clearly documented. The "Out-of-bag predictions and evaluation" vignette provides example code.*
#' @srrstats {ML2.0, ML2.0b} *orsf() enables pre-processing steps to be defined and parametrized without fitting a model when no_fit is TRUE, returning an object with a defined class minimally intended to implement a default `print` method which summarizes the model specifications.*
#' @srrstats {ML3.0} *Model specification can be implemented prior to actual model fitting or training*
#' @srrstats {ML3.0a} *As pre-processing, model specification, and training are controlled by the orsf() function, an input parameter (no_fit) enables models to be specified yet not fitted.*
#' @srrstats {ML3.0c} *when no_fit=TRUE, orsf() will return an object that can be directly trained using orsf_train().*
#' @srrstats {ML1.6a} *Explain why missing values are not admitted.*
#' @srrstats {G1.0} *Jaeger et al describes the ORSF algorithm that aorsf is based on. Note: aorsf uses a different approach to create linear combinations of inputs for speed reasons, but orsf_control_net() allows users to make ensembles that are very similar to obliqueRSF::ORSF().*
#' @srrstats {ML1.6b} *Explicit example showing how missing values may be imputed rather than discarded.*
#' @srrstats {ML6.0} *Reference section explicitly links to aorsf-bench, which includes training and testing stages, and which clearly indicates a need for distinct training and test data sets.*
#' @srrstats {ML6.1} *clearly document how aorsf can be embedded within a typical full ML workflow.*
#' @srrstats {ML6.1a} *Embed aorsf within a full workflow using tidymodels and tidyverse*
#' @srrstats {ML5.2b} *Documentation includes examples of how to save and re-load trained model objects for their re-use.*
#' @srrstats {ML2.3} *Values associated with transformations are recorded in the object returned by orsf()*
#' @srrstats {ML1.3} *Input data are partitioned as training (in-bag) and test (out-of-bag) data within orsf_fit().*
#' @srrstats {ML4.1} *orsf_fit() retains information on model-internal parameters.*
#' @srrstats {ML4.1a} *orsf_fit() output includes all model-internal parameters, specifically the linear combination coefficients.*
#'
#' @param data a `r roxy_data_allowed()` that contains the
#' relevant variables.
#'
Expand Down
8 changes: 7 additions & 1 deletion src/Coxph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,8 @@
// invert vmat
cholesky_invert(vmat);

vec pvalues(beta_current.size());

for (i=0; i < n_vars; i++) {

beta_current[i] = beta_new[i];
Expand All @@ -662,6 +664,10 @@
vmat.at(i, i) = 1.0;
}

pvalues[i] = R::pchisq(
pow(beta_current[i], 2) / vmat.at(i, i), 1, false, false
);

if(do_scale){
// return beta and variance to original scales
beta_current.at(i) *= scales[i];
Expand All @@ -673,7 +679,7 @@

}

return(join_horiz(beta_current, vmat.diag()));
return(join_horiz(beta_current, pvalues));

}

Expand Down
1 change: 0 additions & 1 deletion src/Data.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "globals.h"

using namespace arma;
using namespace Rcpp;

namespace aorsf {

Expand Down
69 changes: 34 additions & 35 deletions src/Forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "Tree.h"

using namespace arma;
using namespace Rcpp;

namespace aorsf {

Expand Down Expand Up @@ -36,7 +35,7 @@ void Forest::init(std::unique_ptr<Data> input_data,
double lincomb_alpha,
arma::uword lincomb_df_target,
arma::uword lincomb_ties_method,
RObject lincomb_R_function,
Rcpp::RObject lincomb_R_function,
// predictions
PredType pred_type,
bool pred_mode,
Expand Down Expand Up @@ -101,12 +100,12 @@ void Forest::init(std::unique_ptr<Data> input_data,
// # nocov start
if(verbosity > 1){

Rcout << "------------ input data dimensions ------------" << std::endl;
Rcout << "N observations total: " << data->get_n_rows() << std::endl;
Rcout << "N columns total: " << data->get_n_cols() << std::endl;
Rcout << "-----------------------------------------------";
Rcout << std::endl;
Rcout << std::endl;
Rcpp::Rcout << "------------ input data dimensions ------------" << std::endl;
Rcpp::Rcout << "N observations total: " << data->get_n_rows() << std::endl;
Rcpp::Rcout << "N columns total: " << data->get_n_cols() << std::endl;
Rcpp::Rcout << "-----------------------------------------------";
Rcpp::Rcout << std::endl;
Rcpp::Rcout << std::endl;

}
// # nocov end
Expand Down Expand Up @@ -262,9 +261,9 @@ void Forest::grow_single_thread(vec* vi_numer_ptr,
for (uint i = 0; i < n_tree; ++i) {

if(verbosity > 1){
Rcout << "------------ Growing tree " << i << " --------------";
Rcout << std::endl;
Rcout << std::endl;
Rcpp::Rcout << "------------ Growing tree " << i << " --------------";
Rcpp::Rcout << std::endl;
Rcpp::Rcout << std::endl;
}

trees[i]->grow(vi_numer_ptr, vi_denom_ptr);
Expand All @@ -282,15 +281,15 @@ void Forest::grow_single_thread(vec* vi_numer_ptr,
seconds time_from_start = duration_cast<seconds>(steady_clock::now() - start_time);
uint remaining_time = (1 / relative_progress - 1) * time_from_start.count();

Rcout << "Growing trees: ";
Rcout << round(100 * relative_progress) << "%. ";
Rcpp::Rcout << "Growing trees: ";
Rcpp::Rcout << round(100 * relative_progress) << "%. ";

if(progress < max_progress){
Rcout << "~ time remaining: ";
Rcout << beautifyTime(remaining_time) << ".";
Rcpp::Rcout << "~ time remaining: ";
Rcpp::Rcout << beautifyTime(remaining_time) << ".";
}

Rcout << std::endl;
Rcpp::Rcout << std::endl;

last_time = steady_clock::now();

Expand Down Expand Up @@ -408,15 +407,15 @@ void Forest::compute_oobag_vi_single_thread(vec* vi_numer_ptr) {
seconds time_from_start = duration_cast<seconds>(steady_clock::now() - start_time);
uint remaining_time = (1 / relative_progress - 1) * time_from_start.count();

Rcout << "Computing importance: ";
Rcout << round(100 * relative_progress) << "%. ";
Rcpp::Rcout << "Computing importance: ";
Rcpp::Rcout << round(100 * relative_progress) << "%. ";

if(progress < max_progress){
Rcout << "~ time remaining: ";
Rcout << beautifyTime(remaining_time) << ".";
Rcpp::Rcout << "~ time remaining: ";
Rcpp::Rcout << beautifyTime(remaining_time) << ".";
}

Rcout << std::endl;
Rcpp::Rcout << std::endl;

last_time = steady_clock::now();

Expand Down Expand Up @@ -660,12 +659,12 @@ void Forest::predict_single_thread(Data* prediction_data,

if(verbosity > 1){
if(oobag){
Rcout << "--- Computing oobag predictions: tree " << i << " ---";
Rcpp::Rcout << "--- Computing oobag predictions: tree " << i << " ---";
} else {
Rcout << "------ Computing predictions: tree " << i << " -----";
Rcpp::Rcout << "------ Computing predictions: tree " << i << " -----";
}
Rcout << std::endl;
Rcout << std::endl;
Rcpp::Rcout << std::endl;
Rcpp::Rcout << std::endl;
}

trees[i]->predict_leaf(prediction_data, oobag);
Expand Down Expand Up @@ -698,15 +697,15 @@ void Forest::predict_single_thread(Data* prediction_data,
seconds time_from_start = duration_cast<seconds>(steady_clock::now() - start_time);
uint remaining_time = (1 / relative_progress - 1) * time_from_start.count();

Rcout << "Computing predictions: ";
Rcout << round(100 * relative_progress) << "%. ";
Rcpp::Rcout << "Computing predictions: ";
Rcpp::Rcout << round(100 * relative_progress) << "%. ";

if(progress < max_progress){
Rcout << "~ time remaining: ";
Rcout << beautifyTime(remaining_time) << ".";
Rcpp::Rcout << "~ time remaining: ";
Rcpp::Rcout << beautifyTime(remaining_time) << ".";
}

Rcout << std::endl;
Rcpp::Rcout << std::endl;

last_time = steady_clock::now();

Expand Down Expand Up @@ -842,15 +841,15 @@ void Forest::show_progress(std::string operation, size_t max_progress) {
seconds time_from_start = duration_cast<seconds>(steady_clock::now() - start_time);
uint remaining_time = (1 / relative_progress - 1) * time_from_start.count();

Rcout << operation << ": ";
Rcout << round(100 * relative_progress) << "%. ";
Rcpp::Rcout << operation << ": ";
Rcpp::Rcout << round(100 * relative_progress) << "%. ";

if(progress < max_progress){
Rcout << "~ time remaining: ";
Rcout << beautifyTime(remaining_time) << ".";
Rcpp::Rcout << "~ time remaining: ";
Rcpp::Rcout << beautifyTime(remaining_time) << ".";
}

Rcout << std::endl;
Rcpp::Rcout << std::endl;

last_time = steady_clock::now();

Expand Down
6 changes: 3 additions & 3 deletions src/Forest.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Forest {
double lincomb_alpha,
arma::uword lincomb_df_target,
arma::uword lincomb_ties_method,
RObject lincomb_R_function,
Rcpp::RObject lincomb_R_function,
// predictions
PredType pred_type,
bool pred_mode,
Expand Down Expand Up @@ -288,7 +288,7 @@ class Forest {
arma::uword lincomb_iter_max;
arma::uword lincomb_df_target;
arma::uword lincomb_ties_method;
RObject lincomb_R_function;
Rcpp::RObject lincomb_R_function;

bool grow_mode;

Expand All @@ -311,7 +311,7 @@ class Forest {
arma::mat oobag_eval;
EvalType oobag_eval_type;
arma::uword oobag_eval_every;
RObject oobag_R_function;
Rcpp::RObject oobag_R_function;


// multi-threading
Expand Down
57 changes: 52 additions & 5 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,50 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// coxph_fit_exported
List coxph_fit_exported(arma::mat& x_node, arma::mat& y_node, arma::vec& w_node, int method, double cph_eps, arma::uword cph_iter_max);
RcppExport SEXP _aorsf_coxph_fit_exported(SEXP x_nodeSEXP, SEXP y_nodeSEXP, SEXP w_nodeSEXP, SEXP methodSEXP, SEXP cph_epsSEXP, SEXP cph_iter_maxSEXP) {
List coxph_fit_exported(arma::mat& x_node, arma::mat& y_node, arma::vec& w_node, int method, double epsilon, arma::uword iter_max);
RcppExport SEXP _aorsf_coxph_fit_exported(SEXP x_nodeSEXP, SEXP y_nodeSEXP, SEXP w_nodeSEXP, SEXP methodSEXP, SEXP epsilonSEXP, SEXP iter_maxSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat& >::type x_node(x_nodeSEXP);
Rcpp::traits::input_parameter< arma::mat& >::type y_node(y_nodeSEXP);
Rcpp::traits::input_parameter< arma::vec& >::type w_node(w_nodeSEXP);
Rcpp::traits::input_parameter< int >::type method(methodSEXP);
Rcpp::traits::input_parameter< double >::type cph_eps(cph_epsSEXP);
Rcpp::traits::input_parameter< arma::uword >::type cph_iter_max(cph_iter_maxSEXP);
rcpp_result_gen = Rcpp::wrap(coxph_fit_exported(x_node, y_node, w_node, method, cph_eps, cph_iter_max));
Rcpp::traits::input_parameter< double >::type epsilon(epsilonSEXP);
Rcpp::traits::input_parameter< arma::uword >::type iter_max(iter_maxSEXP);
rcpp_result_gen = Rcpp::wrap(coxph_fit_exported(x_node, y_node, w_node, method, epsilon, iter_max));
return rcpp_result_gen;
END_RCPP
}
// linreg_fit_exported
arma::mat linreg_fit_exported(arma::mat& x_node, arma::mat& y_node, arma::vec& w_node, bool do_scale, double epsilon, arma::uword iter_max);
RcppExport SEXP _aorsf_linreg_fit_exported(SEXP x_nodeSEXP, SEXP y_nodeSEXP, SEXP w_nodeSEXP, SEXP do_scaleSEXP, SEXP epsilonSEXP, SEXP iter_maxSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat& >::type x_node(x_nodeSEXP);
Rcpp::traits::input_parameter< arma::mat& >::type y_node(y_nodeSEXP);
Rcpp::traits::input_parameter< arma::vec& >::type w_node(w_nodeSEXP);
Rcpp::traits::input_parameter< bool >::type do_scale(do_scaleSEXP);
Rcpp::traits::input_parameter< double >::type epsilon(epsilonSEXP);
Rcpp::traits::input_parameter< arma::uword >::type iter_max(iter_maxSEXP);
rcpp_result_gen = Rcpp::wrap(linreg_fit_exported(x_node, y_node, w_node, do_scale, epsilon, iter_max));
return rcpp_result_gen;
END_RCPP
}
// logreg_fit_exported
arma::mat logreg_fit_exported(arma::mat& x_node, arma::mat& y_node, arma::vec& w_node, bool do_scale, double epsilon, arma::uword iter_max);
RcppExport SEXP _aorsf_logreg_fit_exported(SEXP x_nodeSEXP, SEXP y_nodeSEXP, SEXP w_nodeSEXP, SEXP do_scaleSEXP, SEXP epsilonSEXP, SEXP iter_maxSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat& >::type x_node(x_nodeSEXP);
Rcpp::traits::input_parameter< arma::mat& >::type y_node(y_nodeSEXP);
Rcpp::traits::input_parameter< arma::vec& >::type w_node(w_nodeSEXP);
Rcpp::traits::input_parameter< bool >::type do_scale(do_scaleSEXP);
Rcpp::traits::input_parameter< double >::type epsilon(epsilonSEXP);
Rcpp::traits::input_parameter< arma::uword >::type iter_max(iter_maxSEXP);
rcpp_result_gen = Rcpp::wrap(logreg_fit_exported(x_node, y_node, w_node, do_scale, epsilon, iter_max));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -138,6 +170,18 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// scale_x_exported
List scale_x_exported(arma::mat& x, arma::vec& w);
RcppExport SEXP _aorsf_scale_x_exported(SEXP xSEXP, SEXP wSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat& >::type x(xSEXP);
Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
rcpp_result_gen = Rcpp::wrap(scale_x_exported(x, w));
return rcpp_result_gen;
END_RCPP
}
// cph_scale
List cph_scale(arma::mat& x, arma::vec& w);
RcppExport SEXP _aorsf_cph_scale(SEXP xSEXP, SEXP wSEXP) {
Expand Down Expand Up @@ -207,6 +251,8 @@ END_RCPP

static const R_CallMethodDef CallEntries[] = {
{"_aorsf_coxph_fit_exported", (DL_FUNC) &_aorsf_coxph_fit_exported, 6},
{"_aorsf_linreg_fit_exported", (DL_FUNC) &_aorsf_linreg_fit_exported, 6},
{"_aorsf_logreg_fit_exported", (DL_FUNC) &_aorsf_logreg_fit_exported, 6},
{"_aorsf_compute_cstat_exported_vec", (DL_FUNC) &_aorsf_compute_cstat_exported_vec, 4},
{"_aorsf_compute_cstat_exported_uvec", (DL_FUNC) &_aorsf_compute_cstat_exported_uvec, 4},
{"_aorsf_compute_logrank_exported", (DL_FUNC) &_aorsf_compute_logrank_exported, 3},
Expand All @@ -215,6 +261,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_aorsf_sprout_node_survival_exported", (DL_FUNC) &_aorsf_sprout_node_survival_exported, 2},
{"_aorsf_find_rows_inbag_exported", (DL_FUNC) &_aorsf_find_rows_inbag_exported, 2},
{"_aorsf_x_submat_mult_beta_exported", (DL_FUNC) &_aorsf_x_submat_mult_beta_exported, 6},
{"_aorsf_scale_x_exported", (DL_FUNC) &_aorsf_scale_x_exported, 2},
{"_aorsf_cph_scale", (DL_FUNC) &_aorsf_cph_scale, 2},
{"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 44},
{NULL, NULL, 0}
Expand Down
4 changes: 2 additions & 2 deletions src/Tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@
// # nocov end
}

vec beta_var = beta.unsafe_col(1);
vec pvalues = beta.unsafe_col(1);

double pvalue;

Expand All @@ -1002,7 +1002,7 @@

if(beta_est[i] != 0){

pvalue = R::pchisq(pow(beta_est[i],2)/beta_var[i], 1, false, false);
pvalue = pvalues[i];

if(verbosity > 3){
// # nocov start
Expand Down
Loading

0 comments on commit c87fc58

Please sign in to comment.