Skip to content

Commit

Permalink
deprecate fast control
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Nov 11, 2023
1 parent f47de2d commit 8ccb638
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 93 deletions.
9 changes: 7 additions & 2 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@

orsf <- function(data,
formula,
control = orsf_control_fast(),
control = NULL,
weights = NULL,
n_tree = 500,
n_split = 5,
Expand Down Expand Up @@ -380,7 +380,12 @@ orsf <- function(data,

if(length(formula) < 3) stop("formula must be two sided", call. = FALSE)

tree_type <- infer_tree_type(all.vars(formula[[2]]), args$data)
if(is.null(control)){
tree_type <- infer_tree_type(all.vars(formula[[2]]), args$data)
} else {
tree_type <- control$tree_type
}


object <- switch(
tree_type,
Expand Down
28 changes: 26 additions & 2 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,10 @@ ObliqueForest <- R6::R6Class(
paste0(' N predictors per node: ', self$mtry ),
paste0(' Average leaves per tree: ', private$mean_leaves ),
paste0('Min observations in leaf: ', self$leaf_min_obs ),
paste0(' Min events in leaf: ', self$leaf_min_events ),

if(self$tree_type == 'survival')
paste0(' Min events in leaf: ', self$leaf_min_events ),

paste0(' OOB stat value: ', oobag_stat ),
paste0(' OOB stat type: ', oobag_type ),
paste0(' Variable importance: ', self$importance_type ),
Expand Down Expand Up @@ -987,11 +990,16 @@ ObliqueForest <- R6::R6Class(
private$check_data()
private$check_formula()

if(is.null(self$control)){
private$init_control()
} else {
private$check_control()
}

private$init_data()
private$init_mtry()
private$init_weights()

private$check_control()
private$check_n_tree()
private$check_n_split()
private$check_n_retry()
Expand Down Expand Up @@ -2703,6 +2711,14 @@ ObliqueForestSurvival <- R6::R6Class(

},

init_control = function(){

self$control <- orsf_control_survival(method = 'glm',
scale_x = FALSE,
max_iter = 1)

},

init_internal = function(){

self$tree_type <- "survival"
Expand Down Expand Up @@ -2994,6 +3010,14 @@ ObliqueForestClassification <- R6::R6Class(

},

init_control = function(){

self$control <- orsf_control_classification(method = 'glm',
scale_x = FALSE,
max_iter = 1)

},

init_internal = function(){

self$tree_type <- "classification"
Expand Down
18 changes: 12 additions & 6 deletions R/orsf_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ orsf_control_fast <- function(method = 'efron',
do_scale = TRUE,
...){

lifecycle::deprecate_warn(
when = "0.1.2",
"orsf_control_fast()",
details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = 'fast')`"
)

check_dots(list(...), orsf_control_fast)

method <- tolower(method)
Expand Down Expand Up @@ -122,8 +128,8 @@ orsf_control_cph <- function(method = 'efron',

lifecycle::deprecate_warn(
when = "0.1.2",
"orsf_control_custom()",
details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = your_function)`"
"orsf_control_cph()",
details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = 'glm')`"
)

method <- tolower(method)
Expand Down Expand Up @@ -197,8 +203,8 @@ orsf_control_net <- function(alpha = 1/2,

lifecycle::deprecate_warn(
when = "0.1.2",
"orsf_control_custom()",
details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = your_function)`"
"orsf_control_net()",
details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = 'net')`"
)

check_dots(list(...), orsf_control_net)
Expand Down Expand Up @@ -249,7 +255,7 @@ orsf_control_custom <- function(beta_fun, ...){
lifecycle::deprecate_warn(
when = "0.1.2",
"orsf_control_custom()",
details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = your_function)`"
details = "Please use the appropriate survival, classification, or regression control function instead. E.g., `orsf_control_survival(method = your_function)`, noting that your_function is a function object and not a character value"
)

check_dots(list(...), .f = orsf_control_custom)
Expand Down Expand Up @@ -374,7 +380,7 @@ orsf_control <- function(tree_type,

check_arg_is_valid(arg_value = method,
arg_name = 'method',
valid_options = c("glm", "net"))
valid_options = c("glm", "net", "fast"))

check_arg_length(arg_value = method,
arg_name = 'method',
Expand Down
32 changes: 0 additions & 32 deletions Rmd/orsf_control_cph_content_1.Rmd

This file was deleted.

16 changes: 0 additions & 16 deletions Rmd/orsf_control_cph_examples.Rmd

This file was deleted.

4 changes: 2 additions & 2 deletions Rmd/orsf_control_custom_content_1.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ f_rando <- function(x_node, y_node, w_node){
```

We can plug `f_rando` into `orsf_control_custom()`, and then pass the result into `orsf()`:
We can plug `f_rando` into `orsf_control_survival()`, and then pass the result into `orsf()`:

```{r}
library(aorsf)
fit_rando <- orsf(pbc_orsf,
Surv(time, status) ~ . - id,
control = orsf_control_custom(beta_fun = f_rando),
control = orsf_control_survival(method = f_rando),
n_tree = 500)
fit_rando
Expand Down
4 changes: 2 additions & 2 deletions Rmd/orsf_control_custom_content_2.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ f_pca <- function(x_node, y_node, w_node) {
```

Then plug the function into `orsf_control_custom()` and pass the result into `orsf()`:
Then plug the function into `orsf_control_survival()` and pass the result into `orsf()`:

```{r}
fit_pca <- orsf(pbc_orsf,
Surv(time, status) ~ . - id,
control = orsf_control_custom(beta_fun = f_pca),
control = orsf_control_survival(method = f_pca),
n_tree = 500)
```
Expand Down
8 changes: 5 additions & 3 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,16 @@ mat_list_surv <- list(pbc = pbc_mats,
seeds_standard <- 329
n_tree_test <- 5

controls <- list(
fast = orsf_control_fast(),
controls_surv <- list(
fast = orsf_control_survival(method = 'glm',
scale_x = FALSE,
max_iter = 1),
net = orsf_control_survival(method = 'net'),
custom = orsf_control_survival(method = f_pca)
)

fit_standard_pbc <- lapply(
controls,
controls_surv,
function(cntrl){
orsf(pbc,
formula = time + status ~ .,
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ test_that(
fit_dt <- orsf(as.data.table(pbc),
formula = time + status ~ .,
n_tree = n_tree_test,
control = controls$fast,
control = controls_surv$fast,
tree_seed = seeds_standard)

expect_equal_leaf_summary(fit_dt, fit_standard_pbc$fast)
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/test-orsf_formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ test_that(

fit_long <- orsf(pbc_orsf,
formula = f_long,
control = controls[[i]],
control = controls_surv[[i]],
n_tree = n_tree_test,
tree_seeds = seeds_standard)

Expand All @@ -83,13 +83,13 @@ test_that(

pbc_surv_data <- cbind(pbc_orsf, surv_object = pbc_surv)

for(i in seq_along(controls)){
for(i in seq_along(controls_surv)){

fit_surv <- orsf(
pbc_surv_data,
formula = surv_object ~ . - id - time - status,
n_tree = n_tree_test,
control = controls[[i]],
control = controls_surv[[i]],
tree_seed = seeds_standard
)

Expand Down Expand Up @@ -118,7 +118,7 @@ test_that(
# fit_status_modified <- orsf(pbc_orsf,
# time + status ~ . - id,
# n_tree = n_tree_test,
# control = controls[[j]],
# control = controls_surv[[j]],
# tree_seeds = seeds_standard)
#
# expect_equal_leaf_summary(fit_status_modified, fit_standard_pbc[[j]])
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-orsf_train.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ test_that(
'pbc_status_12',
'pbc_scaled')])){

for(j in seq_along(controls)){
for(j in seq_along(controls_surv)){

fit_untrained <- orsf(data_list_pbc[[i]],
formula = time + status ~ . - id,
n_tree = n_tree_test,
control = controls[[j]],
control = controls_surv[[j]],
tree_seed = seeds_standard,
no_fit = TRUE)

Expand Down
Loading

0 comments on commit 8ccb638

Please sign in to comment.