diff --git a/R/orsf_R6.R b/R/orsf_R6.R index 1ab03401..a2825366 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -1028,6 +1028,7 @@ ObliqueForest <- R6::R6Class( private$init_oobag_eval_function() + private$init_lincomb_R_function() private$init_oobag_pred_mode() private$init_tree_seeds() private$init_internal() @@ -1226,6 +1227,16 @@ ObliqueForest <- R6::R6Class( }, + init_lincomb_R_function = function(){ + + if(self$control$lincomb_type == 'custom'){ + + private$check_lincomb_R_function(self$control$lincomb_R_function) + + } + + }, + # checkers check_data = function(data = NULL, new = FALSE){ @@ -2010,6 +2021,60 @@ ObliqueForest <- R6::R6Class( } }, + + + check_lincomb_R_function = function(lincomb_R_function = NULL){ + + input <- lincomb_R_function %||% self$lincomb_R_function + + args <- names(formals(input)) + + if(length(args) != 3) stop( + "input should have 3 input arguments but instead has ", + length(args), + call. = FALSE + ) + + arg_names_expected <- c("x_node", + "y_node", + "w_node") + + arg_names_refer <- c('first', 'second', 'third') + + for(i in seq_along(arg_names_expected)){ + if(args[i] != arg_names_expected[i]) + stop( + "the ", arg_names_refer[i], " input argument of input ", + "should be named '", arg_names_expected[i],"' ", + "but is instead named '", args[i], "'", + call. = FALSE + ) + } + + test_output <- private$check_lincomb_R_function_internal(input) + + if(!is.matrix(test_output)) stop( + "user-supplied function should return a matrix output ", + "but instead returns output of type ", class(test_output)[1], + call. = FALSE + ) + + if(ncol(test_output) != 1) stop( + "user-supplied function should return a matrix with 1 column ", + "but instead returns a matrix with ", ncol(test_output), " columns.", + call. = FALSE + ) + + if(nrow(test_output) != 3L) stop( + "user-supplied function should return a matrix with 1 row for each ", + " column in x_node but instead returns a matrix with ", + nrow(test_output), " rows ", "in a testing case where x_node has ", + 3L, " columns", + call. = FALSE + ) + + }, + check_verbose_progress = function(verbose_progress = NULL){ input <- verbose_progress %||% self$verbose_progress @@ -2433,7 +2498,9 @@ ObliqueForest <- R6::R6Class( }, - sort_inputs = function(){ + sort_inputs = function(sort_y = NULL, + sort_x = NULL, + sort_w = NULL){ NULL }, @@ -2734,6 +2801,40 @@ ObliqueForestSurvival <- R6::R6Class( }, + check_lincomb_R_function_internal = function(lincomb_R_function = NULL){ + + input <- lincomb_R_function %||% self$lincomb_R_function + + test_time <- seq(from = 1, to = 5, length.out = 100) + test_status <- rep(c(0,1), each = 50) + + .x_node <- matrix(rnorm(300), ncol = 3) + .y_node <- cbind(time = test_time, status = test_status) + .w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1) + + + out <- try(input(.x_node, .y_node, .w_node), silent = FALSE) + + if(is_error(out)){ + + stop("user-supplied function encountered an error when it was tested. ", + "Please make sure the function works for this case:\n\n", + "test_time <- seq(from = 1, to = 5, length.out = 100)\n", + "test_status <- rep(c(0,1), each = 50)\n\n", + ".x_node <- matrix(seq(-1, 1, length.out = 300), ncol = 3)\n", + ".y_node <- cbind(time = test_time, status = test_status)\n", + ".w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)\n\n", + "test_output <- user_function(.x_node, .y_node, .w_node)\n\n", + "test_output should be a numeric matrix with 1 column and", + " with nrow(test_output) = ncol(.x_node)", + call. = FALSE) + + } + + out + + }, + sort_inputs = function(sort_x = TRUE, sort_y = TRUE, sort_w = TRUE){ @@ -3131,6 +3232,34 @@ ObliqueForestClassification <- R6::R6Class( }, + check_lincomb_R_function_internal = function(lincomb_R_function = NULL){ + + input <- lincomb_R_function %||% self$lincomb_R_function + + .x_node <- matrix(rnorm(300), ncol = 3) + .y_node <- matrix(rbinom(100, size = 1, prob = 1/2), ncol = 1) + .w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1) + + out <- try(input(.x_node, .y_node, .w_node), silent = FALSE) + + if(is_error(out)){ + + stop("user-supplied function encountered an error when it was tested. ", + "Please make sure the function works for this case:\n\n", + ".x_node <- matrix(rnorm(300), ncol = 3)\n", + ".y_node <- matrix(rbinom(100, size = 1, prob = 1/2), ncol = 1)\n", + ".w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)\n", + "test_output <- your_function(.x_node, .y_node, .w_node)\n\n", + "test_output should be a numeric matrix with 1 column and", + " with nrow(test_output) = ncol(.x_node)", + call. = FALSE) + + } + + out + + }, + init_control = function(){ self$control <- orsf_control_classification(method = 'glm', @@ -3215,6 +3344,8 @@ ObliqueForestClassification <- R6::R6Class( y <- as.numeric(y) - 1 + if(min(y) > 0) browser() + private$y <- expand_y_clsf(as_matrix(y), n_class) }, @@ -3320,6 +3451,36 @@ ObliqueForestRegression <- R6::R6Class( }, + check_lincomb_R_function_internal = function(lincomb_R_function = NULL){ + + input <- lincomb_R_function %||% self$lincomb_R_function + + .x_node <- matrix(rnorm(300), ncol = 3) + .y_node <- matrix(rnorm(100), ncol = 1) + .w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1) + + out <- try(input(.x_node, .y_node, .w_node), silent = FALSE) + + if(is_error(out)){ + + stop("user-supplied function encountered an error when it was tested. ", + "Please make sure the function works for this case:\n\n", + ".x_node <- matrix(seq(-1, 1, length.out = 300), ncol = 3)\n\n", + "test_time <- seq(from = 1, to = 5, length.out = 100)\n", + "test_status <- rep(c(0,1), each = 50)\n", + ".y_node <- cbind(time = test_time, status = test_status)\n\n", + ".w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)\n\n", + "test_output <- beta_fun(.x_node, .y_node, .w_node)\n\n", + "test_output should be a numeric matrix with 1 column and", + " with nrow(test_output) = ncol(.x_node)", + call. = FALSE) + + } + + out + + }, + init_control = function(){ self$control <- orsf_control_regression(method = 'glm', diff --git a/R/orsf_control.R b/R/orsf_control.R index 0875371c..4e327c0b 100644 --- a/R/orsf_control.R +++ b/R/orsf_control.R @@ -366,12 +366,10 @@ orsf_control <- function(tree_type, arg_name = 'method', expected_length = 1) - } else { - - check_beta_fun(method) - } + # checking of custom functions is done when orsf object is initialized + check_arg_type(arg_value = scale_x, arg_name = 'scale_x', expected_type = 'logical') diff --git a/man/orsf_control_custom.Rd b/man/orsf_control_custom.Rd index 64cdc2a5..a1e5096e 100644 --- a/man/orsf_control_custom.Rd +++ b/man/orsf_control_custom.Rd @@ -67,10 +67,10 @@ fit_rando ## N trees: 500 ## N predictors total: 17 ## N predictors per node: 5 -## Average leaves per tree: 19.76 +## Average leaves per tree: 19.682 ## Min observations in leaf: 5 ## Min events in leaf: 1 -## OOB stat value: 0.84 +## OOB stat value: 0.83 ## OOB stat type: Harrell's C-index ## Variable importance: anova ## @@ -135,15 +135,15 @@ The PCA ORSF does quite well! (higher IPA is better) ## ## model times Brier lower upper IPA ## 1: Null model 1788 20.479 18.090 22.868 0.000 -## 2: rando 1788 11.808 9.729 13.888 42.339 -## 3: pca 1788 12.861 10.833 14.889 37.199 +## 2: rando 1788 11.872 9.771 13.972 42.031 +## 3: pca 1788 12.990 10.971 15.009 36.569 ## ## Results of model comparisons: ## ## times model reference delta.Brier lower upper p -## 1: 1788 rando Null model -8.671 -10.813 -6.528 2.142609e-15 -## 2: 1788 pca Null model -7.618 -9.379 -5.857 2.300807e-17 -## 3: 1788 pca rando 1.053 0.241 1.865 1.106784e-02 +## 1: 1788 rando Null model -8.607 -10.809 -6.406 1.832790e-14 +## 2: 1788 pca Null model -7.489 -9.213 -5.765 1.664802e-17 +## 3: 1788 pca rando 1.118 0.258 1.979 1.087482e-02 ## ## NOTE: Values are multiplied by 100 and given in \%. diff --git a/src/TreeClassification.cpp b/src/TreeClassification.cpp index bdfae897..cf1ac6c5 100644 --- a/src/TreeClassification.cpp +++ b/src/TreeClassification.cpp @@ -209,14 +209,71 @@ } - arma::uword TreeClassification::find_safe_mtry(){ + uword TreeClassification::find_safe_mtry(){ + + if(binary){ return find_safe_mtry_binary(); } + + return find_safe_mtry_multiclass(); + + + } + + uword TreeClassification::find_safe_mtry_binary(){ + + // conditions to split a column: + // >= 3 events per predictor + // >= 3 non-events per predictor double safer_mtry = mtry; + double y_sum_ctrls = sum(y_node.col(0)); + double y_sum_cases = sum(y_node.col(1)); + + if(verbosity > 3){ + Rcout << " -- Y sums (unweighted): "; + Rcout << y_sum_cases << " cases, "; + Rcout << y_sum_ctrls << " controls" << std::endl; + } + + splittable_y_cols.zeros(1); + + if(y_sum_cases >= 3 && y_sum_ctrls >= 3){ + + splittable_y_cols[0] = 1; + y_col_split = 1; + + double min_count = y_sum_cases; + if(y_sum_cases > y_sum_ctrls) min_count = y_sum_ctrls; + + if(lincomb_type != LC_GLMNET){ + + while (min_count / safer_mtry < 3){ + --safer_mtry; + } + + } + + uword out = safer_mtry; + + return(out); + + } + + if(verbosity > 3){ + Rcout << " -- No y columns are splittable"; + Rcout << std::endl << std::endl; + } + + return 0; + + } + uword TreeClassification::find_safe_mtry_multiclass(){ // conditions to split a column: // >= 3 events per predictor // >= 3 non-events per predictor + double safer_mtry = mtry; + double n = y_node.n_rows; vec y_sum_cases = sum(y_node, 0).t(); vec y_sum_ctrls = n - y_sum_cases; @@ -286,7 +343,7 @@ // glmnet can handle higher dimension x, // but other methods probably cannot. - if(lincomb_type != LC_GLM){ + if(lincomb_type != LC_GLMNET){ while (best_count / safer_mtry < 3){ --safer_mtry; diff --git a/src/TreeClassification.h b/src/TreeClassification.h index 5d13973f..2b0050e3 100644 --- a/src/TreeClassification.h +++ b/src/TreeClassification.h @@ -49,6 +49,8 @@ bool oobag) override; arma::uword find_safe_mtry() override; + arma::uword find_safe_mtry_binary(); + arma::uword find_safe_mtry_multiclass(); double compute_prediction_accuracy_internal(arma::mat& preds) override; diff --git a/src/utility.cpp b/src/utility.cpp index a7543dbd..a44396d0 100644 --- a/src/utility.cpp +++ b/src/utility.cpp @@ -725,7 +725,16 @@ return(result); } - vec beta_var = diagvec(inv(-hessian)); + mat inv_hess; + + bool invertible = inv(inv_hess, -hessian); + + if(!invertible){ + mat result(x_node.n_cols, 2, fill::zeros); + return(result); + } + + vec beta_var = inv_hess.diag(); if(do_scale) unscale_outputs(x_node, beta, beta_var, x_transforms); diff --git a/tests/testthat/test-orsf.R b/tests/testthat/test-orsf.R index c28e3e65..6270e6d4 100644 --- a/tests/testthat/test-orsf.R +++ b/tests/testthat/test-orsf.R @@ -341,8 +341,13 @@ test_that( expect_equal_leaf_summary(fits_surv$pbc_status_12, fit_standard_pbc$fast) - expect_equal_oobag_eval(fits_surv$pbc_scaled, fit_standard_pbc$fast) - expect_equal_oobag_eval(fits_surv$pbc_noised, fit_standard_pbc$fast) + expect_equal_oobag_eval(fits_surv$pbc_scaled, + fit_standard_pbc$fast, + tolerance = .01) + + expect_equal_oobag_eval(fits_surv$pbc_noised, + fit_standard_pbc$fast, + tolerance = .01) fits_clsf <- lapply(data_list_penguins[-1], function(data){ orsf(data, @@ -352,8 +357,13 @@ test_that( tree_seeds = seeds_standard) }) - expect_equal_oobag_eval(fits_clsf$penguins_scaled, fit_standard_penguins$fast) - expect_equal_oobag_eval(fits_clsf$penguins_noised, fit_standard_penguins$fast) + expect_equal_oobag_eval(fits_clsf$penguins_scaled, + fit_standard_penguins$fast, + tolerance = .01) + + expect_equal_oobag_eval(fits_clsf$penguins_noised, + fit_standard_penguins$fast, + tolerance = .01) fits_regr <- lapply(data_list_mtcars[-1], function(data){ orsf(data, @@ -363,8 +373,13 @@ test_that( tree_seeds = seeds_standard) }) - expect_equal_oobag_eval(fits_regr$mtcars_scaled, fit_standard_mtcars$fast) - expect_equal_oobag_eval(fits_regr$mtcars_noised, fit_standard_mtcars$fast) + expect_equal_oobag_eval(fits_regr$mtcars_scaled, + fit_standard_mtcars$fast, + tolerance = .01) + + expect_equal_oobag_eval(fits_regr$mtcars_noised, + fit_standard_mtcars$fast, + tolerance = .01) } ) @@ -487,6 +502,110 @@ test_that( } ) +test_that( + desc = 'weights work as intended', + code = { + + fit_unwtd <- orsf(pbc_orsf, + Surv(time, status) ~ . - id, + n_tree = n_tree_test) + + fit_wtd <- orsf(pbc_orsf, + Surv(time, status) ~ . - id, + weights = rep(2, nrow(pbc_orsf)), + n_tree = n_tree_test) + + # using weights should make the trees much deeper: + expect_gt(fit_wtd$get_mean_leaves_per_tree(), + fit_unwtd$get_mean_leaves_per_tree()) + + } +) + + + +test_that( + desc = 'user-supplied beta functions are vetted', + code = { + + f_bad_1 <- function(a_node, y_node, w_node){ 1 } + f_bad_2 <- function(x_node, a_node, w_node){ 1 } + f_bad_3 <- function(x_node, y_node, a_node){ 1 } + f_bad_4 <- function(x_node, y_node){ 1 } + + f_bad_5 <- function(x_node, y_node, w_node) { + stop("an expected error occurred") + } + + f_bad_6 <- function(x_node, y_node, w_node){ + return(matrix(0, ncol = 2, nrow = ncol(x_node))) + } + + f_bad_7 <- function(x_node, y_node, w_node){ + return(matrix(0, ncol = 1, nrow = 2)) + } + + f_bad_8 <- function(x_node, y_node, w_node) {runif(n = ncol(x_node))} + + expect_error( + orsf(pbc, time + status ~ ., + control = orsf_control_survival(method = f_bad_1)), + 'x_node' + ) + expect_error( + orsf(pbc, time + status ~ ., + control = orsf_control_survival(method = f_bad_2)), + 'y_node' + ) + expect_error( + orsf(pbc, time + status ~ ., + control = orsf_control_survival(method = f_bad_3)), + 'w_node' + ) + expect_error( + orsf(pbc, time + status ~ ., + control = orsf_control_survival(method = f_bad_4)), + 'should have 3' + ) + expect_error( + orsf(pbc, time + status ~ ., + control = orsf_control_survival(method = f_bad_5)), + 'encountered an error' + ) + expect_error( + orsf(pbc, time + status ~ ., + control = orsf_control_survival(method = f_bad_6)), + 'with 1 column' + ) + expect_error( + orsf(pbc, time + status ~ ., + control = orsf_control_survival(method = f_bad_7)), + 'with 1 row for each' + ) + expect_error( + orsf(pbc, time + status ~ ., + control = orsf_control_survival(method = f_bad_8)), + 'matrix output' + ) + + } +) + +test_that( + desc = "user supplied beta functions are applied correctly", + code = { + + fit_pca = orsf(pbc, + Surv(time, status) ~ ., + tree_seeds = seeds_standard, + control = orsf_control_survival(method = f_pca), + n_tree = n_tree_test) + + expect_gt(fit_pca$eval_oobag$stat_values, .785) + + } +) + test_that( desc = 'oblique survival forests run as intended for valid inputs', code = { @@ -846,32 +965,3 @@ test_that( } ) - -test_that( - desc = 'weights work as intended', - code = { - - fit_unwtd <- orsf(pbc_orsf, - Surv(time, status) ~ . - id, - n_tree = n_tree_test) - - fit_wtd <- orsf(pbc_orsf, - Surv(time, status) ~ . - id, - weights = rep(2, nrow(pbc_orsf)), - n_tree = n_tree_test) - - # using weights should make the trees much deeper: - expect_gt(fit_wtd$get_mean_leaves_per_tree(), - fit_unwtd$get_mean_leaves_per_tree()) - - } -) - - - - - - - - - diff --git a/tests/testthat/test-orsf_control.R b/tests/testthat/test-orsf_control.R index a0c931e6..9c68579f 100644 --- a/tests/testthat/test-orsf_control.R +++ b/tests/testthat/test-orsf_control.R @@ -14,34 +14,6 @@ test_that( expect_error(orsf_control_survival(net_mix = 32), 'should be <= 1') - f_bad_1 <- function(a_node, y_node, w_node){ 1 } - f_bad_2 <- function(x_node, a_node, w_node){ 1 } - f_bad_3 <- function(x_node, y_node, a_node){ 1 } - f_bad_4 <- function(x_node, y_node){ 1 } - - f_bad_5 <- function(x_node, y_node, w_node) { - stop("an expected error occurred") - } - - f_bad_6 <- function(x_node, y_node, w_node){ - return(matrix(0, ncol = 2, nrow = ncol(x_node))) - } - - f_bad_7 <- function(x_node, y_node, w_node){ - return(matrix(0, ncol = 1, nrow = 2)) - } - - f_bad_8 <- function(x_node, y_node, w_node) {runif(n = ncol(x_node))} - - expect_error(orsf_control_survival(method = f_bad_1), 'x_node') - expect_error(orsf_control_survival(method = f_bad_2), 'y_node') - expect_error(orsf_control_survival(method = f_bad_3), 'w_node') - expect_error(orsf_control_survival(method = f_bad_4), 'should have 3') - expect_error(orsf_control_survival(method = f_bad_5), 'encountered an error') - expect_error(orsf_control_survival(method = f_bad_6), 'with 1 column') - expect_error(orsf_control_survival(method = f_bad_7), 'with 1 row for each') - expect_error(orsf_control_survival(method = f_bad_8), 'matrix output') - f_rando <- function(x_node, y_node, w_node) { matrix(runif(ncol(x_node)), ncol=1) } expect_s3_class(orsf_control_survival(method = f_rando), 'orsf_control') @@ -49,21 +21,3 @@ test_that( } ) - - -test_that( - desc = 'custom orsf_control predictions are good', - code = { - - fit_pca = orsf(pbc, - Surv(time, status) ~ ., - tree_seeds = seeds_standard, - control = orsf_control_survival(method = f_pca), - n_tree = n_tree_test) - - expect_gt(fit_pca$eval_oobag$stat_values, .65) - - } -) - -