Skip to content

Commit

Permalink
orsf_update use named input only
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Dec 30, 2023
1 parent 185839e commit 308d494
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 303 deletions.
181 changes: 54 additions & 127 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -156,163 +156,90 @@ ObliqueForest <- R6::R6Class(
},

# Update: re-initialize if dynamic args were unspecified
update = function(data = NULL,
formula = NULL,
control = NULL,
weights = NULL,
n_tree = NULL,
n_split = NULL,
n_retry = NULL,
n_thread = NULL,
mtry = NULL,
sample_with_replacement = NULL,
sample_fraction = NULL,
leaf_min_events = NULL,
leaf_min_obs = NULL,
split_rule = NULL,
split_min_events = NULL,
split_min_obs = NULL,
split_min_stat = NULL,
pred_type = NULL,
oobag_pred_horizon = NULL,
oobag_eval_every = NULL,
oobag_fun = NULL,
importance = NULL,
importance_max_pvalue = NULL,
group_factors = NULL,
tree_seeds = NULL,
na_action = NULL,
verbose_progress = NULL) {

if(!is.null(formula)){
terms_old <- terms(self$formula, data = data %||% self$data)
self$formula <- stats::update(as.formula(terms_old), new = formula)
}
update = function(args) {

# for args with default values of NULL, keep track of whether
# user specified them or not. If they were un-specified, then
# the standard init() function is called. If they were specified,
# the standard check function is called.

# this allows someone to set control to NULL and revert
# to using a default control for the given forest.
if(is.null(control)){
private$user_specified$control <- FALSE
} else {
self$control <- control
private$user_specified$control <- TRUE
}


if(is.null(weights)){
private$user_specified$weights <- FALSE
} else {
self$weights <- weights
private$user_specified$weights <- TRUE
}

if(!is.null(n_tree))
self$n_tree <- n_tree
data <- args$data %||% self$data

if(!is.null(n_split))
self$n_split <- n_split

if(!is.null(n_retry))
self$n_retry <- n_retry

if(!is.null(n_thread))
self$n_thread <- n_thread

if(is.null(mtry)){
private$user_specified$mtry <- FALSE
} else {
self$mtry <- mtry
private$user_specified$mtry <- TRUE
if("formula" %in% names(args)){
formula <- args[['formula']]
terms_old <- terms(self$formula, data = data)
self$formula <- stats::update(as.formula(terms_old), new = formula)
}

if(!is.null(sample_with_replacement))
self$sample_with_replacement <- sample_with_replacement
null_defaults <- c(
control = "control",
weights = "weights",
mtry = "mtry",
split_rule = "split_rule",
split_min_stat = "split_min_stat",
pred_type = "oobag_pred_type",
pred_horizon = "oobag_pred_horizon",
oobag_eval_every = "oobag_eval_every",
oobag_eval_function = "oobag_fun",
tree_seeds = "tree_seeds"
)

if(!is.null(sample_fraction))
self$sample_fraction <- sample_fraction
hard_defaults <- c(
n_tree = "n_tree",
n_split = "n_split",
n_retry = "n_retry",
n_thread = "n_thread",
sample_with_replacement = "sample_with_replacement",
sample_fraction = "sample_fraction",
leaf_min_events = "leaf_min_events",
leaf_min_obs = "leaf_min_obs",
split_min_events = "split_min_events",
split_min_obs = "split_min_obs",
importance = "importance",
importance_max_pvalue = "importance_max_pvalue",
importance_group_factors = "group_factors",
na_action = "na_action",
verbose_progress = "verbose_progress"
)

if(!is.null(leaf_min_events))
self$leaf_min_events <- leaf_min_events
for( i in seq_along(null_defaults) ){

if(!is.null(leaf_min_obs))
self$leaf_min_obs <- leaf_min_obs
input_name <- null_defaults[i]

if(input_name %in% names(args)){

if(is.null(split_rule)){
private$user_specified$split_rule <- FALSE
} else {
self$split_rule <- split_rule
private$user_specified$split_rule <- TRUE
}
input <- args[[input_name]]

if(!is.null(split_min_events))
self$split_min_events <- split_min_events
r6_name <- names(null_defaults)[i]

if(!is.null(split_min_obs))
self$split_min_obs <- split_min_obs
if(is.null(input)){
private$user_specified[[r6_name]] <- FALSE
} else {
self[[r6_name]] <- input
private$user_specified[[r6_name]] <- TRUE
}

if(is.null(split_min_stat)){
private$user_specified$split_min_stat <- FALSE
} else {
self$split_min_stat <- split_min_stat
private$user_specified$split_min_stat <- TRUE
}
}

if(is.null(pred_type)){
private$user_specified$pred_type <- FALSE
} else {
self$pred_type <- pred_type
private$user_specified$pred_type <- TRUE
}

if(is.null(oobag_pred_horizon)){
private$user_specified$pred_horizon <- FALSE
} else {
self$pred_horizon <- oobag_pred_horizon
private$user_specified$pred_horizon <- TRUE
}
for(i in seq_along(hard_defaults)){

if(is.null(oobag_eval_every)){
private$user_specified$oobag_eval_every <- FALSE
} else {
self$oobag_eval_every <- oobag_eval_every
private$user_specified$oobag_eval_every <- TRUE
}
input_name <- hard_defaults[i]

if(is.null(oobag_fun)){
private$user_specified$oobag_eval_function <- FALSE
} else {
self$oobag_eval_function <- oobag_fun
private$user_specified$oobag_eval_function <- TRUE
}
if(input_name %in% names(args)){

if(!is.null(importance))
self$importance_type <- importance
input <- args[[input_name]]
r6_name <- names(hard_defaults)[i]

if(!is.null(importance_max_pvalue))
self$importance_max_pvalue <- importance_max_pvalue
self[[r6_name]] <- input

if(!is.null(group_factors))
self$importance_group_factors <- group_factors
}

if(is.null(tree_seeds)){
private$user_specified$tree_seeds <- FALSE
} else {
self$tree_seeds <- tree_seeds
private$user_specified$tree_seeds <- TRUE
}

if(!is.null(na_action))
self$na_action <- na_action

if(!is.null(verbose_progress))
self$verbose_progress <- verbose_progress

private$init(data = data)

},
Expand Down
2 changes: 1 addition & 1 deletion R/orsf_update.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ orsf_update <- function(object,

}

object_new$update(...)
object_new$update(.dots)

if(no_fit){

Expand Down
34 changes: 17 additions & 17 deletions man/orsf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 13 additions & 11 deletions man/orsf_control.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 308d494

Please sign in to comment.