Skip to content

Commit

Permalink
oob functions allowed
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Sep 23, 2023
1 parent eeeb141 commit b9453e0
Show file tree
Hide file tree
Showing 16 changed files with 282 additions and 229 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ compute_cstat_exported_uvec <- function(y, w, g, pred_is_risklike) {
.Call(`_aorsf_compute_cstat_exported_uvec`, y, w, g, pred_is_risklike)
}

orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, oobag, oobag_eval_type_R, oobag_eval_every, n_thread, write_forest) {
.Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, oobag, oobag_eval_type_R, oobag_eval_every, n_thread, write_forest)
orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, oobag, oobag_eval_type_R, oobag_eval_every, n_thread, write_forest, run_forest) {
.Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, oobag, oobag_eval_type_R, oobag_eval_every, n_thread, write_forest, run_forest)
}

22 changes: 16 additions & 6 deletions R/check.R
Original file line number Diff line number Diff line change
Expand Up @@ -1616,8 +1616,8 @@ check_oobag_fun <- function(oobag_fun){

oobag_fun_args <- names(formals(oobag_fun))

if(length(oobag_fun_args) != 2) stop(
"oobag_fun should have 2 input arguments but instead has ",
if(length(oobag_fun_args) != 3) stop(
"oobag_fun should have 3 input arguments but instead has ",
length(oobag_fun_args),
call. = FALSE
)
Expand All @@ -1628,8 +1628,14 @@ check_oobag_fun <- function(oobag_fun){
call. = FALSE
)

if(oobag_fun_args[2] != 's_vec') stop(
"the second input argument of oobag_fun should be named 's_vec' ",
if(oobag_fun_args[2] != 'w_vec') stop(
"the second input argument of oobag_fun should be named 'w_vec' ",
"but is instead named '", oobag_fun_args[1], "'",
call. = FALSE
)

if(oobag_fun_args[3] != 's_vec') stop(
"the third input argument of oobag_fun should be named 's_vec' ",
"but is instead named '", oobag_fun_args[2], "'",
call. = FALSE
)
Expand All @@ -1638,9 +1644,12 @@ check_oobag_fun <- function(oobag_fun){
test_status <- rep(c(0,1), each = 50)

.y_mat <- cbind(time = test_time, status = test_status)
.w_vec <- rep(1, times = 100)
.s_vec <- seq(0.9, 0.1, length.out = 100)

test_output <- try(oobag_fun(y_mat = .y_mat, s_vec = .s_vec),
test_output <- try(oobag_fun(y_mat = .y_mat,
w_vec = .w_vec,
s_vec = .s_vec),
silent = FALSE)

if(is_error(test_output)){
Expand All @@ -1650,8 +1659,9 @@ check_oobag_fun <- function(oobag_fun){
"test_time <- seq(from = 1, to = 5, length.out = 100)\n",
"test_status <- rep(c(0,1), each = 50)\n\n",
"y_mat <- cbind(time = test_time, status = test_status)\n",
"w_vec <- rep(1, times = 100)\n",
"s_vec <- seq(0.9, 0.1, length.out = 100)\n\n",
"test_output <- oobag_fun(y_mat = y_mat, s_vec = s_vec)\n\n",
"test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n",
"test_output should be a numeric value of length 1",
call. = FALSE)

Expand Down
49 changes: 6 additions & 43 deletions R/oobag_c_harrell.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,13 @@
#' @noRd
#'

oobag_c_harrell <- function(y_mat, s_vec){
oobag_c_survival <- function(y_mat, w_vec, s_vec){

sorted <- order(y_mat[, 1], -y_mat[, 2])
survival::concordancefit(
y = survival::Surv(y_mat),
x = s_vec
)$concordance

y_mat <- y_mat[sorted, ]
s_vec <- s_vec[sorted]

time = y_mat[, 1]
status = y_mat[, 2]
events = which(status == 1)

k = nrow(y_mat)

total <- 0
concordant <- 0

for(i in events){

if(i+1 <= k){

for(j in seq(i+1, k)){

if(time[j] > time[i]){

total <- total + 1

if(s_vec[j] > s_vec[i]){

concordant <- concordant + 1

} else if (s_vec[j] == s_vec[i]){

concordant <- concordant + 0.5

}

}

}

}

}
}

concordant / total

}
39 changes: 20 additions & 19 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -678,10 +678,11 @@ orsf <- function(data,
collapse::radixorder(y[, 1], # order this way for risk sets
-y[, 2]) # order this way for oob C-statistic.

if(is.null(weights)) weights <- rep(1, nrow(x))

x_sort <- x[sorted, , drop = FALSE]
y_sort <- y[sorted, , drop = FALSE]

if(is.null(weights)) weights <- rep(1, nrow(x))
w_sort <- weights[sorted]

if(length(tree_seeds) == 1) set.seed(tree_seeds)

Expand All @@ -690,13 +691,13 @@ orsf <- function(data,

vi_max_pvalue = 0.01

orsf_out <- orsf_cpp(x = x,
y = y,
w = weights,
orsf_out <- orsf_cpp(x = x_sort,
y = y_sort,
w = w_sort,
tree_type_R = 3,
tree_seeds = as.integer(tree_seeds),
loaded_forest = list(),
n_tree = if(no_fit) 0 else n_tree,
n_tree = n_tree,
mtry = mtry,
vi_type_R = switch(importance,
"none" = 0,
Expand Down Expand Up @@ -745,21 +746,20 @@ orsf <- function(data,
'user' = 2),
oobag_eval_every = oobag_eval_every,
n_thread = n_thread,
write_forest = TRUE)

# browser()
write_forest = TRUE,
run_forest = !no_fit)

# if someone says no_fit and also says don't attach the data,
# give them a warning but also do the right thing for them.
orsf_out$data <- if(attach_data) data else NULL

if(importance != 'none'){
if(importance != 'none' && !no_fit){
rownames(orsf_out$importance) <- colnames(x)
orsf_out$importance <-
rev(orsf_out$importance[order(orsf_out$importance), , drop=TRUE])
}

if(oobag_pred){
if(oobag_pred && !no_fit){

# put the oob predictions into the same order as the training data.
unsorted <- collapse::radixorder(sorted)
Expand Down Expand Up @@ -833,7 +833,7 @@ orsf <- function(data,
attr(orsf_out, 'split_rule') <- split_rule
attr(orsf_out, 'n_thread') <- n_thread

attr(orsf_out, 'tree_seeds') <- if(is.null(tree_seeds)) c() else tree_seeds
attr(orsf_out, 'tree_seeds') <- tree_seeds

#' @srrstats {ML5.0a} *orsf output has its own class*
class(orsf_out) <- "orsf_fit"
Expand Down Expand Up @@ -1037,17 +1037,17 @@ orsf_train_ <- function(object,
-y[, 2]) # order this way for oob C-statistic.
}

weights <- get_weights_user(object)

x_sort <- x[sorted, ]
y_sort <- y[sorted, ]
w_sort <- weights[sorted]

oobag_eval_every <- min(n_tree, get_oobag_eval_every(object))

weights <- get_weights_user(object)

orsf_out <- orsf_cpp(x = x,
y = y,
w = weights,
orsf_out <- orsf_cpp(x = x_sort,
y = y_sort,
w = w_sort,
tree_type_R = 3,
tree_seeds = get_tree_seeds(object),
loaded_forest = list(),
Expand Down Expand Up @@ -1100,9 +1100,10 @@ orsf_train_ <- function(object,
'none' = 0,
'cstat' = 1,
'user' = 2),
oobag_eval_every = get_oobag_eval_every(object),
oobag_eval_every = oobag_eval_every,
n_thread = get_n_thread(object),
write_forest = TRUE)
write_forest = TRUE,
run_forest = TRUE)


object$pred_oobag <- orsf_out$pred_oobag
Expand Down
2 changes: 1 addition & 1 deletion R/orsf_attr.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ is_trained <- function(object) attr(object, 'trained')
#'
#' @noRd
#'
contains_oobag <- function(object) {!is_empty(object$pred_oobag)}
contains_oobag <- function(object) {!is_empty(object$eval_oobag$stat_values)}

#' Determine whether object has variable importance estimates
#'
Expand Down
Loading

0 comments on commit b9453e0

Please sign in to comment.