diff --git a/R/orsf_pd.R b/R/orsf_pd.R index 48caa608..c242bd48 100644 --- a/R/orsf_pd.R +++ b/R/orsf_pd.R @@ -430,7 +430,7 @@ orsf_pred_dependence <- function(object, } - x_cols[[i]] <- match(names(pred_spec_new[[i]]), colnames(x_new))-1 + x_cols[[i]] <- match(names(pred_spec_new[[i]]), colnames(x_new)) pred_spec_new[[i]] <- as.matrix(pred_spec_new[[i]]) } @@ -439,62 +439,156 @@ orsf_pred_dependence <- function(object, control <- get_control(object) - orsf_out <- orsf_cpp(x = x_new, - y = matrix(1, ncol=2), - w = rep(1, nrow(x_new)), - tree_type_R = get_tree_type(object), - tree_seeds = get_tree_seeds(object), - loaded_forest = object$forest, - n_tree = get_n_tree(object), - mtry = get_mtry(object), - sample_with_replacement = get_sample_with_replacement(object), - sample_fraction = get_sample_fraction(object), - vi_type_R = 0, - vi_max_pvalue = get_vi_max_pvalue(object), - oobag_R_function = get_f_oobag_eval(object), - leaf_min_events = get_leaf_min_events(object), - leaf_min_obs = get_leaf_min_obs(object), - split_rule_R = switch(get_split_rule(object), - "logrank" = 1, - "cstat" = 2), - split_min_events = get_split_min_events(object), - split_min_obs = get_split_min_obs(object), - split_min_stat = get_split_min_stat(object), - split_max_cuts = get_n_split(object), - split_max_retry = get_n_retry(object), - lincomb_R_function = control$lincomb_R_function, - lincomb_type_R = switch(control$lincomb_type, - 'glm' = 1, - 'random' = 2, - 'net' = 3, - 'custom' = 4), - lincomb_eps = control$lincomb_eps, - lincomb_iter_max = control$lincomb_iter_max, - lincomb_scale = control$lincomb_scale, - lincomb_alpha = control$lincomb_alpha, - lincomb_df_target = control$lincomb_df_target, - lincomb_ties_method = switch(tolower(control$lincomb_ties_method), - 'breslow' = 0, - 'efron' = 1), - pred_type_R = pred_type_R, - pred_mode = FALSE, - pred_aggregate = TRUE, - pred_horizon = pred_horizon, - oobag = oobag, - oobag_eval_type_R = 0, - oobag_eval_every = get_n_tree(object), - pd_type_R = switch(type_output, - "smry" = 1L, - "ice" = 2L), - pd_x_vals = pred_spec_new, - pd_x_cols = x_cols, - pd_probs = prob_values, - n_thread = n_thread, - write_forest = FALSE, - run_forest = TRUE, - verbosity = 0) + pred_horizon_order <- order(pred_horizon) + pred_horizon_ordered <- pred_horizon[pred_horizon_order] + + results <- list() + + for(i in seq_along(pred_spec_new)){ + + results_i <- list() + + x_pd <- x_new + + for(j in seq(nrow(pred_spec_new[[i]]))){ + + x_pd[, x_cols[[i]]] <- pred_spec_new[[i]][j, ] + + results_i[[j]] <- orsf_cpp( + x = x_pd, + y = matrix(1, ncol=2), + w = rep(1, nrow(x_new)), + tree_type_R = get_tree_type(object), + tree_seeds = get_tree_seeds(object), + loaded_forest = object$forest, + n_tree = get_n_tree(object), + mtry = get_mtry(object), + sample_with_replacement = get_sample_with_replacement(object), + sample_fraction = get_sample_fraction(object), + vi_type_R = 0, + vi_max_pvalue = get_vi_max_pvalue(object), + oobag_R_function = get_f_oobag_eval(object), + leaf_min_events = get_leaf_min_events(object), + leaf_min_obs = get_leaf_min_obs(object), + split_rule_R = switch(get_split_rule(object), + "logrank" = 1, + "cstat" = 2), + split_min_events = get_split_min_events(object), + split_min_obs = get_split_min_obs(object), + split_min_stat = get_split_min_stat(object), + split_max_cuts = get_n_split(object), + split_max_retry = get_n_retry(object), + lincomb_R_function = control$lincomb_R_function, + lincomb_type_R = switch(control$lincomb_type, + 'glm' = 1, + 'random' = 2, + 'net' = 3, + 'custom' = 4), + lincomb_eps = control$lincomb_eps, + lincomb_iter_max = control$lincomb_iter_max, + lincomb_scale = control$lincomb_scale, + lincomb_alpha = control$lincomb_alpha, + lincomb_df_target = control$lincomb_df_target, + lincomb_ties_method = switch(tolower(control$lincomb_ties_method), + 'breslow' = 0, + 'efron' = 1), + pred_type_R = pred_type_R, + pred_mode = TRUE, + pred_aggregate = TRUE, + pred_horizon = pred_horizon_ordered, + oobag = oobag, + oobag_eval_type_R = 0, + oobag_eval_every = get_n_tree(object), + pd_type_R = 0, + pd_x_vals = list(matrix(0, ncol=1, nrow=1)), + pd_x_cols = list(matrix(1L, ncol=1, nrow=1)), + pd_probs = c(0), + n_thread = n_thread, + write_forest = FALSE, + run_forest = TRUE, + verbosity = 0)$pred_new + + } + + if(type_output == 'smry'){ + results_i <- lapply( + results_i, + function(x) { + apply(x, 2, function(x_col){ + as.numeric( + c(mean(x_col, na.rm = TRUE), + quantile(x_col, probs = prob_values, na.rm = TRUE)) + ) + }) + } + ) + } - pd_vals <- orsf_out$pd_values + + results[[i]] <- results_i + + } + + # browser() + + # orsf_out <- orsf_cpp(x = x_new, + # y = matrix(1, ncol=2), + # w = rep(1, nrow(x_new)), + # tree_type_R = get_tree_type(object), + # tree_seeds = get_tree_seeds(object), + # loaded_forest = object$forest, + # n_tree = get_n_tree(object), + # mtry = get_mtry(object), + # sample_with_replacement = get_sample_with_replacement(object), + # sample_fraction = get_sample_fraction(object), + # vi_type_R = 0, + # vi_max_pvalue = get_vi_max_pvalue(object), + # oobag_R_function = get_f_oobag_eval(object), + # leaf_min_events = get_leaf_min_events(object), + # leaf_min_obs = get_leaf_min_obs(object), + # split_rule_R = switch(get_split_rule(object), + # "logrank" = 1, + # "cstat" = 2), + # split_min_events = get_split_min_events(object), + # split_min_obs = get_split_min_obs(object), + # split_min_stat = get_split_min_stat(object), + # split_max_cuts = get_n_split(object), + # split_max_retry = get_n_retry(object), + # lincomb_R_function = control$lincomb_R_function, + # lincomb_type_R = switch(control$lincomb_type, + # 'glm' = 1, + # 'random' = 2, + # 'net' = 3, + # 'custom' = 4), + # lincomb_eps = control$lincomb_eps, + # lincomb_iter_max = control$lincomb_iter_max, + # lincomb_scale = control$lincomb_scale, + # lincomb_alpha = control$lincomb_alpha, + # lincomb_df_target = control$lincomb_df_target, + # lincomb_ties_method = switch(tolower(control$lincomb_ties_method), + # 'breslow' = 0, + # 'efron' = 1), + # pred_type_R = pred_type_R, + # pred_mode = FALSE, + # pred_aggregate = TRUE, + # pred_horizon = pred_horizon, + # oobag = oobag, + # oobag_eval_type_R = 0, + # oobag_eval_every = get_n_tree(object), + # pd_type_R = switch(type_output, + # "smry" = 1L, + # "ice" = 2L), + # pd_x_vals = pred_spec_new, + # pd_x_cols = lapply(x_cols, function(x) x - 1), + # pd_probs = prob_values, + # n_thread = n_thread, + # write_forest = FALSE, + # run_forest = TRUE, + # verbosity = 0) + # + # pd_vals <- orsf_out$pd_values + + pd_vals <- results for(i in seq_along(pd_vals)){ @@ -592,3 +686,33 @@ orsf_pred_dependence <- function(object, } +pd_list_split <- function(x_vals, x_cols){ + + x_vals_out <- x_cols_out <- vector(mode = 'list') + counter <- 1 + + for(i in seq_along(x_vals)){ + + x_vals_split <- split(x_vals[[i]], row(x_vals[[i]])) + + for(j in seq_along(x_vals_split)){ + + x_vals_out[[counter]] <- matrix(x_vals_split[[j]], + ncol = ncol(x_vals[[i]]), + nrow = 1) + colnames(x_vals_out[[counter]]) <- colnames(x_vals[[i]]) + + x_cols_out[[counter]] <- x_cols[[i]] + + counter <- counter + 1 + + } + + } + + list( + x_vals = x_vals_out, + x_cols = x_cols_out + ) + +}