Skip to content

Commit

Permalink
migrating control tests to orsf so that its easier to do family-speci…
Browse files Browse the repository at this point in the history
…fic control tests
  • Loading branch information
bcjaeger committed Nov 16, 2023
1 parent 6ebb48a commit 6e6b1a2
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 96 deletions.
163 changes: 162 additions & 1 deletion R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,7 @@ ObliqueForest <- R6::R6Class(


private$init_oobag_eval_function()
private$init_lincomb_R_function()
private$init_oobag_pred_mode()
private$init_tree_seeds()
private$init_internal()
Expand Down Expand Up @@ -1226,6 +1227,16 @@ ObliqueForest <- R6::R6Class(

},

init_lincomb_R_function = function(){

if(self$control$lincomb_type == 'custom'){

private$check_lincomb_R_function(self$control$lincomb_R_function)

}

},

# checkers
check_data = function(data = NULL, new = FALSE){

Expand Down Expand Up @@ -2010,6 +2021,60 @@ ObliqueForest <- R6::R6Class(
}

},


check_lincomb_R_function = function(lincomb_R_function = NULL){

input <- lincomb_R_function %||% self$lincomb_R_function

args <- names(formals(input))

if(length(args) != 3) stop(
"input should have 3 input arguments but instead has ",
length(args),
call. = FALSE
)

arg_names_expected <- c("x_node",
"y_node",
"w_node")

arg_names_refer <- c('first', 'second', 'third')

for(i in seq_along(arg_names_expected)){
if(args[i] != arg_names_expected[i])
stop(
"the ", arg_names_refer[i], " input argument of input ",
"should be named '", arg_names_expected[i],"' ",
"but is instead named '", args[i], "'",
call. = FALSE
)
}

test_output <- private$check_lincomb_R_function_internal(input)

if(!is.matrix(test_output)) stop(
"user-supplied function should return a matrix output ",
"but instead returns output of type ", class(test_output)[1],
call. = FALSE
)

if(ncol(test_output) != 1) stop(
"user-supplied function should return a matrix with 1 column ",
"but instead returns a matrix with ", ncol(test_output), " columns.",
call. = FALSE
)

if(nrow(test_output) != 3L) stop(
"user-supplied function should return a matrix with 1 row for each ",
" column in x_node but instead returns a matrix with ",
nrow(test_output), " rows ", "in a testing case where x_node has ",
3L, " columns",
call. = FALSE
)

},

check_verbose_progress = function(verbose_progress = NULL){

input <- verbose_progress %||% self$verbose_progress
Expand Down Expand Up @@ -2433,7 +2498,9 @@ ObliqueForest <- R6::R6Class(

},

sort_inputs = function(){
sort_inputs = function(sort_y = NULL,
sort_x = NULL,
sort_w = NULL){
NULL
},

Expand Down Expand Up @@ -2734,6 +2801,40 @@ ObliqueForestSurvival <- R6::R6Class(

},

check_lincomb_R_function_internal = function(lincomb_R_function = NULL){

input <- lincomb_R_function %||% self$lincomb_R_function

test_time <- seq(from = 1, to = 5, length.out = 100)
test_status <- rep(c(0,1), each = 50)

.x_node <- matrix(rnorm(300), ncol = 3)
.y_node <- cbind(time = test_time, status = test_status)
.w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)


out <- try(input(.x_node, .y_node, .w_node), silent = FALSE)

if(is_error(out)){

stop("user-supplied function encountered an error when it was tested. ",
"Please make sure the function works for this case:\n\n",
"test_time <- seq(from = 1, to = 5, length.out = 100)\n",
"test_status <- rep(c(0,1), each = 50)\n\n",
".x_node <- matrix(seq(-1, 1, length.out = 300), ncol = 3)\n",
".y_node <- cbind(time = test_time, status = test_status)\n",
".w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)\n\n",
"test_output <- user_function(.x_node, .y_node, .w_node)\n\n",
"test_output should be a numeric matrix with 1 column and",
" with nrow(test_output) = ncol(.x_node)",
call. = FALSE)

}

out

},

sort_inputs = function(sort_x = TRUE,
sort_y = TRUE,
sort_w = TRUE){
Expand Down Expand Up @@ -3131,6 +3232,34 @@ ObliqueForestClassification <- R6::R6Class(

},

check_lincomb_R_function_internal = function(lincomb_R_function = NULL){

input <- lincomb_R_function %||% self$lincomb_R_function

.x_node <- matrix(rnorm(300), ncol = 3)
.y_node <- matrix(rbinom(100, size = 1, prob = 1/2), ncol = 1)
.w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)

out <- try(input(.x_node, .y_node, .w_node), silent = FALSE)

if(is_error(out)){

stop("user-supplied function encountered an error when it was tested. ",
"Please make sure the function works for this case:\n\n",
".x_node <- matrix(rnorm(300), ncol = 3)\n",
".y_node <- matrix(rbinom(100, size = 1, prob = 1/2), ncol = 1)\n",
".w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)\n",
"test_output <- your_function(.x_node, .y_node, .w_node)\n\n",
"test_output should be a numeric matrix with 1 column and",
" with nrow(test_output) = ncol(.x_node)",
call. = FALSE)

}

out

},

init_control = function(){

self$control <- orsf_control_classification(method = 'glm',
Expand Down Expand Up @@ -3215,6 +3344,8 @@ ObliqueForestClassification <- R6::R6Class(

y <- as.numeric(y) - 1

if(min(y) > 0) browser()

private$y <- expand_y_clsf(as_matrix(y), n_class)

},
Expand Down Expand Up @@ -3320,6 +3451,36 @@ ObliqueForestRegression <- R6::R6Class(

},

check_lincomb_R_function_internal = function(lincomb_R_function = NULL){

input <- lincomb_R_function %||% self$lincomb_R_function

.x_node <- matrix(rnorm(300), ncol = 3)
.y_node <- matrix(rnorm(100), ncol = 1)
.w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)

out <- try(input(.x_node, .y_node, .w_node), silent = FALSE)

if(is_error(out)){

stop("user-supplied function encountered an error when it was tested. ",
"Please make sure the function works for this case:\n\n",
".x_node <- matrix(seq(-1, 1, length.out = 300), ncol = 3)\n\n",
"test_time <- seq(from = 1, to = 5, length.out = 100)\n",
"test_status <- rep(c(0,1), each = 50)\n",
".y_node <- cbind(time = test_time, status = test_status)\n\n",
".w_node <- matrix(rep(c(1,2,3,4), each = 25), ncol = 1)\n\n",
"test_output <- beta_fun(.x_node, .y_node, .w_node)\n\n",
"test_output should be a numeric matrix with 1 column and",
" with nrow(test_output) = ncol(.x_node)",
call. = FALSE)

}

out

},

init_control = function(){

self$control <- orsf_control_regression(method = 'glm',
Expand Down
6 changes: 2 additions & 4 deletions R/orsf_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -366,12 +366,10 @@ orsf_control <- function(tree_type,
arg_name = 'method',
expected_length = 1)

} else {

check_beta_fun(method)

}

# checking of custom functions is done when orsf object is initialized

check_arg_type(arg_value = scale_x,
arg_name = 'scale_x',
expected_type = 'logical')
Expand Down
14 changes: 7 additions & 7 deletions man/orsf_control_custom.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 59 additions & 2 deletions src/TreeClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,71 @@

}

arma::uword TreeClassification::find_safe_mtry(){
uword TreeClassification::find_safe_mtry(){

if(binary){ return find_safe_mtry_binary(); }

return find_safe_mtry_multiclass();


}

uword TreeClassification::find_safe_mtry_binary(){

// conditions to split a column:
// >= 3 events per predictor
// >= 3 non-events per predictor

double safer_mtry = mtry;
double y_sum_ctrls = sum(y_node.col(0));
double y_sum_cases = sum(y_node.col(1));

if(verbosity > 3){
Rcout << " -- Y sums (unweighted): ";
Rcout << y_sum_cases << " cases, ";
Rcout << y_sum_ctrls << " controls" << std::endl;
}

splittable_y_cols.zeros(1);

if(y_sum_cases >= 3 && y_sum_ctrls >= 3){

splittable_y_cols[0] = 1;
y_col_split = 1;

double min_count = y_sum_cases;
if(y_sum_cases > y_sum_ctrls) min_count = y_sum_ctrls;

if(lincomb_type != LC_GLMNET){

while (min_count / safer_mtry < 3){
--safer_mtry;
}

}

uword out = safer_mtry;

return(out);

}

if(verbosity > 3){
Rcout << " -- No y columns are splittable";
Rcout << std::endl << std::endl;
}

return 0;

}
uword TreeClassification::find_safe_mtry_multiclass(){

// conditions to split a column:
// >= 3 events per predictor
// >= 3 non-events per predictor

double safer_mtry = mtry;

double n = y_node.n_rows;
vec y_sum_cases = sum(y_node, 0).t();
vec y_sum_ctrls = n - y_sum_cases;
Expand Down Expand Up @@ -286,7 +343,7 @@

// glmnet can handle higher dimension x,
// but other methods probably cannot.
if(lincomb_type != LC_GLM){
if(lincomb_type != LC_GLMNET){

while (best_count / safer_mtry < 3){
--safer_mtry;
Expand Down
2 changes: 2 additions & 0 deletions src/TreeClassification.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
bool oobag) override;

arma::uword find_safe_mtry() override;
arma::uword find_safe_mtry_binary();
arma::uword find_safe_mtry_multiclass();

double compute_prediction_accuracy_internal(arma::mat& preds) override;

Expand Down
11 changes: 10 additions & 1 deletion src/utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,16 @@
return(result);
}

vec beta_var = diagvec(inv(-hessian));
mat inv_hess;

bool invertible = inv(inv_hess, -hessian);

if(!invertible){
mat result(x_node.n_cols, 2, fill::zeros);
return(result);
}

vec beta_var = inv_hess.diag();

if(do_scale) unscale_outputs(x_node, beta, beta_var, x_transforms);

Expand Down
Loading

0 comments on commit 6e6b1a2

Please sign in to comment.