Skip to content

Commit

Permalink
working but not tested
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Nov 9, 2023
1 parent ab8ef57 commit 9ebf6f2
Show file tree
Hide file tree
Showing 13 changed files with 276 additions and 310 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: aorsf
Title: Accelerated Oblique Random Survival Forests
Version: 0.1.1.9002
Version: 0.1.1.9003
Authors@R: c(
person(given = "Byron",
family = "Jaeger",
Expand Down
5 changes: 3 additions & 2 deletions R/data-penguins_orsf.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#' Size measurements for adult foraging penguins near Palmer Station, Antarctica
#'
#' These data are copied and lightly modified from the `palmerpenguins`
#' `penguins` data. The only modification is removal of rows
#' These data are copied and lightly modified from the `penguins` data in
#' the [palmerpenguins](https://allisonhorst.github.io/palmerpenguins/) R
#' package. The only modification is removal of rows
#' with missing data. The data include measurements for penguin species,
#' island in Palmer Archipelago, size (flipper length, body mass, bill
#' dimensions), and sex.
Expand Down
186 changes: 149 additions & 37 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ ObliqueForest <- R6::R6Class(
private$check_pred_aggregate(pred_aggregate)

if(self$tree_type == 'survival')
private$check_pred_horizon(boundary_checks, pred_horizon)
private$check_pred_horizon(pred_horizon, boundary_checks)

self$data <- new_data
self$pred_horizon <- pred_horizon
Expand Down Expand Up @@ -319,12 +319,6 @@ ObliqueForest <- R6::R6Class(

},

predict_internal = function(){

stop("this method should only be called from derived classes")

},

compute_vi = function(type_vi,
oobag_fun,
n_thread,
Expand Down Expand Up @@ -383,8 +377,8 @@ ObliqueForest <- R6::R6Class(
oobag,
type_output){

public_state <- list(data = self$data,
na_action = self$na_action)
public_state <- list(data = self$data,
na_action = self$na_action)

private_state <- list(data_rows_complete = private$data_rows_complete)

Expand Down Expand Up @@ -556,32 +550,6 @@ ObliqueForest <- R6::R6Class(

pd_vals <- results

# denominator issue in this (i think)
# cpp_args = private$prep_cpp_args(x = private$x,
# y = private$y,
# w = private$w,
# importance_type = 'none',
# pred_type = pred_type,
# pred_aggregate = TRUE,
# pred_horizon = pred_horizon_ordered,
# oobag = oobag,
# oobag_eval_type = 'none',
# pred_mode = FALSE,
# pred_aggregate = TRUE,
# write_forest = FALSE,
# run_forest = TRUE,
# pd_type_R = switch(type_output,
# "smry" = 1L,
# "ice" = 2L),
# pd_x_vals = pred_spec_new,
# pd_x_cols = x_cols,
# pd_probs = prob_values,
# verbosity = 0)
#
# orsf_out <- do.call(orsf_cpp, args = cpp_args)

# pd_vals <- orsf_out$pd_values

for(i in seq_along(pd_vals)){

pd_bind[[i]]$id_variable <- seq(nrow(pd_bind[[i]]))
Expand Down Expand Up @@ -769,6 +737,116 @@ ObliqueForest <- R6::R6Class(

},

summarize_uni = function(n_variables = NULL,
pred_horizon = NULL,
pred_type = NULL,
importance_type = NULL){

# check incoming values if they were specified.
private$check_n_variables(n_variables)
private$check_pred_horizon(pred_horizon, boundary_checks = TRUE)
private$check_pred_type(pred_type, oobag = FALSE)
private$check_importance_type(importance_type)

names_x <- private$data_names$x_original

# use existing values if incoming ones were not specified
n_variables <- n_variables %||% length(names_x)
pred_horizon <- pred_horizon %||% self$pred_horizon
pred_type <- pred_type %||% self$pred_type
importance_type <- importance_type %||% self$importance_type

# bindings for CRAN check
value <- NULL
level <- NULL

# TODO: make this go away. Just sort alphabetically if no importance
if(importance_type == 'none' && is_empty(self$importance_type))
stop("importance cannot be 'none' if object does not have variable",
" importance values.", call. = FALSE)

vi <- switch(
importance_type,
'anova' = orsf_vi_anova(self, group_factors = TRUE),
'negate' = orsf_vi_negate(self, group_factors = TRUE),
'permute' = orsf_vi_permute(self, group_factors = TRUE),
'none' = NULL
)

bounds <- private$data_bounds
fctrs <- private$data_fctrs
n_obs <- self$n_obs

names_vi <- names(vi) %||% names_x

pred_spec <- list_init(names_vi)[seq(n_variables)]

for(i in names(pred_spec)){

if(i %in% colnames(bounds)){

pred_spec[[i]] <- unique(
as.numeric(bounds[c('25%','50%','75%'), i])
)

} else if (i %in% fctrs$cols) {

pred_spec[[i]] <- fctrs$lvls[[i]]

}

}

pd_output <- orsf_pd_oob(object = self,
pred_spec = pred_spec,
expand_grid = FALSE,
pred_type = pred_type,
prob_values = c(0.25, 0.50, 0.75),
pred_horizon = pred_horizon)

fctrs_unordered <- c()

# did the orsf have factor variables?
if(!is_empty(fctrs$cols)){
fctrs_unordered <- fctrs$cols[!fctrs$ordr]
}

# some cart-wheels here for backward compatibility.
f <- as.factor(pd_output$variable)

name_rep <- rle(as.integer(f))

pd_output$importance <- rep(vi[levels(f)[name_rep$values]],
times = name_rep$lengths)

pd_output[, value := fifelse(test = is.na(value),
yes = as.character(level),
no = round_magnitude(value))]

# if a := is used inside a function with no DT[] before the end of the
# function, then the next time DT or print(DT) is typed at the prompt,
# nothing will be printed. A repeated DT or print(DT) will print.
# To avoid this: include a DT[] after the last := in your function.
pd_output[]

setcolorder(pd_output, c('variable',
'importance',
'value',
'mean',
'medn',
'lwr',
'upr'))

structure(
.Data = list(dt = pd_output,
pred_type = pred_type,
pred_horizon = pred_horizon),
class = 'orsf_summary_uni'
)


},

# getters

get_names_x = function(ref_coded = FALSE){
Expand Down Expand Up @@ -1387,6 +1465,39 @@ ObliqueForest <- R6::R6Class(
expected_length = 1)

},

check_n_variables = function(n_variables = NULL){

# n_variables is not a field of ObliqueForest,
# so it is only checked as an incoming input.

if(!is.null(n_variables)){

check_arg_type(arg_value = n_variables,
arg_name = 'n_variables',
expected_type = 'numeric')

check_arg_is_integer(arg_value = n_variables,
arg_name = 'n_variables')

check_arg_gteq(arg_value = n_variables,
arg_name = 'n_variables',
bound = 1)

check_arg_lteq(arg_value = n_variables,
arg_name = 'n_variables',
bound = length(private$data_names$x_original),
append_to_msg = "(total number of predictors)")


check_arg_length(arg_value = n_variables,
arg_name = 'n_variables',
expected_length = 1)

}

},

check_mtry = function(mtry = NULL){

input <- mtry %||% self$mtry
Expand Down Expand Up @@ -2352,7 +2463,7 @@ ObliqueForestSurvival <- R6::R6Class(

},

check_pred_horizon = function(boundary_checks = TRUE, pred_horizon = NULL){
check_pred_horizon = function(pred_horizon = NULL, boundary_checks = TRUE){

input <- pred_horizon %||% self$pred_horizon

Expand Down Expand Up @@ -2451,6 +2562,7 @@ ObliqueForestSurvival <- R6::R6Class(

self$tree_type <- "survival"


self$split_rule <- self$split_rule %||% 'logrank'
self$pred_type <- self$pred_type %||% 'surv'
self$split_min_stat <- self$split_min_stat %||%
Expand Down Expand Up @@ -2501,7 +2613,7 @@ ObliqueForestSurvival <- R6::R6Class(
if(is.null(self$pred_horizon)){
self$pred_horizon <- collapse::fmedian(y[, 1])
} else {
private$check_pred_horizon(boundary_checks = TRUE)
private$check_pred_horizon(self$pred_horizon, boundary_checks = TRUE)
}

private$check_leaf_min_events()
Expand Down
Loading

0 comments on commit 9ebf6f2

Please sign in to comment.