Skip to content

Commit

Permalink
solid progress but need to fix the data restore mechanics
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 1, 2023
1 parent ba30364 commit ea632ac
Show file tree
Hide file tree
Showing 16 changed files with 7,286 additions and 44,590 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ compute_cstat_exported_uvec <- function(y, w, g, pred_is_risklike) {
.Call(`_aorsf_compute_cstat_exported_uvec`, y, w, g, pred_is_risklike)
}

orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_vals, pd_cols, pd_probs, n_thread, write_forest, run_forest, verbosity) {
.Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_vals, pd_cols, pd_probs, n_thread, write_forest, run_forest, verbosity)
orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity) {
.Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity)
}

8 changes: 4 additions & 4 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,8 @@ orsf <- function(data,
'user' = 2),
oobag_eval_every = oobag_eval_every,
pd_type_R = 0,
pd_vals = matrix(0, ncol=1, nrow=1),
pd_cols = matrix(1L, ncol=1, nrow=1),
pd_x_vals = list(matrix(0, ncol=1, nrow=1)),
pd_x_cols = list(matrix(1L, ncol=1, nrow=1)),
pd_probs = c(0),
n_thread = n_thread,
write_forest = TRUE,
Expand Down Expand Up @@ -1134,8 +1134,8 @@ orsf_train_ <- function(object,
'user' = 2),
oobag_eval_every = oobag_eval_every,
pd_type_R = 0,
pd_vals = matrix(0, ncol=1, nrow=1),
pd_cols = matrix(1L, ncol=1, nrow=1),
pd_x_vals = list(matrix(0, ncol=1, nrow=1)),
pd_x_cols = list(matrix(1L, ncol=1, nrow=1)),
pd_probs = c(0),
n_thread = get_n_thread(object),
write_forest = TRUE,
Expand Down
212 changes: 190 additions & 22 deletions R/orsf_pd.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,6 @@ orsf_pred_dependence <- function(object,
pred_type, " predictions.", call. = FALSE)
}

type_input <- if(expand_grid) 'grid' else 'loop'

names_x_data <- intersect(get_names_x(object), names(pd_data))

cc <- which(stats::complete.cases(select_cols(pd_data, names_x_data)))
Expand All @@ -367,19 +365,184 @@ orsf_pred_dependence <- function(object,
pred_spec[[i]] <- (pred_spec[[i]] - means[i]) / standard_deviations[i]
}

if(is.data.frame(pred_spec)) type_input <- 'grid'

pd_fun <- switch(type_input, 'grid' = pd_grid, 'loop' = pd_loop)

pred_type_R <- switch(pred_type,
"risk" = 1, "surv" = 2,
"chf" = 3, "mort" = 4)

browser()
fi <- get_fctr_info(object)

if(expand_grid){

if(!is.data.frame(pred_spec))
pred_spec <- expand.grid(pred_spec, stringsAsFactors = TRUE)

for(i in seq_along(fi$cols)){

ii <- fi$cols[i]

if(is.character(pred_spec[[ii]]) && !fi$ordr[i]){

pred_spec[[ii]] <- factor(pred_spec[[ii]], levels = fi$lvls[[ii]])

}

}

check_new_data_fctrs(new_data = pred_spec,
names_x = get_names_x(object),
fi_ref = fi,
label_new = "pred_spec")

pred_spec_new <- ref_code(x_data = pred_spec,
fi = get_fctr_info(object),
names_x_data = names(pred_spec))

x_cols <- list(match(names(pred_spec_new), colnames(x_new)) - 1)

pred_spec_new <- list(as.matrix(pred_spec_new))

pd_bind <- list(pred_spec)

} else {

pred_spec_new <- pd_bind <- x_cols <- list()

for(i in seq_along(pred_spec)){

pred_spec_new[[i]] <- as.data.frame(pred_spec[i])
pd_name <- names(pred_spec)[i]

pd_bind[[i]] <- data.frame(
variable = pd_name,
value = rep(NA_real_, length(pred_spec[[i]])),
level = rep(NA_character_, length(pred_spec[[i]]))
)

if(pd_name %in% fi$cols) {

pd_bind[[i]]$level <- as.character(pred_spec[[i]])

pred_spec_new[[i]] <- ref_code(pred_spec_new[[i]],
fi = fi,
names_x_data = pd_name)

} else {

pd_bind[[i]]$value <- pred_spec[[i]]

}

out <- pd_fun(object, x_new, pred_spec, pred_horizon,
type_output, prob_values, prob_labels,
n_thread, oobag, pred_type_R)
x_cols[[i]] <- match(names(pred_spec_new[[i]]), colnames(x_new))
pred_spec_new[[i]] <- as.matrix(pred_spec_new[[i]])

}

}

