From f1b12e61b1ef090e8624c511a527c8e5b26aaf6d Mon Sep 17 00:00:00 2001 From: bjaeger Date: Thu, 14 Dec 2023 09:33:57 -0500 Subject: [PATCH] protect from sparseness --- R/orsf_R6.R | 28 +++++++++++++++++++++------- R/orsf_vint.R | 2 +- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/R/orsf_R6.R b/R/orsf_R6.R index 156fa579..f38eeda1 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -1,10 +1,7 @@ # TODO: # - add nocov to cpp -# - automatic bounds for pd (better interface) -# - tests for check_oobag_eval_function # - tests for survival forest w/no censored -# - tests for check_oobag_eval_function_internal # ObliqueForest class ---- @@ -614,7 +611,9 @@ ObliqueForest <- R6::R6Class( pred_spec <- list_init(pred_spec) for(i in names(pred_spec)){ - pred_spec[[i]] <- self$get_var_bounds(i) + + pred_spec[[i]] <- unique(self$get_var_bounds(i)) + } } else if (inherits(pred_spec, 'pspec_intr')){ @@ -978,11 +977,26 @@ ObliqueForest <- R6::R6Class( get_var_bounds = function(.name){ - if(.name %in% private$data_names$x_numeric) - return(as.numeric(private$data_bounds[, .name])) - else + if(.name %in% private$data_names$x_numeric){ + + out <- unique(as.numeric(private$data_bounds[, .name])) + + if(length(out) < 5){ + # too few unique values to use quantiles, + # so use the most common unique values instead. + unis <- sort(table(self$data[[.name]]), decreasing = TRUE) + n_items <- min(5, length(unis)) + out <- sort(as.numeric(names(all_unis)[seq(n_items)])) + } + + return(out) + + } else { + return(private$data_fctrs$lvls[[.name]]) + } + }, get_var_type = function(.name){ diff --git a/R/orsf_vint.R b/R/orsf_vint.R index 2e8e2b49..e3c6bf75 100644 --- a/R/orsf_vint.R +++ b/R/orsf_vint.R @@ -85,7 +85,7 @@ orsf_vint <- function(object, pred_spec = pspec, pred_horizon = NULL, pred_type = ptype, - na_action = 'fail', + na_action = object$na_action, expand_grid = FALSE, prob_values = NULL, prob_labels = NULL,