Skip to content

Commit

Permalink
Merge pull request #57 from ropensci/issue47
Browse files Browse the repository at this point in the history
Issue47
  • Loading branch information
bcjaeger authored May 2, 2024
2 parents 22c1360 + 8c00d72 commit c1fba42
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 19 deletions.
30 changes: 30 additions & 0 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -732,11 +732,13 @@ ObliqueForest <- R6::R6Class(
pred_horizon = NULL,
pred_type = NULL,
importance_type = NULL,
class = NULL,
verbose_progress = FALSE){

# check incoming values if they were specified.
private$check_n_variables(n_variables)
private$check_verbose_progress(verbose_progress)
private$check_class(class)

if(!is.null(pred_horizon)){
private$check_pred_horizon(pred_horizon, boundary_checks = TRUE)
Expand Down Expand Up @@ -838,6 +840,13 @@ ObliqueForest <- R6::R6Class(

if(self$tree_type == 'classification'){
new_order <- insert_vals(new_order, 2, 'class')
if(!is.null(class)){
.class <- class # prevents mix-up with class in dt
pd_output <- pd_output[class == .class]
} else {
# put the highest level class on top
pd_output <- pd_output[order(-class)]
}
}

setcolorder(pd_output, new_order)
Expand Down Expand Up @@ -2153,6 +2162,27 @@ ObliqueForest <- R6::R6Class(

},

check_class = function(class = NULL){

if(!is.null(class)){

check_arg_is(arg_value = class,
arg_name = "class",
expected_class = "character")

check_arg_length(arg_value = class,
arg_name = "class",
expected_length = 1L)

check_arg_is_valid(arg_value = class,
arg_name = "class",
valid_options = self$class_levels)

}


},

# runs checks and sets defaults where needed.
# data is NULL when we are creating a new forest,
# but may be non-NULL if we update an existing one
Expand Down
17 changes: 17 additions & 0 deletions R/orsf_summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
#' - `r roxy_importance_negate()`
#' - `r roxy_importance_permute()`
#'
#' @param class (_character_) only relevant for classification forests.
#' If `NULL` (the default), summary statistics are returned for all
#' classes in the outcome, and printed summaries will show the last
#' class in the class levels. To specify a single class to summarize,
#' indicate the name of the class with `class`. E.g., if the categorical
#' outcome has class levels A, B, and C, then using `class = "A"` will
#' restrict output to class A.
#'
#' For details on these methods, see [orsf_vi].
#'
#' @return an object of class 'orsf_summary', which includes data on
Expand Down Expand Up @@ -53,12 +61,20 @@
#'
#' orsf_summarize_uni(object, n_variables = 2, importance = 'negate')
#'
#' # for multi-category fits, you can specify which class
#' # you want to summarize:
#'
#' object = orsf(species ~ ., data = penguins_orsf, n_tree = 25)
#'
#' orsf_summarize_uni(object, class = "Adelie", n_variables = 1)
#'
#'
orsf_summarize_uni <- function(object,
n_variables = NULL,
pred_horizon = NULL,
pred_type = NULL,
importance = NULL,
class = NULL,
verbose_progress = FALSE,
...){

Expand All @@ -72,6 +88,7 @@ orsf_summarize_uni <- function(object,
pred_horizon = pred_horizon,
pred_type = pred_type,
importance_type = importance,
class = class,
verbose_progress = verbose_progress)

}
Expand Down
48 changes: 30 additions & 18 deletions man/orsf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 17 additions & 1 deletion man/orsf_summarize_uni.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/test-orsf_summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,19 @@ test_that(
}
)

test_that(
desc = "single class can be specified",
code = {

object <- fit_standard_penguin_species$fast

smry_all <- orsf_summarize_uni(object, n_variables = 1)

expect_true(all(object$class_levels %in% smry_all$dt$class))

smry_adelie <- orsf_summarize_uni(object, class = "Adelie", n_variables = 1)

expect_true(all(smry_adelie$dt$class == "Adelie"))

}
)

0 comments on commit c1fba42

Please sign in to comment.