diff --git a/DESCRIPTION b/DESCRIPTION index a0209353..444f25f3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: aorsf Title: Accelerated Oblique Random Survival Forests -Version: 0.1.1 +Version: 0.1.1.9001 Authors@R: c( person(given = "Byron", family = "Jaeger", diff --git a/NEWS.md b/NEWS.md index 42380d32..ae4f0f8e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +# aorsf 0.1.1.9000 + +* optimization implemented for matrix multiplication during prediction (https://github.com/ropensci/aorsf/pull/20) + # aorsf 0.1.1 * fixed an uninitialized value for `pd_type` diff --git a/R/infer.R b/R/infer.R index 727865b2..075c5bd7 100644 --- a/R/infer.R +++ b/R/infer.R @@ -34,3 +34,42 @@ infer_pred_horizon <- function(object, pred_type, pred_horizon){ pred_horizon } + + +#' helper for guessing outcome type +#' +#' @param names_y_data character vector of outcome names +#' @param data dataset containing outcomes +#' +#' @return character value: 'survival', 'regression' or 'classification' +#' +#' @examples +#' +#' infer_outcome_type('bili', pbc_orsf) +#' infer_outcome_type('sex', pbc_orsf) +#' infer_outcome_type(c('time', 'status'), pbc_orsf) +#' infer_outcome_type(Surv(pbc_orsf$time, pbc_orsf$status), pbc_orsf) +#' +#' @noRd +infer_outcome_type <- function(names_y_data, data){ + + if(length(names_y_data) > 2){ + stop("formula should have at most two variables as the response", + call. = FALSE) + } + + if(length(names_y_data) == 2) { + return("survival") + } + + if(is.factor(data[[names_y_data]])){ + return("classification") + } else if(inherits(data[[names_y_data]], 'Surv')) { + return("survival") + } else { + return("regression") + } + + stop("could not infer outcome type", call. = FALSE) + +} diff --git a/R/orsf.R b/R/orsf.R index 24a8df11..453fef60 100644 --- a/R/orsf.R +++ b/R/orsf.R @@ -491,7 +491,6 @@ orsf <- function(data, net_alpha <- control_net$net_alpha net_df_target <- control_net$net_df_target - formula_terms <- suppressWarnings(stats::terms(formula, data=data)) if(attr(formula_terms, 'response') == 0) @@ -499,20 +498,14 @@ orsf <- function(data, names_y_data <- all.vars(formula[[2]]) - if(length(names_y_data) == 1){ - # this is fine if the response is a Surv object, - if(!inherits(data[[names_y_data]], 'Surv')){ - # otherwise it will be a problem - stop("formula must have two variables (time & status) as the response", - call. = FALSE) - } + outcome_type <- infer_outcome_type(names_y_data, data) - } + if(outcome_type %in% c('regression', 'classification')) stop("not ready yet") - if(length(names_y_data) > 2){ - stop("formula must have two variables (time & status) as the response", - call. = FALSE) - } + tree_type_R = switch(outcome_type, + 'classification' = 1, + 'regression'= 2, + 'survival' = 3) types_y_data <- vector(mode = 'character', length = length(names_y_data)) @@ -741,9 +734,7 @@ orsf <- function(data, tree_seeds <- sample(x = n_tree*10, size = n_tree, replace = FALSE) } - vi_max_pvalue = 0.01 - tree_type_R = 3 orsf_out <- orsf_cpp(x = x_sort, y = y_sort, diff --git a/src/orsf_oop.cpp b/src/orsf_oop.cpp index ec9d836f..cf90e38a 100644 --- a/src/orsf_oop.cpp +++ b/src/orsf_oop.cpp @@ -283,6 +283,15 @@ // re-cast integer inputs from R into enumerations // see globals.h for definitions. TreeType tree_type = (TreeType) tree_type_R; + + if(tree_type == TREE_CLASSIFICATION || + tree_type == TREE_PROBABILITY || + tree_type == TREE_REGRESSION){ + + stop("that tree type is not ready yet"); + + } + VariableImportance vi_type = (VariableImportance) vi_type_R; SplitRule split_rule = (SplitRule) split_rule_R; LinearCombo lincomb_type = (LinearCombo) lincomb_type_R; diff --git a/tests/testthat/test-infer.R b/tests/testthat/test-infer.R index 0b6e76dd..efb1c428 100644 --- a/tests/testthat/test-infer.R +++ b/tests/testthat/test-infer.R @@ -38,4 +38,16 @@ test_that( ) +test_that( + desc = 'inferred outcome type is correct', + code = { + + pbc$surv_y <- Surv(pbc_orsf$time, pbc_orsf$status) + expect_equal(infer_outcome_type(c('time', 'status'), pbc), 'survival') + expect_equal(infer_outcome_type('surv_y', pbc), 'survival') + expect_equal(infer_outcome_type('age', pbc), 'regression') + expect_equal(infer_outcome_type('sex', pbc), 'classification') + + } +) diff --git a/tests/testthat/test-orsf_formula.R b/tests/testthat/test-orsf_formula.R index 7303c31c..d7e2ba04 100644 --- a/tests/testthat/test-orsf_formula.R +++ b/tests/testthat/test-orsf_formula.R @@ -28,10 +28,7 @@ test_that( 'not_right') expect_error(orsf(pbc_orsf, Surv(start, time, status) ~ .), - 'must have two variables') - - expect_error(orsf(pbc_orsf, Surv(time, time) ~ . - id), - 'must have two variables') + 'should have at most two variables') expect_error(orsf(pbc_orsf, Surv(time, id) ~ . -id), 'detected >1 event type')