diff --git a/R/orsf.R b/R/orsf.R index 30dfa0ce..0d5dcbab 100644 --- a/R/orsf.R +++ b/R/orsf.R @@ -304,7 +304,7 @@ orsf <- function(data, formula, - control = orsf_control_fast(), + control = NULL, weights = NULL, n_tree = 500, n_split = 5, @@ -380,7 +380,12 @@ orsf <- function(data, if(length(formula) < 3) stop("formula must be two sided", call. = FALSE) - tree_type <- infer_tree_type(all.vars(formula[[2]]), args$data) + if(is.null(control)){ + tree_type <- infer_tree_type(all.vars(formula[[2]]), args$data) + } else { + tree_type <- control$tree_type + } + object <- switch( tree_type, diff --git a/R/orsf_R6.R b/R/orsf_R6.R index 761c7012..c0d472bf 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -186,7 +186,10 @@ ObliqueForest <- R6::R6Class( paste0(' N predictors per node: ', self$mtry ), paste0(' Average leaves per tree: ', private$mean_leaves ), paste0('Min observations in leaf: ', self$leaf_min_obs ), - paste0(' Min events in leaf: ', self$leaf_min_events ), + + if(self$tree_type == 'survival') + paste0(' Min events in leaf: ', self$leaf_min_events ), + paste0(' OOB stat value: ', oobag_stat ), paste0(' OOB stat type: ', oobag_type ), paste0(' Variable importance: ', self$importance_type ), @@ -987,11 +990,16 @@ ObliqueForest <- R6::R6Class( private$check_data() private$check_formula() + if(is.null(self$control)){ + private$init_control() + } else { + private$check_control() + } + private$init_data() private$init_mtry() private$init_weights() - private$check_control() private$check_n_tree() private$check_n_split() private$check_n_retry() @@ -2703,6 +2711,14 @@ ObliqueForestSurvival <- R6::R6Class( }, + init_control = function(){ + + self$control <- orsf_control_survival(method = 'glm', + scale_x = FALSE, + max_iter = 1) + + }, + init_internal = function(){ self$tree_type <- "survival" @@ -2994,6 +3010,14 @@ ObliqueForestClassification <- R6::R6Class( }, + init_control = function(){ + + self$control <- orsf_control_classification(method = 'glm', + scale_x = FALSE, + max_iter = 1) + + }, + init_internal = function(){ self$tree_type <- "classification" diff --git a/R/orsf_control.R b/R/orsf_control.R index b85e5e72..95f3878e 100644 --- a/R/orsf_control.R +++ b/R/orsf_control.R @@ -46,6 +46,12 @@ orsf_control_fast <- function(method = 'efron', do_scale = TRUE, ...){ + lifecycle::deprecate_warn( + when = "0.1.2", + "orsf_control_fast()", + details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = 'fast')`" + ) + check_dots(list(...), orsf_control_fast) method <- tolower(method) @@ -122,8 +128,8 @@ orsf_control_cph <- function(method = 'efron', lifecycle::deprecate_warn( when = "0.1.2", - "orsf_control_custom()", - details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = your_function)`" + "orsf_control_cph()", + details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = 'glm')`" ) method <- tolower(method) @@ -197,8 +203,8 @@ orsf_control_net <- function(alpha = 1/2, lifecycle::deprecate_warn( when = "0.1.2", - "orsf_control_custom()", - details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = your_function)`" + "orsf_control_net()", + details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = 'net')`" ) check_dots(list(...), orsf_control_net) @@ -249,7 +255,7 @@ orsf_control_custom <- function(beta_fun, ...){ lifecycle::deprecate_warn( when = "0.1.2", "orsf_control_custom()", - details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = your_function)`" + details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = your_function)`, noting that your_function is a function object and not a character value" ) check_dots(list(...), .f = orsf_control_custom) @@ -374,7 +380,7 @@ orsf_control <- function(tree_type, check_arg_is_valid(arg_value = method, arg_name = 'method', - valid_options = c("glm", "net")) + valid_options = c("glm", "net", "fast")) check_arg_length(arg_value = method, arg_name = 'method', diff --git a/Rmd/orsf_control_cph_content_1.Rmd b/Rmd/orsf_control_cph_content_1.Rmd deleted file mode 100644 index 73573f09..00000000 --- a/Rmd/orsf_control_cph_content_1.Rmd +++ /dev/null @@ -1,32 +0,0 @@ -Two customized functions to identify linear combinations of predictors are shown here. - -- The first uses random coefficients -- The second derives coefficients from principal component analysis. - -## Random coefficients - -`f_rando()` is our function to get the random coefficients: - -```{r} - -f_rando <- function(x_node, y_node, w_node){ - matrix(runif(ncol(x_node)), ncol=1) -} - -``` - -We can plug `f_rando` into `orsf_control_custom()`, and then pass the result into `orsf()`: - -```{r} - -library(aorsf) - -fit_rando <- orsf(pbc_orsf, - Surv(time, status) ~ . - id, - control = orsf_control_custom(beta_fun = f_rando), - n_tree = 500) - -fit_rando - -``` - diff --git a/Rmd/orsf_control_cph_examples.Rmd b/Rmd/orsf_control_cph_examples.Rmd deleted file mode 100644 index a595087c..00000000 --- a/Rmd/orsf_control_cph_examples.Rmd +++ /dev/null @@ -1,16 +0,0 @@ - -# Examples - - -```{r, include=FALSE, echo=FALSE} - -docs <- c( - "orsf_control_cph_content_1.Rmd" -) - -``` - -```{r child = docs} - -``` - diff --git a/Rmd/orsf_control_custom_content_1.Rmd b/Rmd/orsf_control_custom_content_1.Rmd index 73573f09..1b65906e 100644 --- a/Rmd/orsf_control_custom_content_1.Rmd +++ b/Rmd/orsf_control_custom_content_1.Rmd @@ -15,7 +15,7 @@ f_rando <- function(x_node, y_node, w_node){ ``` -We can plug `f_rando` into `orsf_control_custom()`, and then pass the result into `orsf()`: +We can plug `f_rando` into `orsf_control_survival()`, and then pass the result into `orsf()`: ```{r} @@ -23,7 +23,7 @@ library(aorsf) fit_rando <- orsf(pbc_orsf, Surv(time, status) ~ . - id, - control = orsf_control_custom(beta_fun = f_rando), + control = orsf_control_survival(method = f_rando), n_tree = 500) fit_rando diff --git a/Rmd/orsf_control_custom_content_2.Rmd b/Rmd/orsf_control_custom_content_2.Rmd index 5850c5ca..2e0ab955 100644 --- a/Rmd/orsf_control_custom_content_2.Rmd +++ b/Rmd/orsf_control_custom_content_2.Rmd @@ -15,13 +15,13 @@ f_pca <- function(x_node, y_node, w_node) { ``` -Then plug the function into `orsf_control_custom()` and pass the result into `orsf()`: +Then plug the function into `orsf_control_survival()` and pass the result into `orsf()`: ```{r} fit_pca <- orsf(pbc_orsf, Surv(time, status) ~ . - id, - control = orsf_control_custom(beta_fun = f_pca), + control = orsf_control_survival(method = f_pca), n_tree = 500) ``` diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 5f5b1ee6..a250deb6 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -96,14 +96,16 @@ mat_list_surv <- list(pbc = pbc_mats, seeds_standard <- 329 n_tree_test <- 5 -controls <- list( - fast = orsf_control_fast(), +controls_surv <- list( + fast = orsf_control_survival(method = 'glm', + scale_x = FALSE, + max_iter = 1), net = orsf_control_survival(method = 'net'), custom = orsf_control_survival(method = f_pca) ) fit_standard_pbc <- lapply( - controls, + controls_surv, function(cntrl){ orsf(pbc, formula = time + status ~ ., diff --git a/tests/testthat/test-orsf.R b/tests/testthat/test-orsf.R index bd79a7aa..98e04cb1 100644 --- a/tests/testthat/test-orsf.R +++ b/tests/testthat/test-orsf.R @@ -42,7 +42,7 @@ test_that( fit_dt <- orsf(as.data.table(pbc), formula = time + status ~ ., n_tree = n_tree_test, - control = controls$fast, + control = controls_surv$fast, tree_seed = seeds_standard) expect_equal_leaf_summary(fit_dt, fit_standard_pbc$fast) diff --git a/tests/testthat/test-orsf_formula.R b/tests/testthat/test-orsf_formula.R index f966d3ba..9c4ddaea 100644 --- a/tests/testthat/test-orsf_formula.R +++ b/tests/testthat/test-orsf_formula.R @@ -59,7 +59,7 @@ test_that( fit_long <- orsf(pbc_orsf, formula = f_long, - control = controls[[i]], + control = controls_surv[[i]], n_tree = n_tree_test, tree_seeds = seeds_standard) @@ -83,13 +83,13 @@ test_that( pbc_surv_data <- cbind(pbc_orsf, surv_object = pbc_surv) - for(i in seq_along(controls)){ + for(i in seq_along(controls_surv)){ fit_surv <- orsf( pbc_surv_data, formula = surv_object ~ . - id - time - status, n_tree = n_tree_test, - control = controls[[i]], + control = controls_surv[[i]], tree_seed = seeds_standard ) @@ -118,7 +118,7 @@ test_that( # fit_status_modified <- orsf(pbc_orsf, # time + status ~ . - id, # n_tree = n_tree_test, -# control = controls[[j]], +# control = controls_surv[[j]], # tree_seeds = seeds_standard) # # expect_equal_leaf_summary(fit_status_modified, fit_standard_pbc[[j]]) diff --git a/tests/testthat/test-orsf_train.R b/tests/testthat/test-orsf_train.R index 81f56179..e0a588db 100644 --- a/tests/testthat/test-orsf_train.R +++ b/tests/testthat/test-orsf_train.R @@ -8,12 +8,12 @@ test_that( 'pbc_status_12', 'pbc_scaled')])){ - for(j in seq_along(controls)){ + for(j in seq_along(controls_surv)){ fit_untrained <- orsf(data_list_pbc[[i]], formula = time + status ~ . - id, n_tree = n_tree_test, - control = controls[[j]], + control = controls_surv[[j]], tree_seed = seeds_standard, no_fit = TRUE) diff --git a/vignettes/aorsf.Rmd b/vignettes/aorsf.Rmd index bfc444da..d831c6b1 100644 --- a/vignettes/aorsf.Rmd +++ b/vignettes/aorsf.Rmd @@ -21,39 +21,43 @@ knitr::opts_chunk$set( This article covers core features of the `aorsf` package. -## Background: ORSF +## Background -The oblique random survival forest (ORSF) is an extension of the axis-based RSF algorithm. +The oblique random forest (RF) is an extension of the axis-based RF. Instead of using a single variable to split data and grow new branches, trees in the oblique RF use a weighted combination of multiple variables. -- See [orsf](https://docs.ropensci.org/aorsf/reference/orsf.html) for more details on ORSFs. +## Oblique RFs for survival, classification, and regression -- see the [JCGS](https://doi.org/10.1080/10618600.2023.2231048) paper for more details on algorithms used specifically by `aorsf`. - -## Accelerated ORSF +The purpose of `aorsf` ('a' is short for accelerated) is to provide a unifying framework to fit oblique RFs that can scale adequately to large data sets. The fastest algorithms available in the package are used by default because they often have equivalent prediction accuracy to more computational approaches. -The purpose of `aorsf` ('a' is short for accelerated) is to provide routines to fit ORSFs that will scale adequately to large data sets. The fastest algorithm available in the package is the accelerated ORSF model, which is the default method used by `orsf()`: +Everything in `aorsf` begins with the `orsf()` function. Here we begin with an oblique RF for survival using the `pbc_orsf` data, an oblique RF for classification using the `penguins_orsf` data, and FILL IN FOR REGRESSION. Note that `n_tree` is 5 for convenience in these examples, but should be >= 500 in practice. ```{r} library(aorsf) -set.seed(329) - -orsf_fit <- orsf(data = pbc_orsf, +# An oblique survival RF +pbc_fit <- orsf(data = pbc_orsf, n_tree = 5, formula = Surv(time, status) ~ . - id) -orsf_fit +pbc_fit -``` +# An oblique classification RF +penguin_fit <- orsf(data = penguins_orsf, + n_tree = 5, + formula = species ~ .) + +penguin_fit +``` you may notice that the first input of `aorsf` is `data`. This is a design choice that makes it easier to use `orsf` with pipes (i.e., `%>%` or `|>`). For instance, ```{r, eval=FALSE} + library(dplyr) -orsf_fit <- pbc_orsf |> +pbc_fit <- pbc_orsf |> select(-id) |> orsf(formula = Surv(time, status) ~ ., n_tree = 5) @@ -72,7 +76,7 @@ orsf_fit <- pbc_orsf |> ```{r} - orsf_vi_negate(orsf_fit) + orsf_vi_negate(pbc_fit) ``` @@ -80,7 +84,7 @@ orsf_fit <- pbc_orsf |> ```{r} - orsf_vi_permute(orsf_fit) + orsf_vi_permute(pbc_fit) ``` @@ -88,7 +92,7 @@ orsf_fit <- pbc_orsf |> ```{r} - orsf_vi_anova(orsf_fit) + orsf_vi_anova(pbc_fit) ``` @@ -107,13 +111,13 @@ For more on ICE, see the [vignette](https://docs.ropensci.org/aorsf/articles/pd. ## What about the original ORSF? -The original ORSF (i.e., `obliqueRSF`) used `glmnet` to find linear combinations of inputs. `aorsf` allows users to implement this approach using the `orsf_control_net()` function: +The original ORSF (i.e., `obliqueRSF`) used `glmnet` to find linear combinations of inputs. `aorsf` allows users to implement this approach using the `orsf_control_survival(method = 'net')` function: ```{r, eval=FALSE} orsf_net <- orsf(data = pbc_orsf, formula = Surv(time, status) ~ . - id, - control = orsf_control_net()) + control = orsf_control_survival(method = 'net')) ``` @@ -125,3 +129,11 @@ orsf_net <- orsf(data = pbc_orsf, The unique feature of `aorsf` is its fast algorithms to fit ORSF ensembles. `RLT` and `obliqueRSF` both fit oblique random survival forests, but `aorsf` does so faster. `ranger` and `randomForestSRC` fit survival forests, but neither package supports oblique splitting. `obliqueRF` fits oblique random forests for classification and regression, but not survival. `PPforest` fits oblique random forests for classification but not survival. Note: The default prediction behavior for `aorsf` models is to produce predicted risk at a specific prediction horizon, which is not the default for `ranger` or `randomForestSRC`. I think this will change in the future, as computing time independent predictions with `aorsf` could be helpful. + +## Learning more + +`aorsf` began as a dedicated package for oblique random survival forests, and so most papers published so far have focused on survival analysis and risk prediction. However, the routines for regression and classification oblique RFs in `aorsf` have high overlap with the survival ones. + +- See [orsf](https://docs.ropensci.org/aorsf/reference/orsf.html) for more details on oblique random survival forests. + +- see the [JCGS](https://doi.org/10.1080/10618600.2023.2231048) paper for more details on algorithms used specifically by `aorsf`. diff --git a/vignettes/fast.Rmd b/vignettes/fast.Rmd index 8684d18e..47efb185 100644 --- a/vignettes/fast.Rmd +++ b/vignettes/fast.Rmd @@ -22,9 +22,9 @@ library(aorsf) Analyses can slow to a crawl when models need hours to run. In this article you will find a few tricks to prevent this bottleneck when using `orsf()`. -## Use `orsf_control_fast()` +## Don't specify a `control` -This is the default `control` value for `orsf()` and its run-time compared to other approaches can be striking. For example: +The default `control` for `orsf()` is `NULL` because, if unspecified, `orsf()` will pick the fastest possible `control` for you depending on the type of forest being grown. The default `control` run-time compared to other approaches can be striking. For example: ```{r} @@ -32,7 +32,6 @@ This is the default `control` value for `orsf()` and its run-time compared to ot time_fast <- system.time( expr = orsf(pbc_orsf, formula = time+status~. -id, - control = orsf_control_fast(), n_tree = 5) )