diff --git a/DESCRIPTION b/DESCRIPTION index 2dfebbcd..59206104 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,7 +29,7 @@ Description: Fit, interpret, and make predictions with oblique random survival f License: MIT + file LICENSE Encoding: UTF-8 LazyData: true -Roxygen: list(markdown = TRUE, roclets = c ("namespace", "rd")) +Roxygen: list(markdown = TRUE) RoxygenNote: 7.2.3 LinkingTo: Rcpp, @@ -38,8 +38,9 @@ Imports: Rcpp, data.table, utils, - collapse, - R6 + collapse, + R6, + lifecycle URL: https://github.com/ropensci/aorsf, https://docs.ropensci.org/aorsf/ BugReports: https://github.com/ropensci/aorsf/issues/ diff --git a/NAMESPACE b/NAMESPACE index 9d2fee38..617adbae 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -5,10 +5,14 @@ S3method(predict,ObliqueForest) S3method(print,ObliqueForest) S3method(print,orsf_summary_uni) export(orsf) +export(orsf_control) +export(orsf_control_classification) export(orsf_control_cph) export(orsf_control_custom) export(orsf_control_fast) export(orsf_control_net) +export(orsf_control_regression) +export(orsf_control_survival) export(orsf_ice_inb) export(orsf_ice_new) export(orsf_ice_oob) @@ -29,4 +33,5 @@ import(R6) import(data.table) importFrom(Rcpp,sourceCpp) importFrom(collapse,"%==%") +importFrom(lifecycle,deprecated) useDynLib(aorsf, .registration = TRUE) diff --git a/R/aorsf-package.R b/R/aorsf-package.R index 40ce1538..8199ff7e 100644 --- a/R/aorsf-package.R +++ b/R/aorsf-package.R @@ -6,8 +6,9 @@ # The following block is used by usethis to automatically manage # roxygen namespace tags. Modify with care! ## usethis namespace: start -#' @importFrom Rcpp sourceCpp #' @importFrom collapse %==% +#' @importFrom lifecycle deprecated +#' @importFrom Rcpp sourceCpp #' @useDynLib aorsf, .registration = TRUE ## usethis namespace: end NULL diff --git a/R/orsf_R6.R b/R/orsf_R6.R index 159f3182..761c7012 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -407,7 +407,8 @@ ObliqueForest <- R6::R6Class( private$check_pred_horizon(pred_horizon, boundary_checks, pred_type) - if(is.null(pred_horizon)) pred_horizon <- 1 + pred_horizon <- pred_horizon %||% self$pred_horizon %||% 1 + pred_horizon_order <- order(pred_horizon) pred_horizon_ordered <- pred_horizon[pred_horizon_order] @@ -435,11 +436,11 @@ ObliqueForest <- R6::R6Class( private$prep_x() # y and w do not need to be prepped for prediction, # but they need to match orsf_cpp()'s expectations - private$y <- matrix(0, nrow = nrow(private$x), ncol = 1) + private$prep_y(placeholder = TRUE) private$w <- rep(1, nrow(private$x)) - if(oobag){ private$sort_inputs() } + if(oobag){ private$sort_inputs(sort_y = FALSE) } # the values in pred_spec need to be centered & scaled to match x_new, # which is also centered and scaled @@ -660,6 +661,8 @@ ObliqueForest <- R6::R6Class( if(self$tree_type == 'classification'){ setnames(out, old = 'pred_horizon', new = 'class') + out[, class := factor(class, levels = self$class_levels)] + setkey(out, class) } if(self$tree_type == 'survival' && pred_type != 'mort') @@ -878,13 +881,14 @@ ObliqueForest <- R6::R6Class( # To avoid this: include a DT[] after the last := in your function. pd_output[] - setcolorder(pd_output, c('variable', - 'importance', - 'value', - 'mean', - 'medn', - 'lwr', - 'upr')) + new_order <- c('variable', 'importance', 'value', + 'mean', 'medn', 'lwr', 'upr') + + if(self$tree_type == 'classification'){ + new_order <- insert_vals(new_order, 2, 'class') + } + + setcolorder(pd_output, new_order) structure( .Data = list(dt = pd_output, @@ -2316,14 +2320,14 @@ ObliqueForest <- R6::R6Class( }, - prep_y = function(){ + prep_y = function(placeholder = FALSE){ private$y <- select_cols(self$data, private$data_names$y) - if(self$na_action == 'omit') + if(self$na_action == 'omit' && !placeholder) private$y <- private$y[private$data_rows_complete, ] - private$prep_y_internal() + private$prep_y_internal(placeholder) }, @@ -2686,11 +2690,16 @@ ObliqueForestSurvival <- R6::R6Class( }, - sort_inputs = function(){ + sort_inputs = function(sort_x = TRUE, + sort_y = TRUE, + sort_w = TRUE){ - private$x <- private$x[private$data_row_sort, , drop = FALSE] - private$y <- private$y[private$data_row_sort, , drop = FALSE] - private$w <- private$w[private$data_row_sort] + if(sort_x) + private$x <- private$x[private$data_row_sort, , drop = FALSE] + if(sort_y) + private$y <- private$y[private$data_row_sort, , drop = FALSE] + if(sort_w) + private$w <- private$w[private$data_row_sort] }, @@ -2698,6 +2707,10 @@ ObliqueForestSurvival <- R6::R6Class( self$tree_type <- "survival" + if(!is.function(self$control$lincomb_R_function) && + self$control$lincomb_type == 'net'){ + self$control$lincomb_R_function <- penalized_cph + } self$split_rule <- self$split_rule %||% 'logrank' self$pred_type <- self$pred_type %||% 'surv' @@ -2761,7 +2774,13 @@ ObliqueForestSurvival <- R6::R6Class( } }, - prep_y_internal = function(){ + prep_y_internal = function(placeholder = FALSE){ + + + if(placeholder){ + private$y <- matrix(0, ncol = 2, nrow = 1) + return() + } y <- private$y cols <- names(y) @@ -2979,6 +2998,11 @@ ObliqueForestClassification <- R6::R6Class( self$tree_type <- "classification" + if(!is.function(self$control$lincomb_R_function) && + self$control$lincomb_type == 'net'){ + self$control$lincomb_R_function <- penalized_logreg + } + self$split_rule <- self$split_rule %||% 'gini' self$pred_type <- self$pred_type %||% 'prob' self$split_min_stat <- self$split_min_stat %||% @@ -3003,7 +3027,12 @@ ObliqueForestClassification <- R6::R6Class( }, - prep_y_internal = function(){ + prep_y_internal = function(placeholder = FALSE){ + + if(placeholder){ + private$y <- matrix(0, ncol = self$n_class-1, nrow = 1) + return() + } # y is always 1 column for classification (right?) y <- private$y[[1]] diff --git a/R/orsf_control.R b/R/orsf_control.R index b6908047..b85e5e72 100644 --- a/R/orsf_control.R +++ b/R/orsf_control.R @@ -72,6 +72,8 @@ orsf_control_fast <- function(method = 'efron', #' to create linear combinations of predictor variables #' while fitting an [orsf] model. #' +#' `r lifecycle::badge('superseded')` +#' #' @inheritParams orsf_control_fast #' #' @param eps (_double_) When using Newton Raphson scoring to identify @@ -85,6 +87,7 @@ orsf_control_fast <- function(method = 'efron', #' (see `eps` above) or the number of attempted iterations is equal to #' `iter_max`. #' +#' #' @return an object of class `'orsf_control'`, which should be used as #' an input for the `control` argument of [orsf]. #' @@ -117,6 +120,11 @@ orsf_control_cph <- function(method = 'efron', iter_max = 20, ...){ + lifecycle::deprecate_warn( + when = "0.1.2", + "orsf_control_custom()", + details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = your_function)`" + ) method <- tolower(method) @@ -144,6 +152,8 @@ orsf_control_cph <- function(method = 'efron', #' Use regularized Cox proportional hazard models to identify linear #' combinations of input variables while fitting an [orsf] model. #' +#' `r lifecycle::badge('superseded')` +#' #' @param alpha (_double_) The elastic net mixing parameter. A value of 1 gives the #' lasso penalty, and a value of 0 gives the ridge penalty. If multiple #' values of alpha are given, then a penalized model is fit using each @@ -153,6 +163,7 @@ orsf_control_cph <- function(method = 'efron', #' #' @param ... `r roxy_dots()` #' +#' #' @inherit orsf_control_cph return #' #' @details @@ -169,8 +180,6 @@ orsf_control_cph <- function(method = 'efron', #' #' `r roxy_cite_simon_2011()` #' -#' -#' #' @examples #' #' # orsf_control_net() is considerably slower than orsf_control_cph(), @@ -186,6 +195,12 @@ orsf_control_net <- function(alpha = 1/2, df_target = NULL, ...){ + lifecycle::deprecate_warn( + when = "0.1.2", + "orsf_control_custom()", + details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = your_function)`" + ) + check_dots(list(...), orsf_control_net) check_control_net(alpha, df_target) @@ -202,6 +217,8 @@ orsf_control_net <- function(alpha = 1/2, #' Custom ORSF control #' +#' `r lifecycle::badge('superseded')` +#' #' @param beta_fun (_function_) a function to define coefficients used #' in linear combinations of predictor variables. `beta_fun` must accept #' three inputs named `x_node`, `y_node` and `w_node`, and should expect @@ -219,14 +236,22 @@ orsf_control_net <- function(alpha = 1/2, #' #' @inherit orsf_control_cph return #' +#' #' @export #' +#' #' @family orsf_control #' #' @includeRmd Rmd/orsf_control_custom_examples.Rmd orsf_control_custom <- function(beta_fun, ...){ + lifecycle::deprecate_warn( + when = "0.1.2", + "orsf_control_custom()", + details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = your_function)`" + ) + check_dots(list(...), .f = orsf_control_custom) check_beta_fun(beta_fun) @@ -305,7 +330,7 @@ orsf_control_custom <- function(beta_fun, ...){ #' #' @param ... `r roxy_dots()` #' -#' @noRd +#' @family orsf_control #' #' @details #' @@ -327,7 +352,7 @@ orsf_control_custom <- function(beta_fun, ...){ #' - `lincomb_ties_method`: method for ties in survival time #' - `lincomb_R_function`: R function for custom splits #' -#' +#' @export #' orsf_control <- function(tree_type, method, @@ -339,14 +364,118 @@ orsf_control <- function(tree_type, epsilon, ...){ + check_arg_type(arg_value = method, + arg_name = 'method', + expected_type = c('character', 'function')) + custom <- is.function(method) + if(!custom){ + + check_arg_is_valid(arg_value = method, + arg_name = 'method', + valid_options = c("glm", "net")) + + check_arg_length(arg_value = method, + arg_name = 'method', + expected_length = 1) + + } else { + + check_beta_fun(method) + + } + + check_arg_type(arg_value = scale_x, + arg_name = 'scale_x', + expected_type = 'logical') + + check_arg_length(arg_value = scale_x, + arg_name = 'scale_x', + expected_length = 1) + + if(!is.null(ties)){ + + check_arg_type(arg_value = ties, + arg_name = 'ties', + expected_type = 'character') + + check_arg_is_valid(arg_value = ties, + arg_name = 'ties', + valid_options = c("breslow", "efron")) + + } + + check_arg_type(arg_value = net_mix, + arg_name = 'net_mix', + expected_type = 'numeric') + + check_arg_gteq(arg_value = net_mix, + arg_name = 'net_mix', + bound = 0) + + check_arg_lteq(arg_value = net_mix, + arg_name = 'net_mix', + bound = 1) + + check_arg_length(arg_value = net_mix, + arg_name = 'net_mix', + expected_length = 1) + + if(!is.null(target_df)){ + + check_arg_type(arg_value = target_df, + arg_name = 'target_df', + expected_type = 'numeric') + + check_arg_is_integer(arg_value = target_df, + arg_name = 'target_df') + + } + + + check_arg_type(arg_value = max_iter, + arg_name = 'max_iter', + expected_type = 'numeric') + + check_arg_is_integer(arg_value = max_iter, + arg_name = 'max_iter') + + check_arg_gteq(arg_value = max_iter, + arg_name = 'max_iter', + bound = 1) + + check_arg_length(arg_value = max_iter, + arg_name = 'max_iter', + expected_length = 1) + + check_arg_type(arg_value = epsilon, + arg_name = 'epsilon', + expected_type = 'numeric') + + check_arg_gt(arg_value = epsilon, + arg_name = 'epsilon', + bound = 0) + + check_arg_length(arg_value = epsilon, + arg_name = 'epsilon', + expected_length = 1) + if(custom){ + lincomb_R_function <- method + } else if (method == 'net') { - lincomb_R_function <- penalized_cph + + lincomb_R_function <- switch(tree_type, + 'survival' = penalized_cph, + 'classification' = penalized_logreg, + 'unknown' = 'unknown') + } else { + lincomb_R_function <- function(x) x + } structure( @@ -367,7 +496,8 @@ orsf_control <- function(tree_type, } - +#' @rdname orsf_control +#' @export orsf_control_classification <- function(method = 'glm', scale_x = TRUE, net_mix = 0.5, @@ -390,7 +520,8 @@ orsf_control_classification <- function(method = 'glm', } - +#' @rdname orsf_control +#' @export orsf_control_regression <- function(method = 'glm', scale_x = TRUE, net_mix = 0.5, @@ -413,7 +544,8 @@ orsf_control_regression <- function(method = 'glm', } - +#' @rdname orsf_control +#' @export orsf_control_survival <- function(method = 'glm', scale_x = TRUE, ties = 'efron', diff --git a/R/orsf_print.R b/R/orsf_print.R index 77c2eb69..29e98b4c 100644 --- a/R/orsf_print.R +++ b/R/orsf_print.R @@ -35,7 +35,6 @@ print.ObliqueForest <- function(x, ...){ x$print() - invisible(x) } diff --git a/R/orsf_summary.R b/R/orsf_summary.R index ed60d6df..32da105a 100644 --- a/R/orsf_summary.R +++ b/R/orsf_summary.R @@ -109,10 +109,15 @@ print.orsf_summary_uni <- function(x, n_variables = NULL, ...){ 'prob' = "Probability" ) - msg_btm <- paste("Predicted", tolower(pred_label), - "at time t =", x$pred_horizon, - "for top", n_variables, - "predictors") + extra_surv_text <- "" + + if(!is.null(x$pred_horizon)) + extra_surv_text <- paste0("at time t = ", x$pred_horizon, " ") + + msg_btm <- paste0( + "Predicted ", tolower(pred_label), " ", extra_surv_text, + "for top ", n_variables, " predictors" + ) .sd_orig <- c("value", "mean", diff --git a/R/penalized_cph.R b/R/penalized_cph.R index 733d09c5..33663b0a 100644 --- a/R/penalized_cph.R +++ b/R/penalized_cph.R @@ -22,6 +22,19 @@ #' alpha = 1/2, #' df_target = 2 #' ) +#' +#' penalized_logreg( +#' x_node = as.matrix(penguins_orsf[, c('bill_length_mm', +#' 'bill_depth_mm', +#' 'flipper_length_mm', +#' 'body_mass_g')]), +#' y_node = as.matrix(as.numeric(penguins_orsf$species == 'Adelie')), +#' w_node = rep(1, nrow(penguins_orsf)), +#' alpha = 1/2, +#' df_target = 2 +#' ) + + penalized_cph <- function(x_node, y_node, @@ -31,13 +44,47 @@ penalized_cph <- function(x_node, colnames(y_node) <- c('time', 'status') + penalized_fitter(x_node = x_node, + y_node = y_node, + w_node = w_node, + alpha = alpha, + df_target = df_target, + family = "cox") + +} + +penalized_logreg <- function(x_node, + y_node, + w_node, + alpha, + df_target){ + + y_node <- as.factor(y_node) + w_node <- as.numeric(w_node) + + penalized_fitter(x_node = x_node, + y_node = y_node, + w_node = w_node, + alpha = alpha, + df_target = df_target, + family = "binomial") + +} + +penalized_fitter <- function(x_node, + y_node, + w_node, + alpha, + df_target, + family){ + suppressWarnings( fit <- try( glmnet::glmnet(x = x_node, y = y_node, weights = w_node, alpha = alpha, - family = "cox"), + family = family), silent = TRUE ) ) diff --git a/R/round_magnitude.R b/R/round_magnitude.R index ba46e8e9..a39fb3ff 100644 --- a/R/round_magnitude.R +++ b/R/round_magnitude.R @@ -19,7 +19,7 @@ round_magnitude <- function(x){ # take absolute value to round based on magnitude x_abs <- abs(x) - breaks <- c(0, 1, 10, Inf) + breaks <- c(0, 10, 100, Inf) decimals <- c(2, 1, 0) # x_cuts create boundary categories for rounding diff --git a/man/figures/lifecycle-archived.svg b/man/figures/lifecycle-archived.svg new file mode 100644 index 00000000..745ab0c7 --- /dev/null +++ b/man/figures/lifecycle-archived.svg @@ -0,0 +1,21 @@ + + lifecycle: archived + + + + + + + + + + + + + + + lifecycle + + archived + + diff --git a/man/figures/lifecycle-defunct.svg b/man/figures/lifecycle-defunct.svg new file mode 100644 index 00000000..d5c9559e --- /dev/null +++ b/man/figures/lifecycle-defunct.svg @@ -0,0 +1,21 @@ + + lifecycle: defunct + + + + + + + + + + + + + + + lifecycle + + defunct + + diff --git a/man/figures/lifecycle-deprecated.svg b/man/figures/lifecycle-deprecated.svg new file mode 100644 index 00000000..b61c57c3 --- /dev/null +++ b/man/figures/lifecycle-deprecated.svg @@ -0,0 +1,21 @@ + + lifecycle: deprecated + + + + + + + + + + + + + + + lifecycle + + deprecated + + diff --git a/man/figures/lifecycle-experimental.svg b/man/figures/lifecycle-experimental.svg new file mode 100644 index 00000000..5d88fc2c --- /dev/null +++ b/man/figures/lifecycle-experimental.svg @@ -0,0 +1,21 @@ + + lifecycle: experimental + + + + + + + + + + + + + + + lifecycle + + experimental + + diff --git a/man/figures/lifecycle-maturing.svg b/man/figures/lifecycle-maturing.svg new file mode 100644 index 00000000..897370ec --- /dev/null +++ b/man/figures/lifecycle-maturing.svg @@ -0,0 +1,21 @@ + + lifecycle: maturing + + + + + + + + + + + + + + + lifecycle + + maturing + + diff --git a/man/figures/lifecycle-questioning.svg b/man/figures/lifecycle-questioning.svg new file mode 100644 index 00000000..7c1721d0 --- /dev/null +++ b/man/figures/lifecycle-questioning.svg @@ -0,0 +1,21 @@ + + lifecycle: questioning + + + + + + + + + + + + + + + lifecycle + + questioning + + diff --git a/man/figures/lifecycle-soft-deprecated.svg b/man/figures/lifecycle-soft-deprecated.svg new file mode 100644 index 00000000..9c166ff3 --- /dev/null +++ b/man/figures/lifecycle-soft-deprecated.svg @@ -0,0 +1,21 @@ + + lifecycle: soft-deprecated + + + + + + + + + + + + + + + lifecycle + + soft-deprecated + + diff --git a/man/figures/lifecycle-stable.svg b/man/figures/lifecycle-stable.svg new file mode 100644 index 00000000..9bf21e76 --- /dev/null +++ b/man/figures/lifecycle-stable.svg @@ -0,0 +1,29 @@ + + lifecycle: stable + + + + + + + + + + + + + + + + lifecycle + + + + stable + + + diff --git a/man/figures/lifecycle-superseded.svg b/man/figures/lifecycle-superseded.svg new file mode 100644 index 00000000..db8d757f --- /dev/null +++ b/man/figures/lifecycle-superseded.svg @@ -0,0 +1,21 @@ + + lifecycle: superseded + + + + + + + + + + + + + + + lifecycle + + superseded + + diff --git a/man/orsf_control.Rd b/man/orsf_control.Rd new file mode 100644 index 00000000..8d41dde5 --- /dev/null +++ b/man/orsf_control.Rd @@ -0,0 +1,147 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/orsf_control.R +\name{orsf_control} +\alias{orsf_control} +\alias{orsf_control_classification} +\alias{orsf_control_regression} +\alias{orsf_control_survival} +\title{Oblique random forest control} +\usage{ +orsf_control( + tree_type, + method, + scale_x, + ties, + net_mix, + target_df, + max_iter, + epsilon, + ... +) + +orsf_control_classification( + method = "glm", + scale_x = TRUE, + net_mix = 0.5, + target_df = NULL, + max_iter = 20, + epsilon = 1e-09, + ... +) + +orsf_control_regression( + method = "glm", + scale_x = TRUE, + net_mix = 0.5, + target_df = NULL, + max_iter = 20, + epsilon = 1e-09, + ... +) + +orsf_control_survival( + method = "glm", + scale_x = TRUE, + ties = "efron", + net_mix = 0.5, + target_df = NULL, + max_iter = 20, + epsilon = 1e-09, + ... +) +} +\arguments{ +\item{tree_type}{(\emph{character}) the type of tree. Valid options are +\itemize{ +\item "classification", i.e., categorical outcomes +\item "regression", i.e., continuous outcomes +\item "survival", i.e., time-to event outcomes +}} + +\item{method}{(\emph{character} or \emph{function}) how to identify linear +linear combinations of predictors. If \code{method} is a character value, +it must be one of: +\itemize{ +\item 'glm': linear, logistic, and cox regression +\item 'net': same as 'glm' but with penalty terms +\item 'pca': principal component analysis +\item 'random': random draw from uniform distribution +} + +If \code{method} is a \emph{function}, it will be used to identify linear +combinations of predictor variables. \code{method} must in this case accept +three inputs named \code{x_node}, \code{y_node} and \code{w_node}, and should expect +the following types and dimensions: +\itemize{ +\item \code{x_node} (\emph{matrix}; \code{n} rows, \code{p} columns) +\item \code{y_node} (\emph{matrix}; \code{n} rows, \code{2} columns) +\item \code{w_node} (\emph{matrix}; \code{n} rows, \code{1} column) +} + +In addition, \code{method} must return a matrix with p rows and 1 column.} + +\item{scale_x}{(\emph{logical}) if \code{TRUE}, values of predictors will be +scaled prior to each instance of finding a linear combination of +predictors, using summary values from the data in the current node +of the decision tree.} + +\item{ties}{(\emph{character}) a character string specifying the method +for tie handling. Only relevant when modeling survival outcomes +and using a method that engages with tied outcome times. +If there are no ties, all the methods are equivalent. Valid options +are 'breslow' and 'efron'. The Efron approximation is the default +because it is more accurate when dealing with tied event times and +has similar computational efficiency compared to the Breslow method.} + +\item{net_mix}{(\emph{double}) The elastic net mixing parameter. A value of 1 +gives the lasso penalty, and a value of 0 gives the ridge penalty. If +multiple values of alpha are given, then a penalized model is fit using +each alpha value prior to splitting a node.} + +\item{target_df}{(\emph{integer}) Preferred number of variables used in each +linear combination. For example, with \code{mtry} of 5 and \code{target_df} of 3, +we sample 5 predictors and look for the best linear combination using +3 of them.} + +\item{max_iter}{(\emph{integer}) iteration continues until convergence +(see \code{eps} above) or the number of attempted iterations is equal to +\code{iter_max}.} + +\item{epsilon}{(\emph{double}) When using most modeling based method, +iteration continues in the algorithm until the relative change in +some kind of objective is less than \code{epsilon}, or the absolute +change is less than \code{sqrt(epsilon)}.} + +\item{...}{Further arguments passed to or from other methods (not currently used).} +} +\value{ +an object of class \code{'orsf_control'}, which should be used as +an input for the \code{control} argument of \link{orsf}. Components are: +\itemize{ +\item \code{tree_type}: type of trees to fit +\item \code{lincomb_type}: method for linear combinations +\item \code{lincomb_eps}: epsilon for convergence +\item \code{lincomb_iter_max}: max iterations +\item \code{lincomb_scale}: to scale or not. +\item \code{lincomb_alpha}: mixing parameter +\item \code{lincomb_df_target}: target degrees of freedom +\item \code{lincomb_ties_method}: method for ties in survival time +\item \code{lincomb_R_function}: R function for custom splits +} +} +\description{ +Oblique random forest control +} +\details{ +Adjust \code{scale_x} \emph{at your own risk}. Setting \code{scale_x = FALSE} will +reduce computation time but will also make the \code{orsf} model dependent +on the scale of your data, which is why the default value is \code{TRUE}. +} +\seealso{ +linear combination control functions +\code{\link{orsf_control_cph}()}, +\code{\link{orsf_control_custom}()}, +\code{\link{orsf_control_fast}()}, +\code{\link{orsf_control_net}()} +} +\concept{orsf_control} diff --git a/man/orsf_control_cph.Rd b/man/orsf_control_cph.Rd index 73d527a5..c1a38efb 100644 --- a/man/orsf_control_cph.Rd +++ b/man/orsf_control_cph.Rd @@ -37,6 +37,8 @@ to create linear combinations of predictor variables while fitting an \link{orsf} model. } \details{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#superseded}{\figure{lifecycle-superseded.svg}{options: alt='[Superseded]'}}}{\strong{[Superseded]}} + code from the \href{https://github.com/therneau/survival/blob/master/src/coxfit6.c}{survival package} was modified to make this routine. @@ -59,6 +61,7 @@ Springer, New York, NY. DOI: 10.1007/978-1-4757-3294-8_3 linear combination control functions \code{\link{orsf_control_custom}()}, \code{\link{orsf_control_fast}()}, -\code{\link{orsf_control_net}()} +\code{\link{orsf_control_net}()}, +\code{\link{orsf_control}()} } \concept{orsf_control} diff --git a/man/orsf_control_custom.Rd b/man/orsf_control_custom.Rd index 93126e5e..921ae356 100644 --- a/man/orsf_control_custom.Rd +++ b/man/orsf_control_custom.Rd @@ -28,7 +28,7 @@ an object of class \code{'orsf_control'}, which should be used as an input for the \code{control} argument of \link{orsf}. } \description{ -Custom ORSF control +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#superseded}{\figure{lifecycle-superseded.svg}{options: alt='[Superseded]'}}}{\strong{[Superseded]}} } \section{Examples}{ Two customized functions to identify linear combinations of predictors @@ -55,8 +55,17 @@ fit_rando <- orsf(pbc_orsf, Surv(time, status) ~ . - id, control = orsf_control_custom(beta_fun = f_rando), n_tree = 500) +}\if{html}{\out{}} + +\if{html}{\out{
}}\preformatted{## Warning: `orsf_control_custom()` was deprecated in aorsf 0.1.2. +## i Please use the appropriate survival, classification, or regression control +## function instead. E.g., `orsf_control_survival(method = your_function)` +## This warning is displayed once every 8 hours. +## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was +## generated. +}\if{html}{\out{
}} -fit_rando +\if{html}{\out{
}}\preformatted{fit_rando }\if{html}{\out{
}} \if{html}{\out{
}}\preformatted{## ---------- Oblique random survival forest @@ -157,6 +166,7 @@ The PCA ORSF does quite well! (higher IPA is better) linear combination control functions \code{\link{orsf_control_cph}()}, \code{\link{orsf_control_fast}()}, -\code{\link{orsf_control_net}()} +\code{\link{orsf_control_net}()}, +\code{\link{orsf_control}()} } \concept{orsf_control} diff --git a/man/orsf_control_fast.Rd b/man/orsf_control_fast.Rd index 7d92ed11..9440bb97 100644 --- a/man/orsf_control_fast.Rd +++ b/man/orsf_control_fast.Rd @@ -47,6 +47,7 @@ orsf(data = pbc_orsf, linear combination control functions \code{\link{orsf_control_cph}()}, \code{\link{orsf_control_custom}()}, -\code{\link{orsf_control_net}()} +\code{\link{orsf_control_net}()}, +\code{\link{orsf_control}()} } \concept{orsf_control} diff --git a/man/orsf_control_net.Rd b/man/orsf_control_net.Rd index 1bc70865..427c7ad7 100644 --- a/man/orsf_control_net.Rd +++ b/man/orsf_control_net.Rd @@ -25,6 +25,8 @@ Use regularized Cox proportional hazard models to identify linear combinations of input variables while fitting an \link{orsf} model. } \details{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#superseded}{\figure{lifecycle-superseded.svg}{options: alt='[Superseded]'}}}{\strong{[Superseded]}} + \code{df_target} has to be less than \code{mtry}, which is a separate argument in \link{orsf} that indicates the number of variables chosen at random prior to finding a linear combination of those variables. @@ -47,6 +49,7 @@ Simon N, Friedman J, Hastie T, Tibshirani R. Regularization paths for Cox's prop linear combination control functions \code{\link{orsf_control_cph}()}, \code{\link{orsf_control_custom}()}, -\code{\link{orsf_control_fast}()} +\code{\link{orsf_control_fast}()}, +\code{\link{orsf_control}()} } \concept{orsf_control} diff --git a/man/orsf_ice_oob.Rd b/man/orsf_ice_oob.Rd index ad2a1968..e7788b77 100644 --- a/man/orsf_ice_oob.Rd +++ b/man/orsf_ice_oob.Rd @@ -112,60 +112,3 @@ You can compute individual conditional expectations three ways using a random fo See examples for more details } -\section{Examples}{ -Begin by fitting an ORSF ensemble - -\if{html}{\out{
}}\preformatted{library(aorsf) - -set.seed(329) - -fit <- orsf(data = pbc_orsf, formula = Surv(time, status) ~ . - id) - -fit -}\if{html}{\out{
}} - -\if{html}{\out{
}}\preformatted{## ---------- Oblique random survival forest -## -## Linear combinations: Accelerated Cox regression -## N observations: 276 -## N events: 111 -## N trees: 500 -## N predictors total: 17 -## N predictors per node: 5 -## Average leaves per tree: 21.026 -## Min observations in leaf: 5 -## Min events in leaf: 1 -## OOB stat value: 0.84 -## OOB stat type: Harrell's C-index -## Variable importance: anova -## -## ----------------------------------------- -}\if{html}{\out{
}} - -Use the ensemble to compute ICE values using out-of-bag predictions: - -\if{html}{\out{
}}\preformatted{pred_spec <- list(bili = seq(1, 10, length.out = 25)) - -ice_oob <- orsf_ice_oob(fit, pred_spec, boundary_checks = FALSE) - -ice_oob -}\if{html}{\out{
}} - -\if{html}{\out{
}}\preformatted{## id_variable id_row pred_horizon bili pred -## 1: 1 1 1 1 1 -## 2: 1 2 1 1 1 -## 3: 1 3 1 1 1 -## 4: 1 4 1 1 1 -## 5: 1 5 1 1 1 -## --- -## 6896: 25 272 1 10 1 -## 6897: 25 273 1 10 1 -## 6898: 25 274 1 10 1 -## 6899: 25 275 1 10 1 -## 6900: 25 276 1 10 1 -}\if{html}{\out{
}} - -Much more detailed examples are given in the -\href{https://docs.ropensci.org/aorsf/articles/pd.html#individual-conditional-expectations-ice}{vignette} -} - diff --git a/man/orsf_pd_oob.Rd b/man/orsf_pd_oob.Rd index 854b7bfb..f5a4672c 100644 --- a/man/orsf_pd_oob.Rd +++ b/man/orsf_pd_oob.Rd @@ -133,85 +133,6 @@ See examples for more details \details{ Partial dependence has a number of \href{https://christophm.github.io/interpretable-ml-book/pdp.html#disadvantages-5}{known limitations and assumptions} that users should be aware of (see Hooker, 2021). In particular, partial dependence is less intuitive when >2 predictors are examined jointly, and it is assumed that the feature(s) for which the partial dependence is computed are not correlated with other features (this is likely not true in many cases). Accumulated local effect plots can be used (see \href{https://christophm.github.io/interpretable-ml-book/ale.html}{here}) in the case where feature independence is not a valid assumption. } -\section{Examples}{ -Begin by fitting an ORSF ensemble: - -\if{html}{\out{
}}\preformatted{library(aorsf) - -set.seed(329730) - -index_train <- sample(nrow(pbc_orsf), 150) - -pbc_orsf_train <- pbc_orsf[index_train, ] -pbc_orsf_test <- pbc_orsf[-index_train, ] - -fit <- orsf(data = pbc_orsf_train, - formula = Surv(time, status) ~ . - id, - oobag_pred_horizon = 365.25 * 5) -}\if{html}{\out{
}} -\subsection{Three ways to compute PD and ICE}{ - -You can compute partial dependence and ICE three ways with \code{aorsf}: -\itemize{ -\item using in-bag predictions for the training data - -\if{html}{\out{
}}\preformatted{pd_train <- orsf_pd_inb(fit, pred_spec = list(bili = 1:5)) - -pd_train -}\if{html}{\out{
}} - -\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr -## 1: 1 1 1 1 1 1 -## 2: 1 2 1 1 1 1 -## 3: 1 3 1 1 1 1 -## 4: 1 4 1 1 1 1 -## 5: 1 5 1 1 1 1 -}\if{html}{\out{
}} -\item using out-of-bag predictions for the training data - -\if{html}{\out{
}}\preformatted{pd_train <- orsf_pd_oob(fit, pred_spec = list(bili = 1:5)) - -pd_train -}\if{html}{\out{
}} - -\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr -## 1: 1 1 1 1 1 1 -## 2: 1 2 1 1 1 1 -## 3: 1 3 1 1 1 1 -## 4: 1 4 1 1 1 1 -## 5: 1 5 1 1 1 1 -}\if{html}{\out{
}} -\item using predictions for a new set of data - -\if{html}{\out{
}}\preformatted{pd_test <- orsf_pd_new(fit, - new_data = pbc_orsf_test, - pred_spec = list(bili = 1:5)) - -pd_test -}\if{html}{\out{
}} - -\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr -## 1: 1 1 1 1 1 1 -## 2: 1 2 1 1 1 1 -## 3: 1 3 1 1 1 1 -## 4: 1 4 1 1 1 1 -## 5: 1 5 1 1 1 1 -}\if{html}{\out{
}} -\item in-bag partial dependence indicates relationships that the model has -learned during training. This is helpful if your goal is to interpret -the model. -\item out-of-bag partial dependence indicates relationships that the model -has learned during training but using the out-of-bag data simulates -application of the model to new data. if you want to test your model’s -reliability or fairness in new data but you don’t have access to a -large testing set. -\item new data partial dependence shows how the model predicts outcomes for -observations it has not seen. This is helpful if you want to test your -model’s reliability or fairness. -} -} -} - \references{ Giles Hooker, Lucas Mentch, Siyu Zhou. Unrestricted Permutation forces Extrapolation: Variable Importance Requires at least One More Model, or There Is No Free Variable Importance. \emph{arXiv e-prints} 2021 Oct; arXiv-1905. URL: https://doi.org/10.48550/arXiv.1905.03151 } diff --git a/src/TreeClassification.cpp b/src/TreeClassification.cpp index 3fe5644f..351bcdf9 100644 --- a/src/TreeClassification.cpp +++ b/src/TreeClassification.cpp @@ -186,7 +186,8 @@ double safer_mtry = mtry; - if(lincomb_type == LC_GLM){ + if(lincomb_type == LC_GLM || + lincomb_type == LC_GLMNET){ // conditions to split a column: // >= 3 events per predictor @@ -234,15 +235,23 @@ } } - for (auto& i : splittable_y_cols){ + // glmnet can handle higher dimension x, + // but regular glm cannot. + if(lincomb_type == LC_GLM){ + + for (auto& i : splittable_y_cols){ + + while (y_sum_cases[i] / safer_mtry < 3 || + y_sum_ctrls[i] / safer_mtry < 3){ + --safer_mtry; + } - while (y_sum_cases[i] / safer_mtry < 3 || - y_sum_ctrls[i] / safer_mtry < 3){ - --safer_mtry; } } + + } uword out = safer_mtry; diff --git a/tests/testthat.R b/tests/testthat.R index 362d118b..8d22ab86 100644 --- a/tests/testthat.R +++ b/tests/testthat.R @@ -1,12 +1,4 @@ library(testthat) library(aorsf) -## force tests to be executed if in dev release which we define as -## having a sub-release, eg 0.9.15.5 is one whereas 0.9.16 is not -if (length(strsplit(packageDescription("aorsf")$Version, "\\.")[[1]]) > 3) { - Sys.setenv("run_all_aorsf_tests" = "yes") -} else { - -} - test_check("aorsf") diff --git a/tests/testthat/helper-orsf.R b/tests/testthat/helper-orsf.R index b4c921c8..c3515781 100644 --- a/tests/testthat/helper-orsf.R +++ b/tests/testthat/helper-orsf.R @@ -272,11 +272,17 @@ prep_test_matrices <- function(data, outcomes = c("time", "status")){ cc <- stats::complete.cases(data[, names_x_data]) data <- data[cc, ] - y <- prep_y_surv(data, names_y_data) + if(length(outcomes) > 1){ + y <- prep_y_surv(data, names_y_data) + sorted <- collapse::radixorder(y[, 1], -y[, 2]) + } else { + y <- prep_y_clsf(data, names_y_data) + sorted <- collapse::seq_row(data) + } + x <- prep_x(data, fi, names_x_data, means, standard_deviations) w <- sample(1:3, nrow(y), replace = TRUE) - sorted <- collapse::radixorder(y[, 1], -y[, 2]) return( list( diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 1807403f..5f5b1ee6 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -57,6 +57,32 @@ data_list_pbc <- list(pbc_standard = pbc, pbc_scaled = pbc_scale, pbc_noised = pbc_noise) +# penguins ---- + +penguins <- penguins_orsf + +penguins_scale <- penguins_noise <- penguins + + +vars <- c("bill_length_mm", + "bill_depth_mm", + "flipper_length_mm", + "body_mass_g") + +for(i in vars){ + penguins_noise[[i]] <- add_noise(penguins_noise[[i]]) + penguins_scale[[i]] <- change_scale(penguins_scale[[i]]) +} + +# make sorted x and y matrices for testing internal cpp functions +penguins_mats <- prep_test_matrices(penguins, outcomes = c("species")) + +# data lists ---- + +data_list_penguins <- list(penguins_standard = penguins, + penguins_scaled = penguins_scale, + penguins_noised = penguins_noise) + # matric lists ---- mat_list_surv <- list(pbc = pbc_mats, @@ -72,9 +98,8 @@ n_tree_test <- 5 controls <- list( fast = orsf_control_fast(), - cph = orsf_control_cph(), - net = orsf_control_net(), - custom = orsf_control_custom(beta_fun = f_pca) + net = orsf_control_survival(method = 'net'), + custom = orsf_control_survival(method = f_pca) ) fit_standard_pbc <- lapply( diff --git a/tests/testthat/test-lincomb_logreg.R b/tests/testthat/test-lincomb_logreg.R index b3701396..112a6e65 100644 --- a/tests/testthat/test-lincomb_logreg.R +++ b/tests/testthat/test-lincomb_logreg.R @@ -73,3 +73,14 @@ test_that( } ) + +# benchmark + +# microbenchmark::microbenchmark( +# cpp = logreg_fit_exported(X, Y, W, do_scale = FALSE, +# epsilon = control$epsilon, +# iter_max = control$maxit), +# fglm = fastglm::fastglmPure(X, Y, family = binomial(), weights = W), +# nmr = RcppNumerical::fastLR(cbind(1, X), Y) +# ) + diff --git a/tests/testthat/test-orsf.R b/tests/testthat/test-orsf.R index ea4b6361..bd79a7aa 100644 --- a/tests/testthat/test-orsf.R +++ b/tests/testthat/test-orsf.R @@ -29,21 +29,24 @@ test_that( desc = 'target_df too high is caught', code = { - cntrl <- orsf_control_net(df_target = 10) + cntrl <- orsf_control_survival(method = 'net', target_df = 10) expect_error(orsf(pbc_orsf, formula = f, control = cntrl), 'should be <=') } ) test_that( - desc = 'orsf runs with data.table and with net control', + desc = 'orsf runs the same with data.table vs. data.frame', code = { - expect_s3_class(orsf(as.data.table(pbc_orsf), f, n_tree = 1), 'ObliqueForest') + fit_dt <- orsf(as.data.table(pbc), + formula = time + status ~ ., + n_tree = n_tree_test, + control = controls$fast, + tree_seed = seeds_standard) + + expect_equal_leaf_summary(fit_dt, fit_standard_pbc$fast) - expect_s3_class(orsf(as.data.table(pbc_orsf), f, - control = orsf_control_net(), - n_tree = 1), 'ObliqueForest') } ) @@ -52,7 +55,7 @@ test_that( desc = "blank and non-standard names trigger an error", code = { - pbc_temp <- pbc_orsf + pbc_temp <- pbc pbc_temp$x1 <- rnorm(nrow(pbc_temp)) pbc_temp$x2 <- rnorm(nrow(pbc_temp)) @@ -64,7 +67,7 @@ test_that( ) - pbc_temp <- pbc_orsf + pbc_temp <- pbc pbc_temp$x1 <- rnorm(nrow(pbc_temp)) pbc_temp$x2 <- rnorm(nrow(pbc_temp)) @@ -79,44 +82,6 @@ test_that( ) -# just run locally. units seems to have memory leaks. -# test_that( -# 'orsf tracks meta data for units class variables', -# code = { -# -# # units may have memory leaks -# skip_on_cran() -# -# suppressMessages(library(units)) -# pbc_units <- pbc_orsf -# -# -# units(pbc_units$time) <- 'days' -# units(pbc_units$age) <- 'years' -# units(pbc_units$bili) <- 'mg/dl' -# -# fit_units <- orsf(pbc_units, Surv(time, status) ~ . - id, n_tree=1) -# -# expect_equal( -# fit_units$get_var_unit('time'), -# list( numerator = "d", denominator = character(0), label = "d") -# ) -# -# expect_equal( -# fit_units$get_var_unit('age'), -# list(numerator = "years", denominator = character(0), label = "years") -# ) -# -# expect_equal( -# fit_units$get_var_unit('bili'), -# list(numerator = "mg", denominator = "dl", label = "mg/dl") -# ) -# -# } -# -# ) - - test_that( desc = "algorithm grows more accurate with higher number of iterations", code = { @@ -491,9 +456,9 @@ test_that( '3' = c(1000, 2000, 3000)) control <- switch(inputs$orsf_control[i], - 'cph' = orsf_control_cph(), - 'net' = orsf_control_net(), - 'custom' = orsf_control_custom(beta_fun = f_pca)) + 'cph' = orsf_control_survival(method = 'glm'), + 'net' = orsf_control_survival(method = 'net'), + 'custom' = orsf_control_survival(method = f_pca)) if(inputs$sample_with_replacement[i]){ sample_fraction <- 0.632 diff --git a/tests/testthat/test-orsf_control.R b/tests/testthat/test-orsf_control.R index c7e3dcd0..a0c931e6 100644 --- a/tests/testthat/test-orsf_control.R +++ b/tests/testthat/test-orsf_control.R @@ -10,9 +10,9 @@ test_that( #' @srrstats {G5.2b} *Tests demonstrate conditions which trigger error messages.* - expect_error(orsf_control_cph(method = 'oh no'), "breslow or efron") + expect_error(orsf_control_survival(ties = 'oh no'), "breslow or efron") - expect_error(orsf_control_net(alpha = 32), 'should be <= 1') + expect_error(orsf_control_survival(net_mix = 32), 'should be <= 1') f_bad_1 <- function(a_node, y_node, w_node){ 1 } f_bad_2 <- function(x_node, a_node, w_node){ 1 } @@ -33,18 +33,18 @@ test_that( f_bad_8 <- function(x_node, y_node, w_node) {runif(n = ncol(x_node))} - expect_error(orsf_control_custom(f_bad_1), 'x_node') - expect_error(orsf_control_custom(f_bad_2), 'y_node') - expect_error(orsf_control_custom(f_bad_3), 'w_node') - expect_error(orsf_control_custom(f_bad_4), 'should have 3') - expect_error(orsf_control_custom(f_bad_5), 'encountered an error') - expect_error(orsf_control_custom(f_bad_6), 'with 1 column') - expect_error(orsf_control_custom(f_bad_7), 'with 1 row for each') - expect_error(orsf_control_custom(f_bad_8), 'matrix output') + expect_error(orsf_control_survival(method = f_bad_1), 'x_node') + expect_error(orsf_control_survival(method = f_bad_2), 'y_node') + expect_error(orsf_control_survival(method = f_bad_3), 'w_node') + expect_error(orsf_control_survival(method = f_bad_4), 'should have 3') + expect_error(orsf_control_survival(method = f_bad_5), 'encountered an error') + expect_error(orsf_control_survival(method = f_bad_6), 'with 1 column') + expect_error(orsf_control_survival(method = f_bad_7), 'with 1 row for each') + expect_error(orsf_control_survival(method = f_bad_8), 'matrix output') f_rando <- function(x_node, y_node, w_node) { matrix(runif(ncol(x_node)), ncol=1) } - expect_s3_class(orsf_control_custom(f_rando), 'orsf_control') + expect_s3_class(orsf_control_survival(method = f_rando), 'orsf_control') } @@ -58,7 +58,7 @@ test_that( fit_pca = orsf(pbc, Surv(time, status) ~ ., tree_seeds = seeds_standard, - control = orsf_control_custom(beta_fun = f_pca), + control = orsf_control_survival(method = f_pca), n_tree = n_tree_test) expect_gt(fit_pca$eval_oobag$stat_values, .65) diff --git a/tests/testthat/test-unit_info.R b/tests/testthat/test-unit_info.R index 51429bf3..ba9a0f72 100644 --- a/tests/testthat/test-unit_info.R +++ b/tests/testthat/test-unit_info.R @@ -4,6 +4,43 @@ # on CRAN because for some reason when I load and use the # units package it makes valgrind detect possible memory leaks. +test_that( + 'orsf tracks meta data for units class variables', + code = { + + # units may have memory leaks + skip_on_cran() + + suppressMessages(library(units)) + + pbc_units <- pbc_orsf + + units(pbc_units$time) <- 'days' + units(pbc_units$age) <- 'years' + units(pbc_units$bili) <- 'mg/dl' + + fit_units <- orsf(pbc_units, Surv(time, status) ~ . - id, n_tree=1) + + expect_equal( + fit_units$get_var_unit('time'), + list( numerator = "d", denominator = character(0), label = "d") + ) + + expect_equal( + fit_units$get_var_unit('age'), + list(numerator = "years", denominator = character(0), label = "years") + ) + + expect_equal( + fit_units$get_var_unit('bili'), + list(numerator = "mg", denominator = "dl", label = "mg/dl") + ) + + } + +) + + test_that("output has expected items", { skip_on_cran() diff --git a/vignettes/fast.Rmd b/vignettes/fast.Rmd index ff65f4d3..8684d18e 100644 --- a/vignettes/fast.Rmd +++ b/vignettes/fast.Rmd @@ -39,7 +39,7 @@ time_fast <- system.time( time_net <- system.time( expr = orsf(pbc_orsf, formula = time+status~. -id, - control = orsf_control_net(), + control = orsf_control_survival(method = 'net'), n_tree = 5) )