Skip to content

Commit

Permalink
make vignettes go faster
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 23, 2023
1 parent 1ab010d commit 754b431
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 158 deletions.
37 changes: 6 additions & 31 deletions vignettes/aorsf.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ library(aorsf)
set.seed(329)
orsf_fit <- orsf(data = pbc_orsf,
n_tree = 5,
formula = Surv(time, status) ~ . - id)
orsf_fit
Expand All @@ -54,7 +55,8 @@ library(dplyr)
orsf_fit <- pbc_orsf |>
select(-id) |>
orsf(formula = Surv(time, status) ~ .)
orsf(formula = Surv(time, status) ~ .,
n_tree = 5)
```

Expand Down Expand Up @@ -107,43 +109,16 @@ For more on ICE, see the [vignette](https://docs.ropensci.org/aorsf/articles/pd.
The original ORSF (i.e., `obliqueRSF`) used `glmnet` to find linear combinations of inputs. `aorsf` allows users to implement this approach using the `orsf_control_net()` function:
```{r}
```{r, eval=FALSE}
orsf_net <- orsf(data = pbc_orsf,
formula = Surv(time, status) ~ . - id,
control = orsf_control_net(),
n_tree = 50)
control = orsf_control_net())
```

`net` forests fit a lot faster than the original ORSF function in `obliqueRSF`. However, `net` forests are still much slower than `cph` ones:

```{r}
# tracking how long it takes to fit 50 glmnet trees
print(
t1 <- system.time(
orsf(data = pbc_orsf,
formula = Surv(time, status) ~ . - id,
control = orsf_control_net(),
n_tree = 50)
)
)
`net` forests fit a lot faster than the original ORSF function in `obliqueRSF`. However, `net` forests are still much slower than `cph` ones.

# and how long it takes to fit 50 cph trees
print(
t2 <- system.time(
orsf(data = pbc_orsf,
formula = Surv(time, status) ~ . - id,
control = orsf_control_cph(),
n_tree = 50)
)
)
t1['elapsed'] / t2['elapsed']
```

## aorsf and other machine learning software

Expand Down
85 changes: 28 additions & 57 deletions vignettes/fast.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,7 @@ library(aorsf)

## Go faster

Analyses can slow to a crawl when models need hours to run. In this article you will find a few tricks to prevent this bottleneck when using `orsf()`. We'll use the `flchain` data from `survival` to demonstrate.

```{r}
data("flchain", package = 'survival')
flc <- flchain
# do this to avoid orsf() throwing an error about time to event = 0
flc <- flc[flc$futime > 0, ]
# modify names
names(flc)[names(flc) == 'futime'] <- 'time'
names(flc)[names(flc) == 'death'] <- 'status'
```

Our `flc` data has `r nrow(flc)` rows and `r ncol(flc)` columns:

```{r}
head(flc)
```

Analyses can slow to a crawl when models need hours to run. In this article you will find a few tricks to prevent this bottleneck when using `orsf()`.

## Use `orsf_control_fast()`