orsf_out <- orsf_cpp(x = x_new,
y = matrix(1, ncol=2),
w = rep(1, nrow(x_new)),
tree_type_R = get_tree_type(object),
tree_seeds = get_tree_seeds(object),
loaded_forest = object$forest,
n_tree = get_n_tree(object),
mtry = get_mtry(object),
vi_type_R = 0,
vi_max_pvalue = get_vi_max_pvalue(object),
lincomb_R_function = get_f_beta(object),
oobag_R_function = get_f_oobag_eval(object),
leaf_min_events = get_leaf_min_events(object),
leaf_min_obs = get_leaf_min_obs(object),
split_rule_R = switch(get_split_rule(object),
"logrank" = 1,
"cstat" = 2),
split_min_events = get_split_min_events(object),
split_min_obs = get_split_min_obs(object),
split_min_stat = get_split_min_stat(object),
split_max_cuts = get_n_split(object),
split_max_retry = get_n_retry(object),
lincomb_type_R = switch(get_orsf_type(object),
'fast' = 1,
'cph' = 1,
'random' = 2,
'net' = 3,
'custom' = 4),
lincomb_eps = get_cph_eps(object),
lincomb_iter_max = get_cph_iter_max(object),
lincomb_scale = get_cph_do_scale(object),
lincomb_alpha = get_net_alpha(object),
lincomb_df_target = get_net_df_target(object),
lincomb_ties_method = switch(
tolower(get_cph_method(object)),
'breslow' = 0,
'efron' = 1
),
pred_type_R = pred_type_R,
pred_mode = FALSE,
pred_aggregate = TRUE,
pred_horizon = pred_horizon,
oobag = oobag,
oobag_eval_type_R = 0,
oobag_eval_every = get_n_tree(object),
pd_type_R = switch(type_output,
"smry" = 1L,
"ice" = 2L),
pd_x_vals = pred_spec_new,
pd_x_cols = x_cols,
pd_probs = prob_values,
n_thread = n_thread,
write_forest = FALSE,
run_forest = TRUE,
verbosity = 0)

pd_vals <- orsf_out$pd_values

for(i in seq_along(pd_vals)){

pd_bind[[i]]$id_variable <- seq(nrow(pd_bind[[i]]))

for(j in seq_along(pd_vals[[i]])){

pd_vals[[i]][[j]] <- matrix(pd_vals[[i]][[j]],
nrow=length(pred_horizon),
byrow = T)

rownames(pd_vals[[i]][[j]]) <- pred_horizon

if(type_output=='smry')
colnames(pd_vals[[i]][[j]]) <- c('mean', prob_labels)
else
colnames(pd_vals[[i]][[j]]) <- c(paste(1:nrow(x_new)))

pd_vals[[i]][[j]] <- as.data.table(pd_vals[[i]][[j]],
keep.rownames = 'pred_horizon')

if(type_output == 'ice')
pd_vals[[i]][[j]] <- melt(data = pd_vals[[i]][[j]],
id.vars = 'pred_horizon',
variable.name = 'id_row',
value.name = 'pred_value')

}

pd_vals[[i]] <- rbindlist(pd_vals[[i]], idcol = 'id_variable')

pd_vals[[i]] <- merge(pd_vals[[i]],
as.data.table(pd_bind[[i]]),
by = 'id_variable')

}


