diff --git a/R/orsf_pd.R b/R/orsf_pd.R index c242bd48..ce425474 100644 --- a/R/orsf_pd.R +++ b/R/orsf_pd.R @@ -430,7 +430,7 @@ orsf_pred_dependence <- function(object, } - x_cols[[i]] <- match(names(pred_spec_new[[i]]), colnames(x_new)) + x_cols[[i]] <- match(names(pred_spec_new[[i]]), colnames(x_new)) - 1 pred_spec_new[[i]] <- as.matrix(pred_spec_new[[i]]) } @@ -442,153 +442,152 @@ orsf_pred_dependence <- function(object, pred_horizon_order <- order(pred_horizon) pred_horizon_ordered <- pred_horizon[pred_horizon_order] - results <- list() - - for(i in seq_along(pred_spec_new)){ - - results_i <- list() - - x_pd <- x_new - - for(j in seq(nrow(pred_spec_new[[i]]))){ - - x_pd[, x_cols[[i]]] <- pred_spec_new[[i]][j, ] - - results_i[[j]] <- orsf_cpp( - x = x_pd, - y = matrix(1, ncol=2), - w = rep(1, nrow(x_new)), - tree_type_R = get_tree_type(object), - tree_seeds = get_tree_seeds(object), - loaded_forest = object$forest, - n_tree = get_n_tree(object), - mtry = get_mtry(object), - sample_with_replacement = get_sample_with_replacement(object), - sample_fraction = get_sample_fraction(object), - vi_type_R = 0, - vi_max_pvalue = get_vi_max_pvalue(object), - oobag_R_function = get_f_oobag_eval(object), - leaf_min_events = get_leaf_min_events(object), - leaf_min_obs = get_leaf_min_obs(object), - split_rule_R = switch(get_split_rule(object), - "logrank" = 1, - "cstat" = 2), - split_min_events = get_split_min_events(object), - split_min_obs = get_split_min_obs(object), - split_min_stat = get_split_min_stat(object), - split_max_cuts = get_n_split(object), - split_max_retry = get_n_retry(object), - lincomb_R_function = control$lincomb_R_function, - lincomb_type_R = switch(control$lincomb_type, - 'glm' = 1, - 'random' = 2, - 'net' = 3, - 'custom' = 4), - lincomb_eps = control$lincomb_eps, - lincomb_iter_max = control$lincomb_iter_max, - lincomb_scale = control$lincomb_scale, - lincomb_alpha = control$lincomb_alpha, - lincomb_df_target = control$lincomb_df_target, - lincomb_ties_method = switch(tolower(control$lincomb_ties_method), - 'breslow' = 0, - 'efron' = 1), - pred_type_R = pred_type_R, - pred_mode = TRUE, - pred_aggregate = TRUE, - pred_horizon = pred_horizon_ordered, - oobag = oobag, - oobag_eval_type_R = 0, - oobag_eval_every = get_n_tree(object), - pd_type_R = 0, - pd_x_vals = list(matrix(0, ncol=1, nrow=1)), - pd_x_cols = list(matrix(1L, ncol=1, nrow=1)), - pd_probs = c(0), - n_thread = n_thread, - write_forest = FALSE, - run_forest = TRUE, - verbosity = 0)$pred_new - - } - - if(type_output == 'smry'){ - results_i <- lapply( - results_i, - function(x) { - apply(x, 2, function(x_col){ - as.numeric( - c(mean(x_col, na.rm = TRUE), - quantile(x_col, probs = prob_values, na.rm = TRUE)) - ) - }) - } - ) - } - - - results[[i]] <- results_i - - } - + # results <- list() + # + # for(i in seq_along(pred_spec_new)){ + # + # results_i <- list() + # + # x_pd <- x_new + # + # for(j in seq(nrow(pred_spec_new[[i]]))){ + # + # x_pd[, x_cols[[i]]] <- pred_spec_new[[i]][j, ] + # + # results_i[[j]] <- orsf_cpp( + # x = x_pd, + # y = matrix(1, ncol=2), + # w = rep(1, nrow(x_new)), + # tree_type_R = get_tree_type(object), + # tree_seeds = get_tree_seeds(object), + # loaded_forest = object$forest, + # n_tree = get_n_tree(object), + # mtry = get_mtry(object), + # sample_with_replacement = get_sample_with_replacement(object), + # sample_fraction = get_sample_fraction(object), + # vi_type_R = 0, + # vi_max_pvalue = get_vi_max_pvalue(object), + # oobag_R_function = get_f_oobag_eval(object), + # leaf_min_events = get_leaf_min_events(object), + # leaf_min_obs = get_leaf_min_obs(object), + # split_rule_R = switch(get_split_rule(object), + # "logrank" = 1, + # "cstat" = 2), + # split_min_events = get_split_min_events(object), + # split_min_obs = get_split_min_obs(object), + # split_min_stat = get_split_min_stat(object), + # split_max_cuts = get_n_split(object), + # split_max_retry = get_n_retry(object), + # lincomb_R_function = control$lincomb_R_function, + # lincomb_type_R = switch(control$lincomb_type, + # 'glm' = 1, + # 'random' = 2, + # 'net' = 3, + # 'custom' = 4), + # lincomb_eps = control$lincomb_eps, + # lincomb_iter_max = control$lincomb_iter_max, + # lincomb_scale = control$lincomb_scale, + # lincomb_alpha = control$lincomb_alpha, + # lincomb_df_target = control$lincomb_df_target, + # lincomb_ties_method = switch(tolower(control$lincomb_ties_method), + # 'breslow' = 0, + # 'efron' = 1), + # pred_type_R = pred_type_R, + # pred_mode = TRUE, + # pred_aggregate = TRUE, + # pred_horizon = pred_horizon_ordered, + # oobag = oobag, + # oobag_eval_type_R = 0, + # oobag_eval_every = get_n_tree(object), + # pd_type_R = 0, + # pd_x_vals = list(matrix(0, ncol=1, nrow=1)), + # pd_x_cols = list(matrix(1L, ncol=1, nrow=1)), + # pd_probs = c(0), + # n_thread = n_thread, + # write_forest = FALSE, + # run_forest = TRUE, + # verbosity = 0)$pred_new + # + # } + # + # if(type_output == 'smry'){ + # results_i <- lapply( + # results_i, + # function(x) { + # apply(x, 2, function(x_col){ + # as.numeric( + # c(mean(x_col, na.rm = TRUE), + # quantile(x_col, probs = prob_values, na.rm = TRUE)) + # ) + # }) + # } + # ) + # } + # + # + # results[[i]] <- results_i + # + # } + # + # pd_vals <- results # browser() - # orsf_out <- orsf_cpp(x = x_new, - # y = matrix(1, ncol=2), - # w = rep(1, nrow(x_new)), - # tree_type_R = get_tree_type(object), - # tree_seeds = get_tree_seeds(object), - # loaded_forest = object$forest, - # n_tree = get_n_tree(object), - # mtry = get_mtry(object), - # sample_with_replacement = get_sample_with_replacement(object), - # sample_fraction = get_sample_fraction(object), - # vi_type_R = 0, - # vi_max_pvalue = get_vi_max_pvalue(object), - # oobag_R_function = get_f_oobag_eval(object), - # leaf_min_events = get_leaf_min_events(object), - # leaf_min_obs = get_leaf_min_obs(object), - # split_rule_R = switch(get_split_rule(object), - # "logrank" = 1, - # "cstat" = 2), - # split_min_events = get_split_min_events(object), - # split_min_obs = get_split_min_obs(object), - # split_min_stat = get_split_min_stat(object), - # split_max_cuts = get_n_split(object), - # split_max_retry = get_n_retry(object), - # lincomb_R_function = control$lincomb_R_function, - # lincomb_type_R = switch(control$lincomb_type, - # 'glm' = 1, - # 'random' = 2, - # 'net' = 3, - # 'custom' = 4), - # lincomb_eps = control$lincomb_eps, - # lincomb_iter_max = control$lincomb_iter_max, - # lincomb_scale = control$lincomb_scale, - # lincomb_alpha = control$lincomb_alpha, - # lincomb_df_target = control$lincomb_df_target, - # lincomb_ties_method = switch(tolower(control$lincomb_ties_method), - # 'breslow' = 0, - # 'efron' = 1), - # pred_type_R = pred_type_R, - # pred_mode = FALSE, - # pred_aggregate = TRUE, - # pred_horizon = pred_horizon, - # oobag = oobag, - # oobag_eval_type_R = 0, - # oobag_eval_every = get_n_tree(object), - # pd_type_R = switch(type_output, - # "smry" = 1L, - # "ice" = 2L), - # pd_x_vals = pred_spec_new, - # pd_x_cols = lapply(x_cols, function(x) x - 1), - # pd_probs = prob_values, - # n_thread = n_thread, - # write_forest = FALSE, - # run_forest = TRUE, - # verbosity = 0) - # - # pd_vals <- orsf_out$pd_values + orsf_out <- orsf_cpp(x = x_new, + y = matrix(1, ncol=2), + w = rep(1, nrow(x_new)), + tree_type_R = get_tree_type(object), + tree_seeds = get_tree_seeds(object), + loaded_forest = object$forest, + n_tree = get_n_tree(object), + mtry = get_mtry(object), + sample_with_replacement = get_sample_with_replacement(object), + sample_fraction = get_sample_fraction(object), + vi_type_R = 0, + vi_max_pvalue = get_vi_max_pvalue(object), + oobag_R_function = get_f_oobag_eval(object), + leaf_min_events = get_leaf_min_events(object), + leaf_min_obs = get_leaf_min_obs(object), + split_rule_R = switch(get_split_rule(object), + "logrank" = 1, + "cstat" = 2), + split_min_events = get_split_min_events(object), + split_min_obs = get_split_min_obs(object), + split_min_stat = get_split_min_stat(object), + split_max_cuts = get_n_split(object), + split_max_retry = get_n_retry(object), + lincomb_R_function = control$lincomb_R_function, + lincomb_type_R = switch(control$lincomb_type, + 'glm' = 1, + 'random' = 2, + 'net' = 3, + 'custom' = 4), + lincomb_eps = control$lincomb_eps, + lincomb_iter_max = control$lincomb_iter_max, + lincomb_scale = control$lincomb_scale, + lincomb_alpha = control$lincomb_alpha, + lincomb_df_target = control$lincomb_df_target, + lincomb_ties_method = switch(tolower(control$lincomb_ties_method), + 'breslow' = 0, + 'efron' = 1), + pred_type_R = pred_type_R, + pred_mode = FALSE, + pred_aggregate = TRUE, + pred_horizon = pred_horizon, + oobag = oobag, + oobag_eval_type_R = 0, + oobag_eval_every = get_n_tree(object), + pd_type_R = switch(type_output, + "smry" = 1L, + "ice" = 2L), + pd_x_vals = pred_spec_new, + pd_x_cols = x_cols, + pd_probs = prob_values, + n_thread = n_thread, + write_forest = FALSE, + run_forest = TRUE, + verbosity = 0) - pd_vals <- results + pd_vals <- orsf_out$pd_values for(i in seq_along(pd_vals)){ diff --git a/tests/testthat/test-orsf_pd.R b/tests/testthat/test-orsf_pd.R index 826d58c4..fab0d11c 100644 --- a/tests/testthat/test-orsf_pd.R +++ b/tests/testthat/test-orsf_pd.R @@ -57,9 +57,6 @@ test_that( ) funs <- list( - # ice_new = orsf_ice_new, - # ice_inb = orsf_ice_inb, - # ice_oob = orsf_ice_oob, pd_new = orsf_pd_new, pd_inb = orsf_pd_inb, pd_oob = orsf_pd_oob @@ -87,8 +84,7 @@ for(i in seq_along(funs)){ formals <- setdiff(names(formals(funs[[i]])), '...') - for(pred_type in c('mort')){ - # for(pred_type in setdiff(pred_types_surv, c('leaf', 'mort'))){ + for(pred_type in setdiff(pred_types_surv, c('leaf'))){ args_grid$pred_type = pred_type args_loop$pred_type = pred_type @@ -147,43 +143,40 @@ for(i in seq_along(funs)){ } -# pd_vals_ice <- orsf_ice_new( -# fit, -# new_data = pbc_orsf, -# pred_spec = list(bili = 1:4), -# pred_horizon = 1000 -# ) -# -pd_vals_smry <- orsf_pd_new( +pd_vals_ice <- orsf_ice_new( fit, - new_data = pbc_orsf, + new_data = pbc_test, pred_spec = list(bili = 1:4), pred_horizon = 1000 ) -# -# test_that( -# 'ice values summarized are the same as pd values', -# code = { -# -# pd_vals_check <- pd_vals_ice[, .(medn = median(pred)), by = id_variable] -# -# expect_equal( -# pd_vals_check$medn, -# pd_vals_smry$medn -# ) -# -# } -# ) +pd_vals_smry <- orsf_pd_new( + fit, + new_data = pbc_test, + pred_spec = list(bili = 1:4), + pred_horizon = 1000 +) + test_that( - 'No missing values in output', + 'ice values summarized are the same as pd values', code = { - # expect_false(any(is.na(pd_vals_ice))) - # expect_false(any(is.nan(as.matrix(pd_vals_ice)))) - # expect_false(any(is.infinite(as.matrix(pd_vals_ice)))) + grps <- split(pd_vals_ice, pd_vals_ice$id_variable) + pd_vals_check <- sapply(grps, function(x) median(x$pred)) + + expect_equal( + as.numeric(pd_vals_check), + pd_vals_smry$medn + ) + + } +) + +test_that( + 'No missing values in summary output', + code = { expect_false(any(is.na(pd_vals_smry))) expect_false(any(is.nan(as.matrix(pd_vals_smry)))) expect_false(any(is.infinite(as.matrix(pd_vals_smry)))) @@ -200,10 +193,9 @@ test_that( pred_horizon = c(1000, 2000, 3000) ) - # risk must increase or remain steady over time + # risk monotonically increases expect_lte(pd_smry_multi_horiz$mean[1], pd_smry_multi_horiz$mean[2]) expect_lte(pd_smry_multi_horiz$mean[2], pd_smry_multi_horiz$mean[3]) - expect_lte(pd_smry_multi_horiz$medn[1], pd_smry_multi_horiz$medn[2]) expect_lte(pd_smry_multi_horiz$medn[2], pd_smry_multi_horiz$medn[3]) @@ -213,9 +205,11 @@ test_that( pred_horizon = c(1000, 2000, 3000) ) - ice_check <- pd_ice_multi_horiz[, .(m = mean(pred, na.rm=TRUE)), by = pred_horizon] + grps <- split(pd_ice_multi_horiz, pd_ice_multi_horiz$pred_horizon) + + ice_check <- sapply(grps, function(x) mean(x$pred, na.rm=TRUE)) - expect_equal(ice_check$m, pd_smry_multi_horiz$mean) + expect_equal(as.numeric(ice_check), pd_smry_multi_horiz$mean) }