Skip to content

Commit

Permalink
let updates use null
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Dec 29, 2023
1 parent 1c2ccca commit 55e3c01
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
47 changes: 35 additions & 12 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,19 @@ ObliqueForest <- R6::R6Class(
# the standard init() function is called. If they were specified,
# the standard check function is called.

if(!is.null(control)){
# this allows someone to set control to NULL and revert
# to using a default control for the given forest.
if(is.null(control)){
private$user_specified$control <- FALSE
} else {
self$control <- control
private$user_specified$control <- TRUE
}

if(!is.null(weights)){

if(is.null(weights)){
private$user_specified$weights <- FALSE
} else {
self$weights <- weights
private$user_specified$weights <- TRUE
}
Expand All @@ -216,7 +223,9 @@ ObliqueForest <- R6::R6Class(
if(!is.null(n_thread))
self$n_thread <- n_thread

if(!is.null(mtry)){
if(is.null(mtry)){
private$user_specified$mtry <- FALSE
} else {
self$mtry <- mtry
private$user_specified$mtry <- TRUE
}
Expand All @@ -233,7 +242,10 @@ ObliqueForest <- R6::R6Class(
if(!is.null(leaf_min_obs))
self$leaf_min_obs <- leaf_min_obs

if(!is.null(split_rule)){

if(is.null(split_rule)){
private$user_specified$split_rule <- FALSE
} else {
self$split_rule <- split_rule
private$user_specified$split_rule <- TRUE
}
Expand All @@ -244,28 +256,37 @@ ObliqueForest <- R6::R6Class(
if(!is.null(split_min_obs))
self$split_min_obs <- split_min_obs

if(!is.null(split_min_stat)){
if(is.null(split_min_stat)){
private$user_specified$split_min_stat <- FALSE
} else {
self$split_min_stat <- split_min_stat
private$user_specified$split_min_stat <- TRUE
}

if(!is.null(pred_type)){
if(is.null(pred_type)){
private$user_specified$pred_type <- FALSE
} else {
self$pred_type <- pred_type
private$user_specified$pred_type <- TRUE
}


if(!is.null(oobag_pred_horizon)){
if(is.null(oobag_pred_horizon)){
private$user_specified$pred_horizon <- FALSE
} else {
self$pred_horizon <- oobag_pred_horizon
private$user_specified$pred_horizon <- TRUE
}

if(!is.null(oobag_eval_every)){
private$user_specified$oobag_eval_every <- TRUE
if(is.null(oobag_eval_every)){
private$user_specified$oobag_eval_every <- FALSE
} else {
self$oobag_eval_every <- oobag_eval_every
private$user_specified$oobag_eval_every <- TRUE
}

if(!is.null(oobag_fun)){
if(is.null(oobag_fun)){
private$user_specified$oobag_eval_function <- FALSE
} else {
self$oobag_eval_function <- oobag_fun
private$user_specified$oobag_eval_function <- TRUE
}
Expand All @@ -279,7 +300,9 @@ ObliqueForest <- R6::R6Class(
if(!is.null(group_factors))
self$importance_group_factors <- group_factors

if(!is.null(tree_seeds)){
if(is.null(tree_seeds)){
private$user_specified$tree_seeds <- FALSE
} else {
self$tree_seeds <- tree_seeds
private$user_specified$tree_seeds <- TRUE
}
Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/test-orsf_update.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,19 @@ test_that(

}
)

test_that(
desc = "setting a default null field to null reverts that field to the default value",
code = {

fit_control <- fit_standard_penguin_bills$custom

fit_control_null <- orsf_update(fit_control, control = NULL)

expect_equal_leaf_summary(fit_standard_penguin_bills$fast,
fit_control_null)

}
)


0 comments on commit 55e3c01

Please sign in to comment.