out <- rbindlist(pd_vals)

ids <- c('id_variable', if(type_output == 'ice') 'id_row')

mid <- setdiff(names(out), c(ids, 'mean', prob_labels, 'pred_value'))

end <- setdiff(names(out), c(ids, mid))

setcolorder(out, neworder = c(ids, mid, end))

out[, pred_horizon := as.numeric(pred_horizon)]

Expand Down Expand Up @@ -515,25 +678,30 @@ pd_grid <- function(object,
pd_type_R = switch(type_output,
"smry" = 1L,
"ice" = 2L),
pd_vals = as.matrix(pred_spec_new),
pd_cols = x_cols-1L,
pd_x_vals = list(as.matrix(pred_spec_new)),
pd_x_cols = list(x_cols-1L),
pd_probs = prob_values,
n_thread = n_thread,
write_forest = FALSE,
run_forest = TRUE,
verbosity = 0)
verbosity = 4)

pd_vals <- orsf_out$pd_values

if(type_output == 'smry'){

pd_vals <- lapply(pd_vals, function(x){
m <- matrix(x, nrow=length(pred_horizon), byrow = T)
rownames(m) <- pred_horizon
colnames(m) <- c('mean', prob_labels)
data.table(m, keep.rownames = 'pred_horizon')
})

for(i in seq_along(pd_vals)){
for(j in seq_along(pd_vals[[i]])){
pd_vals[[i]][[j]] <- matrix(pd_vals[[i]][[j]],
nrow=length(pred_horizon),
byrow = T)
rownames(pd_vals[[i]][[j]]) <- pred_horizon
colnames(pd_vals[[i]][[j]]) <- c('mean', prob_labels)
}
pd_vals[[i]] <- as.data.table(
do.call(rbind, pd_vals[[i]]), keep.rownames = 'pred_horizon'
)
}

} else if(type_output == 'ice'){

Expand Down Expand Up @@ -673,8 +841,8 @@ pd_loop <- function(object,
oobag_eval_type_R = 0,
oobag_eval_every = get_n_tree(object),
pd_type_R = 1,
pd_vals = as.matrix(pd_new),
pd_cols = x_cols-1L,
pd_x_vals = as.matrix(pd_new),
pd_x_cols = x_cols-1L,
pd_probs = prob_values,
n_thread = n_thread,
write_forest = FALSE,
Expand Down
4 changes: 2 additions & 2 deletions R/orsf_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ predict.orsf_fit <- function(object,
oobag_eval_type_R = 0,
oobag_eval_every = get_n_tree(object),
pd_type_R = 0,
pd_vals = matrix(0, ncol=1, nrow=1),
pd_cols = matrix(1L, ncol=1, nrow=1),
pd_x_vals = list(matrix(0, ncol=1, nrow=1)),
pd_x_cols = list(matrix(1L, ncol=1, nrow=1)),
pd_probs = c(0),
n_thread = n_thread,
write_forest = FALSE,
Expand Down
4 changes: 2 additions & 2 deletions R/orsf_vi.R
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ orsf_vi_oobag_ <- function(object,
'user' = 2),
oobag_eval_every = get_n_tree(object),
pd_type_R = 0,
pd_vals = matrix(0, ncol=1, nrow=1),
pd_cols = matrix(1L, ncol=1, nrow=1),
pd_x_vals = list(matrix(0, ncol=1, nrow=1)),
pd_x_cols = list(matrix(1L, ncol=1, nrow=1)),
pd_probs = c(0),
n_thread = n_thread,
write_forest = FALSE,
Expand Down
Loading

0 comments on commit ea632ac

Please sign in to comment.