Skip to content

Commit

Permalink
Merge pull request #22 from ropensci/issue21
Browse files Browse the repository at this point in the history
Issue21
  • Loading branch information
bcjaeger authored Oct 15, 2023
2 parents 60b0a36 + 46cb7b3 commit a0eb263
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 20 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: aorsf
Title: Accelerated Oblique Random Survival Forests
Version: 0.1.1
Version: 0.1.1.9001
Authors@R: c(
person(given = "Byron",
family = "Jaeger",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# aorsf 0.1.1.9000

* optimization implemented for matrix multiplication during prediction (https://github.com/ropensci/aorsf/pull/20)

# aorsf 0.1.1

* fixed an uninitialized value for `pd_type`
Expand Down
39 changes: 39 additions & 0 deletions R/infer.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,42 @@ infer_pred_horizon <- function(object, pred_type, pred_horizon){
pred_horizon

}


#' helper for guessing outcome type
#'
#' @param names_y_data character vector of outcome names
#' @param data dataset containing outcomes
#'
#' @return character value: 'survival', 'regression' or 'classification'
#'
#' @examples
#'
#' infer_outcome_type('bili', pbc_orsf)
#' infer_outcome_type('sex', pbc_orsf)
#' infer_outcome_type(c('time', 'status'), pbc_orsf)
#' infer_outcome_type(Surv(pbc_orsf$time, pbc_orsf$status), pbc_orsf)
#'
#' @noRd
infer_outcome_type <- function(names_y_data, data){

if(length(names_y_data) > 2){
stop("formula should have at most two variables as the response",
call. = FALSE)
}

if(length(names_y_data) == 2) {
return("survival")
}

if(is.factor(data[[names_y_data]])){
return("classification")
} else if(inherits(data[[names_y_data]], 'Surv')) {
return("survival")
} else {
return("regression")
}

stop("could not infer outcome type", call. = FALSE)

}
21 changes: 6 additions & 15 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -491,28 +491,21 @@ orsf <- function(data,
net_alpha <- control_net$net_alpha
net_df_target <- control_net$net_df_target


formula_terms <- suppressWarnings(stats::terms(formula, data=data))

if(attr(formula_terms, 'response') == 0)
stop("formula must have a response", call. = FALSE)

names_y_data <- all.vars(formula[[2]])

if(length(names_y_data) == 1){
# this is fine if the response is a Surv object,
if(!inherits(data[[names_y_data]], 'Surv')){
# otherwise it will be a problem
stop("formula must have two variables (time & status) as the response",
call. = FALSE)
}
outcome_type <- infer_outcome_type(names_y_data, data)

}
if(outcome_type %in% c('regression', 'classification')) stop("not ready yet")

if(length(names_y_data) > 2){
stop("formula must have two variables (time & status) as the response",
call. = FALSE)
}
tree_type_R = switch(outcome_type,
'classification' = 1,
'regression'= 2,
'survival' = 3)

types_y_data <- vector(mode = 'character',
length = length(names_y_data))
Expand Down Expand Up @@ -741,9 +734,7 @@ orsf <- function(data,
tree_seeds <- sample(x = n_tree*10, size = n_tree, replace = FALSE)
}


vi_max_pvalue = 0.01
tree_type_R = 3

orsf_out <- orsf_cpp(x = x_sort,
y = y_sort,
Expand Down
9 changes: 9 additions & 0 deletions src/orsf_oop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,15 @@
// re-cast integer inputs from R into enumerations
// see globals.h for definitions.
TreeType tree_type = (TreeType) tree_type_R;

if(tree_type == TREE_CLASSIFICATION ||
tree_type == TREE_PROBABILITY ||
tree_type == TREE_REGRESSION){

stop("that tree type is not ready yet");

}

VariableImportance vi_type = (VariableImportance) vi_type_R;
SplitRule split_rule = (SplitRule) split_rule_R;
LinearCombo lincomb_type = (LinearCombo) lincomb_type_R;
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test-infer.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,16 @@ test_that(
)


test_that(
desc = 'inferred outcome type is correct',
code = {

pbc$surv_y <- Surv(pbc_orsf$time, pbc_orsf$status)

expect_equal(infer_outcome_type(c('time', 'status'), pbc), 'survival')
expect_equal(infer_outcome_type('surv_y', pbc), 'survival')
expect_equal(infer_outcome_type('age', pbc), 'regression')
expect_equal(infer_outcome_type('sex', pbc), 'classification')

}
)
5 changes: 1 addition & 4 deletions tests/testthat/test-orsf_formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ test_that(
'not_right')

expect_error(orsf(pbc_orsf, Surv(start, time, status) ~ .),
'must have two variables')

expect_error(orsf(pbc_orsf, Surv(time, time) ~ . - id),
'must have two variables')
'should have at most two variables')

expect_error(orsf(pbc_orsf, Surv(time, id) ~ . -id),
'detected >1 event type')
Expand Down

0 comments on commit a0eb263

Please sign in to comment.