Expand All @@ -50,13 +30,17 @@ This is the default `control` value for `orsf()` and its run-time compared to ot
```{r}
time_fast <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
control = orsf_control_fast(), n_tree = 10)
expr = orsf(pbc_orsf,
formula = time+status~. -id,
control = orsf_control_fast(),
n_tree = 5)
)
time_net <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
control = orsf_control_net(), n_tree = 10)
expr = orsf(pbc_orsf,
formula = time+status~. -id,
control = orsf_control_net(),
n_tree = 5)
)
# control_fast() is much faster
Expand All @@ -70,25 +54,12 @@ The `n_thread` argument uses multi-threading to run `aorsf` functions in paralle

```{r}
time_1_thread <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
n_thread = 1, n_tree = 500)
)
# automatically pick number of threads based on amount available
time_5_thread <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
n_thread = 5, n_tree = 500)
)
time_auto_thread <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
n_thread = 0, n_tree = 500)
)
# 5 threads and auto thread are both about 3 times faster than one thread
time_1_thread['elapsed'] / time_5_thread['elapsed']
time_1_thread['elapsed'] / time_auto_thread['elapsed']
orsf(pbc_orsf,
formula = time+status~. -id,
n_tree = 5,
n_thread = 0)
```

Expand All @@ -112,16 +83,17 @@ Applying these tips:

```{r}
time_lightweight <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
n_thread = 0, n_tree = 500, n_retry = 0,
oobag_pred_type = 'none', importance = 'none',
split_min_events = 20, leaf_min_events = 10,
split_min_stat = 10)
)
# about two times faster than auto thread with defaults
time_auto_thread['elapsed'] / time_lightweight['elapsed']
orsf(pbc_orsf,
formula = time+status~.,
na_action = 'na_impute_meanmode',
n_thread = 0,
n_tree = 5,
n_retry = 0,
oobag_pred_type = 'none',
importance = 'none',
split_min_events = 20,
leaf_min_events = 10,
split_min_stat = 10)
```

Expand All @@ -133,10 +105,9 @@ Setting `verbose_progress = TRUE` doesn't make anything run faster, but it can h

```{r}
verbose_fit <- orsf(flc, time+status~.,
na_action = 'na_impute_meanmode',
n_thread = 0,
n_tree = 500,
verbose_fit <- orsf(pbc_orsf,
formula = time+status~. -id,
n_tree = 5,
verbose_progress = TRUE)
```
Expand Down
85 changes: 15 additions & 70 deletions vignettes/oobag.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Let's fit an oblique random survival forest and plot the distribution of the ens
fit <- orsf(data = pbc_orsf,
formula = Surv(time, status) ~ . - id,
oobag_pred_type = 'surv',
n_tree = 5,
oobag_pred_horizon = 2000)
hist(fit$pred_oobag,
Expand Down Expand Up @@ -68,22 +69,25 @@ As each out-of-bag data set contains about one-third of the training set, the ou
fit <- orsf(data = pbc_orsf,
formula = Surv(time, status) ~ . - id,
n_tree = 50,
n_tree = 20,
tree_seeds = 2,
oobag_pred_type = 'surv',
oobag_pred_horizon = 2000,
oobag_eval_every = 1)
plot(
x = seq(1, 50, by = 1),
x = seq(1, 20, by = 1),
y = fit$eval_oobag$stat_values,
main = 'Out-of-bag C-statistic computed after each new tree is grown.',
xlab = 'Number of trees grown',
ylab = fit$eval_oobag$stat_type
)
lines(x=seq(1, 20), y = fit$eval_oobag$stat_values)
```

In general, at least 500 trees are recommended for a random forest fit. We're just using 50 in this case for better illustration of the out-of-bag error curve. Also, it helps to make run-times low whenever I need to re-compile the package vignettes.
In general, at least 500 trees are recommended for a random forest fit. We're just using 10 for illustration.

## User-supplied out-of-bag evaluation functions

Expand Down Expand Up @@ -121,52 +125,22 @@ Second, you can pass your function into `orsf()`, and it will be used in place o
fit <- orsf(data = pbc_orsf,
formula = Surv(time, status) ~ . - id,
n_tree = 50,
n_tree = 20,
tree_seeds = 2,
oobag_pred_horizon = 2000,
oobag_fun = oobag_fun_brier,
oobag_eval_every = 1)
plot(
x = seq(1, 50, by = 1),
x = seq(1, 20, by = 1),
y = fit$eval_oobag$stat_values,
main = 'Out-of-bag error computed after each new tree is grown.',
sub = 'For the Brier score, lower values indicate more accurate predictions',
xlab = 'Number of trees grown',
ylab = "Brier score"
)
```

We can also compute a time-dependent C-statistic instead of Harrell's C-statistic (the default oob function):

```{r}
oobag_fun_tdep_cstat <- function(y_mat, w_vec, s_vec){
as.numeric(
SurvMetrics::Cindex(
object = Surv(time = y_mat[, 1], event = y_mat[, 2]),
predicted = s_vec,
t_star = 2000
)
)
}
fit <- orsf(data = pbc_orsf,
formula = Surv(time, status) ~ . - id,
n_tree = 50,
oobag_pred_horizon = 2000,
oobag_fun = oobag_fun_tdep_cstat,
oobag_eval_every = 1)
plot(
x = seq(50),
y = fit$eval_oobag$stat_values,
main = 'Out-of-bag time-dependent AUC\ncomputed after each new tree is grown.',
xlab = 'Number of trees grown',
ylab = "AUC at t = 2,000"
)
lines(x=seq(1, 20), y = fit$eval_oobag$stat_values)
```

Expand All @@ -193,11 +167,11 @@ y_mat <- cbind(time = test_time, status = test_status)
s_vec <- seq(0.9, 0.1, length.out = 100)
# see 1 in the checklist above
names(formals(oobag_fun_tdep_cstat)) == c("y_mat", "w_vec", "s_vec")
names(formals(oobag_fun_brier)) == c("y_mat", "w_vec", "s_vec")
test_output <- oobag_fun_tdep_cstat(y_mat = y_mat,
w_vec = w_vec,
s_vec = s_vec)
test_output <- oobag_fun_brier(y_mat = y_mat,
w_vec = w_vec,
s_vec = s_vec)
# test output should be numeric
is.numeric(test_output)
Expand All @@ -206,35 +180,6 @@ length(test_output) == 1
```

## User-supplied functions for negation importance.

Negation importance is based on the out-of-bag error, so of course you may be curious about what negation importance would be if it were computed using different statistics. The workflow for doing this is exactly the same as the example above, except for two things:

1. We have to specify `importance = 'negate'` when we fit our model.

2. We want to use a modified version of the C-stat, specifically 1 - the C-stat, because of how `aorsf` computes variable importance.

```{r}
oobag_fun_tdep_cstat_inverse <- function(y_mat, w_vec, s_vec){
1 - oobag_fun_tdep_cstat(y_mat, w_vec, s_vec)
}
```

Also, to speed up computations, I am not going to monitor out-of-bag error here.

```{r}
fit_tdep_cstat <- orsf(data = pbc_orsf,
formula = Surv(time, status) ~ . - id,
n_tree = 100,
oobag_pred_horizon = 2000,
oobag_fun = oobag_fun_tdep_cstat_inverse,
importance = 'negate')
fit_tdep_cstat$importance
```

## Notes

When evaluating out-of-bag error:
Expand Down

0 comments on commit 754b431

Please sign in to comment.