From 8f4ac35652bf8aa472a0ac4b7a80e8feade234aa Mon Sep 17 00:00:00 2001 From: bcjaeger Date: Sat, 4 May 2024 17:37:57 -0400 Subject: [PATCH] propagate na's for cart --- R/orsf_R6.R | 17 ++++++++++++ tests/testthat/test-na_action.R | 46 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/R/orsf_R6.R b/R/orsf_R6.R index e7352e68..0f9bf4b1 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -467,6 +467,7 @@ ObliqueForest <- R6::R6Class( self$verbose_progress <- verbose_progress self$pred_aggregate <- pred_aggregate + out <- try( expr = { @@ -4262,6 +4263,8 @@ ObliqueForestClassification <- R6::R6Class( out <- do.call(orsf_cpp, args = cpp_args)$pred_new + out <- private$clean_pred_new(out) + if(self$pred_type == 'prob' && self$pred_aggregate){ colnames(out) <- self$class_levels @@ -4285,6 +4288,12 @@ ObliqueForestClassification <- R6::R6Class( out + }, + + clean_pred_new_internal = function(preds){ + + preds + } ) @@ -4510,10 +4519,18 @@ ObliqueForestRegression <- R6::R6Class( out <- do.call(orsf_cpp, args = cpp_args)$pred_new + out <- private$clean_pred_new(out) + if(simplify) dim(out) <- NULL out + }, + + clean_pred_new_internal = function(preds){ + + preds + } ) diff --git a/tests/testthat/test-na_action.R b/tests/testthat/test-na_action.R index 3049fb2e..f3d637ad 100644 --- a/tests/testthat/test-na_action.R +++ b/tests/testthat/test-na_action.R @@ -111,3 +111,49 @@ test_that( }) +test_that( + desc = "na action of pass works with new preds", + code = { + + mtcars_orsf <- mtcars + mtcars_orsf$vs <- factor(mtcars_orsf$vs) + mtcars_na <- mtcars_orsf + + set_to_miss <- c(1, 4, 18) + mtcars_na$cyl[set_to_miss] <- NA + + aorsf_regr_fit <- orsf( + data = mtcars_orsf, + formula = mpg ~ ., + n_tree = n_tree_test, + tree_seeds = seeds_standard + ) + + aorsf_regr_pred <- predict( + aorsf_regr_fit, + new_data = mtcars_na, + na_action = 'pass' + ) + + expect_equal(which(is.na(aorsf_regr_pred)), set_to_miss) + + aorsf_clsf_fit <- aorsf::orsf( + data = mtcars_orsf, + formula = vs ~ ., + n_tree = n_tree_test, + tree_seeds = seeds_standard + ) + + aorsf_clsf_pred <- predict( + aorsf_clsf_fit, + new_data = mtcars_na, + na_action = 'pass' + ) + + expect_equal(which(is.na(aorsf_clsf_pred[,1])), set_to_miss) + expect_equal(which(is.na(aorsf_clsf_pred[,2])), set_to_miss) + + # this test is already done for survival in test-orsf_predict + + } +)