From c600614b68a9c42445137d33b9215dd33a68ceb3 Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Sun, 19 Nov 2023 15:55:31 -0500 Subject: [PATCH] clean up init, drop mtcars --- R/orsf.R | 16 +- R/orsf_R6.R | 395 +++++++++++++++++--------- R/orsf_control.R | 4 +- tests/testthat/test-impute_meanmode.R | 1 - tests/testthat/test-orsf.R | 63 ++-- 5 files changed, 293 insertions(+), 186 deletions(-) diff --git a/R/orsf.R b/R/orsf.R index 79dddafa..cc45cd68 100644 --- a/R/orsf.R +++ b/R/orsf.R @@ -36,16 +36,10 @@ #' @param weights (*numeric vector*) Optional. If given, this input should #' have length equal to `nrow(data)` for complete or imputed data and should #' have length equal to `nrow(na.omit(data))` if `na_action` is `"omit"`. -#' Values in `weights` are treated like replication weights, i.e., a value -#' of 2 is the same thing as having 2 observations in `data`, each -#' containing a copy of the corresponding person's data. -#' -#' *Use* `weights` *cautiously*, as `orsf` will count the number of -#' observations and events prior to growing a node for a tree, so higher -#' values in `weights` will lead to deeper trees. If you use this -#' input, it is highly recommended you scale the weights so that -#' `sum(weights) == nrow(data)`, as this will help make tree depth -#' consistent with the default `weights = rep(1, nrow(data))` +#' As the weights vector is used to count observations and events prior to +#' growing a node for a tree, `orsf()` scales `weights` so that +#' `sum(weights) == nrow(data)`. This helps to make tree depth consistent +#' between weighted and un-weighted fits. #' #' @param n_tree (*integer*) the number of trees to grow. #' Default is `n_tree = 500.` @@ -111,7 +105,7 @@ #' retries). Defaults are #' #' - 3.84 if `split_rule = 'logrank'` -#' - 0.50 if `split_rule = 'cstat'` (see first note below) +#' - 0.55 if `split_rule = 'cstat'` (see first note below) #' - 0.00 if `split_rule = 'gini'` (see second note below) #' - 0.00 if `split_rule = 'variance'` #' diff --git a/R/orsf_R6.R b/R/orsf_R6.R index 9d3b3ab6..c4ff1ac7 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -105,14 +105,16 @@ ObliqueForest <- R6::R6Class( # always be checked if a user wants to use update(). private$user_specified <- list( - control = !is.null(control), - weights = !is.null(weights), - mtry = !is.null(mtry), - split_rule = !is.null(split_rule), - split_min_stat = !is.null(split_min_stat), - pred_type = !is.null(pred_type), - pred_horizon = !is.null(oobag_pred_horizon), - tree_seeds = !is.null(tree_seeds) + control = !is.null(control), + lincomb_df_target = !is.null(control$lincomb_df_target), + weights = !is.null(weights), + mtry = !is.null(mtry), + split_rule = !is.null(split_rule), + split_min_stat = !is.null(split_min_stat), + pred_type = !is.null(pred_type), + pred_horizon = !is.null(oobag_pred_horizon), + oobag_eval_function = !is.null(oobag_fun), + tree_seeds = !is.null(tree_seeds) ) self$data <- data @@ -175,8 +177,6 @@ ObliqueForest <- R6::R6Class( na_action = NULL, verbose_progress = NULL) { - if(!is.null(data)) - self$data <- data if(!is.null(formula)) self$formula <- formula if(!is.null(control)) @@ -230,7 +230,7 @@ ObliqueForest <- R6::R6Class( if(!is.null(verbose_progress)) self$verbose_progress <- verbose_progress - private$init() + private$init(data = data) }, @@ -322,7 +322,7 @@ ObliqueForest <- R6::R6Class( # for survival, inputs should be sorted by time private$sort_inputs() - # allow re-training (b/c why not). + # allow re-training if(self$trained){ self$forest <- list() } cpp_args <- private$prep_cpp_args(...) @@ -1097,54 +1097,121 @@ ObliqueForest <- R6::R6Class( mean_leaves = 0, - # runs checks and sets defaults where needed - init = function() { + # runs checks and sets defaults where needed. + # data is NULL when we are creating a new forest, + # but may be non-NULL if we update an existing one + init = function(data = NULL) { - private$check_data() + # look for odd symbols in formula before you check variables in data private$check_formula() + # check & init data should be near first bc they set up other checks + private$check_data(data) + private$init_data(data) + + # if data is not null, it means we are updating an orsf spec + # and in that process applying it to a new dataset, so: + if(!is.null(data)) self$data <- data + - if(is.null(self$control)){ + + if(private$user_specified$control){ + private$check_control() + } else { private$init_control() + } + + + if(private$user_specified$mtry){ + private$check_mtry() } else { - private$check_control() + private$init_mtry() + } + + if(private$user_specified$lincomb_df_target){ + private$check_lincomb_df_target() + } else { + private$init_lincomb_df_target() + } + + if(private$user_specified$weights){ + private$check_weights() + } else { + private$init_weights() + } + + if(private$user_specified$pred_type){ + private$check_pred_type(oobag = TRUE) + } else { + private$init_pred_type() } - private$init_data() - private$init_mtry() - private$init_weights() + if(private$user_specified$split_rule){ + private$check_split_rule() + } else { + private$init_split_rule() + } + + if(private$user_specified$split_min_stat){ + private$check_split_min_stat() + } else { + private$init_split_min_stat() + } + + if(private$user_specified$oobag_eval_function){ + private$check_oobag_eval_function() + self$oobag_eval_type <- "User-specified function" + } else { + private$init_oobag_eval_function() + } + if(self$control$lincomb_type == 'custom'){ + private$check_lincomb_R_function() + } else if (is.null(self$control$lincomb_R_function)){ + private$init_lincomb_R_function() + } + + # arguments with hard defaults do not need an init option private$check_n_tree() private$check_n_split() private$check_n_retry() private$check_n_thread() - private$check_mtry() private$check_sample_with_replacement() private$check_sample_fraction() private$check_leaf_min_obs() - private$check_split_rule() private$check_split_min_obs() - private$check_split_min_stat() - private$check_pred_type(oobag = TRUE) private$check_oobag_eval_every() private$check_importance_type() private$check_importance_max_pvalue() private$check_importance_group_factors() - private$check_tree_seeds() private$check_na_action() + # args below depend on at least one upstream arg + + if(private$user_specified$tree_seeds){ + private$check_tree_seeds() + } else { + private$init_tree_seeds() + } + + if(length(self$tree_seeds) == 1 && self$n_tree > 1){ + private$plant_tree_seed() + } - private$init_oobag_eval_function() - private$init_lincomb_R_function() + # oobag_pred_mode depends on pred_type, which is checked above, + # so there is no reason to check it here. private$init_oobag_pred_mode() - private$init_tree_seeds() + # check if sample_fraction conflicts with oobag_pred_mode + private$check_oobag_pred_mode(self$oobag_pred_mode, + label = 'oobag_pred_mode', + sample_fraction = self$sample_fraction) + private$init_internal() }, - init_internal = function(){ - stop("this method should only be called from derived classes") - }, - init_data = function(){ + init_data = function(data = NULL){ + + if(!is.null(data)) self$data <- data formula_terms <- suppressWarnings( stats::terms(x = self$formula, data = self$data) @@ -1160,8 +1227,7 @@ ObliqueForest <- R6::R6Class( fctr_check(self$data, names_x_data) fctr_id_check(self$data, names_x_data) - - private$check_var_names(c(names_x_data, names_y_data)) + private$check_var_names(c(names_x_data, names_y_data), data = self$data) private$data_names <- list(y = names_y_data, x_original = names_x_data) @@ -1205,14 +1271,11 @@ ObliqueForest <- R6::R6Class( }, init_tree_seeds = function(){ + if(is.null(self$tree_seeds)){ self$tree_seeds <- sample(1e6, size = 1) } - if(length(self$tree_seeds) == 1 && self$n_tree > 1){ - set.seed(self$tree_seeds) - self$tree_seeds <- sample(self$n_tree*10, size = self$n_tree) - } }, init_numeric_names = function(){ @@ -1250,104 +1313,88 @@ ObliqueForest <- R6::R6Class( }, init_oobag_pred_mode = function(){ + # if pred_type is null when this is run, it means + # the user did not specify pred_type, which means the + # family-specific default will be used, which means + # pred_type will not be 'none', so it is safe to assume + # oobag_pred_mode is TRUE if pred_type is currently null + if(is.null(self$pred_type)){ self$oobag_pred_mode <- TRUE } else { self$oobag_pred_mode <- self$pred_type != "none" } - if(!self$oobag_pred_mode) self$oobag_eval_type <- "none" - - if(self$oobag_pred_mode && self$sample_fraction == 1){ - stop( - "cannot compute out-of-bag predictions if no samples are out-of-bag.", - " Try setting sample_fraction < 1 or pred_type = 'none'.", - call. = FALSE - ) - } - }, init_mtry = function(){ n_col_x <- length(private$data_names$x_ref_code) - if(is.null(self$mtry)){ + self$mtry <- ceiling(sqrt(n_col_x)) - self$mtry <- ceiling(sqrt(n_col_x)) + }, - } else { - check_arg_lteq( - arg_value = self$mtry, - arg_name = 'mtry', - bound = n_col_x, - append_to_msg = "(number of columns in the one-hot encoded x-matrix)" - ) - } - if(is.null(self$control$lincomb_df_target)){ + init_lincomb_df_target = function(mtry = NULL){ - self$control$lincomb_df_target <- self$mtry + mtry <- mtry %||% self$mtry - } else { + self$control$lincomb_df_target <- mtry - check_arg_lteq( - arg_value = self$control$lincomb_df_target, - arg_name = 'df_target', - bound = self$mtry, - append_to_msg = "(number of randomly selected predictors)" - ) + }, - } + check_lincomb_df_target = function(lincomb_df_target = NULL, + mtry = NULL){ + + input <- lincomb_df_target %||% self$control$lincomb_df_target + mtry <- mtry %||% self$mtry + + check_arg_lteq( + arg_value = input, + arg_name = 'df_target', + bound = mtry, + append_to_msg = "(number of randomly selected predictors)" + ) }, + init_weights = function(){ # set weights as 1 if user did not supply them. # length of weights depends on how missing are handled. - if(is.null(self$weights)){ - - private$w <- rep(1, self$n_obs) - - } else { - - private$check_weights() - private$w <- self$weights - - } + self$weights <- rep(1, self$n_obs) }, - init_oobag_eval_function = function(){ - - if(is.null(self$oobag_eval_function)){ - - self$oobag_eval_function <- function(y_mat, w_vec, s_vec){ - return(1) - } - } else { - private$check_oobag_eval_function() - self$oobag_eval_type <- "User-specified function" + init_oobag_eval_function = function(){ + self$oobag_eval_function <- function(y_mat, w_vec, s_vec){ + return(1) } }, init_lincomb_R_function = function(){ - if(self$control$lincomb_type == 'custom'){ + self$control$lincomb_R_function <- function(x) x - private$check_lincomb_R_function(self$control$lincomb_R_function) + }, - } + # use a starter seed to create n_tree seeds + plant_tree_seed = function(){ + + set.seed(self$tree_seeds) + self$tree_seeds <- sample(self$n_tree*10, size = self$n_tree) }, # checkers check_data = function(data = NULL, new = FALSE){ + # additional data checks are run during initialization. input <- data %||% self$data check_arg_is(arg_value = input, @@ -1421,9 +1468,11 @@ ObliqueForest <- R6::R6Class( }, - check_var_names = function(.names){ + check_var_names = function(.names, data = NULL){ - names_not_found <- setdiff(c(.names), names(self$data)) + data <- data %||% self$data + + names_not_found <- setdiff(c(.names), names(data)) if(!is_empty(names_not_found)){ msg <- paste0( @@ -1589,11 +1638,9 @@ ObliqueForest <- R6::R6Class( arg_name = 'weights', bound = 0) - check_arg_length( - arg_value = input, - arg_name = 'weights', - expected_length = self$n_obs - ) + check_arg_length(arg_value = input, + arg_name = 'weights', + expected_length = self$n_obs) }, check_n_tree = function(n_tree = NULL){ @@ -1713,6 +1760,8 @@ ObliqueForest <- R6::R6Class( input <- mtry %||% self$mtry + n_predictors <- length(private$data_names$x_ref_code) + # okay for this to be unspecified at startup if(!is.null(input)){ @@ -1731,6 +1780,15 @@ ObliqueForest <- R6::R6Class( arg_name = 'mtry', expected_length = 1) + if(!is.null(n_predictors)){ + check_arg_lteq( + arg_value = input, + arg_name = 'mtry', + bound = n_predictors, + append_to_msg = "(number of columns in the reference coded x-matrix)" + ) + } + } }, @@ -1924,11 +1982,6 @@ ObliqueForest <- R6::R6Class( }, - check_pred_type_internal = function(oobag, pred_type = NULL){ - - stop("this method should be defined in a derived class.") - - }, check_pred_aggregate = function(pred_aggregate = NULL){ input <- pred_aggregate %||% self$pred_aggregate @@ -2133,7 +2186,7 @@ ObliqueForest <- R6::R6Class( check_lincomb_R_function = function(lincomb_R_function = NULL){ - input <- lincomb_R_function %||% self$lincomb_R_function + input <- lincomb_R_function %||% self$control$lincomb_R_function args <- names(formals(input)) @@ -2378,7 +2431,10 @@ ObliqueForest <- R6::R6Class( }, - check_oobag_pred_mode = function(oobag_pred_mode, label){ + check_oobag_pred_mode = function(oobag_pred_mode, label, + sample_fraction = NULL){ + + sample_fraction <- sample_fraction %||% self$sample_fraction check_arg_type(arg_value = oobag_pred_mode, arg_name = label, @@ -2388,6 +2444,17 @@ ObliqueForest <- R6::R6Class( arg_name = label, expected_length = 1) + if(!is.null(sample_fraction)){ + + if(oobag_pred_mode && sample_fraction == 1){ + stop( + "cannot compute out-of-bag predictions if no samples are out-of-bag.", + " Try setting sample_fraction < 1 or oobag_pred_type = 'none'.", + call. = FALSE + ) + } + + } }, @@ -2396,6 +2463,11 @@ ObliqueForest <- R6::R6Class( compute_means = function(){ numeric_data <- select_cols(self$data, private$data_names$x_numeric) + + if(self$na_action == 'omit'){ + numeric_data <- collapse::fsubset(numeric_data, private$data_rows_complete) + } + private$data_means <- collapse::fmean(numeric_data, w = self$weights) }, @@ -2403,18 +2475,27 @@ ObliqueForest <- R6::R6Class( compute_modes = function(){ - private$data_modes <- vapply( - select_cols(self$data, private$data_fctrs$cols), - collapse::fmode, - FUN.VALUE = integer(1), - w = self$weights - ) + nominal_data <- select_cols(self$data, private$data_fctrs$cols) + + if(self$na_action == 'omit'){ + nominal_data <- collapse::fsubset(nominal_data, private$data_rows_complete) + } + + private$data_modes <- vapply(nominal_data, + collapse::fmode, + FUN.VALUE = integer(1), + w = self$weights) }, compute_stdev = function(){ numeric_data <- select_cols(self$data, private$data_names$x_numeric) + + if(self$na_action == 'omit'){ + numeric_data <- collapse::fsubset(numeric_data, private$data_rows_complete) + } + private$data_stdev <- collapse::fsd(numeric_data, w = self$weights) }, @@ -2423,6 +2504,10 @@ ObliqueForest <- R6::R6Class( numeric_data <- select_cols(self$data, private$data_names$x_numeric) + if(self$na_action == 'omit'){ + numeric_data <- collapse::fsubset(numeric_data, private$data_rows_complete) + } + private$data_bounds <- matrix( data = c( collapse::fnth(numeric_data, 0.10, w = self$weights), @@ -2513,8 +2598,8 @@ ObliqueForest <- R6::R6Class( prep_w = function(){ - # re-initialize - private$init_weights() + # re-scale so that sum(w) == nrow(data) + private$w <- self$weights * length(self$weights) / sum(self$weights) }, @@ -2964,6 +3049,24 @@ ObliqueForestSurvival <- R6::R6Class( }, + init_pred_type = function(){ + self$pred_type <- 'risk' + }, + + init_split_rule = function(){ + self$split_rule <- 'logrank' + }, + + init_split_min_stat = function(){ + + if(is.null(self$split_rule)) + stop("cannot init split_min_stat without split_rule", call. = FALSE) + + self$split_min_stat <- switch(self$split_rule, + 'logrank' = 3.841459, + 'cstat' = 0.55) + }, + init_internal = function(){ self$tree_type <- "survival" @@ -2973,11 +3076,6 @@ ObliqueForestSurvival <- R6::R6Class( self$control$lincomb_R_function <- penalized_cph } - self$split_rule <- self$split_rule %||% 'logrank' - self$pred_type <- self$pred_type %||% 'surv' - self$split_min_stat <- self$split_min_stat %||% - switch(self$split_rule, 'logrank' = 3.841459, 'cstat' = 0.50) - y <- select_cols(self$data, private$data_names$y) if(inherits(y[[1]], 'Surv')){ @@ -3020,15 +3118,17 @@ ObliqueForestSurvival <- R6::R6Class( # if pred_horizon is unspecified, provide sensible default # if it is specified, check for correctness - if(is.null(self$pred_horizon)){ - self$pred_horizon <- collapse::fmedian(y[, 1]) - } else { + if(private$user_specified$pred_horizon){ private$check_pred_horizon(self$pred_horizon, boundary_checks = TRUE) + } else { + self$pred_horizon <- collapse::fmedian(y[, 1]) } private$check_leaf_min_events() private$check_split_min_events() + if(!self$oobag_pred_mode) self$oobag_eval_type <- "none" + # use default if eval type was not specified by user if(self$oobag_pred_mode && is.null(self$oobag_eval_type)){ self$oobag_eval_type <- "Harrell's C-index" @@ -3376,6 +3476,25 @@ ObliqueForestClassification <- R6::R6Class( }, + init_pred_type = function(){ + self$pred_type <- 'prob' + }, + + init_split_rule = function(){ + self$split_rule <- 'gini' + }, + + init_split_min_stat = function(){ + + if(is.null(self$split_rule)) + stop("cannot init split_min_stat without split_rule", call. = FALSE) + + self$split_min_stat <- switch(self$split_rule, + 'gini' = 0, + 'cstat' = 0.55) + + }, + init_internal = function(){ self$tree_type <- "classification" @@ -3385,10 +3504,7 @@ ObliqueForestClassification <- R6::R6Class( self$control$lincomb_R_function <- penalized_logreg } - self$split_rule <- self$split_rule %||% 'gini' - self$pred_type <- self$pred_type %||% 'prob' - self$split_min_stat <- self$split_min_stat %||% - switch(self$split_rule, 'gini' = 0, 'cstat' = 0.50) + if(!self$oobag_pred_mode) self$oobag_eval_type <- "none" # use default if eval type was not specified by user if(self$oobag_pred_mode && is.null(self$oobag_eval_type)){ @@ -3452,7 +3568,7 @@ ObliqueForestClassification <- R6::R6Class( y <- as.numeric(y) - 1 - if(min(y) > 0) browser() + if(min(y) > 0) stop("y is less than 0") private$y <- expand_y_clsf(as_matrix(y), n_class) @@ -3595,6 +3711,23 @@ ObliqueForestRegression <- R6::R6Class( }, + init_pred_type = function(){ + self$pred_type <- 'mean' + }, + + init_split_rule = function(){ + self$split_rule <- 'variance' + }, + + init_split_min_stat = function(){ + + if(is.null(self$split_rule)) + stop("cannot init split_min_stat without split_rule", call. = FALSE) + + self$split_min_stat <- switch(self$split_rule, 'variance' = 0) + + }, + init_internal = function(){ self$tree_type <- "regression" @@ -3611,10 +3744,7 @@ ObliqueForestRegression <- R6::R6Class( self$control$lincomb_R_function <- penalized_linreg } - self$split_rule <- self$split_rule %||% 'variance' - self$pred_type <- self$pred_type %||% 'mean' - self$split_min_stat <- self$split_min_stat %||% - switch(self$split_rule, 'variance' = 0) + if(!self$oobag_pred_mode) self$oobag_eval_type <- "none" # use default if eval type was not specified by user if(self$oobag_pred_mode && is.null(self$oobag_eval_type)){ @@ -3649,9 +3779,8 @@ ObliqueForestRegression <- R6::R6Class( } # y is always 1 column for regression (for now) - y <- private$y[[1]] - - private$y <- as_matrix(y) + private$y <- as_matrix(private$y[[1]]) + colnames(private$y) <- private$data_names$y }, diff --git a/R/orsf_control.R b/R/orsf_control.R index 4e327c0b..8aa5cd4d 100644 --- a/R/orsf_control.R +++ b/R/orsf_control.R @@ -454,14 +454,16 @@ orsf_control <- function(tree_type, lincomb_R_function <- switch(tree_type, 'survival' = penalized_cph, 'classification' = penalized_logreg, + 'regression' = penalized_linreg, 'unknown' = 'unknown') } else { - lincomb_R_function <- function(x) x + lincomb_R_function <- NULL } + structure( .Data = list( tree_type = tree_type, diff --git a/tests/testthat/test-impute_meanmode.R b/tests/testthat/test-impute_meanmode.R index 65bc98eb..d8dfd761 100644 --- a/tests/testthat/test-impute_meanmode.R +++ b/tests/testthat/test-impute_meanmode.R @@ -97,7 +97,6 @@ test_that( } ) - test_that( desc = "imputation does not coerce columns to new types", code = { diff --git a/tests/testthat/test-orsf.R b/tests/testthat/test-orsf.R index e151f466..0d009a30 100644 --- a/tests/testthat/test-orsf.R +++ b/tests/testthat/test-orsf.R @@ -31,7 +31,7 @@ test_that( code = { - fit_regr <- orsf(mtcars, mpg ~ ., no_fit = TRUE) + fit_regr <- orsf(penguins, bill_length_mm ~ ., no_fit = TRUE) fit_clsf <- orsf(penguins, species ~ ., no_fit = TRUE) fit_surv <- orsf(pbc, time + status ~ ., no_fit = TRUE) @@ -52,8 +52,8 @@ test_that( ) expect_error( - orsf(mtcars, mpg ~., control = orsf_control_classification()), - "please convert mpg to a factor" + orsf(penguins, bill_length_mm ~., control = orsf_control_classification()), + "please convert bill_length_mm to a factor" ) } @@ -230,9 +230,9 @@ test_that( expect_lt(fit$eval_oobag$stat_values[1], last_value(fit$eval_oobag$stat_values)) - fit <- orsf(mtcars, - formula = mpg ~ ., - leaf_min_obs = 10, + fit <- orsf(penguins, + formula = bill_length_mm ~ ., + leaf_min_obs = 50, n_tree = n_tree, # just needs a bit extra tree_seeds = seeds_standard, oobag_eval_every = eval_every) @@ -306,25 +306,6 @@ test_that( } ) -test_that( - desc = 'missing data are imputed when na_action is impute_meanmode', - code = { - - mtcars_temp <- mtcars - mtcars_temp$disp[1] <- NA - - fit_impute <- orsf(mtcars_temp, mpg ~ ., - na_action = 'impute_meanmode') - - expect_equal(fit_impute$n_obs, nrow(mtcars_temp)) - - # users data are not modified by imputation - expect_true(is.na(mtcars_temp$disp[1])) - expect_identical(mtcars_temp, fit_impute$data) - - } -) - test_that( desc = 'robust to threading, outcome formats, scaling, and noising', @@ -392,14 +373,14 @@ test_that( fit <- orsf(data = pbc, formula = time + status ~ . -id, n_tree = n_tree_test, - oobag_fun = oobag_c_survival, + oobag_fun = oobag_c_risk, tree_seeds = seeds_standard) expect_equal_oobag_eval(fit, fit_standard_pbc$fast) # can also reproduce it from the oobag predictions expect_equal( - oobag_c_survival( + oobag_c_risk( y_mat = as.matrix(pbc_orsf[,c("time", "status")]), w_vec = rep(1, nrow(pbc_orsf)), s_vec = fit$pred_oobag @@ -503,21 +484,23 @@ test_that( ) test_that( - desc = 'weights work as intended', + desc = 'weights do not make trees grow more than intended', code = { - fit_unwtd <- orsf(pbc_orsf, - Surv(time, status) ~ . - id, - n_tree = n_tree_test) + fit_unwtd <- orsf(pbc, time + status ~ ., + n_tree = n_tree_test, + tree_seeds = seeds_standard) - fit_wtd <- orsf(pbc_orsf, - Surv(time, status) ~ . - id, + fit_wtd <- orsf(pbc, + time + status ~ ., weights = rep(2, nrow(pbc_orsf)), - n_tree = n_tree_test) + n_tree = n_tree_test, + tree_seeds = seeds_standard) - # using weights should make the trees much deeper: - expect_gt(fit_wtd$get_mean_leaves_per_tree(), - fit_unwtd$get_mean_leaves_per_tree()) + # using weights should not inadvertently make trees deeper. + expect_equal(fit_wtd$get_mean_leaves_per_tree(), + fit_unwtd$get_mean_leaves_per_tree(), + tolerance = 1/2) } ) @@ -897,8 +880,8 @@ test_that( sample_fraction <- runif(n = 1, min = .25, max = .75) } - fit <- orsf(data = data_fun(mtcars), - formula = mpg ~ ., + fit <- orsf(data = data_fun(penguins), + formula = bill_length_mm ~ ., control = control, sample_with_replacement = inputs$sample_with_replacement[i], sample_fraction = sample_fraction, @@ -916,7 +899,7 @@ test_that( expect_s3_class(fit, class = 'ObliqueForestRegression') # data are not unintentionally modified by reference, - expect_identical(data_fun(mtcars), fit$data) + expect_identical(data_fun(penguins), fit$data) expect_no_missing(fit$forest)