diff --git a/R/orsf_R6.R b/R/orsf_R6.R index b8b9df9b..05cdfabf 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -194,12 +194,19 @@ ObliqueForest <- R6::R6Class( # the standard init() function is called. If they were specified, # the standard check function is called. - if(!is.null(control)){ + # this allows someone to set control to NULL and revert + # to using a default control for the given forest. + if(is.null(control)){ + private$user_specified$control <- FALSE + } else { self$control <- control private$user_specified$control <- TRUE } - if(!is.null(weights)){ + + if(is.null(weights)){ + private$user_specified$weights <- FALSE + } else { self$weights <- weights private$user_specified$weights <- TRUE } @@ -216,7 +223,9 @@ ObliqueForest <- R6::R6Class( if(!is.null(n_thread)) self$n_thread <- n_thread - if(!is.null(mtry)){ + if(is.null(mtry)){ + private$user_specified$mtry <- FALSE + } else { self$mtry <- mtry private$user_specified$mtry <- TRUE } @@ -233,7 +242,10 @@ ObliqueForest <- R6::R6Class( if(!is.null(leaf_min_obs)) self$leaf_min_obs <- leaf_min_obs - if(!is.null(split_rule)){ + + if(is.null(split_rule)){ + private$user_specified$split_rule <- FALSE + } else { self$split_rule <- split_rule private$user_specified$split_rule <- TRUE } @@ -244,28 +256,37 @@ ObliqueForest <- R6::R6Class( if(!is.null(split_min_obs)) self$split_min_obs <- split_min_obs - if(!is.null(split_min_stat)){ + if(is.null(split_min_stat)){ + private$user_specified$split_min_stat <- FALSE + } else { self$split_min_stat <- split_min_stat private$user_specified$split_min_stat <- TRUE } - if(!is.null(pred_type)){ + if(is.null(pred_type)){ + private$user_specified$pred_type <- FALSE + } else { self$pred_type <- pred_type private$user_specified$pred_type <- TRUE } - - if(!is.null(oobag_pred_horizon)){ + if(is.null(oobag_pred_horizon)){ + private$user_specified$pred_horizon <- FALSE + } else { self$pred_horizon <- oobag_pred_horizon private$user_specified$pred_horizon <- TRUE } - if(!is.null(oobag_eval_every)){ - private$user_specified$oobag_eval_every <- TRUE + if(is.null(oobag_eval_every)){ + private$user_specified$oobag_eval_every <- FALSE + } else { self$oobag_eval_every <- oobag_eval_every + private$user_specified$oobag_eval_every <- TRUE } - if(!is.null(oobag_fun)){ + if(is.null(oobag_fun)){ + private$user_specified$oobag_eval_function <- FALSE + } else { self$oobag_eval_function <- oobag_fun private$user_specified$oobag_eval_function <- TRUE } @@ -279,7 +300,9 @@ ObliqueForest <- R6::R6Class( if(!is.null(group_factors)) self$importance_group_factors <- group_factors - if(!is.null(tree_seeds)){ + if(is.null(tree_seeds)){ + private$user_specified$tree_seeds <- FALSE + } else { self$tree_seeds <- tree_seeds private$user_specified$tree_seeds <- TRUE } diff --git a/tests/testthat/test-orsf_update.R b/tests/testthat/test-orsf_update.R index ac2991a7..42115edd 100644 --- a/tests/testthat/test-orsf_update.R +++ b/tests/testthat/test-orsf_update.R @@ -77,3 +77,19 @@ test_that( } ) + +test_that( + desc = "setting a default null field to null reverts that field to the default value", + code = { + + fit_control <- fit_standard_penguin_bills$custom + + fit_control_null <- orsf_update(fit_control, control = NULL) + + expect_equal_leaf_summary(fit_standard_penguin_bills$fast, + fit_control_null) + + } +) + +