Skip to content

Commit

Permalink
deprecate the old controls
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Nov 10, 2023
1 parent e0a79fb commit f47de2d
Show file tree
Hide file tree
Showing 34 changed files with 749 additions and 260 deletions.
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Description: Fit, interpret, and make predictions with oblique random survival f
License: MIT + file LICENSE
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE, roclets = c ("namespace", "rd"))
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
LinkingTo:
Rcpp,
Expand All @@ -38,8 +38,9 @@ Imports:
Rcpp,
data.table,
utils,
collapse,
R6
collapse,
R6,
lifecycle
URL: https://github.com/ropensci/aorsf,
https://docs.ropensci.org/aorsf/
BugReports: https://github.com/ropensci/aorsf/issues/
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ S3method(predict,ObliqueForest)
S3method(print,ObliqueForest)
S3method(print,orsf_summary_uni)
export(orsf)
export(orsf_control)
export(orsf_control_classification)
export(orsf_control_cph)
export(orsf_control_custom)
export(orsf_control_fast)
export(orsf_control_net)
export(orsf_control_regression)
export(orsf_control_survival)
export(orsf_ice_inb)
export(orsf_ice_new)
export(orsf_ice_oob)
Expand All @@ -29,4 +33,5 @@ import(R6)
import(data.table)
importFrom(Rcpp,sourceCpp)
importFrom(collapse,"%==%")
importFrom(lifecycle,deprecated)
useDynLib(aorsf, .registration = TRUE)
3 changes: 2 additions & 1 deletion R/aorsf-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
# The following block is used by usethis to automatically manage
# roxygen namespace tags. Modify with care!
## usethis namespace: start
#' @importFrom Rcpp sourceCpp
#' @importFrom collapse %==%
#' @importFrom lifecycle deprecated
#' @importFrom Rcpp sourceCpp
#' @useDynLib aorsf, .registration = TRUE
## usethis namespace: end
NULL
67 changes: 48 additions & 19 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ ObliqueForest <- R6::R6Class(

private$check_pred_horizon(pred_horizon, boundary_checks, pred_type)

if(is.null(pred_horizon)) pred_horizon <- 1
pred_horizon <- pred_horizon %||% self$pred_horizon %||% 1

pred_horizon_order <- order(pred_horizon)
pred_horizon_ordered <- pred_horizon[pred_horizon_order]

Expand Down Expand Up @@ -435,11 +436,11 @@ ObliqueForest <- R6::R6Class(
private$prep_x()
# y and w do not need to be prepped for prediction,
# but they need to match orsf_cpp()'s expectations
private$y <- matrix(0, nrow = nrow(private$x), ncol = 1)
private$prep_y(placeholder = TRUE)
private$w <- rep(1, nrow(private$x))


if(oobag){ private$sort_inputs() }
if(oobag){ private$sort_inputs(sort_y = FALSE) }

# the values in pred_spec need to be centered & scaled to match x_new,
# which is also centered and scaled
Expand Down Expand Up @@ -660,6 +661,8 @@ ObliqueForest <- R6::R6Class(

if(self$tree_type == 'classification'){
setnames(out, old = 'pred_horizon', new = 'class')
out[, class := factor(class, levels = self$class_levels)]
setkey(out, class)
}

if(self$tree_type == 'survival' && pred_type != 'mort')
Expand Down Expand Up @@ -878,13 +881,14 @@ ObliqueForest <- R6::R6Class(
# To avoid this: include a DT[] after the last := in your function.
pd_output[]

setcolorder(pd_output, c('variable',
'importance',
'value',
'mean',
'medn',
'lwr',
'upr'))
new_order <- c('variable', 'importance', 'value',
'mean', 'medn', 'lwr', 'upr')

if(self$tree_type == 'classification'){
new_order <- insert_vals(new_order, 2, 'class')
}

setcolorder(pd_output, new_order)

structure(
.Data = list(dt = pd_output,
Expand Down Expand Up @@ -2316,14 +2320,14 @@ ObliqueForest <- R6::R6Class(

},

prep_y = function(){
prep_y = function(placeholder = FALSE){

private$y <- select_cols(self$data, private$data_names$y)

if(self$na_action == 'omit')
if(self$na_action == 'omit' && !placeholder)
private$y <- private$y[private$data_rows_complete, ]

private$prep_y_internal()
private$prep_y_internal(placeholder)

},

Expand Down Expand Up @@ -2686,18 +2690,27 @@ ObliqueForestSurvival <- R6::R6Class(

},

sort_inputs = function(){
sort_inputs = function(sort_x = TRUE,
sort_y = TRUE,
sort_w = TRUE){

private$x <- private$x[private$data_row_sort, , drop = FALSE]
private$y <- private$y[private$data_row_sort, , drop = FALSE]
private$w <- private$w[private$data_row_sort]
if(sort_x)
private$x <- private$x[private$data_row_sort, , drop = FALSE]
if(sort_y)
private$y <- private$y[private$data_row_sort, , drop = FALSE]
if(sort_w)
private$w <- private$w[private$data_row_sort]

},

init_internal = function(){

self$tree_type <- "survival"

if(!is.function(self$control$lincomb_R_function) &&
self$control$lincomb_type == 'net'){
self$control$lincomb_R_function <- penalized_cph
}

self$split_rule <- self$split_rule %||% 'logrank'
self$pred_type <- self$pred_type %||% 'surv'
Expand Down Expand Up @@ -2761,7 +2774,13 @@ ObliqueForestSurvival <- R6::R6Class(
}

},
prep_y_internal = function(){
prep_y_internal = function(placeholder = FALSE){


if(placeholder){
private$y <- matrix(0, ncol = 2, nrow = 1)
return()
}

y <- private$y
cols <- names(y)
Expand Down Expand Up @@ -2979,6 +2998,11 @@ ObliqueForestClassification <- R6::R6Class(

self$tree_type <- "classification"

if(!is.function(self$control$lincomb_R_function) &&
self$control$lincomb_type == 'net'){
self$control$lincomb_R_function <- penalized_logreg
}

self$split_rule <- self$split_rule %||% 'gini'
self$pred_type <- self$pred_type %||% 'prob'
self$split_min_stat <- self$split_min_stat %||%
Expand All @@ -3003,7 +3027,12 @@ ObliqueForestClassification <- R6::R6Class(

},

prep_y_internal = function(){
prep_y_internal = function(placeholder = FALSE){

if(placeholder){
private$y <- matrix(0, ncol = self$n_class-1, nrow = 1)
return()
}

# y is always 1 column for classification (right?)
y <- private$y[[1]]
Expand Down
Loading

0 comments on commit f47de2d

Please sign in to comment.