Skip to content

Commit

Permalink
Merge pull request #31 from ropensci/pd-threads
Browse files Browse the repository at this point in the history
Pd threads
  • Loading branch information
bcjaeger authored Oct 30, 2023
2 parents 4ba36ae + be854f3 commit af8de1b
Show file tree
Hide file tree
Showing 10 changed files with 446 additions and 62 deletions.
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ x_submat_mult_beta_exported <- function(x, y, w, x_rows, x_cols, beta) {
.Call(`_aorsf_x_submat_mult_beta_exported`, x, y, w, x_rows, x_cols, beta)
}

x_submat_mult_beta_pd_exported <- function(x, y, w, x_rows, x_cols, beta, pd_x_vals, pd_x_cols) {
.Call(`_aorsf_x_submat_mult_beta_pd_exported`, x, y, w, x_rows, x_cols, beta, pd_x_vals, pd_x_cols)
}

scale_x_exported <- function(x, w) {
.Call(`_aorsf_scale_x_exported`, x, w)
}
Expand Down
168 changes: 168 additions & 0 deletions R/infer.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@

#' null operator (copied from rlang)

`%||%` <- function (x, y) {
if (is.null(x))
y
else x
}

#' helper for guessing pred_horizon input
#'
Expand Down Expand Up @@ -73,3 +80,164 @@ infer_outcome_type <- function(names_y_data, data){
stop("could not infer outcome type", call. = FALSE)

}


