Skip to content

Commit

Permalink
Merge pull request #62 from ropensci/issue61
Browse files Browse the repository at this point in the history
Issue61
  • Loading branch information
bcjaeger authored May 4, 2024
2 parents cea2342 + 8bb5ec8 commit fd1a69c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
17 changes: 17 additions & 0 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ ObliqueForest <- R6::R6Class(
self$verbose_progress <- verbose_progress
self$pred_aggregate <- pred_aggregate


out <- try(
expr = {

Expand Down Expand Up @@ -4262,6 +4263,8 @@ ObliqueForestClassification <- R6::R6Class(

out <- do.call(orsf_cpp, args = cpp_args)$pred_new

out <- private$clean_pred_new(out)

if(self$pred_type == 'prob' && self$pred_aggregate){

colnames(out) <- self$class_levels
Expand All @@ -4285,6 +4288,12 @@ ObliqueForestClassification <- R6::R6Class(

out

},

clean_pred_new_internal = function(preds){

preds

}

)
Expand Down Expand Up @@ -4510,10 +4519,18 @@ ObliqueForestRegression <- R6::R6Class(

out <- do.call(orsf_cpp, args = cpp_args)$pred_new

out <- private$clean_pred_new(out)

if(simplify) dim(out) <- NULL

out

},

clean_pred_new_internal = function(preds){

preds

}

)
Expand Down
46 changes: 46 additions & 0 deletions tests/testthat/test-na_action.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,49 @@ test_that(

})

test_that(
desc = "na action of pass works with new preds",
code = {

mtcars_orsf <- mtcars
mtcars_orsf$vs <- factor(mtcars_orsf$vs)
mtcars_na <- mtcars_orsf

set_to_miss <- c(1, 4, 18)
mtcars_na$cyl[set_to_miss] <- NA

aorsf_regr_fit <- orsf(
data = mtcars_orsf,
formula = mpg ~ .,
n_tree = n_tree_test,
tree_seeds = seeds_standard
)

aorsf_regr_pred <- predict(
aorsf_regr_fit,
new_data = mtcars_na,
na_action = 'pass'
)

expect_equal(which(is.na(aorsf_regr_pred)), set_to_miss)

aorsf_clsf_fit <- aorsf::orsf(
data = mtcars_orsf,
formula = vs ~ .,
n_tree = n_tree_test,
tree_seeds = seeds_standard
)

aorsf_clsf_pred <- predict(
aorsf_clsf_fit,
new_data = mtcars_na,
na_action = 'pass'
)

expect_equal(which(is.na(aorsf_clsf_pred[,1])), set_to_miss)
expect_equal(which(is.na(aorsf_clsf_pred[,2])), set_to_miss)

# this test is already done for survival in test-orsf_predict

}
)

0 comments on commit fd1a69c

Please sign in to comment.