From 6ec32d87b76845bc3915c680deb625c5835ca766 Mon Sep 17 00:00:00 2001 From: bcjaeger Date: Wed, 1 May 2024 07:49:13 -0400 Subject: [PATCH 1/4] allow selection of class to summarize --- R/orsf_R6.R | 30 ++++++++++++++++++++++++++++++ R/orsf_summary.R | 9 +++++++++ 2 files changed, 39 insertions(+) diff --git a/R/orsf_R6.R b/R/orsf_R6.R index 864293e9..8c832576 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -732,11 +732,13 @@ ObliqueForest <- R6::R6Class( pred_horizon = NULL, pred_type = NULL, importance_type = NULL, + class = NULL, verbose_progress = FALSE){ # check incoming values if they were specified. private$check_n_variables(n_variables) private$check_verbose_progress(verbose_progress) + private$check_class(class) if(!is.null(pred_horizon)){ private$check_pred_horizon(pred_horizon, boundary_checks = TRUE) @@ -838,6 +840,13 @@ ObliqueForest <- R6::R6Class( if(self$tree_type == 'classification'){ new_order <- insert_vals(new_order, 2, 'class') + if(!is.null(class)){ + .class <- class # prevents mix-up with class in dt + pd_output <- pd_output[class == .class] + } else { + # put the highest level class on top + pd_output <- pd_output[order(-class)] + } } setcolorder(pd_output, new_order) @@ -2153,6 +2162,27 @@ ObliqueForest <- R6::R6Class( }, + check_class = function(class = NULL){ + + if(!is.null(class)){ + + check_arg_is(arg_value = class, + arg_name = "class", + expected_class = "character") + + check_arg_length(arg_value = class, + arg_name = "class", + expected_length = 1L) + + check_arg_is_valid(arg_value = class, + arg_name = "class", + valid_options = self$class_levels) + + } + + + }, + # runs checks and sets defaults where needed. # data is NULL when we are creating a new forest, # but may be non-NULL if we update an existing one diff --git a/R/orsf_summary.R b/R/orsf_summary.R index a2024ffb..5680d94b 100644 --- a/R/orsf_summary.R +++ b/R/orsf_summary.R @@ -53,12 +53,20 @@ #' #' orsf_summarize_uni(object, n_variables = 2, importance = 'negate') #' +#' # for multi-category fits, you can specify which class +#' # you want to summarize: +#' +#' fit = orsf(species ~ ., data = penguins_orsf, n_tree = 25) +#' orsf_summarize_uni(fit, class = "Adelie", n_variables = 1) +#' orsf_summarize_uni(fit, class = "Gentoo", n_variables = 1) +#' #' orsf_summarize_uni <- function(object, n_variables = NULL, pred_horizon = NULL, pred_type = NULL, importance = NULL, + class = NULL, verbose_progress = FALSE, ...){ @@ -72,6 +80,7 @@ orsf_summarize_uni <- function(object, pred_horizon = pred_horizon, pred_type = pred_type, importance_type = importance, + class = class, verbose_progress = verbose_progress) } From 9b15153bb2f160812ed64a1d23a64d354ad07726 Mon Sep 17 00:00:00 2001 From: bcjaeger Date: Wed, 1 May 2024 08:11:19 -0400 Subject: [PATCH 2/4] tests and examples --- R/orsf_summary.R | 6 +++--- tests/testthat/test-orsf_summary.R | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/R/orsf_summary.R b/R/orsf_summary.R index 5680d94b..05afc526 100644 --- a/R/orsf_summary.R +++ b/R/orsf_summary.R @@ -56,9 +56,9 @@ #' # for multi-category fits, you can specify which class #' # you want to summarize: #' -#' fit = orsf(species ~ ., data = penguins_orsf, n_tree = 25) -#' orsf_summarize_uni(fit, class = "Adelie", n_variables = 1) -#' orsf_summarize_uni(fit, class = "Gentoo", n_variables = 1) +#' object = orsf(species ~ ., data = penguins_orsf, n_tree = 25) +#' +#' orsf_summarize_uni(object, class = "Adelie", n_variables = 1) #' #' orsf_summarize_uni <- function(object, diff --git a/tests/testthat/test-orsf_summary.R b/tests/testthat/test-orsf_summary.R index 5b423daa..f1429a83 100644 --- a/tests/testthat/test-orsf_summary.R +++ b/tests/testthat/test-orsf_summary.R @@ -74,3 +74,19 @@ test_that( } ) +test_that( + desc = "single class can be specified", + code = { + + object <- fit_standard_penguin_species$fast + + smry_all <- orsf_summarize_uni(object, n_variables = 1) + + expect_true(all(object$class_levels %in% smry_all$dt$class)) + + smry_adelie <- orsf_summarize_uni(object, class = "Adelie", n_variables = 1) + + expect_true(all(smry_adelie$dt$class == "Adelie")) + + } +) From 99898bbf6a21ae5eeca76a8157725f27fc0001a8 Mon Sep 17 00:00:00 2001 From: bcjaeger Date: Wed, 1 May 2024 08:16:35 -0400 Subject: [PATCH 3/4] update docs for new class param --- R/orsf_summary.R | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/R/orsf_summary.R b/R/orsf_summary.R index 05afc526..292d25b2 100644 --- a/R/orsf_summary.R +++ b/R/orsf_summary.R @@ -14,6 +14,14 @@ #' - `r roxy_importance_negate()` #' - `r roxy_importance_permute()` #' +#' @param class (_character_) only relevant for classification forests. +#' If `NULL` (the default), summary statistics are returned for all +#' classes in the outcome, and printed summaries will show the last +#' class in the class levels. To specify a single class to summarize, +#' indicate the name of the class with `class`. E.g., if the categorical +#' outcome has class levels A, B, and C, then using `class = "A"` will +#' restrict output to class A. +#' #' For details on these methods, see [orsf_vi]. #' #' @return an object of class 'orsf_summary', which includes data on From 8c00d729ceea5516f44cb398e574e27f1465b651 Mon Sep 17 00:00:00 2001 From: bcjaeger Date: Wed, 1 May 2024 10:20:56 -0400 Subject: [PATCH 4/4] doc update --- man/orsf.Rd | 48 ++++++++++++++++++++++++--------------- man/orsf_summarize_uni.Rd | 18 ++++++++++++++- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/man/orsf.Rd b/man/orsf.Rd index 5466277d..dc217bce 100644 --- a/man/orsf.Rd +++ b/man/orsf.Rd @@ -366,6 +366,18 @@ data that were not used to train it, i.e., testing data. library(magrittr) # for \%>\% }\if{html}{\out{}} +\if{html}{\out{
}}\preformatted{## +## Attaching package: 'magrittr' + +## The following object is masked from 'package:tidyr': +## +## extract + +## The following objects are masked from 'package:testthat': +## +## equals, is_less_than, not +}\if{html}{\out{
}} + \code{orsf()} is the entry-point of the \code{aorsf} package. It can be used to fit classification, regression, and survival forests. @@ -388,9 +400,9 @@ penguin_fit ## N trees: 5 ## N predictors total: 7 ## N predictors per node: 3 -## Average leaves per tree: 6.4 +## Average leaves per tree: 4.6 ## Min observations in leaf: 5 -## OOB stat value: 0.98 +## OOB stat value: 0.99 ## OOB stat type: AUC-ROC ## Variable importance: anova ## @@ -415,9 +427,9 @@ bill_fit ## N trees: 5 ## N predictors total: 7 ## N predictors per node: 3 -## Average leaves per tree: 49.2 +## Average leaves per tree: 51 ## Min observations in leaf: 5 -## OOB stat value: 0.75 +## OOB stat value: 0.70 ## OOB stat type: RSQ ## Variable importance: anova ## @@ -447,10 +459,10 @@ pbc_fit ## N trees: 5 ## N predictors total: 17 ## N predictors per node: 5 -## Average leaves per tree: 19.4 +## Average leaves per tree: 22.2 ## Min observations in leaf: 5 ## Min events in leaf: 1 -## OOB stat value: 0.77 +## OOB stat value: 0.78 ## OOB stat type: Harrell's C-index ## Variable importance: anova ## @@ -497,7 +509,7 @@ take to fit the forest before you commit to it: orsf_time_to_train() }\if{html}{\out{}} -\if{html}{\out{
}}\preformatted{## Time difference of 2.32199 secs +\if{html}{\out{
}}\preformatted{## Time difference of 2.429678 secs }\if{html}{\out{
}} \enumerate{ \item If fitting multiple forests, use the blueprint along with @@ -568,12 +580,12 @@ brier_scores \if{html}{\out{
}}\preformatted{## # A tibble: 6 x 4 ## .metric .estimator .eval_time .estimate ## -## 1 brier_survival standard 500 0.0661 -## 2 brier_survival standard 1000 0.0999 -## 3 brier_survival standard 1500 0.110 -## 4 brier_survival standard 2000 0.0789 -## 5 brier_survival standard 2500 0.127 -## 6 brier_survival standard 3000 0.194 +## 1 brier_survival standard 500 0.0597 +## 2 brier_survival standard 1000 0.0943 +## 3 brier_survival standard 1500 0.0883 +## 4 brier_survival standard 2000 0.102 +## 5 brier_survival standard 2500 0.137 +## 6 brier_survival standard 3000 0.153 }\if{html}{\out{
}} \if{html}{\out{
}}\preformatted{roc_scores <- test_pred \%>\% @@ -585,11 +597,11 @@ roc_scores \if{html}{\out{
}}\preformatted{## # A tibble: 6 x 4 ## .metric .estimator .eval_time .estimate ## -## 1 roc_auc_survival standard 500 0.941 -## 2 roc_auc_survival standard 1000 0.920 -## 3 roc_auc_survival standard 1500 0.925 -## 4 roc_auc_survival standard 2000 0.967 -## 5 roc_auc_survival standard 2500 0.937 +## 1 roc_auc_survival standard 500 0.957 +## 2 roc_auc_survival standard 1000 0.912 +## 3 roc_auc_survival standard 1500 0.935 +## 4 roc_auc_survival standard 2000 0.931 +## 5 roc_auc_survival standard 2500 0.907 ## 6 roc_auc_survival standard 3000 0.889 }\if{html}{\out{
}} } diff --git a/man/orsf_summarize_uni.Rd b/man/orsf_summarize_uni.Rd index 6a2746a6..aa071e22 100644 --- a/man/orsf_summarize_uni.Rd +++ b/man/orsf_summarize_uni.Rd @@ -10,6 +10,7 @@ orsf_summarize_uni( pred_horizon = NULL, pred_type = NULL, importance = NULL, + class = NULL, verbose_progress = FALSE, ... ) @@ -53,7 +54,15 @@ For regression: \item 'anova': compute analysis of variance (ANOVA) importance \item 'negate': compute negation importance \item 'permute': compute permutation importance -} +}} + +\item{class}{(\emph{character}) only relevant for classification forests. +If \code{NULL} (the default), summary statistics are returned for all +classes in the outcome, and printed summaries will show the last +class in the class levels. To specify a single class to summarize, +indicate the name of the class with \code{class}. E.g., if the categorical +outcome has class levels A, B, and C, then using \code{class = "A"} will +restrict output to class A. For details on these methods, see \link{orsf_vi}.} @@ -100,6 +109,13 @@ orsf_summarize_uni(object, n_variables = 2) orsf_summarize_uni(object, n_variables = 2, importance = 'negate') +# for multi-category fits, you can specify which class +# you want to summarize: + +object = orsf(species ~ ., data = penguins_orsf, n_tree = 25) + +orsf_summarize_uni(object, class = "Adelie", n_variables = 1) + } \seealso{