infer_orsf_args <- function(x,
y = matrix(1, ncol=2),
w = rep(1, nrow(x)),
...,
object = NULL){

.dots <- list(...)

control <- .dots$control %||%
get_control(object) %||%
orsf_control_fast()

n_tree = .dots$n_tree %||%
get_n_tree(object) %||%
500L

tree_type = .dots$tree_type %||%
get_tree_type(object) %||%
'survival'

split_rule <- .dots$split_rule %||%
get_split_rule(object) %||%
'logrank'

split_min_stat <- .dots$split_min_stat %||%
get_split_min_stat(object) %||%
switch(split_rule, "logrank" = 3.841459, "cstat" = 0.50)

mtry <- .dots$mtry %||%
get_mtry(object) %||%
ceiling(sqrt(ncol(x)))

oobag_pred_type <- .dots$pred_type %||%
get_oobag_pred_type(object) %||%
"surv"

oobag_pred <- .dots$oobag_pred %||%
get_oobag_pred(object) %||%
(oobag_pred_type != 'none')


pred_horizon <- .dots$pred_horizon %||%
get_oobag_pred_horizon(object) %||%
if(tree_type == 'survival') stats::median(y[, 1]) else 1

oobag_eval_type <- 'none'

if(oobag_pred){

oobag_eval_type <- .dots$oobag_eval_type %||%
get_oobag_eval_type(object) %||%
"cstat"

}

vi_type <- .dots$vi_type %||%
get_importance(object) %||%
"none"

pd_type <- .dots$pd_type %||% 'none'

list(
x = x,
y = y,
w = w,
tree_type_R = switch(tree_type,
'classification' = 1,
'regression'= 2,
'survival' = 3),
tree_seeds = .dots$tree_seeds %||%
get_tree_seeds(object) %||%
329,
loaded_forest = object$forest %||% list(),
n_tree = n_tree,
mtry = mtry,
sample_with_replacement = .dots$sample_with_replacement %||%
get_sample_with_replacement(object) %||%
TRUE,
sample_fraction = .dots$sample_fraction %||%
get_sample_fraction(object) %||%
0.632,
vi_type_R = switch(vi_type,
"none" = 0,
"negate" = 1,
"permute" = 2,
"anova" = 3),
vi_max_pvalue = .dots$vi_max_pvalue %||%
get_vi_max_pvalue(object) %||%
0.01,
leaf_min_events = .dots$leaf_min_events %||%
get_leaf_min_events(object) %||%
1,
leaf_min_obs = .dots$leaf_min_obs %||%
get_leaf_min_obs(object) %||%
5,
split_rule_R = switch(split_rule, "logrank" = 1, "cstat" = 2),
split_min_events = .dots$split_min_event %||%
get_split_min_events(object) %||%
5,
split_min_obs = .dots$split_min_obs %||%
get_split_min_obs(object) %||%
10,
split_min_stat = .dots$split_min_stat %||%
get_split_min_stat(object) %||%
NA_real_,
split_max_cuts = .dots$split_max_cuts %||%
get_n_split(object) %||%
5,
split_max_retry = .dots$split_max_retry %||%
get_n_retry(object) %||%
3,
lincomb_R_function = control$lincomb_R_function,
lincomb_type_R = switch(control$lincomb_type,
'glm' = 1,
'random' = 2,
'net' = 3,
'custom' = 4),
lincomb_eps = control$lincomb_eps,
lincomb_iter_max = control$lincomb_iter_max,
lincomb_scale = control$lincomb_scale,
lincomb_alpha = control$lincomb_alpha,
lincomb_df_target = control$lincomb_df_target %||% mtry,
lincomb_ties_method = switch(tolower(control$lincomb_ties_method),
'breslow' = 0,
'efron' = 1),
pred_type_R = switch(oobag_pred_type,
"none" = 0,
"risk" = 1,
"surv" = 2,
"chf" = 3,
"mort" = 4,
"leaf" = 8),
pred_mode = .dots$pred_mode %||% FALSE,
pred_aggregate = .dots$pred_aggregate %||% (oobag_pred_type != 'leaf'),
pred_horizon = pred_horizon,
oobag = oobag_pred,
oobag_R_function = .dots$oobag_R_function %||%
get_f_oobag_eval(object) %||%
function(x) x,
oobag_eval_type_R = switch(oobag_eval_type,
'none' = 0,
'cstat' = 1,
'user' = 2),
oobag_eval_every = .dots$oobag_eval_every %||%
get_oobag_eval_every(object) %||%
n_tree,
pd_type_R = switch(pd_type, "none" = 0L, "smry" = 1L, "ice" = 2L),
pd_x_vals = .dots$pd_x_vals %||% list(matrix(0, ncol=0, nrow=0)),
pd_x_cols = .dots$pd_x_cols %||% list(matrix(0, ncol=0, nrow=0)),
pd_probs = .dots$pd_probs %||% 0,
n_thread = .dots$n_thread %||% get_n_thread(object) %||% 1,
write_forest = .dots$write_forest %||% TRUE,
run_forest = .dots$run_forest %||% TRUE,
verbosity = .dots$verbosity %||%
get_verbose_progress(object) %||%
FALSE
)

}
2 changes: 1 addition & 1 deletion R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ orsf_train_ <- function(object,
pred_aggregate = get_oobag_pred_type(object) != 'leaf',
pred_horizon = get_oobag_pred_horizon(object),
oobag = get_oobag_pred(object),
oobag_eval_type_R = switch(get_type_oobag_eval(object),
oobag_eval_type_R = switch(get_oobag_eval_type(object),
'none' = 0,
'cstat' = 1,
'user' = 2),
Expand Down
3 changes: 1 addition & 2 deletions R/orsf_attr.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ get_modes <- function(object) attr(object, 'modes')
get_standard_deviations<- function(object) attr(object, 'standard_deviations')
get_n_retry <- function(object) attr(object, 'n_retry')
get_f_oobag_eval <- function(object) attr(object, 'f_oobag_eval')
get_type_oobag_eval <- function(object) attr(object, 'type_oobag_eval')
get_oobag_fun <- function(object) attr(object, 'oobag_fun')
get_oobag_pred <- function(object) attr(object, 'oobag_pred')
get_oobag_pred_type <- function(object) attr(object, 'oobag_pred_type')
Expand All @@ -44,7 +43,7 @@ get_importance <- function(object) attr(object, 'importance')
get_importance_values <- function(object) attr(object, 'importance_values')
get_group_factors <- function(object) attr(object, 'group_factors')
get_f_oobag_eval <- function(object) attr(object, 'f_oobag_eval')
get_type_oobag_eval <- function(object) attr(object, 'type_oobag_eval')
get_oobag_eval_type <- function(object) attr(object, 'type_oobag_eval')
get_tree_seeds <- function(object) attr(object, 'tree_seeds')
get_weights_user <- function(object) attr(object, 'weights_user')
get_event_times <- function(object) attr(object, 'event_times')
Expand Down
1 change: 1 addition & 0 deletions R/orsf_pd.R
Original file line number Diff line number Diff line change
Expand Up @@ -726,3 +726,4 @@ pd_list_split <- function(x_vals, x_cols){
)

}

Loading

0 comments on commit af8de1b

Please sign in to comment.