Skip to content

Commit

Permalink
ready to merge
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 1, 2023
1 parent 79d0727 commit 1d6efed
Show file tree
Hide file tree
Showing 26 changed files with 551 additions and 13,293 deletions.
12 changes: 10 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@ compute_cstat_exported_uvec <- function(y, w, g, pred_is_risklike) {
.Call(`_aorsf_compute_cstat_exported_uvec`, y, w, g, pred_is_risklike)
}

orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity) {
.Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity)
compute_logrank_exported <- function(y, w, g) {
.Call(`_aorsf_compute_logrank_exported`, y, w, g)
}

cph_scale <- function(x, w) {
.Call(`_aorsf_cph_scale`, x, w)
}

orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, sample_with_replacement, sample_fraction, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity) {
.Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, sample_with_replacement, sample_fraction, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity)
}

19 changes: 17 additions & 2 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ orsf <- function(data,
n_retry = 3,
n_thread = 1, # TODO: add docs+checks
mtry = NULL,
sample_with_replacement = TRUE, # TODO: add docs+checks
sample_fraction = 0.632, # TODO: add docs+checks
leaf_min_events = 1,
leaf_min_obs = 5,
split_rule = 'logrank', # TODO: add docs+checks
Expand Down Expand Up @@ -372,6 +374,14 @@ orsf <- function(data,

oobag_pred <- oobag_pred_type != 'none'

if(sample_fraction == 1 && oobag_pred){
stop(
"cannot compute out-of-bag predictions if no samples are out-of-bag.",
"To resolve this, set sample_fraction < 1 or oobag_pred_type = 'none'.",
call. = FALSE
)
}

orsf_type <- attr(control, 'type')

switch(
Expand Down Expand Up @@ -710,6 +720,8 @@ orsf <- function(data,
loaded_forest = list(),
n_tree = n_tree,
mtry = mtry,
sample_with_replacement = sample_with_replacement,
sample_fraction = sample_fraction,
vi_type_R = switch(importance,
"none" = 0,
"negate" = 1,
Expand Down Expand Up @@ -863,8 +875,9 @@ orsf <- function(data,
attr(orsf_out, 'split_rule') <- split_rule
attr(orsf_out, 'n_thread') <- n_thread
attr(orsf_out, 'tree_type') <- tree_type_R

attr(orsf_out, 'tree_seeds') <- tree_seeds
attr(orsf_out, 'tree_seeds') <- tree_seeds
attr(orsf_out, 'sample_with_replacement') <- sample_with_replacement
attr(orsf_out, 'sample_fraction') <- sample_fraction

#' @srrstats {ML5.0a} *orsf output has its own class*
class(orsf_out) <- "orsf_fit"
Expand Down Expand Up @@ -1084,6 +1097,8 @@ orsf_train_ <- function(object,
loaded_forest = list(),
n_tree = n_tree,
mtry = get_mtry(object),
sample_with_replacement = get_sample_with_replacement(object),
sample_fraction = get_sample_fraction(object),
vi_type_R = switch(get_importance(object),
"none" = 0,
"negate" = 1,
Expand Down
2 changes: 2 additions & 0 deletions R/orsf_attr.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ get_vi_max_pvalue <- function(object) attr(object, 'vi_max_pvalue')
get_split_rule <- function(object) attr(object, 'split_rule')
get_n_thread <- function(object) attr(object, 'n_thread')
get_tree_type <- function(object) attr(object, 'tree_type')
get_sample_with_replacement <- function(object) attr(object, 'sample_with_replacement')
get_sample_fraction <- function(object) attr(object, 'sample_fraction')


#' ORSF status
Expand Down
Loading

0 comments on commit 1d6efed

Please sign in to comment.