Skip to content

Commit

Permalink
Merge pull request #25 from ropensci/issue24
Browse files Browse the repository at this point in the history
Issue24
  • Loading branch information
bcjaeger authored Oct 15, 2023
2 parents a0eb263 + 7090100 commit 54da1a3
Show file tree
Hide file tree
Showing 22 changed files with 653 additions and 361 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: aorsf
Title: Accelerated Oblique Random Survival Forests
Version: 0.1.1.9001
Version: 0.1.1.9002
Authors@R: c(
person(given = "Byron",
family = "Jaeger",
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ S3method(predict,orsf_fit)
S3method(print,orsf_fit)
S3method(print,orsf_summary_uni)
export(orsf)
export(orsf_control_classification)
export(orsf_control_cph)
export(orsf_control_custom)
export(orsf_control_fast)
export(orsf_control_net)
export(orsf_control_regression)
export(orsf_control_survival)
export(orsf_ice_inb)
export(orsf_ice_new)
export(orsf_ice_oob)
Expand Down
160 changes: 65 additions & 95 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -415,51 +415,6 @@ orsf <- function(data,
)
}

orsf_type <- attr(control, 'type')

switch(
orsf_type,

'fast' = {

control_net <- orsf_control_net()
control_cph <- control
f_beta <- function(x) x

},

'cph' = {

control_net <- orsf_control_net()
control_cph <- control
f_beta <- function(x) x

},

'net' = {

if (!requireNamespace("glmnet", quietly = TRUE)) {
stop(
"Package \"glmnet\" must be installed to use",
" orsf_control_net() with orsf().",
call. = FALSE
)
}

control_net <- control
control_cph <- orsf_control_fast(do_scale = FALSE)
f_beta <- penalized_cph
},

"custom" = {

control_net <- orsf_control_net()
control_cph <- orsf_control_fast(do_scale = FALSE)
f_beta <- control$beta_fun

}

)

if(is.null(oobag_fun)){

Expand All @@ -484,13 +439,6 @@ orsf <- function(data,
if(oobag_pred_type == 'leaf') type_oobag_eval <- 'none'


cph_method <- control_cph$cph_method
cph_eps <- control_cph$cph_eps
cph_iter_max <- control_cph$cph_iter_max
cph_do_scale <- control_cph$cph_do_scale
net_alpha <- control_net$net_alpha
net_df_target <- control_net$net_df_target

formula_terms <- suppressWarnings(stats::terms(formula, data=data))

if(attr(formula_terms, 'response') == 0)
Expand All @@ -502,6 +450,28 @@ orsf <- function(data,

if(outcome_type %in% c('regression', 'classification')) stop("not ready yet")

if(control$tree_type == 'unknown'){

if(outcome_type == 'unknown'){
stop("could not determine outcome type", call. = FALSE)
}

control$tree_type <- outcome_type

}

if(control$lincomb_type == 'net'){

if (!requireNamespace("glmnet", quietly = TRUE)) {
stop(
"Package \"glmnet\" must be installed to use",
" orsf_control_net() with orsf().",
call. = FALSE
)
}

}

tree_type_R = switch(outcome_type,
'classification' = 1,
'regression'= 2,
Expand Down Expand Up @@ -660,14 +630,6 @@ orsf <- function(data,

if(is.null(mtry)) mtry <- ceiling(sqrt(ncol(x)))

if(is.null(net_df_target)) net_df_target <- mtry

# warn instead?
if(net_df_target > mtry)
stop("net_df_target = ", net_df_target,
" must be <= mtry, which is ", mtry,
call. = FALSE)

n_events <- collapse::fsum(y[, 2])

# some additional checks that are dependent on the outcome variable
Expand All @@ -679,6 +641,22 @@ orsf <- function(data,
append_to_msg = "(number of columns in the one-hot encoded x-matrix)"
)

if(is.null(control$lincomb_df_target)){

control$lincomb_df_target <- mtry

} else {

check_arg_lteq(
arg_value = control$lincomb_df_target,
arg_name = 'df_target',
bound = mtry,
append_to_msg = "(number of randomly selected predictors)"
)

}


check_arg_lteq(
arg_value = leaf_min_events,
arg_name = 'leaf_min_events',
Expand Down Expand Up @@ -718,8 +696,8 @@ orsf <- function(data,
}

sorted <-
collapse::radixorder(y[, 1], # order this way for risk sets
-y[, 2]) # order this way for oob C-statistic.
collapse::radixorder(y[, 1], # order for risk sets
-y[, 2]) # order for oob C-statistic.

if(is.null(weights)) weights <- rep(1, nrow(x))

Expand Down Expand Up @@ -752,8 +730,6 @@ orsf <- function(data,
"permute" = 2,
"anova" = 3),
vi_max_pvalue = vi_max_pvalue,
lincomb_R_function = f_beta,
oobag_R_function = f_oobag_eval,
leaf_min_events = leaf_min_events,
leaf_min_obs = leaf_min_obs,
split_rule_R = switch(split_rule,
Expand All @@ -764,18 +740,18 @@ orsf <- function(data,
split_min_stat = split_min_stat,
split_max_cuts = n_split,
split_max_retry = n_retry,
lincomb_type_R = switch(orsf_type,
'fast' = 1,
'cph' = 1,
lincomb_R_function = control$lincomb_R_function,
lincomb_type_R = switch(control$lincomb_type,
'glm' = 1,
'random' = 2,
'net' = 3,
'custom' = 4),
lincomb_eps = cph_eps,
lincomb_iter_max = cph_iter_max,
lincomb_scale = cph_do_scale,
lincomb_alpha = net_alpha,
lincomb_df_target = net_df_target,
lincomb_ties_method = switch(tolower(cph_method),
lincomb_eps = control$lincomb_eps,
lincomb_iter_max = control$lincomb_iter_max,
lincomb_scale = control$lincomb_scale,
lincomb_alpha = control$lincomb_alpha,
lincomb_df_target = control$lincomb_df_target,
lincomb_ties_method = switch(tolower(control$lincomb_ties_method),
'breslow' = 0,
'efron' = 1),
pred_type_R = switch(oobag_pred_type,
Expand All @@ -789,6 +765,7 @@ orsf <- function(data,
pred_aggregate = oobag_pred_type != 'leaf',
pred_horizon = oobag_pred_horizon,
oobag = oobag_pred,
oobag_R_function = f_oobag_eval,
oobag_eval_type_R = switch(type_oobag_eval,
'none' = 0,
'cstat' = 1,
Expand Down Expand Up @@ -877,20 +854,13 @@ orsf <- function(data,
attr(orsf_out, 'split_min_obs') <- split_min_obs
attr(orsf_out, 'split_min_stat') <- split_min_stat
attr(orsf_out, 'na_action') <- na_action
attr(orsf_out, 'cph_method') <- cph_method
attr(orsf_out, 'cph_eps') <- cph_eps
attr(orsf_out, 'cph_iter_max') <- cph_iter_max
attr(orsf_out, 'cph_do_scale') <- cph_do_scale
attr(orsf_out, 'net_alpha') <- net_alpha
attr(orsf_out, 'net_df_target') <- net_df_target
attr(orsf_out, 'control') <- control
attr(orsf_out, 'numeric_bounds') <- numeric_bounds
attr(orsf_out, 'means') <- means
attr(orsf_out, 'modes') <- modes
attr(orsf_out, 'standard_deviations') <- standard_deviations
attr(orsf_out, 'trained') <- !no_fit
attr(orsf_out, 'n_retry') <- n_retry
attr(orsf_out, 'orsf_type') <- orsf_type
attr(orsf_out, 'f_beta') <- f_beta
attr(orsf_out, 'f_oobag_eval') <- f_oobag_eval
attr(orsf_out, 'type_oobag_eval') <- type_oobag_eval
attr(orsf_out, 'oobag_pred') <- oobag_pred
Expand Down Expand Up @@ -1119,6 +1089,8 @@ orsf_train_ <- function(object,

oobag_eval_every <- min(n_tree, get_oobag_eval_every(object))

control <- get_control(object)

orsf_out <- orsf_cpp(x = x_sort,
y = y_sort,
w = w_sort,
Expand All @@ -1135,7 +1107,6 @@ orsf_train_ <- function(object,
"permute" = 2,
"anova" = 3),
vi_max_pvalue = get_vi_max_pvalue(object),
lincomb_R_function = get_f_beta(object),
oobag_R_function = get_f_oobag_eval(object),
leaf_min_events = get_leaf_min_events(object),
leaf_min_obs = get_leaf_min_obs(object),
Expand All @@ -1147,22 +1118,21 @@ orsf_train_ <- function(object,
split_min_stat = get_split_min_stat(object),
split_max_cuts = get_n_split(object),
split_max_retry = get_n_retry(object),
lincomb_type_R = switch(get_orsf_type(object),
'fast' = 1,
'cph' = 1,

lincomb_R_function = control$lincomb_R_function,
lincomb_type_R = switch(control$lincomb_type,
'glm' = 1,
'random' = 2,
'net' = 3,
'custom' = 4),
lincomb_eps = get_cph_eps(object),
lincomb_iter_max = get_cph_iter_max(object),
lincomb_scale = get_cph_do_scale(object),
lincomb_alpha = get_net_alpha(object),
lincomb_df_target = get_net_df_target(object),
lincomb_ties_method = switch(
tolower(get_cph_method(object)),
'breslow' = 0,
'efron' = 1
),
lincomb_eps = control$lincomb_eps,
lincomb_iter_max = control$lincomb_iter_max,
lincomb_scale = control$lincomb_scale,
lincomb_alpha = control$lincomb_alpha,
lincomb_df_target = control$lincomb_df_target,
lincomb_ties_method = switch(tolower(control$lincomb_ties_method),
'breslow' = 0,
'efron' = 1),
pred_type_R = switch(get_oobag_pred_type(object),
"none" = 0,
"risk" = 1,
Expand Down
8 changes: 0 additions & 8 deletions R/orsf_attr.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ get_split_min_events <- function(object) attr(object, 'split_min_events')
get_split_min_obs <- function(object) attr(object, 'split_min_obs')
get_split_min_stat <- function(object) attr(object, 'split_min_stat')
get_na_action <- function(object) attr(object, 'na_action')
get_cph_method <- function(object) attr(object, 'cph_method')
get_cph_eps <- function(object) attr(object, 'cph_eps')
get_cph_iter_max <- function(object) attr(object, 'cph_iter_max')
get_cph_do_scale <- function(object) attr(object, 'cph_do_scale')
get_net_alpha <- function(object) attr(object, 'net_alpha')
get_net_df_target <- function(object) attr(object, 'net_df_target')
get_numeric_bounds <- function(object) attr(object, 'numeric_bounds')
get_means <- function(object) attr(object, 'means')
get_modes <- function(object) attr(object, 'modes')
Expand All @@ -49,8 +43,6 @@ get_oobag_eval_every <- function(object) attr(object, 'oobag_eval_every')
get_importance <- function(object) attr(object, 'importance')
get_importance_values <- function(object) attr(object, 'importance_values')
get_group_factors <- function(object) attr(object, 'group_factors')
get_f_beta <- function(object) attr(object, 'f_beta')
get_orsf_type <- function(object) attr(object, 'orsf_type')
get_f_oobag_eval <- function(object) attr(object, 'f_oobag_eval')
get_type_oobag_eval <- function(object) attr(object, 'type_oobag_eval')
get_tree_seeds <- function(object) attr(object, 'tree_seeds')
Expand Down
Loading

0 comments on commit 54da1a3

Please sign in to comment.