Skip to content

Commit

Permalink
merging new predict with old
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Sep 25, 2023
1 parent cb4476c commit 9c56aac
Show file tree
Hide file tree
Showing 14 changed files with 429 additions and 230 deletions.
9 changes: 6 additions & 3 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ orsf <- function(data,
)

if(importance %in% c("permute", "negate") && !oobag_pred){
oobag_pred <- TRUE # Should I add a warning?
# oobag_pred <- TRUE # Should I add a warning?
oobag_pred_type <- 'surv'
}

Expand Down Expand Up @@ -690,11 +690,12 @@ orsf <- function(data,
tree_seeds <- sample(x = n_tree*2, size = n_tree, replace = FALSE)

vi_max_pvalue = 0.01
tree_type_R = 3

orsf_out <- orsf_cpp(x = x_sort,
y = y_sort,
w = w_sort,
tree_type_R = 3,
tree_type_R = tree_type_R,
tree_seeds = as.integer(tree_seeds),
loaded_forest = list(),
n_tree = n_tree,
Expand Down Expand Up @@ -773,10 +774,11 @@ orsf <- function(data,
"1" = "Harrell's C-statistic",
"2" = "User-specified function")


#' @srrstats {G2.10} *drop = FALSE for type consistency*
orsf_out$pred_oobag <- orsf_out$pred_oobag[unsorted, , drop = FALSE]

orsf_out$pred_oobag[is.nan(orsf_out$pred_oobag)] <- NA_real_

}

orsf_out$pred_horizon <- oobag_pred_horizon
Expand Down Expand Up @@ -833,6 +835,7 @@ orsf <- function(data,
attr(orsf_out, 'vi_max_pvalue') <- vi_max_pvalue
attr(orsf_out, 'split_rule') <- split_rule
attr(orsf_out, 'n_thread') <- n_thread
attr(orsf_out, 'tree_type') <- tree_type_R

attr(orsf_out, 'tree_seeds') <- tree_seeds

Expand Down
1 change: 1 addition & 0 deletions R/orsf_attr.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ get_verbose_progress <- function(object) attr(object, 'verbose_progress')
get_vi_max_pvalue <- function(object) attr(object, 'vi_max_pvalue')
get_split_rule <- function(object) attr(object, 'split_rule')
get_n_thread <- function(object) attr(object, 'n_thread')
get_tree_type <- function(object) attr(object, 'tree_type')


#' ORSF status
Expand Down
78 changes: 56 additions & 22 deletions R/orsf_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ predict.orsf_fit <- function(object,
pred_type = 'risk',
na_action = 'fail',
boundary_checks = TRUE,
n_thread = 1,
...){

# catch any arguments that didn't match and got relegated to ...
Expand Down Expand Up @@ -129,22 +130,64 @@ predict.orsf_fit <- function(object,
# names_x_data = names_x_data)
# )

pred_type_cpp <- switch(
pred_type_R <- switch(
pred_type,
"risk" = "R",
"surv" = "S",
"chf" = "H",
"mort" = "M"
"risk" = 1,
"surv" = 2,
"chf" = 3,
"mort" = 4
)

out_values <-
if(pred_type_cpp == "M"){
orsf_pred_mort(object, x_new)
} else if (length(pred_horizon) == 1L) {
orsf_pred_uni(object$forest, x_new, pred_horizon_ordered, pred_type_cpp)
} else {
orsf_pred_multi(object$forest, x_new, pred_horizon_ordered, pred_type_cpp)
}
orsf_out <- orsf_cpp(x = x_new,
y = matrix(1, ncol=2),
w = rep(1, nrow(x_new)),
tree_type_R = get_tree_type(object),
tree_seeds = get_tree_seeds(object),
loaded_forest = object$forest,
n_tree = get_n_tree(object),
mtry = get_mtry(object),
vi_type_R = 0,
vi_max_pvalue = get_vi_max_pvalue(object),
lincomb_R_function = get_f_beta(object),
oobag_R_function = get_f_oobag_eval(object),
leaf_min_events = get_leaf_min_events(object),
leaf_min_obs = get_leaf_min_obs(object),
split_rule_R = switch(get_split_rule(object),
"logrank" = 1,
"cstat" = 2),
split_min_events = get_split_min_events(object),
split_min_obs = get_split_min_obs(object),
split_min_stat = get_split_min_stat(object),
split_max_cuts = get_n_split(object),
split_max_retry = get_n_retry(object),
lincomb_type_R = switch(get_orsf_type(object),
'fast' = 1,
'cph' = 1,
'random' = 2,
'net' = 3,
'custom' = 4),
lincomb_eps = get_cph_eps(object),
lincomb_iter_max = get_cph_iter_max(object),
lincomb_scale = get_cph_do_scale(object),
lincomb_alpha = get_net_alpha(object),
lincomb_df_target = get_net_df_target(object),
lincomb_ties_method = switch(
tolower(get_cph_method(object)),
'breslow' = 0,
'efron' = 1
),
pred_type_R = pred_type_R,
pred_mode = TRUE,
pred_horizon = pred_horizon_ordered,
oobag = FALSE,
oobag_eval_type_R = 0,
oobag_eval_every = get_n_tree(object),
n_thread = n_thread,
write_forest = FALSE,
run_forest = TRUE,
verbosity = 4)

out_values <- orsf_out$pred_new

if(na_action == "pass"){

Expand All @@ -164,15 +207,6 @@ predict.orsf_fit <- function(object,

}

orsf_pred_mort <- function(object, x_new){

pred_mat <- orsf_pred_multi(object$forest,
x_new = x_new,
time_vec = get_event_times(object),
pred_type = 'H')

matrix(apply(pred_mat, MARGIN = 1, FUN = sum), ncol = 1)

}


15 changes: 8 additions & 7 deletions R/orsf_vi.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,14 @@ orsf_vi_ <- function(object, group_factors, type_vi, oobag_fun = NULL){
#'
orsf_vi_oobag_ <- function(object, type_vi, oobag_fun){

if(!contains_oobag(object)){
stop("cannot compute ",
switch(type_vi, 'negate' = 'negation', 'permute' = 'permutation'),
" importance if the orsf_fit object does not have out-of-bag error",
" (see oobag_pred in ?orsf).",
call. = FALSE)
}
# can remove this b/c prediction accuracy is now computed at tree level
# if(!contains_oobag(object)){
# stop("cannot compute ",
# switch(type_vi, 'negate' = 'negation', 'permute' = 'permutation'),
# " importance if the orsf_fit object does not have out-of-bag error",
# " (see oobag_pred in ?orsf).",
# call. = FALSE)
# }

if(contains_vi(object) &&
is.null(oobag_fun) &&
Expand Down
19 changes: 10 additions & 9 deletions scratch.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@ library(tidyverse)
library(riskRegression)
library(survival)

sink("orsf-output.txt")
fit <- orsf(pbc_orsf, Surv(time, status) ~ . - id,
n_tree = 2,
n_tree = 3,
tree_seeds = 1:3,
n_thread = 1,
mtry = 2,
oobag_pred_type = 'mort',
split_rule = 'logrank',
importance = 'negate',
split_min_stat = 3,
verbose_progress = 4)
sink()
orsf_vi(fit)
oobag_pred_type = 'surv',
split_rule = 'cstat',
importance = 'none',
split_min_stat = 0.4,
verbose_progress = 1)

sink("orsf-output.txt")
prd <- predict(fit, new_data = pbc_orsf, pred_horizon = 1000, pred_type = 'risk')
sink()

library(randomForestSRC)

Expand Down
Loading

0 comments on commit 9c56aac

Please sign in to comment.