Skip to content

Commit

Permalink
protect from sparseness
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Dec 14, 2023
1 parent e048c28 commit f1b12e6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
28 changes: 21 additions & 7 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@

# TODO:
# - add nocov to cpp
# - automatic bounds for pd (better interface)
# - tests for check_oobag_eval_function
# - tests for survival forest w/no censored
# - tests for check_oobag_eval_function_internal


# ObliqueForest class ----
Expand Down Expand Up @@ -614,7 +611,9 @@ ObliqueForest <- R6::R6Class(
pred_spec <- list_init(pred_spec)

for(i in names(pred_spec)){
pred_spec[[i]] <- self$get_var_bounds(i)

pred_spec[[i]] <- unique(self$get_var_bounds(i))

}

} else if (inherits(pred_spec, 'pspec_intr')){
Expand Down Expand Up @@ -978,11 +977,26 @@ ObliqueForest <- R6::R6Class(

get_var_bounds = function(.name){

if(.name %in% private$data_names$x_numeric)
return(as.numeric(private$data_bounds[, .name]))
else
if(.name %in% private$data_names$x_numeric){

out <- unique(as.numeric(private$data_bounds[, .name]))

if(length(out) < 5){
# too few unique values to use quantiles,
# so use the most common unique values instead.
unis <- sort(table(self$data[[.name]]), decreasing = TRUE)
n_items <- min(5, length(unis))
out <- sort(as.numeric(names(all_unis)[seq(n_items)]))
}

return(out)

} else {

return(private$data_fctrs$lvls[[.name]])

}

},

get_var_type = function(.name){
Expand Down
2 changes: 1 addition & 1 deletion R/orsf_vint.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ orsf_vint <- function(object,
pred_spec = pspec,
pred_horizon = NULL,
pred_type = ptype,
na_action = 'fail',
na_action = object$na_action,
expand_grid = FALSE,
prob_values = NULL,
prob_labels = NULL,
Expand Down

0 comments on commit f1b12e6

Please sign in to comment.